concat aggfunc

This commit is contained in:
Andrea Pedrotti 2023-02-10 12:58:26 +01:00
parent 3f3e4982e4
commit 2a42b21ac9
5 changed files with 67 additions and 36 deletions

View File

@ -35,12 +35,15 @@ class GeneralizedFunnelling:
device, device,
load_trained, load_trained,
dataset_name, dataset_name,
probabilistic,
aggfunc,
): ):
# Setting VFGs ----------- # Setting VFGs -----------
self.posteriors_vgf = posterior self.posteriors_vgf = posterior
self.wce_vgf = wce self.wce_vgf = wce
self.multilingual_vgf = multilingual self.multilingual_vgf = multilingual
self.trasformer_vgf = transformer self.trasformer_vgf = transformer
self.probabilistic = probabilistic
# ------------------------ # ------------------------
self.langs = langs self.langs = langs
self.embed_dir = embed_dir self.embed_dir = embed_dir
@ -62,13 +65,16 @@ class GeneralizedFunnelling:
self.n_jobs = n_jobs self.n_jobs = n_jobs
self.first_tier_learners = [] self.first_tier_learners = []
self.metaclassifier = None self.metaclassifier = None
self.aggfunc = "mean" self.aggfunc = aggfunc
self.load_trained = load_trained self.load_trained = load_trained
self.dataset_name = dataset_name self.dataset_name = dataset_name
self._init() self._init()
def _init(self): def _init(self):
print("[Init GeneralizedFunnelling]") print("[Init GeneralizedFunnelling]")
assert not (
self.aggfunc == "mean" and self.probabilistic is False
), "When using averaging aggreagation function probabilistic must be True"
if self.load_trained is not None: if self.load_trained is not None:
print("- loading trained VGFs, metaclassifer and vectorizer") print("- loading trained VGFs, metaclassifer and vectorizer")
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load( self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load(
@ -90,7 +96,7 @@ class GeneralizedFunnelling:
langs=self.langs, langs=self.langs,
n_jobs=self.n_jobs, n_jobs=self.n_jobs,
cached=self.cached, cached=self.cached,
probabilistic=True, probabilistic=self.probabilistic,
) )
self.first_tier_learners.append(multilingual_vgf) self.first_tier_learners.append(multilingual_vgf)
@ -100,6 +106,7 @@ class GeneralizedFunnelling:
if self.trasformer_vgf: if self.trasformer_vgf:
transformer_vgf = TextualTransformerGen( transformer_vgf = TextualTransformerGen(
dataset_name=self.dataset_name,
model_name=self.transformer_name, model_name=self.transformer_name,
lr=self.lr_transformer, lr=self.lr_transformer,
epochs=self.epochs, epochs=self.epochs,
@ -107,11 +114,10 @@ class GeneralizedFunnelling:
max_length=self.max_length, max_length=self.max_length,
device="cuda", device="cuda",
print_steps=50, print_steps=50,
probabilistic=True, probabilistic=self.probabilistic,
evaluate_step=self.evaluate_step, evaluate_step=self.evaluate_step,
verbose=True, verbose=True,
patience=self.patience, patience=self.patience,
dataset_name=self.dataset_name,
) )
self.first_tier_learners.append(transformer_vgf) self.first_tier_learners.append(transformer_vgf)
@ -174,10 +180,18 @@ class GeneralizedFunnelling:
def aggregate(self, first_tier_projections): def aggregate(self, first_tier_projections):
if self.aggfunc == "mean": if self.aggfunc == "mean":
aggregated = self._aggregate_mean(first_tier_projections) aggregated = self._aggregate_mean(first_tier_projections)
elif self.aggfunc == "concat":
aggregated = self._aggregate_concat(first_tier_projections)
else: else:
raise NotImplementedError raise NotImplementedError
return aggregated return aggregated
def _aggregate_concat(self, first_tier_projections):
aggregated = {}
for lang in self.langs:
aggregated[lang] = np.hstack([v[lang] for v in first_tier_projections])
return aggregated
def _aggregate_mean(self, first_tier_projections): def _aggregate_mean(self, first_tier_projections):
aggregated = { aggregated = {
lang: np.zeros(data.shape) lang: np.zeros(data.shape)

View File

@ -20,7 +20,6 @@ transformers.logging.set_verbosity_error()
# TODO: add support to loggers # TODO: add support to loggers
# TODO: multiple inheritance - maybe define a superclass for TransformerGenerator, whether it is a Textual or a Visual one, implementing dataset creation functions
class TextualTransformerGen(ViewGen, TransformerGen): class TextualTransformerGen(ViewGen, TransformerGen):
@ -42,7 +41,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
patience=5, patience=5,
): ):
super().__init__( super().__init__(
model_name, self._validate_model_name(model_name),
dataset_name, dataset_name,
epochs, epochs,
lr, lr,
@ -58,28 +57,19 @@ class TextualTransformerGen(ViewGen, TransformerGen):
patience, patience,
) )
self.fitted = False self.fitted = False
self._init()
def _init(self):
if self.probabilistic:
self.feature2posterior_projector = FeatureSet2Posteriors(
n_jobs=self.n_jobs, verbose=False
)
self.model_name = self._get_model_name(self.model_name)
print( print(
f"- init TransformerModel model_name: {self.model_name}, device: {self.device}]" f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]"
) )
def _get_model_name(self, name): def _validate_model_name(self, model_name):
if "bert" == name: if "bert" == model_name:
name_model = "bert-base-uncased" return "bert-base-uncased"
elif "mbert" == name: elif "mbert" == model_name:
name_model = "bert-base-multilingual-uncased" return "bert-base-multilingual-uncased"
elif "xlm" == name: elif "xlm" == model_name:
name_model = "xlm-roberta-base" return "xlm-roberta-base"
else: else:
raise NotImplementedError raise NotImplementedError
return name_model
def load_pretrained_model(self, model_name, num_labels): def load_pretrained_model(self, model_name, num_labels):
return AutoModelForSequenceClassification.from_pretrained( return AutoModelForSequenceClassification.from_pretrained(
@ -192,6 +182,8 @@ class TextualTransformerGen(ViewGen, TransformerGen):
if self.probabilistic and self.fitted: if self.probabilistic and self.fitted:
l_embeds = self.feature2posterior_projector.transform(l_embeds) l_embeds = self.feature2posterior_projector.transform(l_embeds)
elif not self.probabilistic and self.fitted:
l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()}
return l_embeds return l_embeds

View File

@ -1,5 +1,6 @@
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from vgfs.learners.svms import FeatureSet2Posteriors
class TransformerGen: class TransformerGen:
@ -45,9 +46,16 @@ class TransformerGen:
self.verbose = verbose self.verbose = verbose
self.patience = patience self.patience = patience
self.datasets = {} self.datasets = {}
self.feature2posterior_projector = (
self.make_probabilistic() if probabilistic else None
)
def make_probabilistic(self): def make_probabilistic(self):
raise NotImplementedError if self.probabilistic:
feature2posterior_projector = FeatureSet2Posteriors(
n_jobs=self.n_jobs, verbose=False
)
return feature2posterior_projector
def build_dataloader( def build_dataloader(
self, self,

View File

@ -40,6 +40,10 @@ class VisualTransformerGen(ViewGen, TransformerGen):
evaluate_step=evaluate_step, evaluate_step=evaluate_step,
patience=patience, patience=patience,
) )
self.fitted = False
print(
f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]"
)
def _validate_model_name(self, model_name): def _validate_model_name(self, model_name):
if "vit" == model_name: if "vit" == model_name:
@ -128,6 +132,9 @@ class VisualTransformerGen(ViewGen, TransformerGen):
epochs=self.epochs, epochs=self.epochs,
) )
if self.probabilistic:
self.feature2posterior_projector.fit(self.transform(lX), lY)
def transform(self, lX): def transform(self, lX):
raise NotImplementedError raise NotImplementedError

