diff --git a/refactor/data/datamodule.py b/refactor/data/datamodule.py index 12d7e02..1121a58 100644 --- a/refactor/data/datamodule.py +++ b/refactor/data/datamodule.py @@ -1,7 +1,7 @@ -import torch -from torch.utils.data import Dataset, DataLoader import numpy as np import pytorch_lightning as pl +import torch +from torch.utils.data import Dataset, DataLoader from transformers import BertTokenizer N_WORKERS = 8 diff --git a/refactor/data/dataset_builder.py b/refactor/data/dataset_builder.py index b9650c7..0e91316 100644 --- a/refactor/data/dataset_builder.py +++ b/refactor/data/dataset_builder.py @@ -1,19 +1,21 @@ -from os.path import join, exists -from nltk.corpus import stopwords -from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer -from sklearn.preprocessing import MultiLabelBinarizer -from data.reader.jrcacquis_reader import * -from data.languages import lang_set, NLTK_LANGMAP, RCV2_LANGS_WITH_NLTK_STEMMING -from data.reader.rcv_reader import fetch_RCV1, fetch_RCV2, fetch_topic_hierarchy -from data.text_preprocessor import NLTKStemTokenizer, preprocess_documents -import pickle -import numpy as np -from sklearn.model_selection import train_test_split -from scipy.sparse import issparse import itertools -from tqdm import tqdm +import pickle import re +from os.path import exists + +import numpy as np +from nltk.corpus import stopwords from scipy.sparse import csr_matrix +from scipy.sparse import issparse +from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import MultiLabelBinarizer +from tqdm import tqdm + +from data.languages import NLTK_LANGMAP, RCV2_LANGS_WITH_NLTK_STEMMING +from data.reader.jrcacquis_reader import * +from data.reader.rcv_reader import fetch_RCV1, fetch_RCV2 +from data.text_preprocessor import NLTKStemTokenizer, preprocess_documents class MultilingualDataset: diff --git a/refactor/data/reader/jrcacquis_reader.py b/refactor/data/reader/jrcacquis_reader.py index c0441ed..e911996 100644 --- a/refactor/data/reader/jrcacquis_reader.py +++ b/refactor/data/reader/jrcacquis_reader.py @@ -1,19 +1,22 @@ from __future__ import print_function -import os, sys -from os.path import join + +import os +import pickle +import sys import tarfile import xml.etree.ElementTree as ET -from sklearn.datasets import get_data_home -import pickle -from util.file import download_file, list_dirs, list_files +import zipfile +from collections import Counter +from os.path import join +from random import shuffle + import rdflib from rdflib.namespace import RDF, SKOS -from rdflib import URIRef -import zipfile +from sklearn.datasets import get_data_home + from data.languages import JRC_LANGS -from collections import Counter -from random import shuffle from data.languages import lang_set +from util.file import download_file, list_dirs, list_files """ JRC Acquis' Nomenclature: diff --git a/refactor/data/reader/rcv_reader.py b/refactor/data/reader/rcv_reader.py index cd4b416..b3db098 100644 --- a/refactor/data/reader/rcv_reader.py +++ b/refactor/data/reader/rcv_reader.py @@ -1,15 +1,12 @@ -from zipfile import ZipFile -import xml.etree.ElementTree as ET -from data.languages import RCV2_LANGS_WITH_NLTK_STEMMING, RCV2_LANGS -from util.file import list_files -from sklearn.datasets import get_data_home -import gzip -from os.path import join, exists -from util.file import download_file_if_not_exists import re -from collections import Counter +import xml.etree.ElementTree as ET +from os.path import join, exists +from zipfile import ZipFile + import numpy as np -import sys + +from util.file import download_file_if_not_exists +from util.file import list_files """ RCV2's Nomenclature: diff --git a/refactor/data/reader/wikipedia_tools.py b/refactor/data/reader/wikipedia_tools.py index 83e11e3..9558fb6 100644 --- a/refactor/data/reader/wikipedia_tools.py +++ b/refactor/data/reader/wikipedia_tools.py @@ -1,16 +1,19 @@ from __future__ import print_function + # import ijson # from ijson.common import ObjectBuilder -import os, sys -from os.path import join -from bz2 import BZ2File +import os import pickle -from util.file import list_dirs, list_files, makedirs_if_not_exist -from itertools import islice import re +from bz2 import BZ2File +from itertools import islice +from os.path import join from xml.sax.saxutils import escape + import numpy as np +from util.file import list_dirs, list_files + policies = ["IN_ALL_LANGS", "IN_ANY_LANG"] """ diff --git a/refactor/data/text_preprocessor.py b/refactor/data/text_preprocessor.py index 1a6e3ae..fcfddba 100644 --- a/refactor/data/text_preprocessor.py +++ b/refactor/data/text_preprocessor.py @@ -1,8 +1,9 @@ -from nltk.corpus import stopwords -from data.languages import NLTK_LANGMAP from nltk import word_tokenize +from nltk.corpus import stopwords from nltk.stem import SnowballStemmer +from data.languages import NLTK_LANGMAP + def preprocess_documents(documents, lang): tokens = NLTKStemTokenizer(lang, verbose=True) diff --git a/refactor/data/tsr_function__.py b/refactor/data/tsr_function__.py index 0af8690..c458029 100755 --- a/refactor/data/tsr_function__.py +++ b/refactor/data/tsr_function__.py @@ -1,8 +1,9 @@ import math + import numpy as np -from scipy.stats import t from joblib import Parallel, delayed from scipy.sparse import csr_matrix, csc_matrix +from scipy.stats import t def get_probs(tpr, fpr, pc): diff --git a/refactor/funnelling.py b/refactor/funnelling.py index 6c79ae9..4d19e1a 100644 --- a/refactor/funnelling.py +++ b/refactor/funnelling.py @@ -1,6 +1,6 @@ from models.learners import * -from view_generators import VanillaFunGen from util.common import _normalize +from view_generators import VanillaFunGen class DocEmbedderList: diff --git a/refactor/main.py b/refactor/main.py index d043d76..48936d0 100644 --- a/refactor/main.py +++ b/refactor/main.py @@ -1,11 +1,11 @@ from argparse import ArgumentParser -from funnelling import * -from view_generators import * + from data.dataset_builder import MultilingualDataset +from funnelling import * from util.common import MultilingualIndex, get_params, get_method_name from util.evaluation import evaluate from util.results_csv import CSVlog -from time import time +from view_generators import * def main(args): diff --git a/refactor/models/learners.py b/refactor/models/learners.py index 1c60072..2654109 100644 --- a/refactor/models/learners.py +++ b/refactor/models/learners.py @@ -1,10 +1,12 @@ -import numpy as np import time -from scipy.sparse import issparse -from sklearn.multiclass import OneVsRestClassifier -from sklearn.model_selection import GridSearchCV -from sklearn.svm import SVC + +import numpy as np from joblib import Parallel, delayed +from scipy.sparse import issparse +from sklearn.model_selection import GridSearchCV +from sklearn.multiclass import OneVsRestClassifier +from sklearn.svm import SVC + from util.standardizer import StandardizeTransformer diff --git a/refactor/models/lstm_class.py b/refactor/models/lstm_class.py index 98424f1..7f2cf59 100755 --- a/refactor/models/lstm_class.py +++ b/refactor/models/lstm_class.py @@ -1,7 +1,6 @@ #taken from https://github.com/prakashpandey9/Text-Classification-Pytorch/blob/master/models/LSTM.py -import torch -import torch.nn as nn from torch.autograd import Variable + from models.helpers import * diff --git a/refactor/models/pl_bert.py b/refactor/models/pl_bert.py index 48f5b9a..afb28b5 100644 --- a/refactor/models/pl_bert.py +++ b/refactor/models/pl_bert.py @@ -1,9 +1,10 @@ -import torch import pytorch_lightning as pl +import torch from torch.optim.lr_scheduler import StepLR from transformers import BertForSequenceClassification, AdamW -from util.pl_metrics import CustomF1, CustomK + from util.common import define_pad_length, pad +from util.pl_metrics import CustomF1, CustomK class BertModel(pl.LightningModule): diff --git a/refactor/models/pl_gru.py b/refactor/models/pl_gru.py index eaf7304..afb12e6 100644 --- a/refactor/models/pl_gru.py +++ b/refactor/models/pl_gru.py @@ -1,14 +1,15 @@ # Lightning modules, see https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html +import pytorch_lightning as pl import torch -from torch import nn import torch.nn.functional as F +from torch import nn from torch.autograd import Variable from torch.optim.lr_scheduler import StepLR from transformers import AdamW -import pytorch_lightning as pl + from models.helpers import init_embeddings -from util.pl_metrics import CustomF1, CustomK from util.common import define_pad_length, pad +from util.pl_metrics import CustomF1, CustomK class RecurrentModel(pl.LightningModule): diff --git a/refactor/util/common.py b/refactor/util/common.py index 0cd95e6..61ac52f 100644 --- a/refactor/util/common.py +++ b/refactor/util/common.py @@ -1,9 +1,9 @@ import numpy as np import torch -from tqdm import tqdm from sklearn.feature_extraction.text import TfidfVectorizer -from sklearn.preprocessing import normalize from sklearn.model_selection import train_test_split +from sklearn.preprocessing import normalize + from util.embeddings_manager import supervised_embeddings_tfidf diff --git a/refactor/util/embeddings_manager.py b/refactor/util/embeddings_manager.py index c0aca54..1d708fa 100644 --- a/refactor/util/embeddings_manager.py +++ b/refactor/util/embeddings_manager.py @@ -1,7 +1,9 @@ -from torchtext.vocab import Vectors -import torch from abc import ABC, abstractmethod + import numpy as np +import torch +from torchtext.vocab import Vectors + from util.SIF_embed import remove_pc diff --git a/refactor/util/evaluation.py b/refactor/util/evaluation.py index 03c1792..010d0e9 100644 --- a/refactor/util/evaluation.py +++ b/refactor/util/evaluation.py @@ -1,6 +1,7 @@ -from joblib import Parallel, delayed -from util.metrics import * import numpy as np +from joblib import Parallel, delayed + +from util.metrics import * def evaluation_metrics(y, y_): diff --git a/refactor/util/file.py b/refactor/util/file.py index 98c9910..8754f5a 100644 --- a/refactor/util/file.py +++ b/refactor/util/file.py @@ -1,6 +1,6 @@ +import urllib from os import listdir, makedirs from os.path import isdir, isfile, join, exists, dirname -import urllib from pathlib import Path diff --git a/refactor/util/pl_metrics.py b/refactor/util/pl_metrics.py index 9b44eb0..bf8aa99 100644 --- a/refactor/util/pl_metrics.py +++ b/refactor/util/pl_metrics.py @@ -1,5 +1,6 @@ import torch from pytorch_lightning.metrics import Metric + from util.common import is_false, is_true diff --git a/refactor/util/results_csv.py b/refactor/util/results_csv.py index df80c59..be0ff84 100644 --- a/refactor/util/results_csv.py +++ b/refactor/util/results_csv.py @@ -1,6 +1,7 @@ import os -import pandas as pd + import numpy as np +import pandas as pd class CSVlog: diff --git a/refactor/view_generators.py b/refactor/view_generators.py index e366d7d..6cdd4a9 100644 --- a/refactor/view_generators.py +++ b/refactor/view_generators.py @@ -16,16 +16,18 @@ This module contains the view generators that take care of computing the view sp - View generator (-b): generates document embedding via mBERT model. """ from abc import ABC, abstractmethod -from models.learners import * -from util.embeddings_manager import MuseLoader, XdotM, wce_matrix -from util.common import TfidfVectorizerMultilingual, _normalize -from models.pl_gru import RecurrentModel -from models.pl_bert import BertModel -from pytorch_lightning import Trainer -from data.datamodule import RecurrentDataModule, BertDataModule, tokenize -from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger from time import time +from pytorch_lightning import Trainer +from pytorch_lightning.loggers import TensorBoardLogger + +from data.datamodule import RecurrentDataModule, BertDataModule, tokenize +from models.learners import * +from models.pl_bert import BertModel +from models.pl_gru import RecurrentModel +from util.common import TfidfVectorizerMultilingual, _normalize +from util.embeddings_manager import MuseLoader, XdotM, wce_matrix + class ViewGen(ABC): @abstractmethod