concat aggfunc

This commit is contained in:
Andrea Pedrotti 2023-02-10 12:58:26 +01:00
parent 3f3e4982e4
commit 2a42b21ac9
5 changed files with 67 additions and 36 deletions

View File

@ -35,12 +35,15 @@ class GeneralizedFunnelling:
device,
load_trained,
dataset_name,
probabilistic,
aggfunc,
):
# Setting VFGs -----------
self.posteriors_vgf = posterior
self.wce_vgf = wce
self.multilingual_vgf = multilingual
self.trasformer_vgf = transformer
self.probabilistic = probabilistic
# ------------------------
self.langs = langs
self.embed_dir = embed_dir
@ -62,13 +65,16 @@ class GeneralizedFunnelling:
self.n_jobs = n_jobs
self.first_tier_learners = []
self.metaclassifier = None
self.aggfunc = "mean"
self.aggfunc = aggfunc
self.load_trained = load_trained
self.dataset_name = dataset_name
self._init()
def _init(self):
print("[Init GeneralizedFunnelling]")
assert not (
self.aggfunc == "mean" and self.probabilistic is False
), "When using averaging aggreagation function probabilistic must be True"
if self.load_trained is not None:
print("- loading trained VGFs, metaclassifer and vectorizer")
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load(
@ -90,7 +96,7 @@ class GeneralizedFunnelling:
langs=self.langs,
n_jobs=self.n_jobs,
cached=self.cached,
probabilistic=True,
probabilistic=self.probabilistic,
)
self.first_tier_learners.append(multilingual_vgf)
@ -100,6 +106,7 @@ class GeneralizedFunnelling:
if self.trasformer_vgf:
transformer_vgf = TextualTransformerGen(
dataset_name=self.dataset_name,
model_name=self.transformer_name,
lr=self.lr_transformer,
epochs=self.epochs,
@ -107,11 +114,10 @@ class GeneralizedFunnelling:
max_length=self.max_length,
device="cuda",
print_steps=50,
probabilistic=True,
probabilistic=self.probabilistic,
evaluate_step=self.evaluate_step,
verbose=True,
patience=self.patience,
dataset_name=self.dataset_name,
)
self.first_tier_learners.append(transformer_vgf)
@ -174,10 +180,18 @@ class GeneralizedFunnelling:
def aggregate(self, first_tier_projections):
if self.aggfunc == "mean":
aggregated = self._aggregate_mean(first_tier_projections)
elif self.aggfunc == "concat":
aggregated = self._aggregate_concat(first_tier_projections)
else:
raise NotImplementedError
return aggregated
def _aggregate_concat(self, first_tier_projections):
aggregated = {}
for lang in self.langs:
aggregated[lang] = np.hstack([v[lang] for v in first_tier_projections])
return aggregated
def _aggregate_mean(self, first_tier_projections):
aggregated = {
lang: np.zeros(data.shape)

View File

@ -20,7 +20,6 @@ transformers.logging.set_verbosity_error()
# TODO: add support to loggers
# TODO: multiple inheritance - maybe define a superclass for TransformerGenerator, whether it is a Textual or a Visual one, implementing dataset creation functions
class TextualTransformerGen(ViewGen, TransformerGen):
@ -42,7 +41,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
patience=5,
):
super().__init__(
model_name,
self._validate_model_name(model_name),
dataset_name,
epochs,
lr,
@ -58,28 +57,19 @@ class TextualTransformerGen(ViewGen, TransformerGen):
patience,
)
self.fitted = False
self._init()
def _init(self):
if self.probabilistic:
self.feature2posterior_projector = FeatureSet2Posteriors(
n_jobs=self.n_jobs, verbose=False
)
self.model_name = self._get_model_name(self.model_name)
print(
f"- init TransformerModel model_name: {self.model_name}, device: {self.device}]"
f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]"
)
def _get_model_name(self, name):
if "bert" == name:
name_model = "bert-base-uncased"
elif "mbert" == name:
name_model = "bert-base-multilingual-uncased"
elif "xlm" == name:
name_model = "xlm-roberta-base"
def _validate_model_name(self, model_name):
if "bert" == model_name:
return "bert-base-uncased"
elif "mbert" == model_name:
return "bert-base-multilingual-uncased"
elif "xlm" == model_name:
return "xlm-roberta-base"
else:
raise NotImplementedError
return name_model
def load_pretrained_model(self, model_name, num_labels):
return AutoModelForSequenceClassification.from_pretrained(
@ -192,6 +182,8 @@ class TextualTransformerGen(ViewGen, TransformerGen):
if self.probabilistic and self.fitted:
l_embeds = self.feature2posterior_projector.transform(l_embeds)
elif not self.probabilistic and self.fitted:
l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()}
return l_embeds

