fixed imports

This commit is contained in:
andrea 2021-01-26 15:52:09 +01:00
parent ce4e32aad2
commit 20c76a2103
19 changed files with 52 additions and 60 deletions

View File

@ -1,12 +1,11 @@
from argparse import ArgumentParser from argparse import ArgumentParser
from data.dataset_builder import MultilingualDataset from src.data.dataset_builder import MultilingualDataset
from funnelling import * from src.funnelling import *
from util.common import MultilingualIndex, get_params, get_method_name from src.util.common import MultilingualIndex, get_params, get_method_name
from util.evaluation import evaluate from src.util.evaluation import evaluate
from util.results_csv import CSVlog from src.util.results_csv import CSVlog
from view_generators import * from src.view_generators import *
from time import time
def main(args): def main(args):
@ -60,18 +59,17 @@ def main(args):
# Training --------------------------------------- # Training ---------------------------------------
print('\n[Training Generalized Funnelling]') print('\n[Training Generalized Funnelling]')
time_init = time() time_init = time.time()
time_tr = time()
gfun.fit(lX, ly) gfun.fit(lX, ly)
time_tr = round(time() - time_tr, 3) time_tr = round(time.time() - time_init, 3)
print(f'Training completed in {time_tr} seconds!') print(f'Training completed in {time_tr} seconds!')
# Testing ---------------------------------------- # Testing ----------------------------------------
print('\n[Testing Generalized Funnelling]') print('\n[Testing Generalized Funnelling]')
time_te = time() time_te = time.time()
ly_ = gfun.predict(lXte) ly_ = gfun.predict(lXte)
l_eval = evaluate(ly_true=lyte, ly_pred=ly_) l_eval = evaluate(ly_true=lyte, ly_pred=ly_)
time_te = round(time() - time_te, 3) time_te = round(time.time() - time_te, 3)
print(f'Testing completed in {time_te} seconds!') print(f'Testing completed in {time_te} seconds!')
# Logging --------------------------------------- # Logging ---------------------------------------
@ -101,7 +99,7 @@ def main(args):
notes='') notes='')
print('Averages: MF1, mF1, MK, mK', np.round(np.mean(np.array(metrics), axis=0), 3)) print('Averages: MF1, mF1, MK, mK', np.round(np.mean(np.array(metrics), axis=0), 3))
overall_time = round(time() - time_init, 3) overall_time = round(time.time() - time_init, 3)
exit(f'\nExecuted in: {overall_time} seconds!') exit(f'\nExecuted in: {overall_time} seconds!')
@ -112,7 +110,7 @@ if __name__ == '__main__':
parser.add_argument('-o', '--output', dest='csv_dir', parser.add_argument('-o', '--output', dest='csv_dir',
help='Result file (default ../csv_log/gfun_results.csv)', type=str, help='Result file (default ../csv_log/gfun_results.csv)', type=str,
default='csv_logs/gfun/gfun_results.csv') default='../csv_logs/gfun/gfun_results.csv')
parser.add_argument('-x', '--post_embedder', dest='post_embedder', action='store_true', parser.add_argument('-x', '--post_embedder', dest='post_embedder', action='store_true',
help='deploy posterior probabilities embedder to compute document embeddings', help='deploy posterior probabilities embedder to compute document embeddings',
@ -138,7 +136,7 @@ if __name__ == '__main__':
help='Optimize SVMs C hyperparameter', help='Optimize SVMs C hyperparameter',
default=False) default=False)
parser.add_argument('-n', '--nepochs', dest='nepochs', type=str, parser.add_argument('-n', '--nepochs', dest='nepochs', type=int,
help='Number of max epochs to train Recurrent embedder (i.e., -g)') help='Number of max epochs to train Recurrent embedder (i.e., -g)')
parser.add_argument('-j', '--n_jobs', dest='n_jobs', type=int, parser.add_argument('-j', '--n_jobs', dest='n_jobs', type=int,

View File

View File

@ -135,15 +135,15 @@ class RecurrentDataModule(pl.LightningDataModule):
lPad_index=self.multilingualIndex.l_pad()) lPad_index=self.multilingualIndex.l_pad())
def train_dataloader(self): def train_dataloader(self):
return DataLoader(self.training_dataset, batch_size=self.batchsize, num_workers=self.n_jobs, return DataLoader(self.training_dataset, batch_size=self.batchsize, num_workers=N_WORKERS,
collate_fn=self.training_dataset.collate_fn) collate_fn=self.training_dataset.collate_fn)
def val_dataloader(self): def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batchsize, num_workers=self.n_jobs, return DataLoader(self.val_dataset, batch_size=self.batchsize, num_workers=N_WORKERS,
collate_fn=self.val_dataset.collate_fn) collate_fn=self.val_dataset.collate_fn)
def test_dataloader(self): def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batchsize, num_workers=self.n_jobs, return DataLoader(self.test_dataset, batch_size=self.batchsize, num_workers=N_WORKERS,
collate_fn=self.test_dataset.collate_fn) collate_fn=self.test_dataset.collate_fn)

