from functools import partial

import numpy as np
import torch
from neuralpredictors.measures import modules
from neuralpredictors.training import (LongCycler, MultipleObjectiveTracker,
                                       early_stopping)
from nnfabrik.utility.nn_helpers import set_random_seed
from sklearn.cluster import KMeans
from torch.nn import KLDivLoss
from tqdm import tqdm

import wandb

from ..utility import scores
from ..utility.scores import get_correlations, get_poisson_loss


def standard_trainer(
    model,
    dataloaders,
    seed,
    avg_loss=False,
    scale_loss=True,
    loss_function="PoissonLoss",
    stop_function="get_correlations",
    loss_accum_batch_n=None,
    device="cuda",
    verbose=True,
    interval=1,
    patience=5,
    epoch=0,
    lr_init=0.005,
    max_iter=200,
    maximize=True,
    tolerance=1e-6,
    restore_best=True,
    lr_decay_steps=5,
    lr_decay_factor=0.3,
    min_lr=0.0001,
    cb=None,
    track_training=False,
    detach_core=False,
    deeplake_ds=False,
    include_kldivergence=True,
    cluster_number=10,
    alpha=1.0,
    dec_starting_epoch=1,
    kmeans_init=20,
    base_multiplier=4e3,
    use_diag_cov=True,
    learn_alpha=False,
    exponent=2,
    load_pretrain=False,
    include_mixingcoefficients=False,
    **kwargs,
):
    """

    Args:
        model: model to be trained
        dataloaders: dataloaders containing the data to train the model with
        seed: random seed
        avg_loss: whether to average (or sum) the loss over a batch
        scale_loss: whether to scale the loss according to the size of the dataset
        loss_function: loss function to use
        stop_function: the function (metric) that is used to determine the end of the training in early stopping
        loss_accum_batch_n: number of batches to accumulate the loss over
        device: device to run the training on
        verbose: whether to print out a message for each optimizer step
        interval: interval at which objective is evaluated to consider early stopping
        patience: number of times the objective is allowed to not become better before the iterator terminates
        epoch: starting epoch
        lr_init: initial learning rate
        max_iter: maximum number of training iterations
        maximize: whether to maximize or minimize the objective function
        tolerance: tolerance for early stopping
        restore_best: whether to restore the model to the best state after early stopping
        lr_decay_steps: how many times to decay the learning rate after no improvement
        lr_decay_factor: factor to decay the learning rate with
        min_lr: minimum learning rate
        cb: whether to execute callback function
        track_training: whether to track and print out the training progress

        cluster_number: Give number of clusters for DEC clustering algortihm
        alpha: alpha used for calculation of soft assignment
        dec_starting_epoch: Epoch at which we start the initialisation for the cluster centroids for dec clustering
        base_multiplier: multiplier to get KL to same order of magnitude as Poisson loss
        kmeans_init: number of iterations for kmeans for cluster initialisation
        exponent: The exponent for the target distribution for DEC
        use_diag_conv: Bool that indicates wether to use a diagonal covariance matrix or just one value for each cluster in EM step
        learn_alpha: learn alpha or set it as a parameter
        **kwargs:

    Returns:

    """

    def get_multiplier(epoch, base_multiplier=4e3):
        """Multiplier to scale KL loss in same order of magnitude as main loss
        To avoid hard peek aat starting epoch we include a warm-up phase s.t. the loss can increase slower
        """
        if epoch < dec_starting_epoch:
            return 0
        else:
            return base_multiplier

    def target_distribution(batch: torch.Tensor, exponent=exponent) -> torch.Tensor:
        """
        Compute the target distribution p_ij, given the batch (q_ij), as in 3.1.3 Equation 3 of
        Xie/Girshick/Farhadi; this is used the KL-divergence loss function.
        p_ij = (q_ij^2/f_j) / sum_j'(q_ij'^2/f_j')  f_j =sum_i(q_ij)

        :param batch: [batch size, number of clusters] Tensor of dtype float
        :return: [batch size, number of clusters] Tensor of dtype float
        """
        weight = (batch**exponent) / torch.sum(batch, 0)
        return (weight.t() / torch.sum(weight, 1)).t()

    def soft_assignments_mult(encoded_features, cluster_centers, sigma, alpha, p=1, mixing_coefficients=None):
        """Calculates the q_ij as the t mixture components. Moves to log space to avoid numerical issues."""
        sigma_inv = 1.0 / sigma  # (K, D)
        diff = encoded_features.T.unsqueeze(1) - cluster_centers.unsqueeze(0)  # (N, K, D)
        norm_sigma = torch.sum(diff * sigma_inv * diff, dim=2)  # (N, K)
        det = torch.sum(torch.log(sigma), dim=1)  # log(det) since sigma is diagonal
        log_gamma_top = torch.lgamma((alpha + p) / 2)
        log_gamma_bottom = torch.lgamma(alpha / 2)
        # Log-density formula for multivariate Student-t
        if mixing_coefficients is None:
            log_pdf = (
                log_gamma_top
                - log_gamma_bottom
                - 0.5 * det
                - (p / 2) * torch.log(alpha * torch.pi)
                - ((alpha + p) / 2) * torch.log(1 + (norm_sigma / alpha))
            ) 
        else:
            log_pdf = (
                log_gamma_top
                - log_gamma_bottom
                - 0.5 * det
                - (p / 2) * torch.log(alpha * torch.pi)
                - ((alpha + p) / 2) * torch.log(1 + (norm_sigma / alpha))
            ) + torch.log(mixing_coefficients.squeeze())  
        log_assignments = log_pdf - torch.logsumexp(log_pdf, dim=1, keepdim=True)
        return torch.exp(log_assignments)  # Convert log-assignments to probabilities

    def EM_t_mult(features, resp, cluster_centers, sigma, alpha, d=1):
        "Does EM updates of centers, shape matrix and mixing coefficients for multivariate t-distribution"
        sigma_inv = 1.0 / sigma  # (K,)
        diff = features.T.unsqueeze(1) - cluster_centers.unsqueeze(0)
        norm_sigma = torch.sum((diff**2 * sigma_inv), 2)
        u = ((alpha + d) / (alpha + norm_sigma)).detach()  # ccalculate U shape(N,K)
        mixing_coefficients = 1/features.shape[1] * torch.sum(resp, dim=0, keepdim=True).T.detach() # (K,)
       
        """ M step """
        numerator = torch.matmul(features, resp * u).T.detach()
        denominator = torch.sum(resp * u, dim=0, keepdim=True).T.detach()
        cluster_centers = numerator / denominator
        weighted_sq_diff = resp.unsqueeze(2) * u.unsqueeze(2) * (diff**2)  # (N, K, D)
        numerator = weighted_sq_diff.sum(dim=0)  # (K,D)
        denominator = torch.sum(resp, dim=0, keepdim=True)  # (K,)
        sigma = (numerator / denominator.T).detach()
        sigma = torch.clamp(sigma, min=1e-4, max=1e4)

        return cluster_centers, sigma, mixing_coefficients

    def EM_t_1D(features, resp, cluster_centers, taus, alpha, d=1):
        norm_squared = torch.sum(
            (features.T.unsqueeze(1) - cluster_centers.unsqueeze(0)) ** 2, dim=2
        )
        u = (alpha + d) / (
            alpha + norm_squared * (taus ** (-1))
        ) 

        """ M step """
        numerator = torch.matmul(features, resp * u).T.detach()
        denominator = torch.sum(resp * u, dim=0, keepdim=True).T.detach()
        cluster_centers = numerator / denominator

        weighted_sums = torch.sum(resp * u * norm_squared, dim=0)
        taus = (weighted_sums / torch.sum(resp, dim=0, keepdim=True)).detach()
        return cluster_centers, taus

    def soft_assignments_1D(encoded_features, cluster_centers, tau, alpha=1):
        norm_squared = torch.sum(
            (encoded_features.T.unsqueeze(1) - cluster_centers.unsqueeze(0)) ** 2, 2
        )
        assignments = 1.0 / (1.0 + (norm_squared / (alpha * tau)))
        assignments = (assignments ** ((alpha + 1) / 2)) / (tau**1 / 2)
        print("Assignments ", assignments)
        return assignments / torch.sum(assignments, dim=1, keepdim=True)

    def full_objective(model, dataloader, data_key, *args, **kwargs):
        loss_scale = (
            np.sqrt(len(dataloader[data_key].dataset) / args[0].shape[0])
            if scale_loss
            else 1.0
        )
        regularizers = int(
            not detach_core
        ) * model.core.regularizer() + model.readout.regularizer(data_key)

        tot_main_loss = loss_scale * criterion(
            model(args[0].to(device), data_key=data_key, **kwargs),
            args[1].to(device),
        )
        return (tot_main_loss + regularizers), (tot_main_loss, regularizers)

    ##### Model training ####################################################################################################
    
    model.to('cpu')
    set_random_seed(seed)
    if load_pretrain:  # Load weights if given
        pretrained_path='give path'
        k = list(model.readout.keys())[0]
        dim = model.readout[k].features.shape[1]
        print(dim)
        model.load_state_dict(torch.load(pretrained_path,map_location=torch.device('cpu')))
        print("Loaded pretrained model!")
    
    model.to(device)
    model.train()
    # losses are summed for each batch
    kldiv_criterion = KLDivLoss(
        size_average=False
    )  # losses are summed for each minibatch
    criterion = getattr(modules, loss_function)(avg=avg_loss)
    stop_closure = partial(
        getattr(scores, stop_function),
        dataloaders=dataloaders["validation"],
        device=device,
        per_neuron=False,
        avg=True,
    )

    n_iterations = len(LongCycler(dataloaders["train"]))

    if learn_alpha:
        raw_alpha = torch.nn.Parameter(
            torch.tensor(1.0, device=device, requires_grad=True)
        )
        optimizer = torch.optim.Adam(list(model.parameters()) + [raw_alpha], lr=lr_init)
    else:
        alpha = torch.tensor(alpha, device=device, requires_grad=False)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr_init)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="max" if maximize else "min",
        factor=lr_decay_factor,
        patience=patience,
        threshold=tolerance,
        min_lr=min_lr,
        verbose=verbose,
        threshold_mode="abs",
    )

    # set the number of iterations over which you would like to accummulate gradients
    optim_step_count = (
        len(dataloaders["train"].keys())
        if loss_accum_batch_n is None
        else loss_accum_batch_n
    )

    if track_training:
        tracker_dict = dict(
            correlation=partial(
                get_correlations,
                model,
                dataloaders["validation"],
                device=device,
                per_neuron=False,
            ),
            poisson_loss=partial(
                get_poisson_loss,
                model,
                dataloaders["validation"],
                device=device,
                per_neuron=False,
                avg=False,
            ),
        )
        if hasattr(model, "tracked_values"):
            tracker_dict.update(model.tracked_values)
        tracker = MultipleObjectiveTracker(**tracker_dict)
    else:
        tracker = None
    # train over epochs
    batch_no_total = 0
    kldiv_list = []
    for epoch, val_obj in early_stopping(
        model,
        stop_closure,
        interval=interval,
        patience=patience,
        start=epoch,
        max_iter=max_iter,
        maximize=maximize,
        tolerance=tolerance,
        restore_best=restore_best,
        tracker=tracker,
        scheduler=scheduler,
        lr_decay_steps=lr_decay_steps,
    ):

        if include_kldivergence and epoch == dec_starting_epoch:
            # TODO: include hidden dimension
            cluster_centers_list = []
            kmeans = KMeans(
                n_clusters=cluster_number, n_init=kmeans_init, random_state=seed
            )
            feature_list = []
            # form initial cluster centres
            with torch.no_grad():
                for i,(k,readout) in enumerate(model.readout.items()):
                    features = readout.features.cpu().detach().squeeze().T.numpy()
                    feature_list.append(np.array(features))

                features = np.vstack(feature_list)
                predicted = kmeans.fit_predict(features)
            cluster_centers = torch.tensor(
                kmeans.cluster_centers_, dtype=torch.float, device=device
            )
            if use_diag_cov:
                p = features.shape[1]
                sigma = torch.zeros((cluster_number, p), device=device)
                for k in range(cluster_number):
                    cluster_points = torch.from_numpy(features[predicted == k]).to(
                        device
                    )
                    if len(cluster_points) > 1:

                        sigma[k] = (
                            torch.var(cluster_points, dim=0, unbiased=True) + 1e-6
                        )
                    else:
                        sigma[k] = torch.full_like(cluster_points[0], 1e-6)
                mixing_coefficients = torch.ones(cluster_number, device=device, requires_grad=False) / cluster_number
            else:
                sigma = torch.zeros(cluster_number, device=device)
                for k in range(cluster_number):
                    cluster_points = torch.from_numpy(features[predicted == k]).to(
                        device
                    )
                    if len(cluster_points) > 0:
                        sigma[k] = torch.mean(
                            torch.sum((cluster_points - cluster_centers[k]) ** 2, 1)
                        )
                sigma = sigma.unsqueeze(0)

            if learn_alpha:
                raw_alpha.data = torch.tensor(1.0, dtype=torch.float, device=device)

        model.train()
        # print the quantities from tracker
        if verbose and tracker is not None:
            print("=======================================")
            for key in tracker.log.keys():
                print(key, tracker.log[key][-1], flush=True)

        # executes callback function if passed in keyword args
        if cb is not None:
            cb()

        # train over batches
        optimizer.zero_grad()
        epoch_loss = 0
        epoch_loss_main = 0
        epoch_loss_reg = 0
        epoch_loss_kldiv = 0
        epoch_loss_kldiv_without_scaling = 0
    
        for batch_no, (data_key, data) in tqdm(
            enumerate(LongCycler(dataloaders["train"])),
            total=n_iterations,
            desc="Epoch {}".format(epoch),
        ):
            batch_no_total += 1
            batch_args = list(data)
            batch_kwargs = data._asdict() if not isinstance(data, dict) else data
            loss, loss_parts = full_objective(
                model,
                dataloaders["train"],
                data_key,
                *batch_args,
                **batch_kwargs,
                detach_core=detach_core,
            )

            loss.backward()
            epoch_loss += loss.detach()
            epoch_loss_main += loss_parts[0].detach()
            epoch_loss_reg += loss_parts[1].detach()
            if (batch_no + 1) % optim_step_count == 0:
                # TODO maybe remove the hidden dimensions
                if include_kldivergence and epoch >= dec_starting_epoch:
                    kldiv_loss = torch.zeros(1).to(device)
                    feature_list = []
                    for i, (k, readout) in enumerate(model.readout.items()):
                        features = readout.features.squeeze()
                        feature_list.append(features)

                    # features_subset = torch.cat(features_subset, dim=1)
                    feature_list = torch.cat(feature_list, dim=1)
                    if learn_alpha:
                        alpha = torch.nn.functional.softplus(raw_alpha) + 0.1
                    if use_diag_cov:
                        if include_mixingcoefficients:
                            q = soft_assignments_mult(
                                feature_list, cluster_centers, sigma, alpha, p, mixing_coefficients
                            )
                        else:
                            q = soft_assignments_mult(
                                feature_list, cluster_centers, sigma, alpha, p
                            )
                    else:
                        q = soft_assignments_1D(
                            feature_list, cluster_centers, sigma, alpha
                        )
                
                    q = q.clamp(min=1e-8)
                    target = target_distribution(q, exponent)
                    target = target.clamp(min=1e-8)

                    kldiv_loss = get_multiplier(epoch, base_multiplier) * (
                        kldiv_criterion(q.log(), target)
                    )
                    kldiv_loss.backward()
                    epoch_loss_kldiv += kldiv_loss.detach()
                    epoch_loss_kldiv_without_scaling += (
                        kldiv_loss.detach() / get_multiplier(epoch, base_multiplier)
                    )
                    epoch_loss += kldiv_loss.detach()

                    with torch.no_grad():
                        cluster_centers_list.append(cluster_centers.cpu().detach())
                        kldiv_list.append(
                            kldiv_loss.cpu() / get_multiplier(epoch, base_multiplier)
                        )

                    if use_diag_cov:
                        cluster_centers, sigma, mixing_coefficients = EM_t_mult(
                            feature_list, q, cluster_centers, sigma, alpha, p
                        )
                    else:
                        cluster_centers, sigma = EM_t_1D(
                            feature_list, q, cluster_centers, sigma, alpha
                        )
                optimizer.step()
                optimizer.zero_grad()

        validation_correlation = get_correlations(
            model,
            dataloaders["validation"],
            device=device,
            as_dict=False,
            per_neuron=False,
            deeplake_ds=deeplake_ds,
        )
        val_loss, val_loss_parts = full_objective(
            model,
            dataloaders["validation"],
            data_key,
            *batch_args,
            **batch_kwargs,
            detach_core=detach_core,
        )
        print(
            f"Epoch {epoch}, Batch {batch_no}, Train loss {loss}, Validation loss {val_loss}"
        )
        print(
            f"EPOCH={epoch}  validation_correlation={validation_correlation}  Epoch Train loss Kullback-Leibler-divergence={epoch_loss_kldiv_without_scaling}"
        )
        model.train()

    ##### Model evaluation ####################################################################################################
    model.eval()
    if include_kldivergence:
        soft_assignments_list = []
        for i,(k, readout) in enumerate(model.readout.items()):
            features = readout.features.detach().squeeze()
            if include_mixingcoefficients:
                soft_assignments_list.append(
                    soft_assignments_mult(features, cluster_centers, sigma, alpha, p, mixing_coefficients)
                )
            else:
                soft_assignments_list.append(
                    soft_assignments_mult(features, cluster_centers, sigma, alpha, p)
                ) 
        predicted = torch.cat(soft_assignments_list).max(1)[1]
        cluster_centers_list.append(cluster_centers.cpu().detach().numpy())
        cluster_centers_np = np.array(cluster_centers_list)
        signa_np = sigma.cpu().detach().numpy()
        mixing_coefficients = mixing_coefficients.cpu().detach().numpy() if include_mixingcoefficients else None
    tracker.finalize() if track_training else None
    validation_correlation = get_correlations(
        model, dataloaders["validation"], device=device, as_dict=False, per_neuron=False
    )
    print("Training complete")


    # return the whole tracker output as a dict
    output = {k: v for k, v in tracker.log.items()} if track_training else {}
    output["validation_corr"] = validation_correlation

    score = np.mean(validation_correlation)

    if include_kldivergence:
        return score, output, cluster_centers_np, signa_np, mixing_coefficients, predicted, model.state_dict()
    else:
        return score, output, model.state_dict()
