from os.path import expanduser, join
from dataManager.gFunDataset import gFunDataset
from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.amazonDataset import AmazonDataset


def load_from_pickle(path, dataset_name, nrows):
    import pickle

    filepath = join(path, f"{dataset_name}_{nrows}.pkl")

    with open(filepath, "rb") as f:
        loaded = pickle.load(f)
    print(f"- Loaded dataset from {filepath}")
    loaded.show_dimension()
    return loaded


def get_dataset(dataset_name, args):
    assert dataset_name in [
        "multinews",
        "amazon",
        "rcv1-2",
        "glami",
        "cls",
        "webis",
    ], "dataset not supported"

    RCV_DATAPATH = expanduser(
        "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
    )
    JRC_DATAPATH = expanduser(
        "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
    )
    CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")

    MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")

    GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")

    WEBIS_CLS = expanduser(
        "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
    )

    if dataset_name == "multinews":
        # TODO: convert to gFunDataset
        raise NotImplementedError
        dataset = MultiNewsDataset(
            expanduser(MULTINEWS_DATAPATH),
            excluded_langs=["ar", "pe", "pl", "tr", "ua"],
        )
    elif dataset_name == "amazon":
        # TODO: convert to gFunDataset
        raise NotImplementedError
        dataset = AmazonDataset(
            domains=args.domains,
            nrows=args.nrows,
            min_count=args.min_count,
            max_labels=args.max_labels,
        )
    elif dataset_name == "jrc":
        dataset = gFunDataset(
            dataset_dir=JRC_DATAPATH,
            is_textual=True,
            is_visual=False,
            is_multilabel=True,
            nrows=args.nrows,
        )
    elif dataset_name == "rcv1-2":
        dataset = gFunDataset(
            dataset_dir=RCV_DATAPATH,
            is_textual=True,
            is_visual=False,
            is_multilabel=True,
            nrows=args.nrows,
        )
    elif dataset_name == "glami":
        if args.save_dataset is False:
            dataset = load_from_pickle(GLAMI_DATAPATH, dataset_name, args.nrows)
        else:
            dataset = gFunDataset(
                dataset_dir=GLAMI_DATAPATH,
                is_textual=True,
                is_visual=True,
                is_multilabel=False,
                nrows=args.nrows,
            )

            dataset.save_as_pickle(GLAMI_DATAPATH)

    elif dataset_name == "cls":
        dataset = gFunDataset(
            dataset_dir=CLS_DATAPATH,
            is_textual=True,
            is_visual=False,
            is_multilabel=False,
            nrows=args.nrows,
        )

    elif dataset_name == "webis":
        dataset = gFunDataset(
            dataset_dir=WEBIS_CLS,
            is_textual=True,
            is_visual=False,
            is_multilabel=False,
            nrows=args.nrows,
        )
    else:
        raise NotImplementedError
    return dataset