optimized imports

This commit is contained in:
andrea 2021-01-26 13:12:37 +01:00
parent 5958df3e3c
commit 2a8075bbc2
20 changed files with 91 additions and 74 deletions

View File

@ -1,7 +1,7 @@
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer from transformers import BertTokenizer
N_WORKERS = 8 N_WORKERS = 8

View File

@ -1,19 +1,21 @@
from os.path import join, exists
from nltk.corpus import stopwords
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.preprocessing import MultiLabelBinarizer
from data.reader.jrcacquis_reader import *
from data.languages import lang_set, NLTK_LANGMAP, RCV2_LANGS_WITH_NLTK_STEMMING
from data.reader.rcv_reader import fetch_RCV1, fetch_RCV2, fetch_topic_hierarchy
from data.text_preprocessor import NLTKStemTokenizer, preprocess_documents
import pickle
import numpy as np
from sklearn.model_selection import train_test_split
from scipy.sparse import issparse
import itertools import itertools
from tqdm import tqdm import pickle
import re import re
from os.path import exists
import numpy as np
from nltk.corpus import stopwords
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
from scipy.sparse import issparse
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from tqdm import tqdm
from data.languages import NLTK_LANGMAP, RCV2_LANGS_WITH_NLTK_STEMMING
from data.reader.jrcacquis_reader import *
from data.reader.rcv_reader import fetch_RCV1, fetch_RCV2
from data.text_preprocessor import NLTKStemTokenizer, preprocess_documents
class MultilingualDataset: class MultilingualDataset:

View File

@ -1,19 +1,22 @@
from __future__ import print_function from __future__ import print_function
import os, sys
from os.path import join import os
import pickle
import sys
import tarfile import tarfile
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from sklearn.datasets import get_data_home import zipfile
import pickle from collections import Counter
from util.file import download_file, list_dirs, list_files from os.path import join
from random import shuffle
import rdflib import rdflib
from rdflib.namespace import RDF, SKOS from rdflib.namespace import RDF, SKOS
from rdflib import URIRef from sklearn.datasets import get_data_home
import zipfile
from data.languages import JRC_LANGS from data.languages import JRC_LANGS
from collections import Counter
from random import shuffle
from data.languages import lang_set from data.languages import lang_set
from util.file import download_file, list_dirs, list_files
""" """
JRC Acquis' Nomenclature: JRC Acquis' Nomenclature:

View File

@ -1,15 +1,12 @@
from zipfile import ZipFile
import xml.etree.ElementTree as ET
from data.languages import RCV2_LANGS_WITH_NLTK_STEMMING, RCV2_LANGS
from util.file import list_files
from sklearn.datasets import get_data_home
import gzip
from os.path import join, exists
from util.file import download_file_if_not_exists
import re import re
from collections import Counter import xml.etree.ElementTree as ET
from os.path import join, exists
from zipfile import ZipFile
import numpy as np import numpy as np
import sys
from util.file import download_file_if_not_exists
from util.file import list_files
""" """
RCV2's Nomenclature: RCV2's Nomenclature:

View File

@ -1,16 +1,19 @@
from __future__ import print_function from __future__ import print_function
# import ijson # import ijson
# from ijson.common import ObjectBuilder # from ijson.common import ObjectBuilder
import os, sys import os
from os.path import join
from bz2 import BZ2File
import pickle import pickle
from util.file import list_dirs, list_files, makedirs_if_not_exist
from itertools import islice
import re import re
from bz2 import BZ2File
from itertools import islice
from os.path import join
from xml.sax.saxutils import escape from xml.sax.saxutils import escape
import numpy as np import numpy as np
from util.file import list_dirs, list_files
policies = ["IN_ALL_LANGS", "IN_ANY_LANG"] policies = ["IN_ALL_LANGS", "IN_ANY_LANG"]
""" """

View File

@ -1,8 +1,9 @@
from nltk.corpus import stopwords
from data.languages import NLTK_LANGMAP
from nltk import word_tokenize from nltk import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import SnowballStemmer from nltk.stem import SnowballStemmer
from data.languages import NLTK_LANGMAP
def preprocess_documents(documents, lang): def preprocess_documents(documents, lang):
tokens = NLTKStemTokenizer(lang, verbose=True) tokens = NLTKStemTokenizer(lang, verbose=True)

View File

@ -1,8 +1,9 @@
import math import math
import numpy as np import numpy as np
from scipy.stats import t
from joblib import Parallel, delayed from joblib import Parallel, delayed
from scipy.sparse import csr_matrix, csc_matrix from scipy.sparse import csr_matrix, csc_matrix
from scipy.stats import t
def get_probs(tpr, fpr, pc): def get_probs(tpr, fpr, pc):

View File

@ -1,6 +1,6 @@
from models.learners import * from models.learners import *
from view_generators import VanillaFunGen
from util.common import _normalize from util.common import _normalize
from view_generators import VanillaFunGen
class DocEmbedderList: class DocEmbedderList:

