import ast
from safetensors import safe_open
import torch
from dataclasses import dataclass
from typing import Optional, Union, List

def update_args_from_yaml(group, args, parser):
    for key, value in group.items():
        if isinstance(value, dict):
            update_args_from_yaml(value, args, parser)
        else:
            if value == 'None' or value == 'null':
                value = None
            else:
                arg_type = next((action.type for action in parser._actions if action.dest == key), str)
                
                if arg_type is ast.literal_eval:
                    pass
                elif arg_type is not None and not isinstance(value, arg_type):
                    try:
                        value = arg_type(value)
                    except ValueError as e:
                        raise ValueError(f"Cannot convert {key} to {arg_type}: {e}")

            setattr(args, key, value)


def safe_load(model_path):
    assert "safetensors" in model_path
    state_dict = {}
    with safe_open(model_path, framework="pt", device="cpu") as f:
        for k in f.keys():
            state_dict[k] = f.get_tensor(k) 
    return state_dict


@dataclass
class DDIMSchedulerStepOutput:
    prev_sample: torch.Tensor  # x_{t-1}
    pred_original_sample: Optional[torch.Tensor] = None  # x0


@dataclass
class DDIMSchedulerConversionOutput:
    pred_epsilon: torch.Tensor
    pred_original_sample: torch.Tensor
    pred_velocity: torch.Tensor


