From 234b6031b1979792506e3bc07b2492992847a704 Mon Sep 17 00:00:00 2001 From: andreapdr Date: Thu, 5 Oct 2023 15:39:49 +0200 Subject: [PATCH] branching for rai --- compute_results.py | 12 +-- csvlogger.py | 3 +- dataManager/gFunDataset.py | 135 ++++++++--------------------- dataManager/utils.py | 127 +++------------------------ gfun/generalizedFunnelling.py | 72 +++++---------- gfun/vgfs/learners/svms.py | 15 +++- gfun/vgfs/textualTransformerGen.py | 72 +++------------ hf_trainer.py | 92 ++++++++++---------- main.py | 30 ++----- 9 files changed, 147 insertions(+), 411 deletions(-) diff --git a/compute_results.py b/compute_results.py index a7b7518..31fa4bf 100644 --- a/compute_results.py +++ b/compute_results.py @@ -5,12 +5,6 @@ from sklearn.metrics import mean_absolute_error from os.path import join -""" -MEA and classification is meaningful only in "ordinal" tasks e.g., sentiment classification. -Otherwise the distance between the categories has no semantics! - -- NB: we want to get the macro-averaged class specific MAE! -""" def main(): # SETTINGS = ["p", "m", "w", "t", "mp", "mpw", "mpt", "mptw"] @@ -24,11 +18,9 @@ def main(): print(df) -def evalaute(setting): +def evalaute(setting, result_file="preds.gfun.csv"): result_dir = "results" - # result_file = f"lang-specific.gfun.{setting}.webis.csv" - result_file = f"lang-specific.mbert.webis.csv" - # print(f"- reading from: {result_file}") + print(f"- reading from: {result_file}") df = pd.read_csv(join(result_dir, result_file)) langs = df.langs.unique() res = [] diff --git a/csvlogger.py b/csvlogger.py index 561d819..7bbd87b 100644 --- a/csvlogger.py +++ b/csvlogger.py @@ -14,7 +14,7 @@ class CsvLogger: # os.makedirs(self.outfile.replace(".csv", ".lang.csv"), exist_ok=True) # return - def log_lang_results(self, results: dict, config="gfun-lello"): + def log_lang_results(self, results: dict, config="gfun-default", notes=None): df = pd.DataFrame.from_dict(results, orient="columns") df["config"] = config["gFun"]["simple_id"] df["aggfunc"] = config["gFun"]["aggfunc"] @@ -22,6 +22,7 @@ class CsvLogger: df["id"] = config["gFun"]["id"] df["optimc"] = config["gFun"]["optimc"] df["timing"] = config["gFun"]["timing"] + df["notes"] = notes with open(self.outfile, 'a') as f: df.to_csv(f, mode='a', header=f.tell()==0) diff --git a/dataManager/gFunDataset.py b/dataManager/gFunDataset.py index 9534b07..ab6c3c6 100644 --- a/dataManager/gFunDataset.py +++ b/dataManager/gFunDataset.py @@ -16,7 +16,17 @@ from dataManager.multilingualDataset import MultilingualDataset class SimpleGfunDataset: - def __init__(self, datadir="~/datasets/rai/csv/", textual=True, visual=False, multilabel=False, set_tr_langs=None, set_te_langs=None): + def __init__( + self, + dataset_name=None, + datadir="~/datasets/rai/csv/", + textual=True, + visual=False, + multilabel=False, + set_tr_langs=None, + set_te_langs=None + ): + self.name = dataset_name self.datadir = os.path.expanduser(datadir) self.textual = textual self.visual = visual @@ -41,11 +51,15 @@ class SimpleGfunDataset: print(f"tr: {tr} - va: {va} - te: {te}") def load_csv(self, set_tr_langs, set_te_langs): - # _data_tr = pd.read_csv(os.path.join(self.datadir, "train.small.csv")) - _data_tr = pd.read_csv(os.path.join(self.datadir, "train.balanced.csv")).sample(100, random_state=42) - train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.lang) # TODO stratify on lang or label? - # test = pd.read_csv(os.path.join(self.datadir, "test.small.csv")) - test = pd.read_csv(os.path.join(self.datadir, "test.balanced.csv")).sample(100, random_state=42) + _data_tr = pd.read_csv(os.path.join(self.datadir, "train.small.csv")) + try: + stratified = "class" + train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.label) + except: + stratified = "lang" + train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.lang) + print(f"- dataset stratified by {stratified}") + test = pd.read_csv(os.path.join(self.datadir, "test.small.csv")) self._set_langs (train, test, set_tr_langs, set_te_langs) self._set_labels(_data_tr) self.full_train = _data_tr @@ -56,8 +70,6 @@ class SimpleGfunDataset: return def _set_labels(self, data): - # self.labels = [i for i in range(28)] # todo hard-coded for rai - # self.labels = [i for i in range(3)] # TODO hard coded for sentimnet self.labels = sorted(list(data.label.unique())) def _set_langs(self, train, test, set_tr_langs=None, set_te_langs=None): @@ -70,17 +82,17 @@ class SimpleGfunDataset: print(f"-- [SETTING TESTING LANGS TO: {list(set_tr_langs)}]") self.te_langs = self.te_langs.intersection(set(set_te_langs)) self.all_langs = self.tr_langs.union(self.te_langs) - + return self.tr_langs, self.te_langs, self.all_langs + def _set_datalang(self, data: pd.DataFrame): return {lang: data[data.lang == lang] for lang in self.all_langs} def training(self, merge_validation=False, mask_number=False, target_as_csr=False): - # TODO some additional pre-processing on the textual data? apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x lXtr = { - lang: {"text": apply_mask(self.train_data[lang].text.tolist())} # TODO inserting dict for textual data - we still have to manage visual + lang: {"text": apply_mask(self.train_data[lang].text.tolist())} for lang in self.tr_langs } if merge_validation: @@ -103,7 +115,6 @@ class SimpleGfunDataset: return lXtr, lYtr def test(self, mask_number=False, target_as_csr=False): - # TODO some additional pre-processing on the textual data? apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x lXte = { lang: {"text": apply_mask(self.test_data[lang].text.tolist())} @@ -162,62 +173,12 @@ class gFunDataset: return mlb def _load_dataset(self): - if "glami" in self.dataset_dir.lower(): - print(f"- Loading GLAMI dataset from {self.dataset_dir}") - self.dataset_name = "glami" - self.dataset, self.labels, self.data_langs = self._load_glami( - self.dataset_dir, self.nrows - ) - self.mlb = self.get_label_binarizer(self.labels) - - elif "rcv" in self.dataset_dir.lower(): - print(f"- Loading RCV1-2 dataset from {self.dataset_dir}") - self.dataset_name = "rcv1-2" - self.dataset, self.labels, self.data_langs = self._load_multilingual( - self.dataset_name, self.dataset_dir, self.nrows - ) - self.mlb = self.get_label_binarizer(self.labels) - - elif "jrc" in self.dataset_dir.lower(): - print(f"- Loading JRC dataset from {self.dataset_dir}") - self.dataset_name = "jrc" - self.dataset, self.labels, self.data_langs = self._load_multilingual( - self.dataset_name, self.dataset_dir, self.nrows - ) - self.mlb = self.get_label_binarizer(self.labels) - - # WEBIS-CLS (processed) - elif ( - "cls" in self.dataset_dir.lower() - and "unprocessed" not in self.dataset_dir.lower() - ): - print(f"- Loading WEBIS-CLS (processed) dataset from {self.dataset_dir}") - self.dataset_name = "cls" - self.dataset, self.labels, self.data_langs = self._load_multilingual( - self.dataset_name, self.dataset_dir, self.nrows - ) - self.mlb = self.get_label_binarizer(self.labels) - - # WEBIS-CLS (unprocessed) - elif ( - "cls" in self.dataset_dir.lower() - and "unprocessed" in self.dataset_dir.lower() - ): - print(f"- Loading WEBIS-CLS (unprocessed) dataset from {self.dataset_dir}") - self.dataset_name = "cls" - self.dataset, self.labels, self.data_langs = self._load_multilingual( - self.dataset_name, self.dataset_dir, self.nrows - ) - self.mlb = self.get_label_binarizer(self.labels) - - elif "rai" in self.dataset_dir.lower(): - print(f"- Loading RAI-CORPUS dataset from {self.dataset_dir}") - self.dataset_name = "rai" - self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name, - dataset_dir="~/datasets/rai/csv/train-split-rai.csv", - nrows=self.nrows) - self.mlb = self.get_label_binarizer(self.labels) - + print(f"- Loading dataset from {self.dataset_dir}") + self.dataset_name = "rai" + self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name, + dataset_dir=self.dataset_dir, + nrows=self.nrows) + self.mlb = self.get_label_binarizer(self.labels) self.show_dimension() return @@ -232,15 +193,12 @@ class gFunDataset: else: print(f"-- Labels: {len(self.labels)}") - def _load_multilingual(self, dataset_name, dataset_dir, nrows): + def _load_multilingual(self, dataset_dir, nrows, dataset_name="rai"): if "csv" in dataset_dir: - old_dataset = MultilingualDataset(dataset_name=dataset_name).from_csv( - # path_tr="~/datasets/rai/csv/train-rai-multilingual-2000.csv", - #path_te="~/datasets/rai/csv/test-rai-multilingual-2000.csv") - path_tr="~/datasets/rai/csv/train-split-rai.csv", - path_te="~/datasets/rai/csv/test-split-rai.csv") - else: - old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir) + old_dataset = MultilingualDataset(dataset_name="rai").from_csv( + path_tr=os.path.expanduser(os.path.join(dataset_dir, "train.small.csv")), + path_te=os.path.expanduser(os.path.join(dataset_dir, "test.small.csv")) + ) if nrows is not None: if dataset_name == "cls": old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows) @@ -365,27 +323,4 @@ if __name__ == "__main__": data_rai = SimpleGfunDataset() lXtr, lYtr = data_rai.training(mask_number=False) lXte, lYte = data_rai.test(mask_number=False) - exit() - # import os - - # GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset") - # RCV_DATAPATH = os.path.expanduser( - # "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" - # ) - # JRC_DATAPATH = os.path.expanduser( - # "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle" - # ) - - # print("Hello gFunDataset") - # dataset = gFunDataset( - # dataset_dir=JRC_DATAPATH, - # data_langs=None, - # is_textual=True, - # is_visual=True, - # is_multilabel=False, - # labels=None, - # nrows=13, - # ) - # lXtr, lYtr = dataset.training() - # lXte, lYte = dataset.test() - # exit(0) + exit() \ No newline at end of file diff --git a/dataManager/utils.py b/dataManager/utils.py index 50e79a9..4da6257 100644 --- a/dataManager/utils.py +++ b/dataManager/utils.py @@ -1,7 +1,5 @@ from os.path import expanduser, join -from dataManager.gFunDataset import gFunDataset, SimpleGfunDataset -from dataManager.multiNewsDataset import MultiNewsDataset -from dataManager.amazonDataset import AmazonDataset +from dataManager.gFunDataset import SimpleGfunDataset def load_from_pickle(path, dataset_name, nrows): @@ -16,119 +14,14 @@ def load_from_pickle(path, dataset_name, nrows): return loaded -def get_dataset(dataset_name, args): - assert dataset_name in [ - "multinews", - "amazon", - "rcv1-2", - "glami", - "cls", - "webis", - "rai", - ], "dataset not supported" - - RCV_DATAPATH = expanduser( - "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" +def get_dataset(datasetp_path, args): + dataset = SimpleGfunDataset( + dataset_name="rai", + datadir=datasetp_path, + textual=True, + visual=False, + multilabel=False, + set_tr_langs=args.tr_langs, + set_te_langs=args.te_langs ) - JRC_DATAPATH = expanduser( - "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle" - ) - CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl") - - MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/") - - GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset") - - WEBIS_CLS = expanduser( - # "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl" - "~/datasets/cls-acl10-unprocessed/csv" - ) - - RAI_DATAPATH = expanduser("~/datasets/rai/rai_corpus.pkl") - - if dataset_name == "multinews": - # TODO: convert to gFunDataset - raise NotImplementedError - dataset = MultiNewsDataset( - expanduser(MULTINEWS_DATAPATH), - excluded_langs=["ar", "pe", "pl", "tr", "ua"], - ) - elif dataset_name == "amazon": - # TODO: convert to gFunDataset - raise NotImplementedError - dataset = AmazonDataset( - domains=args.domains, - nrows=args.nrows, - min_count=args.min_count, - max_labels=args.max_labels, - ) - elif dataset_name == "jrc": - dataset = gFunDataset( - dataset_dir=JRC_DATAPATH, - is_textual=True, - is_visual=False, - is_multilabel=True, - nrows=args.nrows, - ) - elif dataset_name == "rcv1-2": - dataset = gFunDataset( - dataset_dir=RCV_DATAPATH, - is_textual=True, - is_visual=False, - is_multilabel=True, - nrows=args.nrows, - ) - elif dataset_name == "glami": - if args.save_dataset is False: - dataset = load_from_pickle(GLAMI_DATAPATH, dataset_name, args.nrows) - else: - dataset = gFunDataset( - dataset_dir=GLAMI_DATAPATH, - is_textual=True, - is_visual=True, - is_multilabel=False, - nrows=args.nrows, - ) - - dataset.save_as_pickle(GLAMI_DATAPATH) - elif dataset_name == "cls": - dataset = gFunDataset( - dataset_dir=CLS_DATAPATH, - is_textual=True, - is_visual=False, - is_multilabel=False, - nrows=args.nrows, - ) - elif dataset_name == "webis": - dataset = SimpleGfunDataset( - datadir=WEBIS_CLS, - textual=True, - visual=False, - multilabel=False, - set_tr_langs=args.tr_langs, - set_te_langs=args.te_langs - ) - # dataset = gFunDataset( - # dataset_dir=WEBIS_CLS, - # is_textual=True, - # is_visual=False, - # is_multilabel=False, - # nrows=args.nrows, - # ) - elif dataset_name == "rai": - dataset = SimpleGfunDataset( - datadir="~/datasets/rai/csv", - textual=True, - visual=False, - multilabel=False - ) - # dataset = gFunDataset( - # dataset_dir=RAI_DATAPATH, - # is_textual=True, - # is_visual=False, - # is_multilabel=False, - # nrows=args.nrows - # ) - else: - raise NotImplementedError return dataset diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index 9efe6ba..67b9632 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -8,7 +8,6 @@ from gfun.vgfs.learners.svms import MetaClassifier, get_learner from gfun.vgfs.multilingualGen import MultilingualGen from gfun.vgfs.textualTransformerGen import TextualTransformerGen from gfun.vgfs.vanillaFun import VanillaFunGen -from gfun.vgfs.visualTransformerGen import VisualTransformerGen from gfun.vgfs.wceGen import WceGen @@ -19,7 +18,6 @@ class GeneralizedFunnelling: wce, multilingual, textual_transformer, - visual_transformer, langs, num_labels, classification_type, @@ -29,12 +27,9 @@ class GeneralizedFunnelling: eval_batch_size, max_length, textual_lr, - visual_lr, epochs, patience, evaluate_step, - textual_transformer_name, - visual_transformer_name, optimc, device, load_trained, @@ -42,13 +37,14 @@ class GeneralizedFunnelling: probabilistic, aggfunc, load_meta, + trained_text_trf=None, + textual_transformer_name=None, ): # Setting VFGs ----------- self.posteriors_vgf = posterior self.wce_vgf = wce self.multilingual_vgf = multilingual self.textual_trf_vgf = textual_transformer - self.visual_trf_vgf = visual_transformer self.probabilistic = probabilistic self.num_labels = num_labels self.clf_type = classification_type @@ -58,6 +54,7 @@ class GeneralizedFunnelling: self.cached = True # Textual Transformer VGF params ---------- self.textual_trf_name = textual_transformer_name + self.trained_text_trf = trained_text_trf self.epochs = epochs self.textual_trf_lr = textual_lr self.textual_scheduler = "ReduceLROnPlateau" @@ -68,10 +65,6 @@ class GeneralizedFunnelling: self.patience = patience self.evaluate_step = evaluate_step self.device = device - # Visual Transformer VGF params ---------- - self.visual_trf_name = visual_transformer_name - self.visual_trf_lr = visual_lr - self.visual_scheduler = "ReduceLROnPlateau" # Metaclassifier params ------------ self.optimc = optimc # ------------------- @@ -131,7 +124,6 @@ class GeneralizedFunnelling: self.multilingual_vgf, self.wce_vgf, self.textual_trf_vgf, - self.visual_trf_vgf, self.aggfunc, ) return self @@ -174,26 +166,10 @@ class GeneralizedFunnelling: patience=self.patience, device=self.device, classification_type=self.clf_type, + saved_model=self.trained_text_trf, ) self.first_tier_learners.append(transformer_vgf) - if self.visual_trf_vgf: - visual_trasformer_vgf = VisualTransformerGen( - dataset_name=self.dataset_name, - model_name="vit", - lr=self.visual_trf_lr, - scheduler=self.visual_scheduler, - epochs=self.epochs, - batch_size=self.batch_size_trf, - batch_size_eval=self.eval_batch_size_trf, - probabilistic=self.probabilistic, - evaluate_step=self.evaluate_step, - patience=self.patience, - device=self.device, - classification_type=self.clf_type, - ) - self.first_tier_learners.append(visual_trasformer_vgf) - if "attn" in self.aggfunc: attn_stacking = self.aggfunc.split("_")[1] self.attn_aggregator = AttentionAggregator( @@ -219,7 +195,6 @@ class GeneralizedFunnelling: self.multilingual_vgf, self.wce_vgf, self.textual_trf_vgf, - self.visual_trf_vgf, self.aggfunc, ) print(f"- model id: {self._model_id}") @@ -309,16 +284,29 @@ class GeneralizedFunnelling: return aggregated def _aggregate_mean(self, first_tier_projections): - aggregated = { - lang: np.zeros(data.shape) - for lang, data in first_tier_projections[0].items() - } + # aggregated = { + # lang: np.zeros(data.shape) + # for lang, data in first_tier_projections[0].items() + # } + aggregated = {} for lang_projections in first_tier_projections: for lang, projection in lang_projections.items(): + if lang not in aggregated: + aggregated[lang] = np.zeros(projection.shape) aggregated[lang] += projection - for lang, projection in aggregated.items(): - aggregated[lang] /= len(first_tier_projections) + def get_denom(lang, projs): + den = 0 + for proj in projs: + if lang in proj: + den += 1 + return den + + for lang, _ in aggregated.items(): + # aggregated[lang] /= len(first_tier_projections) + aggregated[lang] /= get_denom(lang, first_tier_projections) + + return aggregated @@ -416,18 +404,6 @@ class GeneralizedFunnelling: ) as vgf: first_tier_learners.append(pickle.load(vgf)) print(f"- loaded trained Textual Transformer VGF") - if self.visual_trf_vgf: - with open( - os.path.join( - "models", - "vgfs", - "visual_transformer", - f"visualTransformerGen_{model_id}.pkl", - ), - "rb", - print(f"- loaded trained Visual Transformer VGF"), - ) as vgf: - first_tier_learners.append(pickle.load(vgf)) if load_meta: with open( @@ -482,7 +458,6 @@ def get_unique_id( multilingual, wce, textual_transformer, - visual_transformer, aggfunc, ): from datetime import datetime @@ -493,6 +468,5 @@ def get_unique_id( model_id += "m" if multilingual else "" model_id += "w" if wce else "" model_id += "t" if textual_transformer else "" - model_id += "v" if visual_transformer else "" model_id += f"_{aggfunc}" return f"{model_id}_{now}" diff --git a/gfun/vgfs/learners/svms.py b/gfun/vgfs/learners/svms.py index 086e3ff..3f81baf 100644 --- a/gfun/vgfs/learners/svms.py +++ b/gfun/vgfs/learners/svms.py @@ -173,9 +173,18 @@ class NaivePolylingualClassifier: :return: a dictionary of probabilities that each document belongs to each class """ assert self.model is not None, "predict called before fit" - assert set(lX.keys()).issubset( - set(self.model.keys()) - ), "unknown languages requested in decision function" + if not set(lX.keys()).issubset(set(self.model.keys())): + langs = set(lX.keys()).intersection(set(self.model.keys())) + scores = Parallel(n_jobs=self.n_jobs, max_nbytes=None)( + delayed(self.model[lang].predict_proba)(lX[lang]) for lang in langs) + # res = {lang: None for lang in lX.keys()} + # for i, lang in enumerate(langs): + # res[lang] = scores[i] + # return res + return {lang: scores[i] for i, lang in enumerate(langs)} + # assert set(lX.keys()).issubset( + # set(self.model.keys()) + # ), "unknown languages requested in decision function" langs = list(lX.keys()) scores = Parallel(n_jobs=self.n_jobs, max_nbytes=None)( delayed(self.model[lang].predict_proba)(lX[lang]) for lang in langs diff --git a/gfun/vgfs/textualTransformerGen.py b/gfun/vgfs/textualTransformerGen.py index 9573951..5fcc2da 100644 --- a/gfun/vgfs/textualTransformerGen.py +++ b/gfun/vgfs/textualTransformerGen.py @@ -73,6 +73,7 @@ class TextualTransformerGen(ViewGen, TransformerGen): patience=5, classification_type="multilabel", scheduler="ReduceLROnPlateau", + saved_model = None ): super().__init__( self._validate_model_name(model_name), @@ -91,6 +92,7 @@ class TextualTransformerGen(ViewGen, TransformerGen): n_jobs=n_jobs, verbose=verbose, ) + self.saved_model = saved_model self.clf_type = classification_type self.fitted = False print( @@ -109,26 +111,25 @@ class TextualTransformerGen(ViewGen, TransformerGen): else: raise NotImplementedError - def load_pretrained_model(self, model_name, num_labels): + def load_pretrained_model(self, model_name, num_labels, saved_model=None): if model_name == "google/mt5-small": return MT5ForSequenceClassification( model_name, num_labels=num_labels, output_hidden_states=True ) else: - # model_name = "models/vgfs/trained_transformer/mbert-sentiment/checkpoint-8500" # TODO hardcoded to pre-traiend mbert - # model_name = "hf_models/mbert-rai-fewshot-second/checkpoint-19000" # TODO hardcoded to pre-traiend mbert - # model_name = "hf_models/mbert-sentiment/checkpoint-1150" # TODO hardcoded to pre-traiend mbert - model_name = "hf_models/mbert-sentiment-balanced/checkpoint-1600" + if saved_model: + model_name = saved_model + else: + model_name = "google/bert-base-multilingual-cased" return AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=num_labels, output_hidden_states=True ) def load_tokenizer(self, model_name): - # model_name = "mbert-rai-multi-2000/checkpoint-1500" # TODO hardcoded to pre-traiend mbert return AutoTokenizer.from_pretrained(model_name) - def init_model(self, model_name, num_labels): - return self.load_pretrained_model(model_name, num_labels), self.load_tokenizer( + def init_model(self, model_name, num_labels, saved_model): + return self.load_pretrained_model(model_name, num_labels, saved_model), self.load_tokenizer( model_name ) @@ -148,64 +149,11 @@ class TextualTransformerGen(ViewGen, TransformerGen): _l = list(lX.keys())[0] self.num_labels = lY[_l].shape[-1] self.model, self.tokenizer = self.init_model( - self.model_name, num_labels=self.num_labels + self.model_name, num_labels=self.num_labels, saved_model=self.saved_model, ) self.model.to("cuda") - # tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data( - # lX, lY, split=0.2, seed=42, modality="text" - # ) - # - # tra_dataloader = self.build_dataloader( - # tr_lX, - # tr_lY, - # processor_fn=self._tokenize, - # torchDataset=MultilingualDatasetTorch, - # batch_size=self.batch_size, - # split="train", - # shuffle=True, - # ) - # - # val_dataloader = self.build_dataloader( - # val_lX, - # val_lY, - # processor_fn=self._tokenize, - # torchDataset=MultilingualDatasetTorch, - # batch_size=self.batch_size_eval, - # split="val", - # shuffle=False, - # ) - # - # experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}" - # - # trainer = Trainer( - # model=self.model, - # optimizer_name="adamW", - # lr=self.lr, - # device=self.device, - # loss_fn=torch.nn.CrossEntropyLoss(), - # print_steps=self.print_steps, - # evaluate_step=self.evaluate_step, - # patience=self.patience, - # experiment_name=experiment_name, - # checkpoint_path=os.path.join( - # "models", - # "vgfs", - # "trained_transformer", - # self._format_model_name(self.model_name), - # ), - # vgf_name="textual_trf", - # classification_type=self.clf_type, - # n_jobs=self.n_jobs, - # scheduler_name=self.scheduler, - # ) - # trainer.train( - # train_dataloader=tra_dataloader, - # eval_dataloader=val_dataloader, - # epochs=self.epochs, - # ) - if self.probabilistic: transformed = self.transform(lX) self.feature2posterior_projector.fit(transformed, lY) diff --git a/hf_trainer.py b/hf_trainer.py index 0412ca1..a1457ce 100644 --- a/hf_trainer.py +++ b/hf_trainer.py @@ -1,4 +1,4 @@ -from os.path import expanduser +from os.path import expanduser, join import torch from transformers import ( @@ -19,13 +19,8 @@ import pandas as pd transformers.logging.set_verbosity_error() -IWSLT_D_COLUMNS = ["text", "category", "rating", "summary", "title"] -RAI_D_COLUMNS = ["id", "provider", "date", "title", "text", "label"] # "lang" -WEBIS_D_COLUMNS = ['Unnamed: 0', 'asin', 'category', 'original_rating', 'label', 'title', 'text', 'summary'] # "lang" +RAI_D_COLUMNS = ["id", "provider", "date", "title", "text", "label"] MAX_LEN = 128 -# DATASET_NAME = "rai" -# DATASET_NAME = "rai-multilingual-2000" -# DATASET_NAME = "webis-cls" def init_callbacks(patience=-1, nosave=False): @@ -35,13 +30,17 @@ def init_callbacks(patience=-1, nosave=False): return callbacks -def init_model(model_name, nlabels): +def init_model(model_name, nlabels, saved_model=None): if model_name == "mbert": - # hf_name = "bert-base-multilingual-cased" - hf_name = "hf_models/mbert-sentiment-balanced/checkpoint-1600" - # hf_name = "hf_models/mbert-rai-fewshot-second/checkpoint-9000" + if saved_model is None: + hf_name = "bert-base-multilingual-cased" + else: + hf_name = saved_model elif model_name == "xlm-roberta": - hf_name = "xlm-roberta-base" + if saved_model is None: + hf_name = "xlm-roberta-base" + else: + hf_name = saved_model else: raise NotImplementedError tokenizer = AutoTokenizer.from_pretrained(hf_name) @@ -50,43 +49,41 @@ def init_model(model_name, nlabels): def main(args): - tokenizer, model = init_model(args.model, args.nlabels) + saved_model = args.savedmodel + trainlang = args.trainlangs + datapath = args.datapath + + tokenizer, model = init_model(args.model, args.nlabels, saved_model=saved_model) data = load_dataset( "csv", data_files = { - "train": expanduser(f"~/datasets/cls-acl10-unprocessed/csv/train.balanced.csv"), - "test": expanduser(f"~/datasets/cls-acl10-unprocessed/csv/test.balanced.csv") - # "train": expanduser(f"~/datasets/rai/csv/train-{DATASET_NAME}.csv"), - # "test": expanduser(f"~/datasets/rai/csv/test-{DATASET_NAME}.csv") - # "train": expanduser(f"~/datasets/rai/csv/train.small.csv"), - # "test": expanduser(f"~/datasets/rai/csv/test.small.csv") + "train": expanduser(join(datapath, "train.csv")), + "test": expanduser(join(datapath, "test.small.csv")) } ) + def filter_dataset(dataset, lang): + indices = [i for i, l in enumerate(dataset["lang"]) if l == lang] + dataset = dataset.select(indices) + return dataset + + if trainlang is not None: + data["train"] = filter_dataset(data["train"], lang=trainlang) + def process_sample_rai(sample): inputs = [f"{title}. {text}" for title, text in zip(sample["title"], sample["text"])] labels = sample["label"] - model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True) # TODO pre-process text cause there's a lot of noise in there... + model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True) model_inputs["labels"] = labels return model_inputs - def process_sample_webis(sample): - inputs = sample["text"] - labels = sample["label"] - model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True) # TODO pre-process text cause there's a lot of noise in there... - model_inputs["labels"] = labels - return model_inputs - - data = data.map( - # process_sample_rai, - process_sample_webis, + process_sample_rai, batched=True, num_proc=4, load_from_cache_file=True, - # remove_columns=RAI_D_COLUMNS, - remove_columns=WEBIS_D_COLUMNS, + remove_columns=RAI_D_COLUMNS, ) train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42) data.set_format("torch") @@ -107,8 +104,8 @@ def main(args): recall_metric = evaluate.load("recall") training_args = TrainingArguments( - # output_dir=f"hf_models/{args.model}-rai", - output_dir=f"hf_models/{args.model}-sentiment-balanced", + output_dir=f"hf_models/{args.model}-fewshot-full" if trainlang is None else f"hf_models/{args.model}-zeroshot-full", + run_name="model-zeroshot" if trainlang is not None else "model-fewshot", do_train=True, evaluation_strategy="steps", per_device_train_batch_size=args.batch, @@ -130,8 +127,6 @@ def main(args): save_strategy="no" if args.nosave else "steps", save_total_limit=2, eval_steps=args.stepeval, - # run_name=f"{args.model}-rai-stratified", - run_name=f"{args.model}-sentiment", disable_tqdm=False, log_level="warning", report_to=["wandb"] if args.wandb else "none", @@ -142,7 +137,6 @@ def main(args): def compute_metrics(eval_preds): preds = eval_preds.predictions.argmax(-1) - # targets = eval_preds.label_ids.argmax(-1) targets = eval_preds.label_ids setting = "macro" f1_score_macro = f1_metric.compute( @@ -170,7 +164,9 @@ def main(args): if args.wandb: import wandb - wandb.init(entity="andreapdr", project=f"gfun-rai-hf", name="mbert-rai", config=vars(args)) + wandb.init(entity="andreapdr", project=f"gfun", + name="model-zeroshot-full" if trainlang is not None else "model-fewshot-full", + config=vars(args)) trainer = Trainer( model=model, @@ -188,17 +184,21 @@ def main(args): trainer.train() print("- Testing:") + test_results = trainer.evaluate(eval_dataset=data["test"], metric_key_prefix="test") test_results = trainer.predict(test_dataset=data["test"], metric_key_prefix="test") pprint(test_results.metrics) - save_preds(data["test"], test_results.predictions) + save_preds(data["test"], test_results.predictions, trainlang) exit() -def save_preds(dataset, predictions): +def save_preds(dataset, predictions, trainlang=None): df = pd.DataFrame() df["langs"] = dataset["lang"] df["labels"] = dataset["labels"] df["preds"] = predictions.argmax(axis=1) - df.to_csv("results/lang-specific.mbert.webis.csv", index=False) + if trainlang is not None: + df.to_csv(f"results/zeroshot.{trainlang}.model.csv", index=False) + else: + df.to_csv("results/fewshot.model.csv", index=False) return @@ -210,16 +210,18 @@ if __name__ == "__main__": parser.add_argument("--nlabels", type=int, metavar="", default=28) parser.add_argument("--lr", type=float, metavar="", default=5e-5, help="Set learning rate",) parser.add_argument("--scheduler", type=str, metavar="", default="cosine", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]") - parser.add_argument("--batch", type=int, metavar="", default=8, help="Set batch size") + parser.add_argument("--batch", type=int, metavar="", default=8, help="Set batch size") parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps") - parser.add_argument("--epochs", type=int, metavar="", default=100, help="Set epochs") + parser.add_argument("--epochs", type=int, metavar="", default=10, help="Set epochs") parser.add_argument("--stepeval", type=int, metavar="", default=50, help="Run evaluation every n steps") - parser.add_argument("--steplog", type=int, metavar="", default=50, help="Log training every n steps") + parser.add_argument("--steplog", type=int, metavar="", default=50, help="Log training every n steps") parser.add_argument("--patience", type=int, metavar="", default=10, help="EarlyStopper patience") parser.add_argument("--fp16", action="store_true", help="Use fp16 precision") parser.add_argument("--wandb", action="store_true", help="Log to wandb") parser.add_argument("--nosave", action="store_true", help="Avoid saving model") parser.add_argument("--onlytest", action="store_true", help="Simply test model on test set") - # parser.add_argument("--sanity", action="store_true", help="Train and evaluate on the same reduced (1000) data") + parser.add_argument("--trainlang", default=None, type=str, help="set training language for zero-shot experiments" ) + parser.add_argument("--datapath", type=str, default="data", help="path to the csv dataset. Dir should contain both a train.csv and a test.csv file") + parser.add_argument("--savedmodel", type=str, default="hf_models/mbert-rai-fewshot-second/checkpoint-9000") args = parser.parse_args() main(args) diff --git a/main.py b/main.py index dc57157..ad4abea 100644 --- a/main.py +++ b/main.py @@ -10,8 +10,6 @@ import pandas as pd """ TODO: - - General: - [!] zero-shot setup - Docs: - add documentations sphinx """ @@ -27,13 +25,11 @@ def get_config_name(args): config_name += "M+" if args.textual_transformer: config_name += f"TT_{args.textual_trf_name}+" - if args.visual_transformer: - config_name += f"VT_{args.visual_trf_name}+" return config_name.rstrip("+") def main(args): - dataset = get_dataset(args.dataset, args) + dataset = get_dataset(args.datadir, args) lX, lY = dataset.training(merge_validation=True) lX_te, lY_te = dataset.test() @@ -47,13 +43,12 @@ def main(args): args.multilingual, args.multilingual, args.textual_transformer, - args.visual_transformer, ] ), "At least one of VGF must be True" gfun = GeneralizedFunnelling( # dataset params ---------------------- - dataset_name=args.dataset, + dataset_name=dataset, langs=dataset.langs(), num_labels=dataset.num_labels(), classification_type=args.clf_type, @@ -67,24 +62,15 @@ def main(args): # Transformer VGF params -------------- textual_transformer=args.textual_transformer, textual_transformer_name=args.textual_trf_name, + trained_text_trf="hf_models/mbert-zeroshot-rai/checkpoint-1350", batch_size=args.batch_size, eval_batch_size=args.eval_batch_size, epochs=args.epochs, textual_lr=args.textual_lr, - visual_lr=args.visual_lr, max_length=args.max_length, patience=args.patience, evaluate_step=args.evaluate_step, device=args.device, - # Visual Transformer VGF params -------------- - visual_transformer=args.visual_transformer, - visual_transformer_name=args.visual_trf_name, - # batch_size=args.batch_size, - # epochs=args.epochs, - # lr=args.lr, - # patience=args.patience, - # evaluate_step=args.evaluate_step, - # device="cuda", # General params --------------------- probabilistic=args.features, aggfunc=args.aggfunc, @@ -145,7 +131,7 @@ def main(args): log_barplot_wandb(lang_metrics_gfun, title_affix="per language") config["gFun"]["timing"] = f"{timeval - tinit:.2f}" - csvlogger = CsvLogger(outfile="results/random.log.csv").log_lang_results(lang_metrics_gfun, config) + csvlogger = CsvLogger(outfile="results/gfun.log.csv").log_lang_results(lang_metrics_gfun, config, notes="") save_preds(gfun_preds, lY_te, config=config["gFun"]["simple_id"], dataset=config["gFun"]["dataset"]) @@ -162,7 +148,7 @@ def save_preds(preds, targets, config="unk", dataset="unk"): df["langs"] = _langs df["labels"] = _targets df["preds"] = _preds - df.to_csv(f"results/lang-specific.gfun.{config}.{dataset}.csv", index=False) + df.to_csv(f"results/preds/preds.gfun.{config}.{dataset}.correct.csv", index=False) if __name__ == "__main__": @@ -174,7 +160,7 @@ if __name__ == "__main__": parser.add_argument("--tr_langs", nargs="+", default=None) parser.add_argument("--te_langs", nargs="+", default=None) # Dataset parameters ------------------- - parser.add_argument("-d", "--dataset", type=str, default="rcv1-2") + parser.add_argument("-d", "--datadir", type=str, default=None, help="dir to dataset. It should contain both a train.csv and a test.csv file") parser.add_argument("--domains", type=str, default="all") parser.add_argument("--nrows", type=int, default=None) parser.add_argument("--min_count", type=int, default=10) @@ -186,7 +172,6 @@ if __name__ == "__main__": parser.add_argument("-m", "--multilingual", action="store_true") parser.add_argument("-w", "--wce", action="store_true") parser.add_argument("-t", "--textual_transformer", action="store_true") - parser.add_argument("-v", "--visual_transformer", action="store_true") parser.add_argument("--n_jobs", type=int, default=-1) parser.add_argument("--optimc", action="store_true") parser.add_argument("--features", action="store_false") @@ -200,9 +185,6 @@ if __name__ == "__main__": parser.add_argument("--max_length", type=int, default=128) parser.add_argument("--patience", type=int, default=5) parser.add_argument("--evaluate_step", type=int, default=10) - # Visual Transformer parameters -------------- - parser.add_argument("--visual_trf_name", type=str, default="vit") - parser.add_argument("--visual_lr", type=float, default=1e-4) # logging parser.add_argument("--wandb", action="store_true")