View File

@ -1,11 +1,11 @@
from argparse import ArgumentParser from argparse import ArgumentParser
from funnelling import *
from view_generators import *
from data.dataset_builder import MultilingualDataset from data.dataset_builder import MultilingualDataset
from funnelling import *
from util.common import MultilingualIndex, get_params, get_method_name from util.common import MultilingualIndex, get_params, get_method_name
from util.evaluation import evaluate from util.evaluation import evaluate
from util.results_csv import CSVlog from util.results_csv import CSVlog
from time import time from view_generators import *
def main(args): def main(args):

View File

@ -1,10 +1,12 @@
import numpy as np
import time import time
from scipy.sparse import issparse
from sklearn.multiclass import OneVsRestClassifier import numpy as np
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from joblib import Parallel, delayed from joblib import Parallel, delayed
from scipy.sparse import issparse
from sklearn.model_selection import GridSearchCV
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC
from util.standardizer import StandardizeTransformer from util.standardizer import StandardizeTransformer

View File

@ -1,7 +1,6 @@
#taken from https://github.com/prakashpandey9/Text-Classification-Pytorch/blob/master/models/LSTM.py #taken from https://github.com/prakashpandey9/Text-Classification-Pytorch/blob/master/models/LSTM.py
import torch
import torch.nn as nn
from torch.autograd import Variable from torch.autograd import Variable
from models.helpers import * from models.helpers import *

View File

@ -1,9 +1,10 @@
import torch
import pytorch_lightning as pl import pytorch_lightning as pl
import torch
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
from transformers import BertForSequenceClassification, AdamW from transformers import BertForSequenceClassification, AdamW
from util.pl_metrics import CustomF1, CustomK
from util.common import define_pad_length, pad from util.common import define_pad_length, pad
from util.pl_metrics import CustomF1, CustomK
class BertModel(pl.LightningModule): class BertModel(pl.LightningModule):

View File

@ -1,14 +1,15 @@
# Lightning modules, see https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html # Lightning modules, see https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html
import pytorch_lightning as pl
import torch import torch
from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
from transformers import AdamW from transformers import AdamW
import pytorch_lightning as pl
from models.helpers import init_embeddings from models.helpers import init_embeddings
from util.pl_metrics import CustomF1, CustomK
from util.common import define_pad_length, pad from util.common import define_pad_length, pad
from util.pl_metrics import CustomF1, CustomK
class RecurrentModel(pl.LightningModule): class RecurrentModel(pl.LightningModule):

View File

@ -1,9 +1,9 @@
import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm
from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import normalize
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize
from util.embeddings_manager import supervised_embeddings_tfidf from util.embeddings_manager import supervised_embeddings_tfidf

View File

@ -1,7 +1,9 @@
from torchtext.vocab import Vectors
import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import numpy as np import numpy as np
import torch
from torchtext.vocab import Vectors
from util.SIF_embed import remove_pc from util.SIF_embed import remove_pc

View File

@ -1,6 +1,7 @@
from joblib import Parallel, delayed
from util.metrics import *
import numpy as np import numpy as np
from joblib import Parallel, delayed
from util.metrics import *
def evaluation_metrics(y, y_): def evaluation_metrics(y, y_):

View File

@ -1,6 +1,6 @@
import urllib
from os import listdir, makedirs from os import listdir, makedirs
from os.path import isdir, isfile, join, exists, dirname from os.path import isdir, isfile, join, exists, dirname
import urllib
from pathlib import Path from pathlib import Path

View File

@ -1,5 +1,6 @@
import torch import torch
from pytorch_lightning.metrics import Metric from pytorch_lightning.metrics import Metric
from util.common import is_false, is_true from util.common import is_false, is_true

View File

@ -1,6 +1,7 @@
import os import os
import pandas as pd
import numpy as np import numpy as np
import pandas as pd
class CSVlog: class CSVlog:

View File

@ -16,16 +16,18 @@ This module contains the view generators that take care of computing the view sp
- View generator (-b): generates document embedding via mBERT model. - View generator (-b): generates document embedding via mBERT model.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from models.learners import *
from util.embeddings_manager import MuseLoader, XdotM, wce_matrix
from util.common import TfidfVectorizerMultilingual, _normalize
from models.pl_gru import RecurrentModel
from models.pl_bert import BertModel
from pytorch_lightning import Trainer
from data.datamodule import RecurrentDataModule, BertDataModule, tokenize
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from time import time from time import time
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from data.datamodule import RecurrentDataModule, BertDataModule, tokenize
from models.learners import *
from models.pl_bert import BertModel
from models.pl_gru import RecurrentModel
from util.common import TfidfVectorizerMultilingual, _normalize
from util.embeddings_manager import MuseLoader, XdotM, wce_matrix
class ViewGen(ABC): class ViewGen(ABC):
@abstractmethod @abstractmethod