QuAcc/merge_data.py

111 lines
3.5 KiB
Python

import argparse
import os
import shutil
from pathlib import Path
import numpy as np
import pandas as pd
from quacc.evaluation.estimators import CE
from quacc.evaluation.report import DatasetReport, DatasetReportInfo
def load_report_info(path: Path) -> DatasetReportInfo:
return DatasetReport.unpickle(path, report_info=True)
def list_reports(base_path: Path | str):
if isinstance(base_path, str):
base_path = Path(base_path)
if base_path.name == "plot":
return []
reports = []
for f in os.listdir(base_path):
fp = base_path / f
if fp.is_dir():
reports.extend(list_reports(fp))
elif fp.is_file():
if fp.suffix == ".pickle" and fp.stem == base_path.name:
reports.append(load_report_info(fp))
return reports
def playground():
data_a = np.array(np.random.random((4, 6)))
data_b = np.array(np.random.random((4, 4)))
_ind1 = pd.MultiIndex.from_product([["0.2", "0.8"], ["0", "1"]])
_col1 = pd.MultiIndex.from_product([["a", "b"], ["1", "2", "5"]])
_col2 = pd.MultiIndex.from_product([["a", "b"], ["1", "2"]])
a = pd.DataFrame(data_a, index=_ind1, columns=_col1)
b = pd.DataFrame(data_b, index=_ind1, columns=_col2)
print(a)
print(b)
print((a.index == b.index).all())
update_col = a.columns.intersection(b.columns)
col_to_join = b.columns.difference(update_col)
_b = b.drop(columns=[(slice(None), "2")])
_join = pd.concat([a, _b.loc[:, col_to_join]], axis=1)
_join.loc[:, update_col.to_list()] = _b.loc[:, update_col.to_list()]
_join.sort_index(axis=1, level=0, sort_remaining=False, inplace=True)
print(_join)
def merge(dri1: DatasetReportInfo, dri2: DatasetReportInfo, path: Path):
drm = dri1.dr.join(dri2.dr, estimators=CE.name.all)
# save merged dr
_path = path / drm.name / f"{drm.name}.pickle"
os.makedirs(_path.parent, exist_ok=True)
drm.pickle(_path)
# rename dri1 pickle
dri1_bp = Path(dri1.name) / f"{dri1.name.split('/')[-1]}.pickle"
os.rename(dri1_bp, dri1_bp.with_suffix(f".pickle.pre_{dri2.name.split('/')[-2]}"))
# copy merged pickle in place of old dri1 one
shutil.copyfile(_path, dri1_bp)
# copy dri2 log file inside dri1 folder
dri2_bp = Path(dri2.name) / f"{dri2.name.split('/')[-1]}.pickle"
shutil.copyfile(
dri2_bp.with_suffix(".log"),
dri1_bp.with_name(f"{dri1_bp.stem}_{dri2.name.split('/')[-2]}.log"),
)
def run():
parser = argparse.ArgumentParser()
parser.add_argument("path1", nargs="?", default=None)
parser.add_argument("path2", nargs="?", default=None)
parser.add_argument("-l", "--list", action="store_true", dest="list")
parser.add_argument("-v", "--verbose", action="store_true", dest="verbose")
parser.add_argument(
"-o", "--output", action="store", dest="output", default="output/merge"
)
args = parser.parse_args()
reports = list_reports("output")
reports = {r.name: r for r in reports}
if args.list:
for i, r in enumerate(reports.values()):
if args.verbose:
print(f"{i}: {r}")
else:
print(f"{i}: {r.name}")
else:
dri1, dri2 = reports.get(args.path1, None), reports.get(args.path2, None)
if dri1 is None or dri2 is None:
raise ValueError(
f"({args.path1}, {args.path2}) is not a valid pair of paths"
)
merge(dri1, dri2, path=Path(args.output))
if __name__ == "__main__":
run()