class DDIMScheduler:
    prediction_types = ["epsilon", "sample", "v_prediction"]

    def __init__(
        self,
        num_train_timesteps: int,
        num_inference_timesteps: int,
        betas: torch.Tensor,
        set_alpha_to_one: bool = True,
        set_inference_timesteps_from_pure_noise: bool = True,
        inference_timesteps: Union[str, List[int]] = "trailing",
        device: Optional[Union[str, torch.device]] = None,
        dtype: torch.dtype = torch.float32,
        skip_step:bool = False,
        original_inference_step: int=20,
        steps_offset: int=0,
        
    ):
        assert num_train_timesteps > 0
        assert num_train_timesteps >= num_inference_timesteps
        assert num_train_timesteps == betas.size(0)
        assert betas.ndim == 1
        # self.user_name = user_name
        # self.run_time = Recorder.format_time()
        # self.task_name = 'AutoAIGC_%s' % str(self.run_time)
        self.module_name = 'AutoAIGC'
        self.config_list = {"num_train_timesteps": num_train_timesteps,
                            "num_inference_timesteps": num_inference_timesteps,
                            "betas": betas,
                            "set_alpha_to_one": set_alpha_to_one,
                            "set_inference_timesteps_from_pure_noise": set_inference_timesteps_from_pure_noise,
                            "inference_timesteps": inference_timesteps}
        self.module_info = str(self.config_list)

        # self.upload_logger(user_name=user_name)

        device = device or betas.device

        self.num_train_timesteps = num_train_timesteps
        self.num_inference_steps = num_inference_timesteps
        self.steps_offset = steps_offset

        self.betas = betas # .to(device=device, dtype=dtype)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.final_alpha_cumprod = torch.tensor(1.0, device=device, dtype=dtype) if set_alpha_to_one else self.alphas_cumprod[0]

        if isinstance(inference_timesteps, torch.Tensor):
            assert len(inference_timesteps) == num_inference_timesteps
            self.timesteps = inference_timesteps.cpu().numpy().tolist()
        elif set_inference_timesteps_from_pure_noise:
            if inference_timesteps == "trailing":
                # [999, 949, 899, 849, 799, 749, 699, 649, 599, 549, 499, 449, 399, 349, 299, 249, 199, 149,  99,  49]
                if skip_step:  #  ?
                    original_timesteps = torch.arange(num_train_timesteps - 1, -1, -num_train_timesteps / original_inference_step, device=device).round().int().tolist()
                    skipping_step = len(original_timesteps) // num_inference_timesteps
                    self.timesteps = original_timesteps[::skipping_step][:num_inference_timesteps]
                else:  # [999, 899, 799, 699, 599, 499, 399, 299, 199, 99]
                    self.timesteps = torch.arange(num_train_timesteps - 1, -1, -num_train_timesteps / num_inference_timesteps, device=device).round().int().tolist()
            elif inference_timesteps == "linspace":
                # Fixed DDIM timestep. Make sure the timestep starts from 999.
                # Example 20 steps: 
                # [999, 946, 894, 841, 789, 736, 684, 631, 578, 526, 473, 421, 368, 315, 263, 210, 158, 105,  53,   0]
                # [999,      888,      777,      666,      555,      444,      333,      222,      111,       0]
                self.timesteps = torch.linspace(0, num_train_timesteps - 1, num_inference_timesteps, device=device).round().int().flip(0).tolist()
            elif inference_timesteps == "leading":
                step_ratio = num_train_timesteps // num_inference_timesteps
                # # creates integer timesteps by multiplying by ratio
                # # casting to int to avoid issues when num_inference_step is power of 3
                self.timesteps = torch.arange(0, num_inference_timesteps).mul(step_ratio).round().flip(dims=[0]) #.clone().long()
                # self.timesteps += self.steps_offset
            
                # Original SD and DDIM paper may have a bug: <https://github.com/huggingface/diffusers/issues/2585>
                # The inference timestep does not start from 999.
                # Example 20 steps: 
                # [950, 900, 850, 800, 750, 700, 650, 600, 550, 500, 450, 400, 350, 300, 250, 200, 150, 100,  50,   0]
                # [     900,      800,      700,      600,      500,      400,      300,      200,      100,        0]
                # self.timesteps = torch.arange(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps, device=self.device, dtype=torch.int).flip(0)
                # self.timesteps = list(reversed(range(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps)))
            else:
                raise NotImplementedError
                
        elif inference_timesteps == "leading":
            # Original SD and DDIM paper may have a bug: <https://github.com/huggingface/diffusers/issues/2585>
            # The inference timestep does not start from 999.
            # Example 20 steps: 
            # [950, 900, 850, 800, 750, 700, 650, 600, 550, 500, 450, 400, 350, 300, 250, 200, 150, 100,  50,   0]
            # [     900,      800,      700,      600,      500,      400,      300,      200,      100,        0]
            # self.timesteps = torch.arange(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps, device=self.device, dtype=torch.int).flip(0)
            self.timesteps = list(reversed(range(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps)))

        else:
            self.timesteps = list(reversed(range(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps)))
            # raise NotImplementedError

        self.to(device=device)


    def to(self, device):
        self.betas = self.betas.to(device)
        self.alphas_cumprod = self.alphas_cumprod.to(device)
        self.final_alpha_cumprod = self.final_alpha_cumprod.to(device)
        # self.timesteps = self.timesteps.to(device)
        return self
    
    def step(
        self,
        model_output: torch.Tensor,
        model_output_type: str,
        timestep: Union[torch.Tensor, int],
        sample: torch.Tensor,
        eta: float = 0.0,
        clip_sample: bool = False,
        dynamic_threshold: Optional[float] = None,
        variance_noise: Optional[torch.Tensor] = None,
    ) -> DDIMSchedulerStepOutput:
        # 1. get previous step value (t-1)
        if isinstance(timestep, int):
            # 1. get previous step value (t-1)
            idx = self.timesteps.index(timestep)
            prev_timestep = self.timesteps[idx + 1] if idx < self.num_inference_steps - 1 else None

            # 2. compute alphas, betas
            alpha_prod_t = self.alphas_cumprod[timestep]
            alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep is not None else self.final_alpha_cumprod
            beta_prod_t = 1 - alpha_prod_t
            beta_prod_t_prev = 1 - alpha_prod_t_prev
        else:
            timesteps = torch.tensor(self.timesteps).to(timestep.device)
            idx = timestep.reshape(-1, 1).eq(timesteps.reshape(1, -1)).nonzero()[:, 1] # 找到 timestep 在 timesteps 中的索引 idx
            # 根据idx找到idx+1对应的timesteps元素,也就是下一个时间步。如果idx+1超出了timesteps的长度,它会被限制在self.num_inference_steps - 1
            prev_timestep = timesteps[idx.add(1).clamp_max(self.num_inference_steps - 1)]

            assert (prev_timestep is not None)
            # 2. compute alphas, betas
            alpha_prod_t = self.alphas_cumprod[timestep]
            alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]
            alpha_prod_t_prev = torch.where(prev_timestep < 0, self.final_alpha_cumprod, alpha_prod_t_prev)
            beta_prod_t = 1 - alpha_prod_t
            beta_prod_t_prev = 1 - alpha_prod_t_prev

            bs = timestep.size(0)
            alpha_prod_t = alpha_prod_t.view(bs, 1, 1, 1)
            alpha_prod_t_prev = alpha_prod_t_prev.view(bs, 1, 1, 1)
            beta_prod_t = beta_prod_t.view(bs, 1, 1, 1)
            beta_prod_t_prev = beta_prod_t_prev.view(bs, 1, 1, 1)

        # # 2. compute alphas, betas
        # alpha_prod_t = self.alphas_cumprod[timestep]
        # alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep is not None else self.final_alpha_cumprod
        # beta_prod_t = 1 - alpha_prod_t
        # beta_prod_t_prev = 1 - alpha_prod_t_prev
        # rcfg
        self.stock_alpha_prod_t_prev = alpha_prod_t_prev
        self.stock_beta_prod_t_prev = beta_prod_t_prev
            
        # rcfg
        self.stock_alpha_prod_t_prev = alpha_prod_t_prev
        self.stock_beta_prod_t_prev = beta_prod_t_prev

        # 3. compute predicted original sample from predicted noise also called
        model_output_conversion = self.convert_output(model_output, model_output_type, sample, timestep)
        pred_original_sample = model_output_conversion.pred_original_sample
        pred_epsilon = model_output_conversion.pred_epsilon

        # 4. Clip or threshold "predicted x_0"
        if clip_sample:
            pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
            pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon

        if dynamic_threshold is not None:
            # Dynamic thresholding in https://arxiv.org/abs/2205.11487
            dynamic_max_val = pred_original_sample \
                .flatten(1) \
                .abs() \
                .float() \
                .quantile(dynamic_threshold, dim=1) \
                .type_as(pred_original_sample) \
                .clamp_min(1) \
                .view(-1, *([1] * (pred_original_sample.ndim - 1)))
            pred_original_sample = pred_original_sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
            pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon

        # 5. compute variance: "sigma_t(η)" -> see formula (16) from https://arxiv.org/pdf/2010.02502.pdf
        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
        std_dev_t = eta * variance ** (0.5)

        # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
        pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon

        # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
        prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

        # 8. add "random noise" if needed.
        if eta > 0:
            if variance_noise is None:
                variance_noise = torch.randn_like(model_output)
            prev_sample = prev_sample + std_dev_t * variance_noise

        return DDIMSchedulerStepOutput(
            prev_sample=prev_sample, # x_{t-1}
            pred_original_sample=pred_original_sample # x0
            )

    def add_noise(
        self,
        original_samples: torch.Tensor,
        noise: torch.Tensor,
        timesteps: Union[torch.Tensor, int],
        replace_noise=True
    ) -> torch.Tensor:
        alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (original_samples.ndim - 1)))
        if replace_noise:
            indices = (timesteps == 999).nonzero()
            if indices.numel() > 0:
                alpha_prod_t[indices] = 0
        return alpha_prod_t ** (0.5) * original_samples + (1 - alpha_prod_t) ** (0.5) * noise
    
    def add_noise_lcm(
        self,
        original_samples: torch.Tensor,
        noise: torch.Tensor,
        timestep: Union[torch.Tensor, int],
    ) -> torch.Tensor:
        if isinstance(timestep, int):
            # 1. get previous step value (t-1)
            idx = self.timesteps.index(timestep)
            prev_timestep = self.timesteps[idx + 1] if idx < self.num_inference_steps - 1 else None

            # 2. compute alphas, betas
            alpha_prod_t = self.alphas_cumprod[timestep]
            alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep is not None else self.final_alpha_cumprod
            beta_prod_t = 1 - alpha_prod_t
            beta_prod_t_prev = 1 - alpha_prod_t_prev
        else:
            timesteps = torch.tensor(self.timesteps).to(timestep.device)
            idx = timestep.reshape(-1, 1).eq(timesteps.reshape(1, -1)).nonzero()[:, 1] # 找到 timestep 在 timesteps 中的索引 idx
            prev_timestep = timesteps[idx.add(1).clamp_max(self.num_inference_steps - 1)]

            assert (prev_timestep is not None)
            # 2. compute alphas, betas
            alpha_prod_t = self.alphas_cumprod[timestep]
            alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]
            alpha_prod_t_prev = torch.where(prev_timestep < 0, self.final_alpha_cumprod, alpha_prod_t_prev)
            beta_prod_t = 1 - alpha_prod_t
            beta_prod_t_prev = 1 - alpha_prod_t_prev

            bs = timestep.size(0)
            alpha_prod_t = alpha_prod_t.view(bs, 1, 1, 1)
            alpha_prod_t_prev = alpha_prod_t_prev.view(bs, 1, 1, 1)
            beta_prod_t = beta_prod_t.view(bs, 1, 1, 1)
            beta_prod_t_prev = beta_prod_t_prev.view(bs, 1, 1, 1)

        alpha_prod_t_prev = alpha_prod_t_prev.reshape(-1, *([1] * (original_samples.ndim - 1)))
        return alpha_prod_t_prev ** (0.5) * original_samples + (1 - alpha_prod_t_prev) ** (0.5) * noise


    def convert_output(
        self,
        model_output: torch.Tensor,
        model_output_type: str,
        sample: torch.Tensor,
        timesteps: Union[torch.Tensor, int]
    ) -> DDIMSchedulerConversionOutput:
        assert model_output_type in self.prediction_types

        alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (sample.ndim - 1)))
        beta_prod_t = 1 - alpha_prod_t

        if model_output_type == "epsilon":
            pred_epsilon = model_output
            pred_original_sample = (sample - beta_prod_t ** (0.5) * pred_epsilon) / alpha_prod_t ** (0.5)
            pred_velocity = alpha_prod_t ** (0.5) * pred_epsilon - (1 - alpha_prod_t) ** (0.5) * pred_original_sample
        elif model_output_type == "sample":
            pred_original_sample = model_output
            pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
            pred_velocity = alpha_prod_t ** (0.5) * pred_epsilon - (1 - alpha_prod_t) ** (0.5) * pred_original_sample
        elif model_output_type == "v_prediction":
            pred_velocity = model_output
            pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
            pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
        else:
            raise ValueError("Unknown prediction type")

        return DDIMSchedulerConversionOutput(
            pred_epsilon=pred_epsilon,
            pred_original_sample=pred_original_sample,
            pred_velocity=pred_velocity)

    def get_velocity(
        self,
        sample: torch.Tensor,
        noise: torch.Tensor,
        timesteps: torch.Tensor
    ) -> torch.FloatTensor:
        alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (sample.ndim - 1)))
        return alpha_prod_t ** (0.5) * noise - (1 - alpha_prod_t) ** (0.5) * sample