View File

@ -1,5 +1,4 @@
import itertools import itertools
import pickle
import re import re
from os.path import exists from os.path import exists
@ -12,10 +11,10 @@ from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer from sklearn.preprocessing import MultiLabelBinarizer
from tqdm import tqdm from tqdm import tqdm
from data.languages import NLTK_LANGMAP, RCV2_LANGS_WITH_NLTK_STEMMING from src.data.languages import NLTK_LANGMAP, RCV2_LANGS_WITH_NLTK_STEMMING
from data.reader.jrcacquis_reader import * from src.data.reader.jrcacquis_reader import *
from data.reader.rcv_reader import fetch_RCV1, fetch_RCV2 from src.data.reader.rcv_reader import fetch_RCV1, fetch_RCV2
from data.text_preprocessor import NLTKStemTokenizer, preprocess_documents from src.data.text_preprocessor import NLTKStemTokenizer, preprocess_documents
class MultilingualDataset: class MultilingualDataset:

View File

@ -14,9 +14,9 @@ import rdflib
from rdflib.namespace import RDF, SKOS from rdflib.namespace import RDF, SKOS
from sklearn.datasets import get_data_home from sklearn.datasets import get_data_home
from data.languages import JRC_LANGS from src.data.languages import JRC_LANGS
from data.languages import lang_set from src.data.languages import lang_set
from util.file import download_file, list_dirs, list_files from src.util.file import download_file, list_dirs, list_files
""" """
JRC Acquis' Nomenclature: JRC Acquis' Nomenclature:

View File

@ -5,8 +5,8 @@ from zipfile import ZipFile
import numpy as np import numpy as np
from util.file import download_file_if_not_exists from src.util.file import download_file_if_not_exists
from util.file import list_files from src.util.file import list_files
""" """
RCV2's Nomenclature: RCV2's Nomenclature:

View File

@ -11,7 +11,6 @@ from os.path import join
from xml.sax.saxutils import escape from xml.sax.saxutils import escape
import numpy as np import numpy as np
from util.file import list_dirs, list_files from util.file import list_dirs, list_files
policies = ["IN_ALL_LANGS", "IN_ANY_LANG"] policies = ["IN_ALL_LANGS", "IN_ANY_LANG"]

View File

@ -2,7 +2,7 @@ from nltk import word_tokenize
from nltk.corpus import stopwords from nltk.corpus import stopwords
from nltk.stem import SnowballStemmer from nltk.stem import SnowballStemmer
from data.languages import NLTK_LANGMAP from src.data.languages import NLTK_LANGMAP
def preprocess_documents(documents, lang): def preprocess_documents(documents, lang):

View File

@ -1,6 +1,6 @@
from models.learners import * from src.models.learners import *
from util.common import _normalize from src.util.common import _normalize
from view_generators import VanillaFunGen from src.view_generators import VanillaFunGen
class DocEmbedderList: class DocEmbedderList:

View File

@ -7,7 +7,7 @@ from sklearn.model_selection import GridSearchCV
from sklearn.multiclass import OneVsRestClassifier from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC from sklearn.svm import SVC
from util.standardizer import StandardizeTransformer from src.util.standardizer import StandardizeTransformer
def get_learner(calibrate=False, kernel='linear', C=1): def get_learner(calibrate=False, kernel='linear', C=1):

View File

@ -1,7 +1,6 @@
#taken from https://github.com/prakashpandey9/Text-Classification-Pytorch/blob/master/models/LSTM.py #taken from https://github.com/prakashpandey9/Text-Classification-Pytorch/blob/master/models/LSTM.py
from torch.autograd import Variable
from models.helpers import * from models.helpers import *
from torch.autograd import Variable
class RNNMultilingualClassifier(nn.Module): class RNNMultilingualClassifier(nn.Module):

View File

@ -3,8 +3,8 @@ import torch
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
from transformers import BertForSequenceClassification, AdamW from transformers import BertForSequenceClassification, AdamW
from util.common import define_pad_length, pad from src.util.common import define_pad_length, pad
from util.pl_metrics import CustomF1, CustomK from src.util.pl_metrics import CustomF1, CustomK
class BertModel(pl.LightningModule): class BertModel(pl.LightningModule):

View File

@ -7,9 +7,9 @@ from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
from transformers import AdamW from transformers import AdamW
from models.helpers import init_embeddings from src.models.helpers import init_embeddings
from util.common import define_pad_length, pad from src.util.common import define_pad_length, pad
from util.pl_metrics import CustomF1, CustomK from src.util.pl_metrics import CustomF1, CustomK
class RecurrentModel(pl.LightningModule): class RecurrentModel(pl.LightningModule):

View File

