import numpy as np
from math import ceil, floor
from skimage.measure import label as connected_components


# TODO: In future, combine all label transforms into one (?)
class LabelTransformJointTraining:
    def __init__(self, ensure_cc: bool = False):
        self.ensure_cc = ensure_cc

    def __call__(self, labels):

        if self.ensure_cc:
            # Ensure all objects are with individual ids.
            labels = connected_components(labels).astype(labels.dtype)

        if labels.ndim == 2:  # Add an empty dimension.
            labels = labels[None]

        # First channel for interactive segmentation.
        # Second channel for background semantic segmentation.
        # Third channel for foreground semantic segmentation.
        return np.concatenate([labels, labels == 0, labels != 0], axis=0)


class LabelTrafoToBinary:
    def __init__(self, label_id_mapping=None, switch_last_axes=False):
        self.switch_last_axes = switch_last_axes
        self.label_id_mapping = label_id_mapping

    def _binarise_labels(self, labels):
        if self.label_id_mapping is None:
            labels = (labels > 0).astype(labels.dtype)
        else:
            neu_labels = np.zeros_like(labels)
            for curr_id, neu_id in self.label_id_mapping.items():
                neu_labels[labels == curr_id] = neu_id
            labels = neu_labels

        return labels

    def _switch_last_axes_for_labels(self, labels):
        labels = labels.transpose(0, 2, 1)
        return labels

    def __call__(self, labels):
        labels = self._binarise_labels(labels)
        if self.switch_last_axes:
            labels = self._switch_last_axes_for_labels(labels)
        return labels


# for 3d volumes like SegA
class LabelResizeTrafoFor3dInputs(LabelTrafoToBinary):
    def __init__(self, desired_shape, padding="constant", switch_last_axes=False, binary=True):
        self.desired_shape = desired_shape
        self.padding = padding
        self.switch_last_axes = switch_last_axes
        self.binary = binary

    def __call__(self, labels):
        if self.binary:
            # binarize the samples
            labels = self._binarise_labels(labels)

        # let's pad the labels
        tmp_ddim = (
           self.desired_shape[0] - labels.shape[0],
           self.desired_shape[1] - labels.shape[1],
           self.desired_shape[2] - labels.shape[2]
        )
        ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2, tmp_ddim[2] / 2)
        labels = np.pad(
            labels,
            pad_width=(
                (ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1])), (ceil(ddim[2]), floor(ddim[2]))
            ),
            mode=self.padding
        )

        if self.switch_last_axes:
            labels = self._switch_last_axes_for_labels(labels)

        return labels
