From d949c773172956c3ef696dbc20309033b78ece4f Mon Sep 17 00:00:00 2001
From: Alejandro Moreo <alejandro.moreo@isti.cnr.it>
Date: Tue, 15 Mar 2022 14:01:40 +0100
Subject: [PATCH] generating BERT outputs for textual documents

---
 Ordinal/utils.py | 22 ++++++++++++++++++++++
 1 file changed, 22 insertions(+)

diff --git a/Ordinal/utils.py b/Ordinal/utils.py
index 88a278e..ac22671 100644
--- a/Ordinal/utils.py
+++ b/Ordinal/utils.py
@@ -12,6 +12,28 @@ def load_samples(path_dir, classes):
         yield LabelledCollection.load(join(path_dir, f'{id}.txt'), loader_func=qp.data.reader.from_text, classes=classes)
 
 
+def load_samples_as_csv(path_dir, debug=False):
+    import pandas as pd
+    import csv
+    import datasets
+    from datasets import Dataset
+
+    nsamples = len(glob(join(path_dir, f'*.txt')))
+    for id in range(nsamples):
+        df = pd.read_csv(join(path_dir, f'{id}.txt'), sep='\t', names=['labels', 'review'], quoting=csv.QUOTE_NONE)
+        labels = df.pop('labels').to_frame()
+        X = df
+
+        features = datasets.Features({'review': datasets.Value('string')})
+        if debug:
+            sample = Dataset.from_pandas(df=X, features=features).select(range(50))
+            labels = labels[:50]
+        else:
+            sample = Dataset.from_pandas(df=X, features=features)
+
+        yield sample, labels
+
+
 def load_samples_pkl(path_dir, filter=None):
     nsamples = len(glob(join(path_dir, f'*.pkl')))
     for id in range(nsamples):