@ -4,7 +4,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize from sklearn.preprocessing import normalize
from util.embeddings_manager import supervised_embeddings_tfidf from src.util.embeddings_manager import supervised_embeddings_tfidf
class TfidfVectorizerMultilingual: class TfidfVectorizerMultilingual:

View File

@ -4,7 +4,7 @@ import numpy as np
import torch import torch
from torchtext.vocab import Vectors from torchtext.vocab import Vectors
from util.SIF_embed import remove_pc from src.util.SIF_embed import remove_pc
class PretrainedEmbeddings(ABC): class PretrainedEmbeddings(ABC):

View File

@ -1,7 +1,6 @@
import numpy as np
from joblib import Parallel, delayed from joblib import Parallel, delayed
from util.metrics import * from src.util.metrics import *
def evaluation_metrics(y, y_): def evaluation_metrics(y, y_):

View File

@ -1,7 +1,7 @@
import torch import torch
from pytorch_lightning.metrics import Metric from pytorch_lightning.metrics import Metric
from util.common import is_false, is_true from src.util.common import is_false, is_true
def _update(pred, target, device): def _update(pred, target, device):

View File

@ -21,12 +21,12 @@ from time import time
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.loggers import TensorBoardLogger
from data.datamodule import RecurrentDataModule, BertDataModule, tokenize from src.data.datamodule import RecurrentDataModule, BertDataModule, tokenize
from models.learners import * from src.models.learners import *
from models.pl_bert import BertModel from src.models.pl_bert import BertModel
from models.pl_gru import RecurrentModel from src.models.pl_gru import RecurrentModel
from util.common import TfidfVectorizerMultilingual, _normalize from src.util.common import TfidfVectorizerMultilingual, _normalize
from util.embeddings_manager import MuseLoader, XdotM, wce_matrix from src.util.embeddings_manager import MuseLoader, XdotM, wce_matrix
class ViewGen(ABC): class ViewGen(ABC):
@ -232,7 +232,7 @@ class RecurrentGen(ViewGen):
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1) self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
self.multilingualIndex.embedding_matrices(self.pretrained, supervised=self.wce) self.multilingualIndex.embedding_matrices(self.pretrained, supervised=self.wce)
self.model = self._init_model() self.model = self._init_model()
self.logger = TensorBoardLogger(save_dir='tb_logs', name='rnn', default_hp_metric=False) self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False)
# self.logger = CSVLogger(save_dir='csv_logs', name='rnn_dev') # self.logger = CSVLogger(save_dir='csv_logs', name='rnn_dev')
def _init_model(self): def _init_model(self):
@ -293,9 +293,9 @@ class RecurrentGen(ViewGen):
data = self.multilingualIndex.l_devel_index() data = self.multilingualIndex.l_devel_index()
self.model.to('cuda' if self.gpus else 'cpu') self.model.to('cuda' if self.gpus else 'cpu')
self.model.eval() self.model.eval()
time_init = time() time_init = time.time()
l_embeds = self.model.encode(data, l_pad, batch_size=256) l_embeds = self.model.encode(data, l_pad, batch_size=256)
transform_time = round(time() - time_init, 3) transform_time = round(time.time() - time_init, 3)
print(f'Executed! Transform took: {transform_time}') print(f'Executed! Transform took: {transform_time}')
return l_embeds return l_embeds
@ -328,7 +328,7 @@ class BertGen(ViewGen):
self.n_jobs = n_jobs self.n_jobs = n_jobs
self.stored_path = stored_path self.stored_path = stored_path
self.model = self._init_model() self.model = self._init_model()
self.logger = TensorBoardLogger(save_dir='tb_logs', name='bert', default_hp_metric=False) self.logger = TensorBoardLogger(save_dir='../tb_logs', name='bert', default_hp_metric=False)
def _init_model(self): def _init_model(self):
output_size = self.multilingualIndex.get_target_dim() output_size = self.multilingualIndex.get_target_dim()
@ -362,14 +362,12 @@ class BertGen(ViewGen):
data = tokenize(data, max_len=512) data = tokenize(data, max_len=512)
self.model.to('cuda' if self.gpus else 'cpu') self.model.to('cuda' if self.gpus else 'cpu')
self.model.eval() self.model.eval()
time_init = time() time_init = time.time()
l_emebds = self.model.encode(data, batch_size=64) l_emebds = self.model.encode(data, batch_size=64)
transform_time = round(time() - time_init, 3) transform_time = round(time.time() - time_init, 3)
print(f'Executed! Transform took: {transform_time}') print(f'Executed! Transform took: {transform_time}')
return l_emebds return l_emebds
def fit_transform(self, lX, ly): def fit_transform(self, lX, ly):
# we can assume that we have already indexed data for transform() since we are first calling fit() # we can assume that we have already indexed data for transform() since we are first calling fit()
return self.fit(lX, ly).transform(lX) return self.fit(lX, ly).transform(lX)