from typing import Tuple

import numpy as np
from math import ceil, floor

from torch_em.transform.raw import normalize, normalize_percentile, standardize


# TODO: In future, combine all raw transforms into one (?)
class RawTransformJointTraining:

    def __call__(self, raw: np.ndarray):
        raw = normalize_percentile(raw)  # Percentile normalization.
        raw = np.clip(raw, 0, 1)  # Ensure values between range 0 and 1.
        raw = raw * 255  # Convert to 8-bit.
        return raw


class RawTrafoFor3dInputs:
    def __init__(self, switch_last_axes: bool = False, benchmark_models: bool = False):
        self.switch_last_axes = switch_last_axes
        self.benchmark_models = benchmark_models

    def _normalize_inputs(self, raw):
        raw = normalize(raw) * 255
        return raw

    def _set_channels_for_inputs(self, raw):
        raw = np.stack([raw] * 3, axis=0)
        return raw

    def _switch_last_axes_for_inputs(self, raw):
        raw = raw.transpose(0, 1, 3, 2)
        return raw

    def __call__(self, raw):
        if self.benchmark_models:  # Models trained from scratch.
            raw = standardize(raw)
            raw = raw[None]
        else:  # SAM-pretrained model training style.
            raw = self._normalize_inputs(raw)
            raw = self._set_channels_for_inputs(raw)

        if self.switch_last_axes:
            raw = self._switch_last_axes_for_inputs(raw)
        return raw


# for 3d volumes like SegA
class RawResizeTrafoFor3dInputs(RawTrafoFor3dInputs):
    def __init__(
        self,
        desired_shape: Tuple[int, ...],
        padding: str = "constant",
        switch_last_axes:  bool = False,
        benchmark_models: bool = False
    ):
        super().__init__()
        self.desired_shape = desired_shape
        self.padding = padding
        self.switch_last_axes = switch_last_axes
        self.benchmark_models = benchmark_models

    def __call__(self, raw):
        if self.benchmark_models:
            raw = standardize(raw)
        else:
            raw = self._normalize_inputs(raw)

        # let's pad the inputs
        tmp_ddim = (
           self.desired_shape[0] - raw.shape[0],
           self.desired_shape[1] - raw.shape[1],
           self.desired_shape[2] - raw.shape[2]
        )
        ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2, tmp_ddim[2] / 2)
        raw = np.pad(
            raw,
            pad_width=(
                (ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1])), (ceil(ddim[2]), floor(ddim[2]))
            ),
            mode=self.padding
        )

        if self.benchmark_models:  # Models trained from scratch.
            raw = raw[None]
        else:   # SAM-pretrained model training style.
            raw = self._set_channels_for_inputs(raw)

        if self.switch_last_axes:
            raw = self._switch_last_axes_for_inputs(raw)

        return raw
