concat aggfunc
This commit is contained in:
parent
3f3e4982e4
commit
2a42b21ac9
|
@ -35,12 +35,15 @@ class GeneralizedFunnelling:
|
|||
device,
|
||||
load_trained,
|
||||
dataset_name,
|
||||
probabilistic,
|
||||
aggfunc,
|
||||
):
|
||||
# Setting VFGs -----------
|
||||
self.posteriors_vgf = posterior
|
||||
self.wce_vgf = wce
|
||||
self.multilingual_vgf = multilingual
|
||||
self.trasformer_vgf = transformer
|
||||
self.probabilistic = probabilistic
|
||||
# ------------------------
|
||||
self.langs = langs
|
||||
self.embed_dir = embed_dir
|
||||
|
@ -62,13 +65,16 @@ class GeneralizedFunnelling:
|
|||
self.n_jobs = n_jobs
|
||||
self.first_tier_learners = []
|
||||
self.metaclassifier = None
|
||||
self.aggfunc = "mean"
|
||||
self.aggfunc = aggfunc
|
||||
self.load_trained = load_trained
|
||||
self.dataset_name = dataset_name
|
||||
self._init()
|
||||
|
||||
def _init(self):
|
||||
print("[Init GeneralizedFunnelling]")
|
||||
assert not (
|
||||
self.aggfunc == "mean" and self.probabilistic is False
|
||||
), "When using averaging aggreagation function probabilistic must be True"
|
||||
if self.load_trained is not None:
|
||||
print("- loading trained VGFs, metaclassifer and vectorizer")
|
||||
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load(
|
||||
|
@ -90,7 +96,7 @@ class GeneralizedFunnelling:
|
|||
langs=self.langs,
|
||||
n_jobs=self.n_jobs,
|
||||
cached=self.cached,
|
||||
probabilistic=True,
|
||||
probabilistic=self.probabilistic,
|
||||
)
|
||||
self.first_tier_learners.append(multilingual_vgf)
|
||||
|
||||
|
@ -100,6 +106,7 @@ class GeneralizedFunnelling:
|
|||
|
||||
if self.trasformer_vgf:
|
||||
transformer_vgf = TextualTransformerGen(
|
||||
dataset_name=self.dataset_name,
|
||||
model_name=self.transformer_name,
|
||||
lr=self.lr_transformer,
|
||||
epochs=self.epochs,
|
||||
|
@ -107,11 +114,10 @@ class GeneralizedFunnelling:
|
|||
max_length=self.max_length,
|
||||
device="cuda",
|
||||
print_steps=50,
|
||||
probabilistic=True,
|
||||
probabilistic=self.probabilistic,
|
||||
evaluate_step=self.evaluate_step,
|
||||
verbose=True,
|
||||
patience=self.patience,
|
||||
dataset_name=self.dataset_name,
|
||||
)
|
||||
self.first_tier_learners.append(transformer_vgf)
|
||||
|
||||
|
@ -174,10 +180,18 @@ class GeneralizedFunnelling:
|
|||
def aggregate(self, first_tier_projections):
|
||||
if self.aggfunc == "mean":
|
||||
aggregated = self._aggregate_mean(first_tier_projections)
|
||||
elif self.aggfunc == "concat":
|
||||
aggregated = self._aggregate_concat(first_tier_projections)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return aggregated
|
||||
|
||||
def _aggregate_concat(self, first_tier_projections):
|
||||
aggregated = {}
|
||||
for lang in self.langs:
|
||||
aggregated[lang] = np.hstack([v[lang] for v in first_tier_projections])
|
||||
return aggregated
|
||||
|
||||
def _aggregate_mean(self, first_tier_projections):
|
||||
aggregated = {
|
||||
lang: np.zeros(data.shape)
|
||||
|
|
|
@ -20,7 +20,6 @@ transformers.logging.set_verbosity_error()
|
|||
|
||||
|
||||
# TODO: add support to loggers
|
||||
# TODO: multiple inheritance - maybe define a superclass for TransformerGenerator, whether it is a Textual or a Visual one, implementing dataset creation functions
|
||||
|
||||
|
||||
class TextualTransformerGen(ViewGen, TransformerGen):
|
||||
|
@ -42,7 +41,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
patience=5,
|
||||
):
|
||||
super().__init__(
|
||||
model_name,
|
||||
self._validate_model_name(model_name),
|
||||
dataset_name,
|
||||
epochs,
|
||||
lr,
|
||||
|
@ -58,28 +57,19 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
patience,
|
||||
)
|
||||
self.fitted = False
|
||||
self._init()
|
||||
|
||||
def _init(self):
|
||||
if self.probabilistic:
|
||||
self.feature2posterior_projector = FeatureSet2Posteriors(
|
||||
n_jobs=self.n_jobs, verbose=False
|
||||
)
|
||||
self.model_name = self._get_model_name(self.model_name)
|
||||
print(
|
||||
f"- init TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
||||
f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
||||
)
|
||||
|
||||
def _get_model_name(self, name):
|
||||
if "bert" == name:
|
||||
name_model = "bert-base-uncased"
|
||||
elif "mbert" == name:
|
||||
name_model = "bert-base-multilingual-uncased"
|
||||
elif "xlm" == name:
|
||||
name_model = "xlm-roberta-base"
|
||||
def _validate_model_name(self, model_name):
|
||||
if "bert" == model_name:
|
||||
return "bert-base-uncased"
|
||||
elif "mbert" == model_name:
|
||||
return "bert-base-multilingual-uncased"
|
||||
elif "xlm" == model_name:
|
||||
return "xlm-roberta-base"
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return name_model
|
||||
|
||||
def load_pretrained_model(self, model_name, num_labels):
|
||||
return AutoModelForSequenceClassification.from_pretrained(
|
||||
|
@ -192,6 +182,8 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
|
||||
if self.probabilistic and self.fitted:
|
||||
l_embeds = self.feature2posterior_projector.transform(l_embeds)
|
||||
elif not self.probabilistic and self.fitted:
|
||||
l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()}
|
||||
|
||||
return l_embeds
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from vgfs.learners.svms import FeatureSet2Posteriors
|
||||
|
||||
|
||||
class TransformerGen:
|
||||
|
@ -45,9 +46,16 @@ class TransformerGen:
|
|||
self.verbose = verbose
|
||||
self.patience = patience
|
||||
self.datasets = {}
|
||||
self.feature2posterior_projector = (
|
||||
self.make_probabilistic() if probabilistic else None
|
||||
)
|
||||
|
||||
def make_probabilistic(self):
|
||||
raise NotImplementedError
|
||||
if self.probabilistic:
|
||||
feature2posterior_projector = FeatureSet2Posteriors(
|
||||
n_jobs=self.n_jobs, verbose=False
|
||||
)
|
||||
return feature2posterior_projector
|
||||
|
||||
def build_dataloader(
|
||||
self,
|
||||
|
|
|
@ -40,6 +40,10 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
evaluate_step=evaluate_step,
|
||||
patience=patience,
|
||||
)
|
||||
self.fitted = False
|
||||
print(
|
||||
f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
||||
)
|
||||
|
||||
def _validate_model_name(self, model_name):
|
||||
if "vit" == model_name:
|
||||
|
@ -128,6 +132,9 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
epochs=self.epochs,
|
||||
)
|
||||
|
||||
if self.probabilistic:
|
||||
self.feature2posterior_projector.fit(self.transform(lX), lY)
|
||||
|
||||
def transform(self, lX):
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
34
main.py
34
main.py
|
@ -41,7 +41,7 @@ def get_dataset(datasetname):
|
|||
dataset = (
|
||||
MultilingualDataset(dataset_name="rcv1-2")
|
||||
.load(RCV_DATAPATH)
|
||||
.reduce_data(langs=["en", "it", "fr"], maxn=500)
|
||||
.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@ -54,9 +54,9 @@ def main(args):
|
|||
dataset, MultiNewsDataset
|
||||
):
|
||||
lX, lY = dataset.training()
|
||||
# lX_te, lY_te = dataset.test()
|
||||
print("[NB: for debug purposes, training set is also used as test set]\n")
|
||||
lX_te, lY_te = dataset.training()
|
||||
lX_te, lY_te = dataset.test()
|
||||
# print("[NB: for debug purposes, training set is also used as test set]\n")
|
||||
# lX_te, lY_te = dataset.training()
|
||||
else:
|
||||
_lX = dataset.dX
|
||||
_lY = dataset.dY
|
||||
|
@ -75,24 +75,32 @@ def main(args):
|
|||
), "At least one of VGF must be True"
|
||||
|
||||
gfun = GeneralizedFunnelling(
|
||||
# dataset params ----------------------
|
||||
dataset_name=args.dataset,
|
||||
posterior=args.posteriors,
|
||||
multilingual=args.multilingual,
|
||||
wce=args.wce,
|
||||
transformer=args.transformer,
|
||||
langs=dataset.langs(),
|
||||
# Posterior VGF params ----------------
|
||||
posterior=args.posteriors,
|
||||
# Multilingual VGF params -------------
|
||||
multilingual=args.multilingual,
|
||||
embed_dir="~/resources/muse_embeddings",
|
||||
n_jobs=args.n_jobs,
|
||||
max_length=args.max_length,
|
||||
# WCE VGF params ----------------------
|
||||
wce=args.wce,
|
||||
# Transformer VGF params --------------
|
||||
transformer=args.transformer,
|
||||
transformer_name=args.transformer_name,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
lr=args.lr,
|
||||
max_length=args.max_length,
|
||||
patience=args.patience,
|
||||
evaluate_step=args.evaluate_step,
|
||||
transformer_name=args.transformer_name,
|
||||
device="cuda",
|
||||
# General params ----------------------
|
||||
probabilistic=args.features,
|
||||
aggfunc=args.aggfunc,
|
||||
optimc=args.optimc,
|
||||
load_trained=args.load_trained,
|
||||
n_jobs=args.n_jobs,
|
||||
)
|
||||
|
||||
# gfun.get_config()
|
||||
|
@ -125,7 +133,7 @@ if __name__ == "__main__":
|
|||
# Dataset parameters -------------------
|
||||
parser.add_argument("-d", "--dataset", type=str, default="multinews")
|
||||
parser.add_argument("--domains", type=str, default="all")
|
||||
parser.add_argument("--nrows", type=int, default=10000)
|
||||
parser.add_argument("--nrows", type=int, default=100)
|
||||
parser.add_argument("--min_count", type=int, default=10)
|
||||
parser.add_argument("--max_labels", type=int, default=50)
|
||||
# gFUN parameters ----------------------
|
||||
|
@ -135,6 +143,8 @@ if __name__ == "__main__":
|
|||
parser.add_argument("-t", "--transformer", action="store_true")
|
||||
parser.add_argument("--n_jobs", type=int, default=1)
|
||||
parser.add_argument("--optimc", action="store_true")
|
||||
parser.add_argument("--features", action="store_false")
|
||||
parser.add_argument("--aggfunc", type=str, default="mean")
|
||||
# transformer parameters ---------------
|
||||
parser.add_argument("--transformer_name", type=str, default="mbert")
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
|
|
Loading…
Reference in New Issue