View File

@ -1,5 +1,6 @@
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from vgfs.learners.svms import FeatureSet2Posteriors
class TransformerGen:
@ -45,9 +46,16 @@ class TransformerGen:
self.verbose = verbose
self.patience = patience
self.datasets = {}
self.feature2posterior_projector = (
self.make_probabilistic() if probabilistic else None
)
def make_probabilistic(self):
raise NotImplementedError
if self.probabilistic:
feature2posterior_projector = FeatureSet2Posteriors(
n_jobs=self.n_jobs, verbose=False
)
return feature2posterior_projector
def build_dataloader(
self,

View File

@ -40,6 +40,10 @@ class VisualTransformerGen(ViewGen, TransformerGen):
evaluate_step=evaluate_step,
patience=patience,
)
self.fitted = False
print(
f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]"
)
def _validate_model_name(self, model_name):
if "vit" == model_name:
@ -128,6 +132,9 @@ class VisualTransformerGen(ViewGen, TransformerGen):
epochs=self.epochs,
)
if self.probabilistic:
self.feature2posterior_projector.fit(self.transform(lX), lY)
def transform(self, lX):
raise NotImplementedError

34
main.py
View File

@ -41,7 +41,7 @@ def get_dataset(datasetname):
dataset = (
MultilingualDataset(dataset_name="rcv1-2")
.load(RCV_DATAPATH)
.reduce_data(langs=["en", "it", "fr"], maxn=500)
.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows)
)
else:
raise NotImplementedError
@ -54,9 +54,9 @@ def main(args):
dataset, MultiNewsDataset
):
lX, lY = dataset.training()
# lX_te, lY_te = dataset.test()
print("[NB: for debug purposes, training set is also used as test set]\n")
lX_te, lY_te = dataset.training()
lX_te, lY_te = dataset.test()
# print("[NB: for debug purposes, training set is also used as test set]\n")
# lX_te, lY_te = dataset.training()
else:
_lX = dataset.dX
_lY = dataset.dY
@ -75,24 +75,32 @@ def main(args):
), "At least one of VGF must be True"
gfun = GeneralizedFunnelling(
# dataset params ----------------------
dataset_name=args.dataset,
posterior=args.posteriors,
multilingual=args.multilingual,
wce=args.wce,
transformer=args.transformer,
langs=dataset.langs(),
# Posterior VGF params ----------------
posterior=args.posteriors,
# Multilingual VGF params -------------
multilingual=args.multilingual,
embed_dir="~/resources/muse_embeddings",
n_jobs=args.n_jobs,
max_length=args.max_length,
# WCE VGF params ----------------------
wce=args.wce,
# Transformer VGF params --------------
transformer=args.transformer,
transformer_name=args.transformer_name,
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
max_length=args.max_length,
patience=args.patience,
evaluate_step=args.evaluate_step,
transformer_name=args.transformer_name,
device="cuda",
# General params ----------------------
probabilistic=args.features,
aggfunc=args.aggfunc,
optimc=args.optimc,
load_trained=args.load_trained,
n_jobs=args.n_jobs,
)
# gfun.get_config()
@ -125,7 +133,7 @@ if __name__ == "__main__":
# Dataset parameters -------------------
parser.add_argument("-d", "--dataset", type=str, default="multinews")
parser.add_argument("--domains", type=str, default="all")
parser.add_argument("--nrows", type=int, default=10000)
parser.add_argument("--nrows", type=int, default=100)
parser.add_argument("--min_count", type=int, default=10)
parser.add_argument("--max_labels", type=int, default=50)
# gFUN parameters ----------------------
@ -135,6 +143,8 @@ if __name__ == "__main__":
parser.add_argument("-t", "--transformer", action="store_true")
parser.add_argument("--n_jobs", type=int, default=1)
parser.add_argument("--optimc", action="store_true")
parser.add_argument("--features", action="store_false")
parser.add_argument("--aggfunc", type=str, default="mean")
# transformer parameters ---------------
parser.add_argument("--transformer_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32)