import numpy as np
import torch
import pickle
import os
from nnfabrik.utility.nn_helpers import set_random_seed
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, SubsetSequentialSampler

def synthetic_loader(
    image_path,
    response_path,
    seed: int = 42,
    splits: tuple = (0.7, 0.15, 0.15),
    dtype=torch.float32,
    device=torch.device("cpu"),
    batch_size: int = 64,
    shuffle: bool = True,
    cuda: bool = True
):
    """
    Creates a DataLoader for synthetic data (images and responses).

    Args:
        stimulus (np.ndarray): shape (frames, height, width)
        responses (np.ndarray): shape (frames, neurons)
        batch_size (int): batch size for the loader
        shuffle (bool): whether to shuffle data
        cuda (bool): whether to move tensors to GPU

    Returns:
        DataLoader: yields batches of (images, responses)
    """
    set_random_seed(seed)
    with open(image_path, "rb") as f:
        images_dict = pickle.load(f)
    with open(response_path, "rb") as f:
        responses_dict = pickle.load(f)

    # Convert dict -> list sorted by index
    images = [images_dict[i] for i in sorted(images_dict.keys())]
    responses = [responses_dict[i] for i in sorted(responses_dict.keys())]

    # Convert to arrays
    images = np.stack(images)
    responses = np.stack(responses)

    # To tensors
    images = torch.tensor(images, dtype=dtype, device=device)
    responses = torch.tensor(responses, dtype=dtype, device=device)
    N = len(images_dict.keys())
    indices = np.arange(N)
    np.random.shuffle(indices)

    n_train = int(splits[0] * N)
    n_val = int(splits[1] * N)
    n_test = N - n_train - n_val

    idx_train = indices[:n_train]
    idx_val = indices[n_train:n_train+n_val]
    idx_test = indices[n_train+n_val:]

    # --- Dataloaders ---
    dataloaders = {}

    dataloaders["train"] = DataLoader(
        dataset, sampler=SubsetRandomSampler(idx_train), batch_size=batch_size
    )
    dataloaders["validation"] = DataLoader(
        dataset, sampler=SubsetSequentialSampler(idx_val), batch_size=batch_size
    )
    dataloaders["test"] = DataLoader(
        dataset, sampler=SubsetSequentialSampler(idx_test), batch_size=batch_size
    )

    return dataloaders

class ImageResponseDataset(Dataset):
    """Dataset holding paired images and responses."""

    def __init__(self, images, responses):
        assert len(images) == len(responses), "Images and responses must match in length"
        self.images = images
        self.responses = responses

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.responses[idx]


def load_pickle_to_tensors(image_pkl, response_pkl, dtype=torch.float32):
    """Load pickle dicts into tensors."""
    with open(image_pkl, "rb") as f:
        images_dict = pickle.load(f)
    with open(response_pkl, "rb") as f:
        responses_dict = pickle.load(f)

    images = [images_dict[i] for i in sorted(images_dict.keys())]
    responses = [responses_dict[i] for i in sorted(responses_dict.keys())]

    images = torch.tensor(np.stack(images), dtype=dtype)
    responses = torch.tensor(np.stack(responses), dtype=dtype)

    return images, responses


def create_dataloaders(image_pkl, response_pkl, batch_size=64, splits=(0.7, 0.15, 0.15), seed=42):
    """
    Create train/validation/test dataloaders from pickle files.

    Args:
        image_pkl (str): path to image pickle file.
        response_pkl (str): path to response pickle file.
        batch_size (int): batch size.
        splits (tuple): fractions for (train, validation, test).
        seed (int): random seed.

    Returns:
        dict: dataloaders with keys ['train', 'validation', 'test', 'final_test']
    """
    images, responses = load_pickle_to_tensors(image_pkl, response_pkl)
    dataset = ImageResponseDataset(images, responses)

    # --- Split indices ---
    N = len(dataset)
    indices = np.arange(N)
    np.random.seed(seed)
    np.random.shuffle(indices)

    n_train = int(splits[0] * N)
    n_val = int(splits[1] * N)
    n_test = N - n_train - n_val

    idx_train = indices[:n_train]
    idx_val = indices[n_train:n_train+n_val]
    idx_test = indices[n_train+n_val:]

    # --- Dataloaders ---
    dataloaders = {}

    dataloaders["train"] = DataLoader(
        dataset, sampler=SubsetRandomSampler(idx_train), batch_size=batch_size
    )
    dataloaders["validation"] = DataLoader(
        dataset, sampler=SubsetSequentialSampler(idx_val), batch_size=batch_size
    )
    dataloaders["test"] = DataLoader(
        dataset, sampler=SubsetSequentialSampler(idx_test), batch_size=batch_size
    )
    dataloaders["final_test"] = dataloaders["test"]  # alias

    return dataloaders