concat aggfunc
This commit is contained in:
parent
3f3e4982e4
commit
2a42b21ac9
|
@ -35,12 +35,15 @@ class GeneralizedFunnelling:
|
||||||
device,
|
device,
|
||||||
load_trained,
|
load_trained,
|
||||||
dataset_name,
|
dataset_name,
|
||||||
|
probabilistic,
|
||||||
|
aggfunc,
|
||||||
):
|
):
|
||||||
# Setting VFGs -----------
|
# Setting VFGs -----------
|
||||||
self.posteriors_vgf = posterior
|
self.posteriors_vgf = posterior
|
||||||
self.wce_vgf = wce
|
self.wce_vgf = wce
|
||||||
self.multilingual_vgf = multilingual
|
self.multilingual_vgf = multilingual
|
||||||
self.trasformer_vgf = transformer
|
self.trasformer_vgf = transformer
|
||||||
|
self.probabilistic = probabilistic
|
||||||
# ------------------------
|
# ------------------------
|
||||||
self.langs = langs
|
self.langs = langs
|
||||||
self.embed_dir = embed_dir
|
self.embed_dir = embed_dir
|
||||||
|
@ -62,13 +65,16 @@ class GeneralizedFunnelling:
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
self.first_tier_learners = []
|
self.first_tier_learners = []
|
||||||
self.metaclassifier = None
|
self.metaclassifier = None
|
||||||
self.aggfunc = "mean"
|
self.aggfunc = aggfunc
|
||||||
self.load_trained = load_trained
|
self.load_trained = load_trained
|
||||||
self.dataset_name = dataset_name
|
self.dataset_name = dataset_name
|
||||||
self._init()
|
self._init()
|
||||||
|
|
||||||
def _init(self):
|
def _init(self):
|
||||||
print("[Init GeneralizedFunnelling]")
|
print("[Init GeneralizedFunnelling]")
|
||||||
|
assert not (
|
||||||
|
self.aggfunc == "mean" and self.probabilistic is False
|
||||||
|
), "When using averaging aggreagation function probabilistic must be True"
|
||||||
if self.load_trained is not None:
|
if self.load_trained is not None:
|
||||||
print("- loading trained VGFs, metaclassifer and vectorizer")
|
print("- loading trained VGFs, metaclassifer and vectorizer")
|
||||||
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load(
|
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load(
|
||||||
|
@ -90,7 +96,7 @@ class GeneralizedFunnelling:
|
||||||
langs=self.langs,
|
langs=self.langs,
|
||||||
n_jobs=self.n_jobs,
|
n_jobs=self.n_jobs,
|
||||||
cached=self.cached,
|
cached=self.cached,
|
||||||
probabilistic=True,
|
probabilistic=self.probabilistic,
|
||||||
)
|
)
|
||||||
self.first_tier_learners.append(multilingual_vgf)
|
self.first_tier_learners.append(multilingual_vgf)
|
||||||
|
|
||||||
|
@ -100,6 +106,7 @@ class GeneralizedFunnelling:
|
||||||
|
|
||||||
if self.trasformer_vgf:
|
if self.trasformer_vgf:
|
||||||
transformer_vgf = TextualTransformerGen(
|
transformer_vgf = TextualTransformerGen(
|
||||||
|
dataset_name=self.dataset_name,
|
||||||
model_name=self.transformer_name,
|
model_name=self.transformer_name,
|
||||||
lr=self.lr_transformer,
|
lr=self.lr_transformer,
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
|
@ -107,11 +114,10 @@ class GeneralizedFunnelling:
|
||||||
max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
print_steps=50,
|
print_steps=50,
|
||||||
probabilistic=True,
|
probabilistic=self.probabilistic,
|
||||||
evaluate_step=self.evaluate_step,
|
evaluate_step=self.evaluate_step,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
dataset_name=self.dataset_name,
|
|
||||||
)
|
)
|
||||||
self.first_tier_learners.append(transformer_vgf)
|
self.first_tier_learners.append(transformer_vgf)
|
||||||
|
|
||||||
|
@ -174,10 +180,18 @@ class GeneralizedFunnelling:
|
||||||
def aggregate(self, first_tier_projections):
|
def aggregate(self, first_tier_projections):
|
||||||
if self.aggfunc == "mean":
|
if self.aggfunc == "mean":
|
||||||
aggregated = self._aggregate_mean(first_tier_projections)
|
aggregated = self._aggregate_mean(first_tier_projections)
|
||||||
|
elif self.aggfunc == "concat":
|
||||||
|
aggregated = self._aggregate_concat(first_tier_projections)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return aggregated
|
return aggregated
|
||||||
|
|
||||||
|
def _aggregate_concat(self, first_tier_projections):
|
||||||
|
aggregated = {}
|
||||||
|
for lang in self.langs:
|
||||||
|
aggregated[lang] = np.hstack([v[lang] for v in first_tier_projections])
|
||||||
|
return aggregated
|
||||||
|
|
||||||
def _aggregate_mean(self, first_tier_projections):
|
def _aggregate_mean(self, first_tier_projections):
|
||||||
aggregated = {
|
aggregated = {
|
||||||
lang: np.zeros(data.shape)
|
lang: np.zeros(data.shape)
|
||||||
|
|
|
@ -20,7 +20,6 @@ transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
||||||
# TODO: add support to loggers
|
# TODO: add support to loggers
|
||||||
# TODO: multiple inheritance - maybe define a superclass for TransformerGenerator, whether it is a Textual or a Visual one, implementing dataset creation functions
|
|
||||||
|
|
||||||
|
|
||||||
class TextualTransformerGen(ViewGen, TransformerGen):
|
class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
|
@ -42,7 +41,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
patience=5,
|
patience=5,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_name,
|
self._validate_model_name(model_name),
|
||||||
dataset_name,
|
dataset_name,
|
||||||
epochs,
|
epochs,
|
||||||
lr,
|
lr,
|
||||||
|
@ -58,28 +57,19 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
patience,
|
patience,
|
||||||
)
|
)
|
||||||
self.fitted = False
|
self.fitted = False
|
||||||
self._init()
|
|
||||||
|
|
||||||
def _init(self):
|
|
||||||
if self.probabilistic:
|
|
||||||
self.feature2posterior_projector = FeatureSet2Posteriors(
|
|
||||||
n_jobs=self.n_jobs, verbose=False
|
|
||||||
)
|
|
||||||
self.model_name = self._get_model_name(self.model_name)
|
|
||||||
print(
|
print(
|
||||||
f"- init TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_model_name(self, name):
|
def _validate_model_name(self, model_name):
|
||||||
if "bert" == name:
|
if "bert" == model_name:
|
||||||
name_model = "bert-base-uncased"
|
return "bert-base-uncased"
|
||||||
elif "mbert" == name:
|
elif "mbert" == model_name:
|
||||||
name_model = "bert-base-multilingual-uncased"
|
return "bert-base-multilingual-uncased"
|
||||||
elif "xlm" == name:
|
elif "xlm" == model_name:
|
||||||
name_model = "xlm-roberta-base"
|
return "xlm-roberta-base"
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return name_model
|
|
||||||
|
|
||||||
def load_pretrained_model(self, model_name, num_labels):
|
def load_pretrained_model(self, model_name, num_labels):
|
||||||
return AutoModelForSequenceClassification.from_pretrained(
|
return AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
@ -192,6 +182,8 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
|
|
||||||
if self.probabilistic and self.fitted:
|
if self.probabilistic and self.fitted:
|
||||||
l_embeds = self.feature2posterior_projector.transform(l_embeds)
|
l_embeds = self.feature2posterior_projector.transform(l_embeds)
|
||||||
|
elif not self.probabilistic and self.fitted:
|
||||||
|
l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()}
|
||||||
|
|
||||||
return l_embeds
|
return l_embeds
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from vgfs.learners.svms import FeatureSet2Posteriors
|
||||||
|
|
||||||
|
|
||||||
class TransformerGen:
|
class TransformerGen:
|
||||||
|
@ -45,9 +46,16 @@ class TransformerGen:
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.patience = patience
|
self.patience = patience
|
||||||
self.datasets = {}
|
self.datasets = {}
|
||||||
|
self.feature2posterior_projector = (
|
||||||
|
self.make_probabilistic() if probabilistic else None
|
||||||
|
)
|
||||||
|
|
||||||
def make_probabilistic(self):
|
def make_probabilistic(self):
|
||||||
raise NotImplementedError
|
if self.probabilistic:
|
||||||
|
feature2posterior_projector = FeatureSet2Posteriors(
|
||||||
|
n_jobs=self.n_jobs, verbose=False
|
||||||
|
)
|
||||||
|
return feature2posterior_projector
|
||||||
|
|
||||||
def build_dataloader(
|
def build_dataloader(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -40,6 +40,10 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
evaluate_step=evaluate_step,
|
evaluate_step=evaluate_step,
|
||||||
patience=patience,
|
patience=patience,
|
||||||
)
|
)
|
||||||
|
self.fitted = False
|
||||||
|
print(
|
||||||
|
f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
||||||
|
)
|
||||||
|
|
||||||
def _validate_model_name(self, model_name):
|
def _validate_model_name(self, model_name):
|
||||||
if "vit" == model_name:
|
if "vit" == model_name:
|
||||||
|
@ -128,6 +132,9 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.probabilistic:
|
||||||
|
self.feature2posterior_projector.fit(self.transform(lX), lY)
|
||||||
|
|
||||||
def transform(self, lX):
|
def transform(self, lX):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
34
main.py
34
main.py
|
@ -41,7 +41,7 @@ def get_dataset(datasetname):
|
||||||
dataset = (
|
dataset = (
|
||||||
MultilingualDataset(dataset_name="rcv1-2")
|
MultilingualDataset(dataset_name="rcv1-2")
|
||||||
.load(RCV_DATAPATH)
|
.load(RCV_DATAPATH)
|
||||||
.reduce_data(langs=["en", "it", "fr"], maxn=500)
|
.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -54,9 +54,9 @@ def main(args):
|
||||||
dataset, MultiNewsDataset
|
dataset, MultiNewsDataset
|
||||||
):
|
):
|
||||||
lX, lY = dataset.training()
|
lX, lY = dataset.training()
|
||||||
# lX_te, lY_te = dataset.test()
|
lX_te, lY_te = dataset.test()
|
||||||
print("[NB: for debug purposes, training set is also used as test set]\n")
|
# print("[NB: for debug purposes, training set is also used as test set]\n")
|
||||||
lX_te, lY_te = dataset.training()
|
# lX_te, lY_te = dataset.training()
|
||||||
else:
|
else:
|
||||||
_lX = dataset.dX
|
_lX = dataset.dX
|
||||||
_lY = dataset.dY
|
_lY = dataset.dY
|
||||||
|
@ -75,24 +75,32 @@ def main(args):
|
||||||
), "At least one of VGF must be True"
|
), "At least one of VGF must be True"
|
||||||
|
|
||||||
gfun = GeneralizedFunnelling(
|
gfun = GeneralizedFunnelling(
|
||||||
|
# dataset params ----------------------
|
||||||
dataset_name=args.dataset,
|
dataset_name=args.dataset,
|
||||||
posterior=args.posteriors,
|
|
||||||
multilingual=args.multilingual,
|
|
||||||
wce=args.wce,
|
|
||||||
transformer=args.transformer,
|
|
||||||
langs=dataset.langs(),
|
langs=dataset.langs(),
|
||||||
|
# Posterior VGF params ----------------
|
||||||
|
posterior=args.posteriors,
|
||||||
|
# Multilingual VGF params -------------
|
||||||
|
multilingual=args.multilingual,
|
||||||
embed_dir="~/resources/muse_embeddings",
|
embed_dir="~/resources/muse_embeddings",
|
||||||
n_jobs=args.n_jobs,
|
# WCE VGF params ----------------------
|
||||||
max_length=args.max_length,
|
wce=args.wce,
|
||||||
|
# Transformer VGF params --------------
|
||||||
|
transformer=args.transformer,
|
||||||
|
transformer_name=args.transformer_name,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
|
max_length=args.max_length,
|
||||||
patience=args.patience,
|
patience=args.patience,
|
||||||
evaluate_step=args.evaluate_step,
|
evaluate_step=args.evaluate_step,
|
||||||
transformer_name=args.transformer_name,
|
|
||||||
device="cuda",
|
device="cuda",
|
||||||
|
# General params ----------------------
|
||||||
|
probabilistic=args.features,
|
||||||
|
aggfunc=args.aggfunc,
|
||||||
optimc=args.optimc,
|
optimc=args.optimc,
|
||||||
load_trained=args.load_trained,
|
load_trained=args.load_trained,
|
||||||
|
n_jobs=args.n_jobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# gfun.get_config()
|
# gfun.get_config()
|
||||||
|
@ -125,7 +133,7 @@ if __name__ == "__main__":
|
||||||
# Dataset parameters -------------------
|
# Dataset parameters -------------------
|
||||||
parser.add_argument("-d", "--dataset", type=str, default="multinews")
|
parser.add_argument("-d", "--dataset", type=str, default="multinews")
|
||||||
parser.add_argument("--domains", type=str, default="all")
|
parser.add_argument("--domains", type=str, default="all")
|
||||||
parser.add_argument("--nrows", type=int, default=10000)
|
parser.add_argument("--nrows", type=int, default=100)
|
||||||
parser.add_argument("--min_count", type=int, default=10)
|
parser.add_argument("--min_count", type=int, default=10)
|
||||||
parser.add_argument("--max_labels", type=int, default=50)
|
parser.add_argument("--max_labels", type=int, default=50)
|
||||||
# gFUN parameters ----------------------
|
# gFUN parameters ----------------------
|
||||||
|
@ -135,6 +143,8 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("-t", "--transformer", action="store_true")
|
parser.add_argument("-t", "--transformer", action="store_true")
|
||||||
parser.add_argument("--n_jobs", type=int, default=1)
|
parser.add_argument("--n_jobs", type=int, default=1)
|
||||||
parser.add_argument("--optimc", action="store_true")
|
parser.add_argument("--optimc", action="store_true")
|
||||||
|
parser.add_argument("--features", action="store_false")
|
||||||
|
parser.add_argument("--aggfunc", type=str, default="mean")
|
||||||
# transformer parameters ---------------
|
# transformer parameters ---------------
|
||||||
parser.add_argument("--transformer_name", type=str, default="mbert")
|
parser.add_argument("--transformer_name", type=str, default="mbert")
|
||||||
parser.add_argument("--batch_size", type=int, default=32)
|
parser.add_argument("--batch_size", type=int, default=32)
|
||||||
|
|
Loading…
Reference in New Issue