34
main.py
View File

@ -41,7 +41,7 @@ def get_dataset(datasetname):
dataset = ( dataset = (
MultilingualDataset(dataset_name="rcv1-2") MultilingualDataset(dataset_name="rcv1-2")
.load(RCV_DATAPATH) .load(RCV_DATAPATH)
.reduce_data(langs=["en", "it", "fr"], maxn=500) .reduce_data(langs=["en", "it", "fr"], maxn=args.nrows)
) )
else: else:
raise NotImplementedError raise NotImplementedError
@ -54,9 +54,9 @@ def main(args):
dataset, MultiNewsDataset dataset, MultiNewsDataset
): ):
lX, lY = dataset.training() lX, lY = dataset.training()
# lX_te, lY_te = dataset.test() lX_te, lY_te = dataset.test()
print("[NB: for debug purposes, training set is also used as test set]\n") # print("[NB: for debug purposes, training set is also used as test set]\n")
lX_te, lY_te = dataset.training() # lX_te, lY_te = dataset.training()
else: else:
_lX = dataset.dX _lX = dataset.dX
_lY = dataset.dY _lY = dataset.dY
@ -75,24 +75,32 @@ def main(args):
), "At least one of VGF must be True" ), "At least one of VGF must be True"
gfun = GeneralizedFunnelling( gfun = GeneralizedFunnelling(
# dataset params ----------------------
dataset_name=args.dataset, dataset_name=args.dataset,
posterior=args.posteriors,
multilingual=args.multilingual,
wce=args.wce,
transformer=args.transformer,
langs=dataset.langs(), langs=dataset.langs(),
# Posterior VGF params ----------------
posterior=args.posteriors,
# Multilingual VGF params -------------
multilingual=args.multilingual,
embed_dir="~/resources/muse_embeddings", embed_dir="~/resources/muse_embeddings",
n_jobs=args.n_jobs, # WCE VGF params ----------------------
max_length=args.max_length, wce=args.wce,
# Transformer VGF params --------------
transformer=args.transformer,
transformer_name=args.transformer_name,
batch_size=args.batch_size, batch_size=args.batch_size,
epochs=args.epochs, epochs=args.epochs,
lr=args.lr, lr=args.lr,
max_length=args.max_length,
patience=args.patience, patience=args.patience,
evaluate_step=args.evaluate_step, evaluate_step=args.evaluate_step,
transformer_name=args.transformer_name,
device="cuda", device="cuda",
# General params ----------------------
probabilistic=args.features,
aggfunc=args.aggfunc,
optimc=args.optimc, optimc=args.optimc,
load_trained=args.load_trained, load_trained=args.load_trained,
n_jobs=args.n_jobs,
) )
# gfun.get_config() # gfun.get_config()
@ -125,7 +133,7 @@ if __name__ == "__main__":
# Dataset parameters ------------------- # Dataset parameters -------------------
parser.add_argument("-d", "--dataset", type=str, default="multinews") parser.add_argument("-d", "--dataset", type=str, default="multinews")
parser.add_argument("--domains", type=str, default="all") parser.add_argument("--domains", type=str, default="all")
parser.add_argument("--nrows", type=int, default=10000) parser.add_argument("--nrows", type=int, default=100)
parser.add_argument("--min_count", type=int, default=10) parser.add_argument("--min_count", type=int, default=10)
parser.add_argument("--max_labels", type=int, default=50) parser.add_argument("--max_labels", type=int, default=50)
# gFUN parameters ---------------------- # gFUN parameters ----------------------
@ -135,6 +143,8 @@ if __name__ == "__main__":
parser.add_argument("-t", "--transformer", action="store_true") parser.add_argument("-t", "--transformer", action="store_true")
parser.add_argument("--n_jobs", type=int, default=1) parser.add_argument("--n_jobs", type=int, default=1)
parser.add_argument("--optimc", action="store_true") parser.add_argument("--optimc", action="store_true")
parser.add_argument("--features", action="store_false")
parser.add_argument("--aggfunc", type=str, default="mean")
# transformer parameters --------------- # transformer parameters ---------------
parser.add_argument("--transformer_name", type=str, default="mbert") parser.add_argument("--transformer_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--batch_size", type=int, default=32)