diff --git a/merge_data.py b/merge_data.py new file mode 100644 index 0000000..46d6157 --- /dev/null +++ b/merge_data.py @@ -0,0 +1,110 @@ +import argparse +import os +import shutil +from pathlib import Path + +import numpy as np +import pandas as pd + +from quacc.evaluation.comp import CE +from quacc.evaluation.report import DatasetReport, DatasetReportInfo + + +def load_report_info(path: Path) -> DatasetReportInfo: + return DatasetReport.unpickle(path, report_info=True) + + +def list_reports(base_path: Path | str): + if isinstance(base_path, str): + base_path = Path(base_path) + + if base_path.name == "plot": + return [] + + reports = [] + for f in os.listdir(base_path): + fp = base_path / f + if fp.is_dir(): + reports.extend(list_reports(fp)) + elif fp.is_file(): + if fp.suffix == ".pickle" and fp.stem == base_path.name: + reports.append(load_report_info(fp)) + + return reports + + +def playground(): + data_a = np.array(np.random.random((4, 6))) + data_b = np.array(np.random.random((4, 4))) + _ind1 = pd.MultiIndex.from_product([["0.2", "0.8"], ["0", "1"]]) + _col1 = pd.MultiIndex.from_product([["a", "b"], ["1", "2", "5"]]) + _col2 = pd.MultiIndex.from_product([["a", "b"], ["1", "2"]]) + a = pd.DataFrame(data_a, index=_ind1, columns=_col1) + b = pd.DataFrame(data_b, index=_ind1, columns=_col2) + print(a) + print(b) + print((a.index == b.index).all()) + update_col = a.columns.intersection(b.columns) + col_to_join = b.columns.difference(update_col) + _b = b.drop(columns=[(slice(None), "2")]) + _join = pd.concat([a, _b.loc[:, col_to_join]], axis=1) + _join.loc[:, update_col.to_list()] = _b.loc[:, update_col.to_list()] + _join.sort_index(axis=1, level=0, sort_remaining=False, inplace=True) + + print(_join) + + +def merge(dri1: DatasetReportInfo, dri2: DatasetReportInfo, path: Path): + drm = dri1.dr.join(dri2.dr, estimators=CE.name.all) + + # save merged dr + _path = path / drm.name / f"{drm.name}.pickle" + os.makedirs(_path.parent, exist_ok=True) + drm.pickle(_path) + + # rename dri1 pickle + dri1_bp = Path(dri1.name) / f"{dri1.name.split('/')[-1]}.pickle" + os.rename(dri1_bp, dri1_bp.with_suffix(f".pickle.pre_{dri2.name.split('/')[-2]}")) + + # copy merged pickle in place of old dri1 one + shutil.copyfile(_path, dri1_bp) + + # copy dri2 log file inside dri1 folder + dri2_bp = Path(dri2.name) / f"{dri2.name.split('/')[-1]}.pickle" + shutil.copyfile( + dri2_bp.with_suffix(".log"), + dri1_bp.with_name(f"{dri1_bp.stem}_{dri2.name.split('/')[-2]}.log"), + ) + + +def run(): + parser = argparse.ArgumentParser() + parser.add_argument("path1", nargs="?", default=None) + parser.add_argument("path2", nargs="?", default=None) + parser.add_argument("-l", "--list", action="store_true", dest="list") + parser.add_argument("-v", "--verbose", action="store_true", dest="verbose") + parser.add_argument( + "-o", "--output", action="store", dest="output", default="output/merge" + ) + args = parser.parse_args() + + reports = list_reports("output") + reports = {r.name: r for r in reports} + + if args.list: + for i, r in enumerate(reports.values()): + if args.verbose: + print(f"{i}: {r}") + else: + print(f"{i}: {r.name}") + else: + dri1, dri2 = reports.get(args.path1, None), reports.get(args.path2, None) + if dri1 is None or dri2 is None: + raise ValueError( + f"({args.path1}, {args.path2}) is not a valid pair of paths" + ) + merge(dri1, dri2, path=Path(args.output)) + + +if __name__ == "__main__": + run()