diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b39cfd2a29bedded4e8aac69833506e3654f0eb
--- /dev/null
+++ b/app.py
@@ -0,0 +1,276 @@
+import torch
+import torchvision
+
+import os
+import os.path as osp
+import random
+from argparse import ArgumentParser
+from datetime import datetime
+
+import gradio as gr
+
+from foleycrafter.utils.util import build_foleycrafter, read_frames_with_moviepy
+from foleycrafter.pipelines.auffusion_pipeline import denormalize_spectrogram
+from foleycrafter.pipelines.auffusion_pipeline import Generator
+from foleycrafter.models.time_detector.model import VideoOnsetNet
+from foleycrafter.models.specvqgan.onset_baseline.utils import torch_utils
+
+from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor 
+from huggingface_hub import snapshot_download
+from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
+
+import soundfile as sf
+from moviepy.editor import AudioFileClip, VideoFileClip
+os.environ['GRADIO_TEMP_DIR'] = './tmp'
+
+sample_idx = 0
+scheduler_dict = {
+    "DDIM": DDIMScheduler,
+    "Euler": EulerDiscreteScheduler,
+    "PNDM": PNDMScheduler,
+}
+
+css = """
+.toolbutton {
+    margin-buttom: 0em 0em 0em 0em;
+    max-width: 2.5em;
+    min-width: 2.5em !important;
+    height: 2.5em;
+}
+"""
+
+parser = ArgumentParser()
+parser.add_argument("--config", type=str, default="example/config/base.yaml")
+parser.add_argument("--server-name", type=str, default="0.0.0.0")
+parser.add_argument("--port", type=int, default=11451)
+parser.add_argument("--share", action="store_true")
+
+parser.add_argument("--save-path", default="samples")
+
+args = parser.parse_args()
+
+
+N_PROMPT = (
+    ""
+)
+
+class FoleyController:
+    def __init__(self):
+        # config dirs
+        self.basedir = os.getcwd()
+        self.model_dir = os.path.join(self.basedir, "models")
+        self.savedir = os.path.join(self.basedir, args.save_path, datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
+        self.savedir_sample = os.path.join(self.savedir, "sample")
+        os.makedirs(self.savedir, exist_ok=True)
+
+        self.device = "cuda" if torch.cuda.is_available() else "cpu"
+
+        self.pipeline = None
+
+        self.loaded = False
+
+        self.load_model()
+
+    def load_model(self):
+        gr.Info("Start Load Models...")
+        print("Start Load Models...")
+
+        # download ckpt
+        pretrained_model_name_or_path = 'auffusion/auffusion-full-no-adapter'
+        if not os.path.isdir(pretrained_model_name_or_path):
+            pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path, local_dir='models/auffusion') 
+
+        fc_ckpt = 'ymzhang319/FoleyCrafter'
+        if not os.path.isdir(fc_ckpt):
+            fc_ckpt = snapshot_download(fc_ckpt, local_dir='models/') 
+
+        # set model config
+        temporal_ckpt_path = osp.join(self.model_dir, 'temporal_adapter.ckpt')
+
+        # load vocoder
+        vocoder_config_path= "./models/auffusion"
+        self.vocoder       = Generator.from_pretrained(
+                        vocoder_config_path, 
+                        subfolder="vocoder").to(self.device)
+        
+        # load time detector
+        time_detector_ckpt = osp.join(osp.join(self.model_dir, 'timestamp_detector.pth.tar'))
+        time_detector      = VideoOnsetNet(False)
+        self.time_detector, _   = torch_utils.load_model(time_detector_ckpt, time_detector, strict=True, device=self.device)
+
+        self.pipeline = build_foleycrafter().to(self.device)
+        ckpt = torch.load(temporal_ckpt_path)
+
+        # load temporal adapter
+        if 'state_dict' in ckpt.keys():
+            ckpt = ckpt['state_dict']
+        load_gligen_ckpt = {}
+        for key, value in ckpt.items():
+            if key.startswith('module.'):
+                load_gligen_ckpt[key[len('module.'):]] = value
+            else:
+                load_gligen_ckpt[key] = value
+        m, u        = self.pipeline.controlnet.load_state_dict(load_gligen_ckpt, strict=False)
+        print(f"### Control Net missing keys: {len(m)}; \n### unexpected keys: {len(u)};") 
+
+        self.image_processor      = CLIPImageProcessor()
+        self.image_encoder        = CLIPVisionModelWithProjection.from_pretrained('h94/IP-Adapter', subfolder='models/image_encoder').to(self.device)
+
+        self.pipeline.load_ip_adapter(fc_ckpt, subfolder='semantic', weight_name='semantic_adapter.bin', image_encoder_folder=None)
+
+        gr.Info("Load Finish!")
+        print("Load Finish!")
+        self.loaded = True
+
+        return "Load"
+
+    def foley(
+        self,
+        input_video,
+        prompt_textbox,
+        negative_prompt_textbox, 
+        ip_adapter_scale,
+        temporal_scale,
+        sampler_dropdown,
+        sample_step_slider,
+        cfg_scale_slider,
+        seed_textbox, 
+    ):
+        
+        vision_transform_list = [
+            torchvision.transforms.Resize((128, 128)),
+            torchvision.transforms.CenterCrop((112, 112)),
+            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        ]
+        video_transform = torchvision.transforms.Compose(vision_transform_list)
+        if not self.loaded:
+            raise gr.Error("Error with loading model")
+        generator  = torch.Generator()
+        if seed_textbox != "":
+            torch.manual_seed(int(seed_textbox))
+            generator.manual_seed(int(seed_textbox))
+        max_frame_nums = 15
+        frames, duration  = read_frames_with_moviepy(input_video, max_frame_nums=max_frame_nums)
+        if duration >= 10:
+            duration = 10
+        time_frames = torch.FloatTensor(frames).permute(0, 3, 1, 2)
+        time_frames = video_transform(time_frames)
+        time_frames = {'frames': time_frames.unsqueeze(0).permute(0, 2, 1, 3, 4)}
+        preds       = self.time_detector(time_frames)
+        preds       = torch.sigmoid(preds)
+
+        # duration 
+        time_condition = [-1 if preds[0][int(i / (1024 / 10 * duration) * max_frame_nums)] < 0.5 else 1 for i in range(int(1024 / 10 * duration))]
+        time_condition = time_condition + [-1] * (1024 - len(time_condition))
+        # w -> b c h w
+        time_condition = torch.FloatTensor(time_condition).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(1, 1, 256, 1)
+        
+        images = self.image_processor(images=frames, return_tensors="pt").to(self.device)
+        image_embeddings = self.image_encoder(**images).image_embeds
+        image_embeddings = torch.mean(image_embeddings, dim=0, keepdim=True).unsqueeze(0).unsqueeze(0)
+        neg_image_embeddings = torch.zeros_like(image_embeddings)
+        image_embeddings = torch.cat([neg_image_embeddings, image_embeddings], dim=1)
+        self.pipeline.set_ip_adapter_scale(ip_adapter_scale)
+        sample = self.pipeline(
+            prompt=prompt_textbox,
+            negative_prompt=negative_prompt_textbox,
+            ip_adapter_image_embeds=image_embeddings,
+            image=time_condition,
+            controlnet_conditioning_scale=float(temporal_scale),
+            num_inference_steps=sample_step_slider,
+            height=256,
+            width=1024,
+            output_type="pt", 
+            generator=generator,
+        )
+        name = 'output'
+        audio_img = sample.images[0]
+        audio     = denormalize_spectrogram(audio_img)
+        audio     = self.vocoder.inference(audio, lengths=160000)[0]
+        audio_save_path = osp.join(self.savedir_sample, 'audio')
+        os.makedirs(audio_save_path, exist_ok=True)
+        audio = audio[:int(duration * 16000)]
+
+        save_path = osp.join(audio_save_path, f'{name}.wav')
+        sf.write(save_path, audio, 16000)
+
+        audio = AudioFileClip(osp.join(audio_save_path, f'{name}.wav'))
+        video = VideoFileClip(input_video)
+        audio = audio.subclip(0, duration)
+        video.audio = audio
+        video = video.subclip(0, duration)
+        video.write_videofile(osp.join(self.savedir_sample, f'{name}.mp4'))
+        save_sample_path = os.path.join(self.savedir_sample, f"{name}.mp4")
+
+        return save_sample_path 
+
+controller = FoleyController()
+
+def ui():
+    with gr.Blocks(css=css) as demo:
+        gr.HTML(
+            "<div align='center'><font size='6'>FoleyCrafter: Bring Silent Videos to Life with Lifelike and Synchronized Sounds</font></div>"
+        )
+        with gr.Row():
+            gr.Markdown(
+                "<div align='center'><font size='5'><a href='https://foleycrafter.github.io/'>Project Page</a> &ensp;"  # noqa
+                "<a href='https://arxiv.org/abs/xxxx.xxxxx/'>Paper</a> &ensp;"
+                "<a href='https://github.com/open-mmlab/foleycrafter'>Code</a> &ensp;"
+                "<a href='https://huggingface.co/spaces/ymzhang319/FoleyCrafter'>Demo</a> </font></div>"
+            )
+
+        with gr.Column(variant="panel"):
+            with gr.Row(equal_height=False):
+                with gr.Column():
+                    with gr.Row():
+                        init_img = gr.Video(label="Input Video")
+                    with gr.Row():
+                        prompt_textbox = gr.Textbox(value='', label="Prompt", lines=1)
+                    with gr.Row():
+                        negative_prompt_textbox = gr.Textbox(value=N_PROMPT, label="Negative prompt", lines=1)
+
+                    with gr.Row():
+                        sampler_dropdown = gr.Dropdown(
+                            label="Sampling method",
+                            choices=list(scheduler_dict.keys()),
+                            value=list(scheduler_dict.keys())[0],
+                        )
+                        sample_step_slider = gr.Slider(
+                            label="Sampling steps", value=25, minimum=10, maximum=100, step=1
+                        )
+
+                    cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20)
+                    ip_adapter_scale = gr.Slider(label="Visual Content Scale", value=1.0, minimum=0, maximum=1)
+                    temporal_scale = gr.Slider(label="Temporal Align Scale", value=0., minimum=0., maximum=1.0)
+
+                    with gr.Row():
+                        seed_textbox = gr.Textbox(label="Seed", value=42)
+                        seed_button = gr.Button(value="\U0001f3b2", elem_classes="toolbutton")
+                    seed_button.click(fn=lambda x: random.randint(1, 1e8), outputs=[seed_textbox], queue=False)
+
+                    generate_button = gr.Button(value="Generate", variant="primary")
+
+                result_video = gr.Video(label="Generated Audio", interactive=False)
+
+            generate_button.click(
+                fn=controller.foley,
+                inputs=[
+                    init_img,
+                    prompt_textbox,
+                    negative_prompt_textbox,
+                    ip_adapter_scale,
+                    temporal_scale,
+                    sampler_dropdown,
+                    sample_step_slider,
+                    cfg_scale_slider,
+                    seed_textbox,
+                ],
+                outputs=[result_video],
+            )
+
+    return demo
+
+if __name__ == "__main__":
+    demo = ui()
+    demo.queue(3)
+    demo.launch(server_name=args.server_name, server_port=args.port, share=args.share)
\ No newline at end of file
diff --git a/configs/auffusion/vocoder/config.json b/configs/auffusion/vocoder/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..07860a8422ad8ffd7838b0b87c5a2f7126fbff06
--- /dev/null
+++ b/configs/auffusion/vocoder/config.json
@@ -0,0 +1,37 @@
+{
+    "resblock": "1",
+    "num_gpus": 0,
+    "batch_size": 16,
+    "learning_rate": 0.0002,
+    "adam_b1": 0.8,
+    "adam_b2": 0.99,
+    "lr_decay": 0.999,
+    "seed": 1234,
+
+    "upsample_rates": [5,4,4,2],
+    "upsample_kernel_sizes": [11,8,8,4],
+    "upsample_initial_channel": 512,
+    "resblock_kernel_sizes": [3,7,11],
+    "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
+
+    "segment_size": 5120,
+    "num_mels": 256,
+    "num_freq": 2049,
+    "n_fft": 2048,
+    "hop_size": 160,
+    "win_size": 1024,
+
+    "sampling_rate": 16000,
+
+    "fmin": 0,
+    "fmax": null,
+    "fmax_for_loss": null,
+
+    "num_workers": 4,
+
+    "dist_config": {
+        "dist_backend": "nccl",
+        "dist_url": "tcp://localhost:54321",
+        "world_size": 1
+    }
+}
diff --git a/configs/train/train_semantic_adapter.yaml b/configs/train/train_semantic_adapter.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e967440443d7c9c8b51a085f4c32d63f0180c871
--- /dev/null
+++ b/configs/train/train_semantic_adapter.yaml
@@ -0,0 +1,54 @@
+output_dir: "outputs"
+
+pretrained_model_path: ""
+
+motion_module_path: "models/mm_sd_v15_v2.ckpt"
+
+train_data:
+  csv_path: "./curated.csv"
+  audio_fps: 48000
+  audio_size: 480000
+
+validation_data:
+  prompts:
+    - "./data/input/lighthouse.png"
+    - "./data/input/guitar.png"
+    - "./data/input/lion.png"
+    - "./data/input/gun.png"
+  num_inference_steps: 25
+  guidance_scale: 7.5
+  sample_size: 512
+
+trainable_modules:
+  - 'to_k_ip'
+  - 'to_v_ip'
+
+audio_unet_checkpoint_path: ""
+
+learning_rate:    1.0e-4
+train_batch_size: 1 # max for mixed
+gradient_accumulation_steps: 1
+
+max_train_epoch:      -1
+max_train_steps:      200000
+checkpointing_epochs: 4000
+checkpointing_steps:  500
+
+validation_steps:       3000
+validation_steps_tuple: [2, 50, 300, 1000]
+
+global_seed: 42
+mixed_precision_training: true
+
+is_debug: False
+
+resume_ckpt: ""
+
+# params for adapter
+init_from_ip_adapter: false
+
+always_null_text: false
+
+reverse_null_text_prob: true
+
+frame_wise_condition: true
diff --git a/configs/train/train_temporal_adapter.yaml b/configs/train/train_temporal_adapter.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..92018e38460bf3c57af8f83bb20a026fad15427a
--- /dev/null
+++ b/configs/train/train_temporal_adapter.yaml
@@ -0,0 +1,48 @@
+output_dir: "outputs"
+
+pretrained_model_path: ""
+
+motion_module_path: "models/mm_sd_v15_v2.ckpt"
+
+train_data:
+  csv_path: "./curated.csv"
+  audio_fps: 48000
+  audio_size: 480000
+
+validation_data:
+  prompts:
+    - "./data/input/lighthouse.png"
+    - "./data/input/guitar.png"
+    - "./data/input/lion.png"
+    - "./data/input/gun.png"
+  num_inference_steps: 25
+  guidance_scale: 7.5
+  sample_size: 512
+
+trainable_modules:
+  - 'time_conv_in.'
+  - 'conv_in.'
+
+video_unet_checkpoint_path: "models/vggsound_unet.ckpt"
+audio_unet_checkpoint_path: ""
+
+learning_rate:    5.0e-5
+train_batch_size: 1 # max for mixed
+gradient_accumulation_steps: 1
+
+max_train_epoch:      -1
+max_train_steps:      500000
+checkpointing_epochs: 4000
+checkpointing_steps:  500
+
+validation_steps:       3000
+validation_steps_tuple: [2, 300, 1000]
+
+global_seed: 42
+mixed_precision_training: true
+
+is_debug: False
+
+resume_ckpt: ""
+
+zero_no_label_mel: false
\ No newline at end of file
diff --git a/environment.yaml b/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dddf02b89aa390a24d543ed1ff60413003707022
--- /dev/null
+++ b/environment.yaml
@@ -0,0 +1,24 @@
+name: foleycrafter
+channels:
+  - pytorch
+  - nvidia
+dependencies:
+  - python=3.10
+  - pytorch=2.2.0
+  - torchvision=0.17.0
+  - pytorch-cuda=11.8
+  - pip
+  - pip:
+    - diffusers==0.25.1
+    - transformers==4.30.2
+    - xformers
+    - imageio==2.33.1
+    - decord==0.6.0
+    - einops
+    - omegaconf
+    - safetensors
+    - gradio
+    - tqdm==4.66.1
+    - soundfile==0.12.1
+    - wandb
+    - moviepy==1.0.3
\ No newline at end of file
diff --git a/foleycrafter/data/dataset.py b/foleycrafter/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7b77b07232caee25bc1fc661cbdaf086ba9e7a1
--- /dev/null
+++ b/foleycrafter/data/dataset.py
@@ -0,0 +1,175 @@
+import torch
+import torchvision.transforms as transforms
+from torch.utils.data.dataset import Dataset
+import torch.distributed as dist
+import torchaudio
+import torchvision
+import torchvision.io
+
+import os, io, csv, math, random
+import os.path as osp
+from pathlib import Path
+import numpy as np
+import pandas as pd
+from einops import rearrange
+import glob
+
+from decord import VideoReader, AudioReader
+import decord
+from copy import deepcopy
+import pickle
+
+from petrel_client.client import Client
+import sys
+sys.path.append('./')
+from foleycrafter.data import video_transforms
+
+from foleycrafter.utils.util import \
+    random_audio_video_clip, get_full_indices, video_tensor_to_np, get_video_frames 
+from foleycrafter.utils.spec_to_mel import wav_tensor_to_fbank, read_wav_file_io, load_audio, normalize_wav, pad_wav
+from foleycrafter.utils.converter import get_mel_spectrogram_from_audio, pad_spec, normalize, normalize_spectrogram
+
+def zero_rank_print(s):
+    if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s, flush=True)
+
+@torch.no_grad()
+def get_mel(audio_data, audio_cfg):
+    # mel shape: (n_mels, T)
+    mel = torchaudio.transforms.MelSpectrogram(
+        sample_rate=audio_cfg["sample_rate"],
+        n_fft=audio_cfg["window_size"],
+        win_length=audio_cfg["window_size"],
+        hop_length=audio_cfg["hop_size"],
+        center=True,
+        pad_mode="reflect",
+        power=2.0,
+        norm=None,
+        onesided=True,
+        n_mels=64,
+        f_min=audio_cfg["fmin"],
+        f_max=audio_cfg["fmax"],
+    ).to(audio_data.device)
+    mel = mel(audio_data)
+    # we use log mel spectrogram as input
+    mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel)
+    return mel  # (T, n_mels)
+
+def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
+    """
+    PARAMS
+    ------
+    C: compression factor
+    """
+    return normalize_fun(torch.clamp(x, min=clip_val) * C)
+
+class CPU_Unpickler(pickle.Unpickler):
+    def find_class(self, module, name):
+        if module == 'torch.storage' and name == '_load_from_bytes':
+            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
+        else:
+            return super().find_class(module, name)
+
+class AudioSetStrong(Dataset):
+    # read feature and audio
+    def __init__(
+        self, 
+    ):
+        super().__init__()
+        self.data_path = 'data/AudioSetStrong/train/feature'
+        self.data_list = list(self._client.list(self.data_path))
+        self.length = len(self.data_list)
+        # get video feature
+        self.video_path = 'data/AudioSetStrong/train/video'
+        vision_transform_list = [
+            transforms.Resize((128, 128)),
+            transforms.CenterCrop((112, 112)),
+            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        ]
+        self.video_transform = transforms.Compose(vision_transform_list) 
+
+    def get_batch(self, idx):
+        embeds = self.data_list[idx]
+        mel           = embeds['mel']
+        save_bsz      = mel.shape[0]
+        audio_info    = embeds['audio_info'] 
+        text_embeds   = embeds['text_embeds']
+
+        # audio_info['label_list'] = np.array(audio_info['label_list'])
+        audio_info_array = np.array(audio_info['label_list'])
+        prompts = []
+        for i in range(save_bsz):
+            prompts.append(', '.join(audio_info_array[i, :audio_info['event_num'][i]].tolist()))
+        # import ipdb; ipdb.set_trace()
+        # read videos  
+        videos = None
+        for video_name in audio_info['audio_name']:
+            video_bytes  = self._client.Get(osp.join(self.video_path, video_name+'.mp4'))
+            video_bytes  = io.BytesIO(video_bytes)
+            video_reader = VideoReader(video_bytes)
+            video        = video_reader.get_batch(get_full_indices(video_reader)).asnumpy()
+            video        = get_video_frames(video, 150)
+            video        = torch.from_numpy(video).permute(0, 3, 1, 2).contiguous().float()
+            video        = self.video_transform(video)
+            video        = video.unsqueeze(0)
+            if videos is None:
+                videos = video
+            else:
+                videos = torch.cat([videos, video], dim=0)
+            # video        = torch.from_numpy(video).permute(0, 3, 1, 2).contiguous() 
+        assert videos is not None, 'no video read'
+
+        return mel, audio_info, text_embeds, prompts, videos
+    
+    def __len__(self):
+        return self.length
+    
+    def __getitem__(self, idx):
+        while True:
+            try:
+                mel, audio_info, text_embeds, prompts, videos = self.get_batch(idx)
+                break
+            except Exception as e:
+                zero_rank_print(' >>> load error <<<')
+                idx = random.randint(0, self.length-1)
+        sample = dict(mel=mel, audio_info=audio_info, text_embeds=text_embeds, prompts=prompts, videos=videos)
+        return sample
+    
+class VGGSound(Dataset):
+    # read feature and audio
+    def __init__(
+        self,
+    ):
+        super().__init__()
+        self.data_path = 'data/VGGSound/train/video'
+        self.visual_data_path = 'data/VGGSound/train/feature'
+        self.embeds_list = glob.glob(f'{self.data_path}/*.pt')
+        self.visual_list = glob.glob(f'{self.visual_data_path}/*.pt')
+        self.length = len(self.embeds_list)
+
+    def get_batch(self, idx):
+        embeds = torch.load(self.embeds_list[idx], map_location='cpu')
+        visual_embeds = torch.load(self.visual_list[idx], map_location='cpu')
+
+        # audio_embeds  = embeds['audio_embeds']
+        visual_embeds = visual_embeds['visual_embeds']
+        video_name    = embeds['video_name']
+        text          = embeds['text']
+        mel           = embeds['mel']
+
+        audio = mel
+        
+        return visual_embeds, audio, text
+    
+    def __len__(self):
+        return self.length
+    
+    def __getitem__(self, idx):
+        while True:
+            try:
+                visual_embeds, audio, text = self.get_batch(idx)
+                break
+            except Exception as e:
+                zero_rank_print('load error')
+                idx = random.randint(0, self.length-1)
+        sample = dict(visual_embeds=visual_embeds, audio=audio, text=text)
+        return sample
\ No newline at end of file
diff --git a/foleycrafter/data/video_transforms.py b/foleycrafter/data/video_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..909f555105e4851b0da5747e0cdba991060b4428
--- /dev/null
+++ b/foleycrafter/data/video_transforms.py
@@ -0,0 +1,400 @@
+import torch
+import random
+import numbers
+from torchvision.transforms import RandomCrop, RandomResizedCrop
+
+def _is_tensor_video_clip(clip):
+    if not torch.is_tensor(clip):
+        raise TypeError("clip should be Tensor. Got %s" % type(clip))
+
+    if not clip.ndimension() == 4:
+        raise ValueError("clip should be 4D. Got %dD" % clip.dim())
+
+    return True
+
+
+def crop(clip, i, j, h, w):
+    """
+    Args:
+        clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+    """
+    if len(clip.size()) != 4:
+        raise ValueError("clip should be a 4D tensor")
+    return clip[..., i : i + h, j : j + w]
+
+
+def resize(clip, target_size, interpolation_mode):
+    if len(target_size) != 2:
+        raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
+    return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
+
+def resize_scale(clip, target_size, interpolation_mode):
+    if len(target_size) != 2:
+        raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
+    _, _, H, W = clip.shape
+    scale_ = target_size[0] / min(H, W)
+    return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
+
+
+def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
+    """
+    Do spatial cropping and resizing to the video clip
+    Args:
+        clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+        i (int): i in (i,j) i.e coordinates of the upper left corner.
+        j (int): j in (i,j) i.e coordinates of the upper left corner.
+        h (int): Height of the cropped region.
+        w (int): Width of the cropped region.
+        size (tuple(int, int)): height and width of resized clip
+    Returns:
+        clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
+    """
+    if not _is_tensor_video_clip(clip):
+        raise ValueError("clip should be a 4D torch.tensor")
+    clip = crop(clip, i, j, h, w)
+    clip = resize(clip, size, interpolation_mode)
+    return clip
+
+
+def center_crop(clip, crop_size):
+    if not _is_tensor_video_clip(clip):
+        raise ValueError("clip should be a 4D torch.tensor")
+    h, w = clip.size(-2), clip.size(-1)
+    th, tw = crop_size
+    if h < th or w < tw:
+        raise ValueError("height and width must be no smaller than crop_size")
+
+    i = int(round((h - th) / 2.0))
+    j = int(round((w - tw) / 2.0))
+    return crop(clip, i, j, th, tw)
+
+def random_shift_crop(clip):
+    '''
+    Slide along the long edge, with the short edge as crop size
+    '''
+    if not _is_tensor_video_clip(clip):
+        raise ValueError("clip should be a 4D torch.tensor")
+    h, w = clip.size(-2), clip.size(-1)
+    
+    if h <= w:
+        long_edge = w
+        short_edge = h
+    else:
+        long_edge = h
+        short_edge =w
+
+    th, tw = short_edge, short_edge
+
+    i = torch.randint(0, h - th + 1, size=(1,)).item()
+    j = torch.randint(0, w - tw + 1, size=(1,)).item()
+    return crop(clip, i, j, th, tw)
+
+
+def to_tensor(clip):
+    """
+    Convert tensor data type from uint8 to float, divide value by 255.0 and
+    permute the dimensions of clip tensor
+    Args:
+        clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
+    Return:
+        clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
+    """
+    _is_tensor_video_clip(clip)
+    if not clip.dtype == torch.uint8:
+        raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
+    # return clip.float().permute(3, 0, 1, 2) / 255.0
+    return clip.float() / 255.0
+
+
+def normalize(clip, mean, std, inplace=False):
+    """
+    Args:
+        clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
+        mean (tuple): pixel RGB mean. Size is (3)
+        std (tuple): pixel standard deviation. Size is (3)
+    Returns:
+        normalized clip (torch.tensor): Size is (T, C, H, W)
+    """
+    if not _is_tensor_video_clip(clip):
+        raise ValueError("clip should be a 4D torch.tensor")
+    if not inplace:
+        clip = clip.clone()
+    mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
+    print(mean)
+    std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
+    clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
+    return clip
+
+
+def hflip(clip):
+    """
+    Args:
+        clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
+    Returns:
+        flipped clip (torch.tensor): Size is (T, C, H, W)
+    """
+    if not _is_tensor_video_clip(clip):
+        raise ValueError("clip should be a 4D torch.tensor")
+    return clip.flip(-1)
+
+
+class RandomCropVideo:
+    def __init__(self, size):
+        if isinstance(size, numbers.Number):
+            self.size = (int(size), int(size))
+        else:
+            self.size = size
+
+    def __call__(self, clip):
+        """
+        Args:
+            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+        Returns:
+            torch.tensor: randomly cropped video clip.
+                size is (T, C, OH, OW)
+        """
+        i, j, h, w = self.get_params(clip)
+        return crop(clip, i, j, h, w)
+    
+    def get_params(self, clip):
+        h, w = clip.shape[-2:]
+        th, tw = self.size
+
+        if h < th or w < tw:
+            raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
+
+        if w == tw and h == th:
+            return 0, 0, h, w
+
+        i = torch.randint(0, h - th + 1, size=(1,)).item()
+        j = torch.randint(0, w - tw + 1, size=(1,)).item()
+
+        return i, j, th, tw
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}(size={self.size})"
+    
+
+class UCFCenterCropVideo:
+    def __init__(
+        self,
+        size,
+        interpolation_mode="bilinear",
+    ):
+        if isinstance(size, tuple):
+            if len(size) != 2:
+                raise ValueError(f"size should be tuple (height, width), instead got {size}")
+            self.size = size
+        else:
+            self.size = (size, size)
+
+        self.interpolation_mode = interpolation_mode
+       
+
+    def __call__(self, clip):
+        """
+        Args:
+            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+        Returns:
+            torch.tensor: scale resized / center cropped video clip.
+                size is (T, C, crop_size, crop_size)
+        """
+        clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
+        clip_center_crop = center_crop(clip_resize, self.size)
+        return clip_center_crop
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
+    
+class KineticsRandomCropResizeVideo:
+    '''
+    Slide along the long edge, with the short edge as crop size. And resie to the desired size.
+    '''
+    def __init__(
+            self,
+            size,
+            interpolation_mode="bilinear",
+         ):
+        if isinstance(size, tuple):
+                if len(size) != 2:
+                    raise ValueError(f"size should be tuple (height, width), instead got {size}")
+                self.size = size
+        else:
+            self.size = (size, size)
+
+        self.interpolation_mode = interpolation_mode
+
+    def __call__(self, clip):
+        clip_random_crop = random_shift_crop(clip)
+        clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
+        return clip_resize
+
+
+class CenterCropVideo:
+    def __init__(
+        self,
+        size,
+        interpolation_mode="bilinear",
+    ):
+        if isinstance(size, tuple):
+            if len(size) != 2:
+                raise ValueError(f"size should be tuple (height, width), instead got {size}")
+            self.size = size
+        else:
+            self.size = (size, size)
+
+        self.interpolation_mode = interpolation_mode
+       
+
+    def __call__(self, clip):
+        """
+        Args:
+            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+        Returns:
+            torch.tensor: center cropped video clip.
+                size is (T, C, crop_size, crop_size)
+        """
+        clip_center_crop = center_crop(clip, self.size)
+        return clip_center_crop
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
+    
+
+class NormalizeVideo:
+    """
+    Normalize the video clip by mean subtraction and division by standard deviation
+    Args:
+        mean (3-tuple): pixel RGB mean
+        std (3-tuple): pixel RGB standard deviation
+        inplace (boolean): whether do in-place normalization
+    """
+
+    def __init__(self, mean, std, inplace=False):
+        self.mean = mean
+        self.std = std
+        self.inplace = inplace
+
+    def __call__(self, clip):
+        """
+        Args:
+            clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
+        """
+        return normalize(clip, self.mean, self.std, self.inplace)
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
+
+
+class ToTensorVideo:
+    """
+    Convert tensor data type from uint8 to float, divide value by 255.0 and
+    permute the dimensions of clip tensor
+    """
+
+    def __init__(self):
+        pass
+
+    def __call__(self, clip):
+        """
+        Args:
+            clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
+        Return:
+            clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
+        """
+        return to_tensor(clip)
+
+    def __repr__(self) -> str:
+        return self.__class__.__name__
+
+
+class RandomHorizontalFlipVideo:
+    """
+    Flip the video clip along the horizontal direction with a given probability
+    Args:
+        p (float): probability of the clip being flipped. Default value is 0.5
+    """
+
+    def __init__(self, p=0.5):
+        self.p = p
+
+    def __call__(self, clip):
+        """
+        Args:
+            clip (torch.tensor): Size is (T, C, H, W)
+        Return:
+            clip (torch.tensor): Size is (T, C, H, W)
+        """
+        if random.random() < self.p:
+            clip = hflip(clip)
+        return clip
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}(p={self.p})"
+    
+#  ------------------------------------------------------------
+#  ---------------------  Sampling  ---------------------------
+#  ------------------------------------------------------------
+class TemporalRandomCrop(object):
+	"""Temporally crop the given frame indices at a random location.
+
+	Args:
+		size (int): Desired length of frames will be seen in the model.
+	"""
+
+	def __init__(self, size):
+		self.size = size
+
+	def __call__(self, total_frames):
+		rand_end = max(0, total_frames - self.size - 1)
+		begin_index = random.randint(0, rand_end)
+		end_index = min(begin_index + self.size, total_frames)
+		return begin_index, end_index
+    
+
+if __name__ == '__main__':
+    from torchvision import transforms
+    import torchvision.io as io
+    import numpy as np
+    from torchvision.utils import save_image
+    import os
+
+    vframes, aframes, info = io.read_video(
+    filename='./v_Archery_g01_c03.avi',
+    pts_unit='sec',
+    output_format='TCHW'
+    )
+ 
+    trans = transforms.Compose([
+        ToTensorVideo(),
+        RandomHorizontalFlipVideo(),
+        UCFCenterCropVideo(512),
+        # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+    ])
+
+    target_video_len = 32
+    frame_interval = 1
+    total_frames = len(vframes)
+    print(total_frames)
+
+    temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
+
+
+    # Sampling video frames
+    start_frame_ind, end_frame_ind = temporal_sample(total_frames)
+    # print(start_frame_ind)
+    # print(end_frame_ind)
+    assert end_frame_ind - start_frame_ind >= target_video_len
+    frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
+
+    select_vframes = vframes[frame_indice]
+
+    select_vframes_trans = trans(select_vframes)
+
+    select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
+
+    io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
+    
+    for i in range(target_video_len):
+        save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True, value_range=(-1, 1))
\ No newline at end of file
diff --git a/foleycrafter/models/adapters/attention_processor.py b/foleycrafter/models/adapters/attention_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..de165385bf77c483ee7844918adf1adc493e9b51
--- /dev/null
+++ b/foleycrafter/models/adapters/attention_processor.py
@@ -0,0 +1,653 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Union
+from einops import rearrange, repeat
+
+from diffusers.utils import logging 
+from foleycrafter.models.adapters.ip_adapter import MLPProjModel
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+class AttnProcessor(nn.Module):
+    r"""
+    Default processor for performing attention-related computations.
+    """
+
+    def __init__(
+        self,
+        hidden_size=None,
+        cross_attention_dim=None,
+    ):
+        super().__init__()
+
+    def __call__(
+        self,
+        attn,
+        hidden_states,
+        encoder_hidden_states=None,
+        attention_mask=None,
+        temb=None,
+    ):
+        residual = hidden_states
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        query = attn.head_to_batch_dim(query)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        attention_probs = attn.get_attention_scores(query, key, attention_mask)
+        hidden_states = torch.bmm(attention_probs, value)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+
+class IPAttnProcessor(nn.Module):
+    r"""
+    Attention processor for IP-Adapater.
+    Args:
+        hidden_size (`int`):
+            The hidden size of the attention layer.
+        cross_attention_dim (`int`):
+            The number of channels in the `encoder_hidden_states`.
+        scale (`float`, defaults to 1.0):
+            the weight scale of image prompt.
+        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
+            The context length of the image features.
+    """
+
+    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
+        super().__init__()
+
+        self.hidden_size = hidden_size
+        self.cross_attention_dim = cross_attention_dim
+        self.scale = scale
+        self.num_tokens = num_tokens
+
+        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+
+    def __call__(
+        self,
+        attn,
+        hidden_states,
+        encoder_hidden_states=None,
+        attention_mask=None,
+        temb=None,
+    ):
+        residual = hidden_states
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        else:
+            # get encoder_hidden_states, ip_hidden_states
+            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+            encoder_hidden_states, ip_hidden_states = (
+                encoder_hidden_states[:, :end_pos, :],
+                encoder_hidden_states[:, end_pos:, :],
+            )
+            if attn.norm_cross:
+                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        query = attn.head_to_batch_dim(query)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        attention_probs = attn.get_attention_scores(query, key, attention_mask)
+        hidden_states = torch.bmm(attention_probs, value)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # for ip-adapter
+        ip_key = self.to_k_ip(ip_hidden_states)
+        ip_value = self.to_v_ip(ip_hidden_states)
+
+        ip_key = attn.head_to_batch_dim(ip_key)
+        ip_value = attn.head_to_batch_dim(ip_value)
+
+        ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+        self.attn_map = ip_attention_probs
+        ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+        ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
+
+        hidden_states = hidden_states + self.scale * ip_hidden_states
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+
+class AttnProcessor2_0(torch.nn.Module):
+    r"""
+    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+    """
+
+    def __init__(
+        self,
+        hidden_size=None,
+        cross_attention_dim=None,
+    ):
+        super().__init__()
+        if not hasattr(F, "scaled_dot_product_attention"):
+            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+    def __call__(
+        self,
+        attn,
+        hidden_states,
+        encoder_hidden_states=None,
+        attention_mask=None,
+        temb=None,
+    ):  
+        residual = hidden_states
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+
+        if attention_mask is not None:
+            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+            # scaled_dot_product_attention expects attention_mask shape to be
+            # (batch, heads, source_length, target_length)
+            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        inner_dim = key.shape[-1]
+        head_dim = inner_dim // attn.heads
+
+        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        # the output of sdp = (batch, num_heads, seq_len, head_dim)
+        # TODO: add support for attn.scale when we move to Torch 2.1
+        hidden_states = F.scaled_dot_product_attention(
+            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+        )
+
+        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+        hidden_states = hidden_states.to(query.dtype)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+    
+class AttnProcessor2_0WithProjection(torch.nn.Module):
+    r"""
+    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+    """
+
+    def __init__(
+        self,
+        hidden_size=None,
+        cross_attention_dim=None,
+    ):
+        super().__init__()
+        if not hasattr(F, "scaled_dot_product_attention"):
+            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+        self.before_proj_size = 1024
+        self.after_proj_size = 768
+        self.visual_proj = nn.Linear(self.before_proj_size, self.after_proj_size)
+
+    def __call__(
+        self,
+        attn,
+        hidden_states,
+        encoder_hidden_states=None,
+        attention_mask=None,
+        temb=None,
+    ):  
+        residual = hidden_states
+        # encoder_hidden_states = self.visual_proj(encoder_hidden_states)
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+
+        if attention_mask is not None:
+            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+            # scaled_dot_product_attention expects attention_mask shape to be
+            # (batch, heads, source_length, target_length)
+            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        inner_dim = key.shape[-1]
+        head_dim = inner_dim // attn.heads
+
+        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        # the output of sdp = (batch, num_heads, seq_len, head_dim)
+        # TODO: add support for attn.scale when we move to Torch 2.1
+        hidden_states = F.scaled_dot_product_attention(
+            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+        )
+
+        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+        hidden_states = hidden_states.to(query.dtype)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+    
+class IPAttnProcessor2_0(torch.nn.Module):
+    r"""
+    Attention processor for IP-Adapater for PyTorch 2.0.
+    Args:
+        hidden_size (`int`):
+            The hidden size of the attention layer.
+        cross_attention_dim (`int`):
+            The number of channels in the `encoder_hidden_states`.
+        scale (`float`, defaults to 1.0):
+            the weight scale of image prompt.
+        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
+            The context length of the image features.
+    """
+
+    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
+        super().__init__()
+
+        if not hasattr(F, "scaled_dot_product_attention"):
+            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+        self.hidden_size = hidden_size
+        self.cross_attention_dim = cross_attention_dim
+        self.scale = scale
+        self.num_tokens = num_tokens
+
+        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+
+    def __call__(
+        self,
+        attn,
+        hidden_states,
+        encoder_hidden_states=None,
+        attention_mask=None,
+        temb=None,
+    ):
+        residual = hidden_states
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+
+        if attention_mask is not None:
+            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+            # scaled_dot_product_attention expects attention_mask shape to be
+            # (batch, heads, source_length, target_length)
+            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        else:
+            # get encoder_hidden_states, ip_hidden_states
+            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+            encoder_hidden_states, ip_hidden_states = (
+                encoder_hidden_states[:, :end_pos, :],
+                encoder_hidden_states[:, end_pos:, :],
+            )
+            if attn.norm_cross:
+                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        inner_dim = key.shape[-1]
+        head_dim = inner_dim // attn.heads
+
+        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        # the output of sdp = (batch, num_heads, seq_len, head_dim)
+        # TODO: add support for attn.scale when we move to Torch 2.1
+        hidden_states = F.scaled_dot_product_attention(
+            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+        )
+
+        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+        hidden_states = hidden_states.to(query.dtype)
+
+        # for ip-adapter
+        ip_key = self.to_k_ip(ip_hidden_states)
+        ip_value = self.to_v_ip(ip_hidden_states)
+
+        ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        # the output of sdp = (batch, num_heads, seq_len, head_dim)
+        # TODO: add support for attn.scale when we move to Torch 2.1
+        ip_hidden_states = F.scaled_dot_product_attention(
+            query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+        )
+        with torch.no_grad():
+            self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
+            #print(self.attn_map.shape)
+
+        ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+        ip_hidden_states = ip_hidden_states.to(query.dtype)
+
+        hidden_states = hidden_states + self.scale * ip_hidden_states
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+## for controlnet
+class CNAttnProcessor:
+    r"""
+    Default processor for performing attention-related computations.
+    """
+
+    def __init__(self, num_tokens=4):
+        self.num_tokens = num_tokens
+
+    def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
+        residual = hidden_states
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        else:
+            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+            encoder_hidden_states = encoder_hidden_states[:, :end_pos]  # only use text
+            if attn.norm_cross:
+                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        query = attn.head_to_batch_dim(query)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        attention_probs = attn.get_attention_scores(query, key, attention_mask)
+        hidden_states = torch.bmm(attention_probs, value)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+
+class CNAttnProcessor2_0:
+    r"""
+    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+    """
+
+    def __init__(self, num_tokens=4):
+        if not hasattr(F, "scaled_dot_product_attention"):
+            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+        self.num_tokens = num_tokens
+
+    def __call__(
+        self,
+        attn,
+        hidden_states,
+        encoder_hidden_states=None,
+        attention_mask=None,
+        temb=None,
+    ):
+        residual = hidden_states
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+
+        if attention_mask is not None:
+            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+            # scaled_dot_product_attention expects attention_mask shape to be
+            # (batch, heads, source_length, target_length)
+            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        else:
+            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+            encoder_hidden_states = encoder_hidden_states[:, :end_pos]  # only use text
+            if attn.norm_cross:
+                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        inner_dim = key.shape[-1]
+        head_dim = inner_dim // attn.heads
+
+        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        # the output of sdp = (batch, num_heads, seq_len, head_dim)
+        # TODO: add support for attn.scale when we move to Torch 2.1
+        hidden_states = F.scaled_dot_product_attention(
+            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+        )
+
+        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+        hidden_states = hidden_states.to(query.dtype)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
\ No newline at end of file
diff --git a/foleycrafter/models/adapters/ip_adapter.py b/foleycrafter/models/adapters/ip_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6bb9b5d2d63ce17add49c1a8eb2acb091212ab1
--- /dev/null
+++ b/foleycrafter/models/adapters/ip_adapter.py
@@ -0,0 +1,217 @@
+import torch
+import torch.nn as nn
+
+import numpy as np
+
+import os
+from typing import List
+
+from diffusers import StableDiffusionPipeline
+from diffusers.pipelines.controlnet import MultiControlNetModel
+from PIL import Image
+from safetensors import safe_open
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+from foleycrafter.models.adapters.resampler import Resampler
+from foleycrafter.models.adapters.utils import is_torch2_available
+
+class IPAdapter(torch.nn.Module):
+    """IP-Adapter"""
+    def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
+        super().__init__()
+        self.unet = unet
+        self.image_proj_model = image_proj_model
+        self.adapter_modules = adapter_modules
+
+        if ckpt_path is not None:
+            self.load_from_checkpoint(ckpt_path)
+
+    def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
+        ip_tokens = self.image_proj_model(image_embeds)
+        encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
+        # Predict the noise residual
+        noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
+        return noise_pred
+
+    def load_from_checkpoint(self, ckpt_path: str):
+        # Calculate original checksums
+        orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
+        orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
+
+        state_dict = torch.load(ckpt_path, map_location="cpu")
+
+        # Load state dict for image_proj_model and adapter_modules
+        self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
+        self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
+
+        # Calculate new checksums
+        new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
+        new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
+
+        # Verify if the weights have changed
+        assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
+        assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
+
+        print(f"Successfully loaded weights from checkpoint {ckpt_path}")
+
+class VideoProjModel(torch.nn.Module):
+    """Projection Model"""
+
+    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=1, video_frame=50):
+        super().__init__()
+
+        self.cross_attention_dim = cross_attention_dim
+        self.clip_extra_context_tokens = clip_extra_context_tokens
+        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
+        self.norm = torch.nn.LayerNorm(cross_attention_dim)
+
+        self.video_frame = video_frame
+
+    def forward(self, image_embeds):
+        embeds = image_embeds
+        clip_extra_context_tokens = self.proj(embeds)
+        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
+        return clip_extra_context_tokens
+
+class ImageProjModel(torch.nn.Module):
+    """Projection Model"""
+
+    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
+        super().__init__()
+
+        self.cross_attention_dim = cross_attention_dim
+        self.clip_extra_context_tokens = clip_extra_context_tokens
+        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
+        self.norm = torch.nn.LayerNorm(cross_attention_dim)
+
+    def forward(self, image_embeds):
+        embeds = image_embeds
+        clip_extra_context_tokens = self.proj(embeds).reshape(
+            -1, self.clip_extra_context_tokens, self.cross_attention_dim
+        )
+        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
+        return clip_extra_context_tokens
+
+
+class MLPProjModel(torch.nn.Module):
+    """SD model with image prompt"""
+    def zero_initialize(module):
+        for param in module.parameters():
+            param.data.zero_()
+
+    def zero_initialize_last_layer(module):
+        last_layer = None
+        for module_name, layer in module.named_modules():
+            if isinstance(layer, torch.nn.Linear):
+                last_layer = layer
+
+        if last_layer is not None:
+            last_layer.weight.data.zero_()
+            last_layer.bias.data.zero_()
+
+    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
+        
+        super().__init__()
+        
+        self.proj = torch.nn.Sequential(
+            torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
+            torch.nn.GELU(),
+            torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
+            torch.nn.LayerNorm(cross_attention_dim)
+        )
+        # zero initialize the last layer
+        # self.zero_initialize_last_layer()
+        
+    def forward(self, image_embeds):
+        clip_extra_context_tokens = self.proj(image_embeds)
+        return clip_extra_context_tokens
+
+class V2AMapperMLP(torch.nn.Module):
+    def __init__(self, cross_attention_dim=512, clip_embeddings_dim=512, mult=4):
+        super().__init__()
+        self.proj = torch.nn.Sequential(
+            torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim * mult),
+            torch.nn.GELU(),
+            torch.nn.Linear(clip_embeddings_dim * mult, cross_attention_dim),
+            torch.nn.LayerNorm(cross_attention_dim)
+        )
+
+    def forward(self, image_embeds):
+        clip_extra_context_tokens = self.proj(image_embeds)
+        return clip_extra_context_tokens
+
+class TimeProjModel(torch.nn.Module):
+    def __init__(self, positive_len, out_dim, feature_type="text-only", frame_nums:int=64):
+        super().__init__()
+        self.positive_len = positive_len
+        self.out_dim = out_dim
+
+        self.position_dim = frame_nums 
+
+        if isinstance(out_dim, tuple):
+            out_dim = out_dim[0]
+
+        if feature_type == "text-only":
+            self.linears = nn.Sequential(
+                nn.Linear(self.positive_len + self.position_dim, 512),
+                nn.SiLU(),
+                nn.Linear(512, 512),
+                nn.SiLU(),
+                nn.Linear(512, out_dim),
+            )
+            self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
+
+        elif feature_type == "text-image":
+            self.linears_text = nn.Sequential(
+                nn.Linear(self.positive_len + self.position_dim, 512),
+                nn.SiLU(),
+                nn.Linear(512, 512),
+                nn.SiLU(),
+                nn.Linear(512, out_dim),
+            )
+            self.linears_image = nn.Sequential(
+                nn.Linear(self.positive_len + self.position_dim, 512),
+                nn.SiLU(),
+                nn.Linear(512, 512),
+                nn.SiLU(),
+                nn.Linear(512, out_dim),
+            )
+            self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
+            self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
+
+        # self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
+
+    def forward(
+        self,
+        boxes,
+        masks,
+        positive_embeddings=None,
+    ):
+        masks = masks.unsqueeze(-1)
+
+        # # embedding position (it may includes padding as placeholder)
+        # xyxy_embedding = self.fourier_embedder(boxes)  # B*N*4 -> B*N*C
+
+        # # learnable null embedding
+        # xyxy_null = self.null_position_feature.view(1, 1, -1)
+
+        # # replace padding with learnable null embedding
+        # xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
+
+        time_embeds = boxes
+
+        # positionet with text only information
+        if positive_embeddings is not None:
+            # learnable null embedding
+            positive_null = self.null_positive_feature.view(1, 1, -1)
+
+            # replace padding with learnable null embedding
+            positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
+
+            objs = self.linears(torch.cat([positive_embeddings, time_embeds], dim=-1))
+
+        # positionet with text and image infomation
+        else:
+            raise NotImplementedError 
+
+        return objs    
\ No newline at end of file
diff --git a/foleycrafter/models/adapters/resampler.py b/foleycrafter/models/adapters/resampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..f18a6751cd795a607e6fe34d4f050da1aa2045c1
--- /dev/null
+++ b/foleycrafter/models/adapters/resampler.py
@@ -0,0 +1,158 @@
+# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
+# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
+
+import math
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from einops.layers.torch import Rearrange
+
+
+# FFN
+def FeedForward(dim, mult=4):
+    inner_dim = int(dim * mult)
+    return nn.Sequential(
+        nn.LayerNorm(dim),
+        nn.Linear(dim, inner_dim, bias=False),
+        nn.GELU(),
+        nn.Linear(inner_dim, dim, bias=False),
+    )
+
+
+def reshape_tensor(x, heads):
+    bs, length, width = x.shape
+    # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
+    x = x.view(bs, length, heads, -1)
+    # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+    x = x.transpose(1, 2)
+    # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+    x = x.reshape(bs, heads, length, -1)
+    return x
+
+
+class PerceiverAttention(nn.Module):
+    def __init__(self, *, dim, dim_head=64, heads=8):
+        super().__init__()
+        self.scale = dim_head**-0.5
+        self.dim_head = dim_head
+        self.heads = heads
+        inner_dim = dim_head * heads
+
+        self.norm1 = nn.LayerNorm(dim)
+        self.norm2 = nn.LayerNorm(dim)
+
+        self.to_q = nn.Linear(dim, inner_dim, bias=False)
+        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+        self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+    def forward(self, x, latents):
+        """
+        Args:
+            x (torch.Tensor): image features
+                shape (b, n1, D)
+            latent (torch.Tensor): latent features
+                shape (b, n2, D)
+        """
+        x = self.norm1(x)
+        latents = self.norm2(latents)
+
+        b, l, _ = latents.shape
+
+        q = self.to_q(latents)
+        kv_input = torch.cat((x, latents), dim=-2)
+        k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+
+        q = reshape_tensor(q, self.heads)
+        k = reshape_tensor(k, self.heads)
+        v = reshape_tensor(v, self.heads)
+
+        # attention
+        scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+        weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards
+        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+        out = weight @ v
+
+        out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
+
+        return self.to_out(out)
+
+
+class Resampler(nn.Module):
+    def __init__(
+        self,
+        dim=1024,
+        depth=8,
+        dim_head=64,
+        heads=16,
+        num_queries=8,
+        embedding_dim=768,
+        output_dim=1024,
+        ff_mult=4,
+        max_seq_len: int = 257,  # CLIP tokens + CLS token
+        apply_pos_emb: bool = False,
+        num_latents_mean_pooled: int = 0,  # number of latents derived from mean pooled representation of the sequence
+    ):
+        super().__init__()
+        self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
+
+        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
+
+        self.proj_in = nn.Linear(embedding_dim, dim)
+
+        self.proj_out = nn.Linear(dim, output_dim)
+        self.norm_out = nn.LayerNorm(output_dim)
+
+        self.to_latents_from_mean_pooled_seq = (
+            nn.Sequential(
+                nn.LayerNorm(dim),
+                nn.Linear(dim, dim * num_latents_mean_pooled),
+                Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
+            )
+            if num_latents_mean_pooled > 0
+            else None
+        )
+
+        self.layers = nn.ModuleList([])
+        for _ in range(depth):
+            self.layers.append(
+                nn.ModuleList(
+                    [
+                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+                        FeedForward(dim=dim, mult=ff_mult),
+                    ]
+                )
+            )
+
+    def forward(self, x):
+        if self.pos_emb is not None:
+            n, device = x.shape[1], x.device
+            pos_emb = self.pos_emb(torch.arange(n, device=device))
+            x = x + pos_emb
+
+        latents = self.latents.repeat(x.size(0), 1, 1)
+
+        x = self.proj_in(x)
+
+        if self.to_latents_from_mean_pooled_seq:
+            meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
+            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
+            latents = torch.cat((meanpooled_latents, latents), dim=-2)
+
+        for attn, ff in self.layers:
+            latents = attn(x, latents) + latents
+            latents = ff(latents) + latents
+
+        latents = self.proj_out(latents)
+        return self.norm_out(latents)
+
+
+def masked_mean(t, *, dim, mask=None):
+    if mask is None:
+        return t.mean(dim=dim)
+
+    denom = mask.sum(dim=dim, keepdim=True)
+    mask = rearrange(mask, "b n -> b n 1")
+    masked_t = t.masked_fill(~mask, 0.0)
+
+    return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
\ No newline at end of file
diff --git a/foleycrafter/models/adapters/transformer.py b/foleycrafter/models/adapters/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..16309b4d70ca9f77b46d14cf9c2a14650833330a
--- /dev/null
+++ b/foleycrafter/models/adapters/transformer.py
@@ -0,0 +1,327 @@
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from typing import Any, Optional, Tuple, Union
+
+class Attention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, hidden_size, num_attention_heads, attention_head_dim, attention_dropout=0.0):
+        super().__init__()
+        self.embed_dim = hidden_size
+        self.num_heads = num_attention_heads
+        self.head_dim = attention_head_dim
+
+        self.scale = self.head_dim**-0.5
+        self.dropout = attention_dropout
+
+        self.inner_dim = self.head_dim * self.num_heads
+
+        self.k_proj = nn.Linear(self.embed_dim, self.inner_dim)
+        self.v_proj = nn.Linear(self.embed_dim, self.inner_dim)
+        self.q_proj = nn.Linear(self.embed_dim, self.inner_dim)
+        self.out_proj = nn.Linear(self.inner_dim, self.embed_dim)
+
+    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        causal_attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        """Input shape: Batch x Time x Channel"""
+
+        bsz, tgt_len, embed_dim = hidden_states.size()
+
+        # get query proj
+        query_states = self.q_proj(hidden_states) * self.scale
+        key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+        value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+        key_states = key_states.view(*proj_shape)
+        value_states = value_states.view(*proj_shape)
+
+        src_len = key_states.size(1)
+        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+
+        # apply the causal_attention_mask first
+        if causal_attention_mask is not None:
+            if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+                    f" {causal_attention_mask.size()}"
+                )
+            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+        if output_attentions:
+            # this operation is a bit akward, but it's required to
+            # make sure that attn_weights keeps its gradient.
+            # In order to do so, attn_weights have to reshaped
+            # twice and have to be reused in the following
+            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+        else:
+            attn_weights_reshaped = None
+
+        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+        attn_output = torch.bmm(attn_probs, value_states)
+
+        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+        attn_output = attn_output.transpose(1, 2)
+        attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim)
+
+        attn_output = self.out_proj(attn_output)
+
+        return attn_output, attn_weights_reshaped
+
+
+class MLP(nn.Module):
+    def __init__(self, hidden_size, intermediate_size, mult=4):
+        super().__init__()
+        self.activation_fn = nn.SiLU() 
+        self.fc1 = nn.Linear(hidden_size, intermediate_size * mult)
+        self.fc2 = nn.Linear(intermediate_size * mult, hidden_size)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.fc1(hidden_states)
+        hidden_states = self.activation_fn(hidden_states)
+        hidden_states = self.fc2(hidden_states)
+        return hidden_states
+
+class Transformer(nn.Module):
+    def __init__(self, depth=12):
+        super().__init__()
+        self.layers = nn.ModuleList([TransformerBlock() for _ in range(depth)])
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: torch.Tensor=None,
+        causal_attention_mask: torch.Tensor=None,
+        output_attentions: Optional[bool] = False,
+    ) -> Tuple[torch.FloatTensor]:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+                `(config.encoder_attention_heads,)`.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+        """
+        for layer in self.layers:
+            hidden_states = layer(
+                hidden_states=hidden_states,
+                attention_mask=attention_mask,
+                causal_attention_mask=causal_attention_mask,
+                output_attentions=output_attentions,
+            )
+
+        return hidden_states
+
+class TransformerBlock(nn.Module):
+    def __init__(self, hidden_size=512, num_attention_heads=12, attention_head_dim=64, attention_dropout=0.0, dropout=0.0, eps=1e-5):
+        super().__init__()
+        self.embed_dim = hidden_size
+        self.self_attn = Attention(hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim)
+        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=eps)
+        self.mlp = MLP(hidden_size=hidden_size, intermediate_size=hidden_size)
+        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=eps)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: torch.Tensor=None,
+        causal_attention_mask: torch.Tensor=None,
+        output_attentions: Optional[bool] = False,
+    ) -> Tuple[torch.FloatTensor]:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+                `(config.encoder_attention_heads,)`.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+        """
+        residual = hidden_states
+
+        hidden_states = self.layer_norm1(hidden_states)
+        hidden_states, attn_weights = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            causal_attention_mask=causal_attention_mask,
+            output_attentions=output_attentions,
+        )
+        hidden_states = residual + hidden_states
+
+        residual = hidden_states
+        hidden_states = self.layer_norm2(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (attn_weights,)
+
+        return outputs[0]
+    
+class DiffusionTransformerBlock(nn.Module):
+    def __init__(self, hidden_size=512, num_attention_heads=12, attention_head_dim=64, attention_dropout=0.0, dropout=0.0, eps=1e-5):
+        super().__init__()
+        self.embed_dim = hidden_size
+        self.self_attn = Attention(hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim)
+        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=eps)
+        self.mlp = MLP(hidden_size=hidden_size, intermediate_size=hidden_size)
+        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=eps)
+        self.output_token = nn.Parameter(torch.randn(1, hidden_size))
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: torch.Tensor=None,
+        causal_attention_mask: torch.Tensor=None,
+        output_attentions: Optional[bool] = False,
+    ) -> Tuple[torch.FloatTensor]:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+                `(config.encoder_attention_heads,)`.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+        """
+        output_token = self.output_token.unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)
+        hidden_states = torch.cat([output_token, hidden_states], dim=1)
+        residual = hidden_states
+
+        hidden_states = self.layer_norm1(hidden_states)
+        hidden_states, attn_weights = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            causal_attention_mask=causal_attention_mask,
+            output_attentions=output_attentions,
+        )
+        hidden_states = residual + hidden_states
+
+        residual = hidden_states
+        hidden_states = self.layer_norm2(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (attn_weights,)
+
+        return outputs[0][:,0:1,...]
+
+class V2AMapperMLP(nn.Module):
+    def __init__(self, input_dim=512, output_dim=512, expansion_rate=4):
+        super().__init__()
+        self.linear = nn.Linear(input_dim, input_dim * expansion_rate)
+        self.silu = nn.SiLU()
+        self.layer_norm = nn.LayerNorm(input_dim * expansion_rate)
+        self.linear2 = nn.Linear(input_dim * expansion_rate, output_dim)
+    
+    def forward(self, x):
+
+        x = self.linear(x)
+        x = self.silu(x)
+        x = self.layer_norm(x)    
+        x = self.linear2(x)
+        
+        return x
+
+class ImageProjModel(torch.nn.Module):
+    """Projection Model"""
+
+    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
+        super().__init__()
+
+        self.cross_attention_dim = cross_attention_dim
+        self.clip_extra_context_tokens = clip_extra_context_tokens
+        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
+        self.norm = torch.nn.LayerNorm(cross_attention_dim)
+
+        self.zero_initialize_last_layer()
+
+    def zero_initialize_last_layer(module):
+        last_layer = None
+        for module_name, layer in module.named_modules():
+            if isinstance(layer, torch.nn.Linear):
+                last_layer = layer
+
+        if last_layer is not None:
+            last_layer.weight.data.zero_()
+            last_layer.bias.data.zero_()
+
+    def forward(self, image_embeds):
+        embeds = image_embeds
+        clip_extra_context_tokens = self.proj(embeds).reshape(
+            -1, self.clip_extra_context_tokens, self.cross_attention_dim
+        )
+        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
+        return clip_extra_context_tokens
+    
+class VisionAudioAdapter(torch.nn.Module):
+    def __init__(
+            self,
+            embedding_size=768,
+            expand_dim=4,
+            token_num=4,
+        ):
+        super().__init__()
+
+        self.mapper = V2AMapperMLP(
+            embedding_size, 
+            embedding_size, 
+            expansion_rate=expand_dim,
+        )
+
+        self.proj = ImageProjModel(
+            cross_attention_dim=embedding_size, 
+            clip_embeddings_dim=embedding_size,
+            clip_extra_context_tokens=token_num,
+        )
+
+    def forward(self, image_embeds):
+        image_embeds = self.mapper(image_embeds)
+        image_embeds = self.proj(image_embeds)
+        return image_embeds
+
+    
\ No newline at end of file
diff --git a/foleycrafter/models/adapters/utils.py b/foleycrafter/models/adapters/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..edd7879590a495d11f11d7a1265445705d8bfb72
--- /dev/null
+++ b/foleycrafter/models/adapters/utils.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+from PIL import Image
+
+attn_maps = {}
+def hook_fn(name):
+    def forward_hook(module, input, output):
+        if hasattr(module.processor, "attn_map"):
+            attn_maps[name] = module.processor.attn_map
+            del module.processor.attn_map
+
+    return forward_hook
+
+def register_cross_attention_hook(unet):
+    for name, module in unet.named_modules():
+        if name.split('.')[-1].startswith('attn2'):
+            module.register_forward_hook(hook_fn(name))
+
+    return unet
+
+def upscale(attn_map, target_size):
+    attn_map = torch.mean(attn_map, dim=0)
+    attn_map = attn_map.permute(1,0)
+    temp_size = None
+
+    for i in range(0,5):
+        scale = 2 ** i
+        if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
+            temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
+            break
+
+    assert temp_size is not None, "temp_size cannot is None"
+
+    attn_map = attn_map.view(attn_map.shape[0], *temp_size)
+
+    attn_map = F.interpolate(
+        attn_map.unsqueeze(0).to(dtype=torch.float32),
+        size=target_size,
+        mode='bilinear',
+        align_corners=False
+    )[0]
+
+    attn_map = torch.softmax(attn_map, dim=0)
+    return attn_map
+def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
+
+    idx = 0 if instance_or_negative else 1
+    net_attn_maps = []
+
+    for name, attn_map in attn_maps.items():
+        attn_map = attn_map.cpu() if detach else attn_map
+        attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
+        attn_map = upscale(attn_map, image_size) 
+        net_attn_maps.append(attn_map) 
+
+    net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
+
+    return net_attn_maps
+
+def attnmaps2images(net_attn_maps):
+
+    #total_attn_scores = 0
+    images = []
+
+    for attn_map in net_attn_maps:
+        attn_map = attn_map.cpu().numpy()
+        #total_attn_scores += attn_map.mean().item()
+
+        normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
+        normalized_attn_map = normalized_attn_map.astype(np.uint8)
+        #print("norm: ", normalized_attn_map.shape)
+        image = Image.fromarray(normalized_attn_map)
+
+        #image = fix_save_attn_map(attn_map)
+        images.append(image)
+
+    #print(total_attn_scores)
+    return images
+def is_torch2_available():
+    return hasattr(F, "scaled_dot_product_attention")
\ No newline at end of file
diff --git a/foleycrafter/models/auffusion/attention.py b/foleycrafter/models/auffusion/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc362a8718b8f79f7d1a875cf56cf70e8da17b6c
--- /dev/null
+++ b/foleycrafter/models/auffusion/attention.py
@@ -0,0 +1,669 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import USE_PEFT_BACKEND
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
+from diffusers.models.embeddings import SinusoidalPositionalEmbedding
+from diffusers.models.lora import LoRACompatibleLinear
+from diffusers.models.normalization import\
+        AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
+
+from foleycrafter.models.auffusion.attention_processor import Attention
+
+def _chunked_feed_forward(
+    ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
+):
+    # "feed_forward_chunk_size" can be used to save memory
+    if hidden_states.shape[chunk_dim] % chunk_size != 0:
+        raise ValueError(
+            f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+        )
+
+    num_chunks = hidden_states.shape[chunk_dim] // chunk_size
+    if lora_scale is None:
+        ff_output = torch.cat(
+            [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
+            dim=chunk_dim,
+        )
+    else:
+        # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
+        ff_output = torch.cat(
+            [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
+            dim=chunk_dim,
+        )
+
+    return ff_output
+
+
+@maybe_allow_in_graph
+class GatedSelfAttentionDense(nn.Module):
+    r"""
+    A gated self-attention dense layer that combines visual features and object features.
+
+    Parameters:
+        query_dim (`int`): The number of channels in the query.
+        context_dim (`int`): The number of channels in the context.
+        n_heads (`int`): The number of heads to use for attention.
+        d_head (`int`): The number of channels in each head.
+    """
+
+    def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
+        super().__init__()
+
+        # we need a linear projection since we need cat visual feature and obj feature
+        self.linear = nn.Linear(context_dim, query_dim)
+
+        self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
+        self.ff = FeedForward(query_dim, activation_fn="geglu")
+
+        self.norm1 = nn.LayerNorm(query_dim)
+        self.norm2 = nn.LayerNorm(query_dim)
+
+        self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
+        self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
+
+        self.enabled = True
+
+    def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
+        if not self.enabled:
+            return x
+
+        n_visual = x.shape[1]
+        objs = self.linear(objs)
+
+        x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
+        x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
+
+        return x
+
+
+@maybe_allow_in_graph
+class BasicTransformerBlock(nn.Module):
+    r"""
+    A basic Transformer block.
+
+    Parameters:
+        dim (`int`): The number of channels in the input and output.
+        num_attention_heads (`int`): The number of heads to use for multi-head attention.
+        attention_head_dim (`int`): The number of channels in each head.
+        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+        num_embeds_ada_norm (:
+            obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+        attention_bias (:
+            obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+        only_cross_attention (`bool`, *optional*):
+            Whether to use only cross-attention layers. In this case two cross attention layers are used.
+        double_self_attention (`bool`, *optional*):
+            Whether to use two self-attention layers. In this case no cross attention layers are used.
+        upcast_attention (`bool`, *optional*):
+            Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+        norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+            Whether to use learnable elementwise affine parameters for normalization.
+        norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+            The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+        final_dropout (`bool` *optional*, defaults to False):
+            Whether to apply a final dropout after the last feed-forward layer.
+        attention_type (`str`, *optional*, defaults to `"default"`):
+            The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+        positional_embeddings (`str`, *optional*, defaults to `None`):
+            The type of positional embeddings to apply to.
+        num_positional_embeddings (`int`, *optional*, defaults to `None`):
+            The maximum number of positional embeddings to apply.
+    """
+
+    def __init__(
+        self,
+        dim: int,
+        num_attention_heads: int,
+        attention_head_dim: int,
+        dropout=0.0,
+        cross_attention_dim: Optional[int] = None,
+        activation_fn: str = "geglu",
+        num_embeds_ada_norm: Optional[int] = None,
+        attention_bias: bool = False,
+        only_cross_attention: bool = False,
+        double_self_attention: bool = False,
+        upcast_attention: bool = False,
+        norm_elementwise_affine: bool = True,
+        norm_type: str = "layer_norm",  # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
+        norm_eps: float = 1e-5,
+        final_dropout: bool = False,
+        attention_type: str = "default",
+        positional_embeddings: Optional[str] = None,
+        num_positional_embeddings: Optional[int] = None,
+        ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
+        ada_norm_bias: Optional[int] = None,
+        ff_inner_dim: Optional[int] = None,
+        ff_bias: bool = True,
+        attention_out_bias: bool = True,
+    ):
+        super().__init__()
+        self.only_cross_attention = only_cross_attention
+
+        self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+        self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+        self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
+        self.use_layer_norm = norm_type == "layer_norm"
+        self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
+
+        if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+            raise ValueError(
+                f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+                f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+            )
+
+        if positional_embeddings and (num_positional_embeddings is None):
+            raise ValueError(
+                "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
+            )
+
+        if positional_embeddings == "sinusoidal":
+            self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+        else:
+            self.pos_embed = None
+
+        # Define 3 blocks. Each block has its own normalization layer.
+        # 1. Self-Attn
+        if self.use_ada_layer_norm:
+            self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+        elif self.use_ada_layer_norm_zero:
+            self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+        elif self.use_ada_layer_norm_continuous:
+            self.norm1 = AdaLayerNormContinuous(
+                dim,
+                ada_norm_continous_conditioning_embedding_dim,
+                norm_elementwise_affine,
+                norm_eps,
+                ada_norm_bias,
+                "rms_norm",
+            )
+        else:
+            self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+        self.attn1 = Attention(
+            query_dim=dim,
+            heads=num_attention_heads,
+            dim_head=attention_head_dim,
+            dropout=dropout,
+            bias=attention_bias,
+            cross_attention_dim=cross_attention_dim if (only_cross_attention and not double_self_attention) else None,
+            upcast_attention=upcast_attention,
+            out_bias=attention_out_bias,
+        )
+
+        # 2. Cross-Attn
+        if cross_attention_dim is not None or double_self_attention:
+            # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+            # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+            # the second cross attention block.
+            if self.use_ada_layer_norm:
+                self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
+            elif self.use_ada_layer_norm_continuous:
+                self.norm2 = AdaLayerNormContinuous(
+                    dim,
+                    ada_norm_continous_conditioning_embedding_dim,
+                    norm_elementwise_affine,
+                    norm_eps,
+                    ada_norm_bias,
+                    "rms_norm",
+                )
+            else:
+                self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
+
+            self.attn2 = Attention(
+                query_dim=dim,
+                cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+                heads=num_attention_heads,
+                dim_head=attention_head_dim,
+                dropout=dropout,
+                bias=attention_bias,
+                upcast_attention=upcast_attention,
+                out_bias=attention_out_bias,
+            )  # is self-attn if encoder_hidden_states is none
+        else:
+            self.norm2 = None
+            self.attn2 = None
+
+        # 3. Feed-forward
+        if self.use_ada_layer_norm_continuous:
+            self.norm3 = AdaLayerNormContinuous(
+                dim,
+                ada_norm_continous_conditioning_embedding_dim,
+                norm_elementwise_affine,
+                norm_eps,
+                ada_norm_bias,
+                "layer_norm",
+            )
+        elif not self.use_ada_layer_norm_single:
+            self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
+
+        self.ff = FeedForward(
+            dim,
+            dropout=dropout,
+            activation_fn=activation_fn,
+            final_dropout=final_dropout,
+            inner_dim=ff_inner_dim,
+            bias=ff_bias,
+        )
+
+        # 4. Fuser
+        if attention_type == "gated" or attention_type == "gated-text-image":
+            self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
+
+        # 5. Scale-shift for PixArt-Alpha.
+        if self.use_ada_layer_norm_single:
+            self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+        # let chunk size default to None
+        self._chunk_size = None
+        self._chunk_dim = 0
+
+    def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
+        # Sets chunk feed-forward
+        self._chunk_size = chunk_size
+        self._chunk_dim = dim
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        timestep: Optional[torch.LongTensor] = None,
+        cross_attention_kwargs: Dict[str, Any] = None,
+        class_labels: Optional[torch.LongTensor] = None,
+        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+    ) -> torch.FloatTensor:
+        # Notice that normalization is always applied before the real computation in the following blocks.
+        # 0. Self-Attention
+        batch_size = hidden_states.shape[0]
+
+        if self.use_ada_layer_norm:
+            norm_hidden_states = self.norm1(hidden_states, timestep)
+        elif self.use_ada_layer_norm_zero:
+            norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+                hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+            )
+        elif self.use_layer_norm:
+            norm_hidden_states = self.norm1(hidden_states)
+        elif self.use_ada_layer_norm_continuous:
+            norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
+        elif self.use_ada_layer_norm_single:
+            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+                self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+            ).chunk(6, dim=1)
+            norm_hidden_states = self.norm1(hidden_states)
+            norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+            norm_hidden_states = norm_hidden_states.squeeze(1)
+        else:
+            raise ValueError("Incorrect norm used")
+
+        if self.pos_embed is not None:
+            norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+        # 1. Retrieve lora scale.
+        lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+        # 2. Prepare GLIGEN inputs
+        cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+        gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+        attn_output = self.attn1(
+            norm_hidden_states,
+            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+            attention_mask=attention_mask,
+            **cross_attention_kwargs,
+        )
+        if self.use_ada_layer_norm_zero:
+            attn_output = gate_msa.unsqueeze(1) * attn_output
+        elif self.use_ada_layer_norm_single:
+            attn_output = gate_msa * attn_output
+
+        hidden_states = attn_output + hidden_states
+        if hidden_states.ndim == 4:
+            hidden_states = hidden_states.squeeze(1)
+
+        # 2.5 GLIGEN Control
+        if gligen_kwargs is not None:
+            hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+
+        # 3. Cross-Attention
+        if self.attn2 is not None:
+            if self.use_ada_layer_norm:
+                norm_hidden_states = self.norm2(hidden_states, timestep)
+            elif self.use_ada_layer_norm_zero or self.use_layer_norm:
+                norm_hidden_states = self.norm2(hidden_states)
+            elif self.use_ada_layer_norm_single:
+                # For PixArt norm2 isn't applied here:
+                # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
+                norm_hidden_states = hidden_states
+            elif self.use_ada_layer_norm_continuous:
+                norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
+            else:
+                raise ValueError("Incorrect norm")
+
+            if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
+                norm_hidden_states = self.pos_embed(norm_hidden_states)
+            
+            attn_output = self.attn2(
+                norm_hidden_states,
+                encoder_hidden_states=encoder_hidden_states,
+                attention_mask=encoder_attention_mask,
+                **cross_attention_kwargs,
+            )
+            hidden_states = attn_output + hidden_states
+
+        # 4. Feed-forward
+        if self.use_ada_layer_norm_continuous:
+            norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
+        elif not self.use_ada_layer_norm_single:
+            norm_hidden_states = self.norm3(hidden_states)
+
+        if self.use_ada_layer_norm_zero:
+            norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+        if self.use_ada_layer_norm_single:
+            norm_hidden_states = self.norm2(hidden_states)
+            norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+        if self._chunk_size is not None:
+            # "feed_forward_chunk_size" can be used to save memory
+            ff_output = _chunked_feed_forward(
+                self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
+            )
+        else:
+            ff_output = self.ff(norm_hidden_states, scale=lora_scale)
+
+        if self.use_ada_layer_norm_zero:
+            ff_output = gate_mlp.unsqueeze(1) * ff_output
+        elif self.use_ada_layer_norm_single:
+            ff_output = gate_mlp * ff_output
+
+        hidden_states = ff_output + hidden_states
+        if hidden_states.ndim == 4:
+            hidden_states = hidden_states.squeeze(1)
+
+        return hidden_states
+
+
+@maybe_allow_in_graph
+class TemporalBasicTransformerBlock(nn.Module):
+    r"""
+    A basic Transformer block for video like data.
+
+    Parameters:
+        dim (`int`): The number of channels in the input and output.
+        time_mix_inner_dim (`int`): The number of channels for temporal attention.
+        num_attention_heads (`int`): The number of heads to use for multi-head attention.
+        attention_head_dim (`int`): The number of channels in each head.
+        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+    """
+
+    def __init__(
+        self,
+        dim: int,
+        time_mix_inner_dim: int,
+        num_attention_heads: int,
+        attention_head_dim: int,
+        cross_attention_dim: Optional[int] = None,
+    ):
+        super().__init__()
+        self.is_res = dim == time_mix_inner_dim
+
+        self.norm_in = nn.LayerNorm(dim)
+
+        # Define 3 blocks. Each block has its own normalization layer.
+        # 1. Self-Attn
+        self.norm_in = nn.LayerNorm(dim)
+        self.ff_in = FeedForward(
+            dim,
+            dim_out=time_mix_inner_dim,
+            activation_fn="geglu",
+        )
+
+        self.norm1 = nn.LayerNorm(time_mix_inner_dim)
+        self.attn1 = Attention(
+            query_dim=time_mix_inner_dim,
+            heads=num_attention_heads,
+            dim_head=attention_head_dim,
+            cross_attention_dim=None,
+        )
+
+        # 2. Cross-Attn
+        if cross_attention_dim is not None:
+            # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+            # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+            # the second cross attention block.
+            self.norm2 = nn.LayerNorm(time_mix_inner_dim)
+            self.attn2 = Attention(
+                query_dim=time_mix_inner_dim,
+                cross_attention_dim=cross_attention_dim,
+                heads=num_attention_heads,
+                dim_head=attention_head_dim,
+            )  # is self-attn if encoder_hidden_states is none
+        else:
+            self.norm2 = None
+            self.attn2 = None
+
+        # 3. Feed-forward
+        self.norm3 = nn.LayerNorm(time_mix_inner_dim)
+        self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
+
+        # let chunk size default to None
+        self._chunk_size = None
+        self._chunk_dim = None
+
+    def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
+        # Sets chunk feed-forward
+        self._chunk_size = chunk_size
+        # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
+        self._chunk_dim = 1
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        num_frames: int,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+    ) -> torch.FloatTensor:
+        # Notice that normalization is always applied before the real computation in the following blocks.
+        # 0. Self-Attention
+        batch_size = hidden_states.shape[0]
+
+        batch_frames, seq_length, channels = hidden_states.shape
+        batch_size = batch_frames // num_frames
+
+        hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
+        hidden_states = hidden_states.permute(0, 2, 1, 3)
+        hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
+
+        residual = hidden_states
+        hidden_states = self.norm_in(hidden_states)
+
+        if self._chunk_size is not None:
+            hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
+        else:
+            hidden_states = self.ff_in(hidden_states)
+
+        if self.is_res:
+            hidden_states = hidden_states + residual
+
+        norm_hidden_states = self.norm1(hidden_states)
+        attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
+        hidden_states = attn_output + hidden_states
+
+        # 3. Cross-Attention
+        if self.attn2 is not None:
+            norm_hidden_states = self.norm2(hidden_states)
+            attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
+            hidden_states = attn_output + hidden_states
+
+        # 4. Feed-forward
+        norm_hidden_states = self.norm3(hidden_states)
+
+        if self._chunk_size is not None:
+            ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
+        else:
+            ff_output = self.ff(norm_hidden_states)
+
+        if self.is_res:
+            hidden_states = ff_output + hidden_states
+        else:
+            hidden_states = ff_output
+
+        hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
+        hidden_states = hidden_states.permute(0, 2, 1, 3)
+        hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
+
+        return hidden_states
+
+
+class SkipFFTransformerBlock(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        num_attention_heads: int,
+        attention_head_dim: int,
+        kv_input_dim: int,
+        kv_input_dim_proj_use_bias: bool,
+        dropout=0.0,
+        cross_attention_dim: Optional[int] = None,
+        attention_bias: bool = False,
+        attention_out_bias: bool = True,
+    ):
+        super().__init__()
+        if kv_input_dim != dim:
+            self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
+        else:
+            self.kv_mapper = None
+
+        self.norm1 = RMSNorm(dim, 1e-06)
+
+        self.attn1 = Attention(
+            query_dim=dim,
+            heads=num_attention_heads,
+            dim_head=attention_head_dim,
+            dropout=dropout,
+            bias=attention_bias,
+            cross_attention_dim=cross_attention_dim,
+            out_bias=attention_out_bias,
+        )
+
+        self.norm2 = RMSNorm(dim, 1e-06)
+
+        self.attn2 = Attention(
+            query_dim=dim,
+            cross_attention_dim=cross_attention_dim,
+            heads=num_attention_heads,
+            dim_head=attention_head_dim,
+            dropout=dropout,
+            bias=attention_bias,
+            out_bias=attention_out_bias,
+        )
+
+    def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
+        cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+
+        if self.kv_mapper is not None:
+            encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
+
+        norm_hidden_states = self.norm1(hidden_states)
+
+        attn_output = self.attn1(
+            norm_hidden_states,
+            encoder_hidden_states=encoder_hidden_states,
+            **cross_attention_kwargs,
+        )
+
+        hidden_states = attn_output + hidden_states
+
+        norm_hidden_states = self.norm2(hidden_states)
+
+        attn_output = self.attn2(
+            norm_hidden_states,
+            encoder_hidden_states=encoder_hidden_states,
+            **cross_attention_kwargs,
+        )
+
+        hidden_states = attn_output + hidden_states
+
+        return hidden_states
+
+
+class FeedForward(nn.Module):
+    r"""
+    A feed-forward layer.
+
+    Parameters:
+        dim (`int`): The number of channels in the input.
+        dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+        mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+        final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+        bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+    """
+
+    def __init__(
+        self,
+        dim: int,
+        dim_out: Optional[int] = None,
+        mult: int = 4,
+        dropout: float = 0.0,
+        activation_fn: str = "geglu",
+        final_dropout: bool = False,
+        inner_dim=None,
+        bias: bool = True,
+    ):
+        super().__init__()
+        if inner_dim is None:
+            inner_dim = int(dim * mult)
+        dim_out = dim_out if dim_out is not None else dim
+        linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
+
+        if activation_fn == "gelu":
+            act_fn = GELU(dim, inner_dim, bias=bias)
+        if activation_fn == "gelu-approximate":
+            act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
+        elif activation_fn == "geglu":
+            act_fn = GEGLU(dim, inner_dim, bias=bias)
+        elif activation_fn == "geglu-approximate":
+            act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
+
+        self.net = nn.ModuleList([])
+        # project in
+        self.net.append(act_fn)
+        # project dropout
+        self.net.append(nn.Dropout(dropout))
+        # project out
+        self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
+        # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+        if final_dropout:
+            self.net.append(nn.Dropout(dropout))
+
+    def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
+        compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
+        for module in self.net:
+            if isinstance(module, compatible_cls):
+                hidden_states = module(hidden_states, scale)
+            else:
+                hidden_states = module(hidden_states)
+        return hidden_states
\ No newline at end of file
diff --git a/foleycrafter/models/auffusion/attention_processor.py b/foleycrafter/models/auffusion/attention_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c46ac9a8773a2a535a758e9cf5eddc9c73f04df6
--- /dev/null
+++ b/foleycrafter/models/auffusion/attention_processor.py
@@ -0,0 +1,2682 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from importlib import import_module
+from typing import Callable, Optional, Union, List
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+import math
+
+from einops import rearrange
+
+from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.lora import LoRACompatibleLinear, LoRALinearLayer
+
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+if is_xformers_available():
+    import xformers
+    import xformers.ops
+else:
+    xformers = None
+
+
+@maybe_allow_in_graph
+class Attention(nn.Module):
+    r"""
+    A cross attention layer.
+
+    Parameters:
+        query_dim (`int`):
+            The number of channels in the query.
+        cross_attention_dim (`int`, *optional*):
+            The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+        heads (`int`,  *optional*, defaults to 8):
+            The number of heads to use for multi-head attention.
+        dim_head (`int`,  *optional*, defaults to 64):
+            The number of channels in each head.
+        dropout (`float`, *optional*, defaults to 0.0):
+            The dropout probability to use.
+        bias (`bool`, *optional*, defaults to False):
+            Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+        upcast_attention (`bool`, *optional*, defaults to False):
+            Set to `True` to upcast the attention computation to `float32`.
+        upcast_softmax (`bool`, *optional*, defaults to False):
+            Set to `True` to upcast the softmax computation to `float32`.
+        cross_attention_norm (`str`, *optional*, defaults to `None`):
+            The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
+        cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
+            The number of groups to use for the group norm in the cross attention.
+        added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+            The number of channels to use for the added key and value projections. If `None`, no projection is used.
+        norm_num_groups (`int`, *optional*, defaults to `None`):
+            The number of groups to use for the group norm in the attention.
+        spatial_norm_dim (`int`, *optional*, defaults to `None`):
+            The number of channels to use for the spatial normalization.
+        out_bias (`bool`, *optional*, defaults to `True`):
+            Set to `True` to use a bias in the output linear layer.
+        scale_qk (`bool`, *optional*, defaults to `True`):
+            Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
+        only_cross_attention (`bool`, *optional*, defaults to `False`):
+            Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
+            `added_kv_proj_dim` is not `None`.
+        eps (`float`, *optional*, defaults to 1e-5):
+            An additional value added to the denominator in group normalization that is used for numerical stability.
+        rescale_output_factor (`float`, *optional*, defaults to 1.0):
+            A factor to rescale the output by dividing it with this value.
+        residual_connection (`bool`, *optional*, defaults to `False`):
+            Set to `True` to add the residual connection to the output.
+        _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
+            Set to `True` if the attention block is loaded from a deprecated state dict.
+        processor (`AttnProcessor`, *optional*, defaults to `None`):
+            The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
+            `AttnProcessor` otherwise.
+    """
+
+    def __init__(
+        self,
+        query_dim: int,
+        cross_attention_dim: Optional[int] = None,
+        heads: int = 8,
+        dim_head: int = 64,
+        dropout: float = 0.0,
+        bias: bool = False,
+        upcast_attention: bool = False,
+        upcast_softmax: bool = False,
+        cross_attention_norm: Optional[str] = None,
+        cross_attention_norm_num_groups: int = 32,
+        added_kv_proj_dim: Optional[int] = None,
+        norm_num_groups: Optional[int] = None,
+        spatial_norm_dim: Optional[int] = None,
+        out_bias: bool = True,
+        scale_qk: bool = True,
+        only_cross_attention: bool = False,
+        eps: float = 1e-5,
+        rescale_output_factor: float = 1.0,
+        residual_connection: bool = False,
+        _from_deprecated_attn_block: bool = False,
+        processor: Optional["AttnProcessor"] = None,
+        out_dim: int = None,
+    ):
+        super().__init__()
+        self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+        self.query_dim = query_dim
+        self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+        self.upcast_attention = upcast_attention
+        self.upcast_softmax = upcast_softmax
+        self.rescale_output_factor = rescale_output_factor
+        self.residual_connection = residual_connection
+        self.dropout = dropout
+        self.fused_projections = False
+        self.out_dim = out_dim if out_dim is not None else query_dim
+
+        # we make use of this private variable to know whether this class is loaded
+        # with an deprecated state dict so that we can convert it on the fly
+        self._from_deprecated_attn_block = _from_deprecated_attn_block
+
+        self.scale_qk = scale_qk
+        self.scale = dim_head**-0.5 if self.scale_qk else 1.0
+
+        self.heads = out_dim // dim_head if out_dim is not None else heads
+        # for slice_size > 0 the attention score computation
+        # is split across the batch axis to save memory
+        # You can set slice_size with `set_attention_slice`
+        self.sliceable_head_dim = heads
+
+        self.added_kv_proj_dim = added_kv_proj_dim
+        self.only_cross_attention = only_cross_attention
+
+        if self.added_kv_proj_dim is None and self.only_cross_attention:
+            raise ValueError(
+                "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
+            )
+
+        if norm_num_groups is not None:
+            self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
+        else:
+            self.group_norm = None
+
+        if spatial_norm_dim is not None:
+            self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
+        else:
+            self.spatial_norm = None
+
+        if cross_attention_norm is None:
+            self.norm_cross = None
+        elif cross_attention_norm == "layer_norm":
+            self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
+        elif cross_attention_norm == "group_norm":
+            if self.added_kv_proj_dim is not None:
+                # The given `encoder_hidden_states` are initially of shape
+                # (batch_size, seq_len, added_kv_proj_dim) before being projected
+                # to (batch_size, seq_len, cross_attention_dim). The norm is applied
+                # before the projection, so we need to use `added_kv_proj_dim` as
+                # the number of channels for the group norm.
+                norm_cross_num_channels = added_kv_proj_dim
+            else:
+                norm_cross_num_channels = self.cross_attention_dim
+
+            self.norm_cross = nn.GroupNorm(
+                num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
+            )
+        else:
+            raise ValueError(
+                f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
+            )
+
+        if USE_PEFT_BACKEND:
+            linear_cls = nn.Linear
+        else:
+            linear_cls = LoRACompatibleLinear
+
+        self.linear_cls = linear_cls
+        self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
+
+        if not self.only_cross_attention:
+            # only relevant for the `AddedKVProcessor` classes
+            self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
+            self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
+        else:
+            self.to_k = None
+            self.to_v = None
+
+        if self.added_kv_proj_dim is not None:
+            self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
+            self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
+
+        self.to_out = nn.ModuleList([])
+        self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
+        self.to_out.append(nn.Dropout(dropout))
+
+        # set attention processor
+        # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+        # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+        # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+        if processor is None:
+            processor = (
+                AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+            )
+        self.set_processor(processor)
+
+    def set_use_memory_efficient_attention_xformers(
+        self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
+    ) -> None:
+        r"""
+        Set whether to use memory efficient attention from `xformers` or not.
+
+        Args:
+            use_memory_efficient_attention_xformers (`bool`):
+                Whether to use memory efficient attention from `xformers` or not.
+            attention_op (`Callable`, *optional*):
+                The attention operation to use. Defaults to `None` which uses the default attention operation from
+                `xformers`.
+        """
+        is_lora = hasattr(self, "processor") and isinstance(
+            self.processor,
+            LORA_ATTENTION_PROCESSORS,
+        )
+        is_custom_diffusion = hasattr(self, "processor") and isinstance(
+            self.processor,
+            (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
+        )
+        is_added_kv_processor = hasattr(self, "processor") and isinstance(
+            self.processor,
+            (
+                AttnAddedKVProcessor,
+                AttnAddedKVProcessor2_0,
+                SlicedAttnAddedKVProcessor,
+                XFormersAttnAddedKVProcessor,
+                LoRAAttnAddedKVProcessor,
+            ),
+        )
+
+        if use_memory_efficient_attention_xformers:
+            if is_added_kv_processor and (is_lora or is_custom_diffusion):
+                raise NotImplementedError(
+                    f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
+                )
+            if not is_xformers_available():
+                raise ModuleNotFoundError(
+                    (
+                        "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+                        " xformers"
+                    ),
+                    name="xformers",
+                )
+            elif not torch.cuda.is_available():
+                raise ValueError(
+                    "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
+                    " only available for GPU "
+                )
+            else:
+                try:
+                    # Make sure we can run the memory efficient attention
+                    _ = xformers.ops.memory_efficient_attention(
+                        torch.randn((1, 2, 40), device="cuda"),
+                        torch.randn((1, 2, 40), device="cuda"),
+                        torch.randn((1, 2, 40), device="cuda"),
+                    )
+                except Exception as e:
+                    raise e
+
+            if is_lora:
+                # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
+                # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
+                processor = LoRAXFormersAttnProcessor(
+                    hidden_size=self.processor.hidden_size,
+                    cross_attention_dim=self.processor.cross_attention_dim,
+                    rank=self.processor.rank,
+                    attention_op=attention_op,
+                )
+                processor.load_state_dict(self.processor.state_dict())
+                processor.to(self.processor.to_q_lora.up.weight.device)
+            elif is_custom_diffusion:
+                processor = CustomDiffusionXFormersAttnProcessor(
+                    train_kv=self.processor.train_kv,
+                    train_q_out=self.processor.train_q_out,
+                    hidden_size=self.processor.hidden_size,
+                    cross_attention_dim=self.processor.cross_attention_dim,
+                    attention_op=attention_op,
+                )
+                processor.load_state_dict(self.processor.state_dict())
+                if hasattr(self.processor, "to_k_custom_diffusion"):
+                    processor.to(self.processor.to_k_custom_diffusion.weight.device)
+            elif is_added_kv_processor:
+                # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
+                # which uses this type of cross attention ONLY because the attention mask of format
+                # [0, ..., -10.000, ..., 0, ...,] is not supported
+                # throw warning
+                logger.info(
+                    "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
+                )
+                processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
+            else:
+                processor = XFormersAttnProcessor(attention_op=attention_op)
+        else:
+            if is_lora:
+                attn_processor_class = (
+                    LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
+                )
+                processor = attn_processor_class(
+                    hidden_size=self.processor.hidden_size,
+                    cross_attention_dim=self.processor.cross_attention_dim,
+                    rank=self.processor.rank,
+                )
+                processor.load_state_dict(self.processor.state_dict())
+                processor.to(self.processor.to_q_lora.up.weight.device)
+            elif is_custom_diffusion:
+                attn_processor_class = (
+                    CustomDiffusionAttnProcessor2_0
+                    if hasattr(F, "scaled_dot_product_attention")
+                    else CustomDiffusionAttnProcessor
+                )
+                processor = attn_processor_class(
+                    train_kv=self.processor.train_kv,
+                    train_q_out=self.processor.train_q_out,
+                    hidden_size=self.processor.hidden_size,
+                    cross_attention_dim=self.processor.cross_attention_dim,
+                )
+                processor.load_state_dict(self.processor.state_dict())
+                if hasattr(self.processor, "to_k_custom_diffusion"):
+                    processor.to(self.processor.to_k_custom_diffusion.weight.device)
+            else:
+                # set attention processor
+                # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+                # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+                # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+                processor = (
+                    AttnProcessor2_0()
+                    if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
+                    else AttnProcessor()
+                )
+
+        self.set_processor(processor)
+
+    def set_attention_slice(self, slice_size: int) -> None:
+        r"""
+        Set the slice size for attention computation.
+
+        Args:
+            slice_size (`int`):
+                The slice size for attention computation.
+        """
+        if slice_size is not None and slice_size > self.sliceable_head_dim:
+            raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+        if slice_size is not None and self.added_kv_proj_dim is not None:
+            processor = SlicedAttnAddedKVProcessor(slice_size)
+        elif slice_size is not None:
+            processor = SlicedAttnProcessor(slice_size)
+        elif self.added_kv_proj_dim is not None:
+            processor = AttnAddedKVProcessor()
+        else:
+            # set attention processor
+            # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+            # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+            # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+            processor = (
+                AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+            )
+
+        self.set_processor(processor)
+
+    def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None:
+        r"""
+        Set the attention processor to use.
+
+        Args:
+            processor (`AttnProcessor`):
+                The attention processor to use.
+            _remove_lora (`bool`, *optional*, defaults to `False`):
+                Set to `True` to remove LoRA layers from the model.
+        """
+        if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
+            deprecate(
+                "set_processor to offload LoRA",
+                "0.26.0",
+                "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
+            )
+            # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
+            # We need to remove all LoRA layers
+            # Don't forget to remove ALL `_remove_lora` from the codebase
+            for module in self.modules():
+                if hasattr(module, "set_lora_layer"):
+                    module.set_lora_layer(None)
+
+        # if current processor is in `self._modules` and if passed `processor` is not, we need to
+        # pop `processor` from `self._modules`
+        if (
+            hasattr(self, "processor")
+            and isinstance(self.processor, torch.nn.Module)
+            and not isinstance(processor, torch.nn.Module)
+        ):
+            logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
+            self._modules.pop("processor")
+
+        self.processor = processor
+
+    def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
+        r"""
+        Get the attention processor in use.
+
+        Args:
+            return_deprecated_lora (`bool`, *optional*, defaults to `False`):
+                Set to `True` to return the deprecated LoRA attention processor.
+
+        Returns:
+            "AttentionProcessor": The attention processor in use.
+        """
+        if not return_deprecated_lora:
+            return self.processor
+
+        # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
+        # serialization format for LoRA Attention Processors. It should be deleted once the integration
+        # with PEFT is completed.
+        is_lora_activated = {
+            name: module.lora_layer is not None
+            for name, module in self.named_modules()
+            if hasattr(module, "lora_layer")
+        }
+
+        # 1. if no layer has a LoRA activated we can return the processor as usual
+        if not any(is_lora_activated.values()):
+            return self.processor
+
+        # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
+        is_lora_activated.pop("add_k_proj", None)
+        is_lora_activated.pop("add_v_proj", None)
+        # 2. else it is not posssible that only some layers have LoRA activated
+        if not all(is_lora_activated.values()):
+            raise ValueError(
+                f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
+            )
+
+        # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
+        non_lora_processor_cls_name = self.processor.__class__.__name__
+        lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
+
+        hidden_size = self.inner_dim
+
+        # now create a LoRA attention processor from the LoRA layers
+        if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
+            kwargs = {
+                "cross_attention_dim": self.cross_attention_dim,
+                "rank": self.to_q.lora_layer.rank,
+                "network_alpha": self.to_q.lora_layer.network_alpha,
+                "q_rank": self.to_q.lora_layer.rank,
+                "q_hidden_size": self.to_q.lora_layer.out_features,
+                "k_rank": self.to_k.lora_layer.rank,
+                "k_hidden_size": self.to_k.lora_layer.out_features,
+                "v_rank": self.to_v.lora_layer.rank,
+                "v_hidden_size": self.to_v.lora_layer.out_features,
+                "out_rank": self.to_out[0].lora_layer.rank,
+                "out_hidden_size": self.to_out[0].lora_layer.out_features,
+            }
+
+            if hasattr(self.processor, "attention_op"):
+                kwargs["attention_op"] = self.processor.attention_op
+
+            lora_processor = lora_processor_cls(hidden_size, **kwargs)
+            lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
+            lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
+            lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
+            lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
+        elif lora_processor_cls == LoRAAttnAddedKVProcessor:
+            lora_processor = lora_processor_cls(
+                hidden_size,
+                cross_attention_dim=self.add_k_proj.weight.shape[0],
+                rank=self.to_q.lora_layer.rank,
+                network_alpha=self.to_q.lora_layer.network_alpha,
+            )
+            lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
+            lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
+            lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
+            lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
+
+            # only save if used
+            if self.add_k_proj.lora_layer is not None:
+                lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
+                lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
+            else:
+                lora_processor.add_k_proj_lora = None
+                lora_processor.add_v_proj_lora = None
+        else:
+            raise ValueError(f"{lora_processor_cls} does not exist.")
+
+        return lora_processor
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        **cross_attention_kwargs,
+    ) -> torch.Tensor:
+        r"""
+        The forward method of the `Attention` class.
+
+        Args:
+            hidden_states (`torch.Tensor`):
+                The hidden states of the query.
+            encoder_hidden_states (`torch.Tensor`, *optional*):
+                The hidden states of the encoder.
+            attention_mask (`torch.Tensor`, *optional*):
+                The attention mask to use. If `None`, no mask is applied.
+            **cross_attention_kwargs:
+                Additional keyword arguments to pass along to the cross attention.
+
+        Returns:
+            `torch.Tensor`: The output of the attention layer.
+        """
+        # The `Attention` class can call different attention processors / attention functions
+        # here we simply pass along all tensors to the selected processor class
+        # For standard processors that are defined here, `**cross_attention_kwargs` is empty
+        return self.processor(
+            self,
+            hidden_states,
+            encoder_hidden_states=encoder_hidden_states,
+            attention_mask=attention_mask,
+            **cross_attention_kwargs,
+        )
+
+    def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
+        r"""
+        Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
+        is the number of heads initialized while constructing the `Attention` class.
+
+        Args:
+            tensor (`torch.Tensor`): The tensor to reshape.
+
+        Returns:
+            `torch.Tensor`: The reshaped tensor.
+        """
+        head_size = self.heads
+        batch_size, seq_len, dim = tensor.shape
+        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+        return tensor
+
+    def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
+        r"""
+        Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
+        the number of heads initialized while constructing the `Attention` class.
+
+        Args:
+            tensor (`torch.Tensor`): The tensor to reshape.
+            out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
+                reshaped to `[batch_size * heads, seq_len, dim // heads]`.
+
+        Returns:
+            `torch.Tensor`: The reshaped tensor.
+        """
+        head_size = self.heads
+        batch_size, seq_len, dim = tensor.shape
+        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+        tensor = tensor.permute(0, 2, 1, 3)
+
+        if out_dim == 3:
+            tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
+
+        return tensor
+
+    def get_attention_scores(
+        self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
+    ) -> torch.Tensor:
+        r"""
+        Compute the attention scores.
+
+        Args:
+            query (`torch.Tensor`): The query tensor.
+            key (`torch.Tensor`): The key tensor.
+            attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
+
+        Returns:
+            `torch.Tensor`: The attention probabilities/scores.
+        """
+        dtype = query.dtype
+        if self.upcast_attention:
+            query = query.float()
+            key = key.float()
+
+        if attention_mask is None:
+            baddbmm_input = torch.empty(
+                query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
+            )
+            beta = 0
+        else:
+            baddbmm_input = attention_mask
+            beta = 1
+
+        attention_scores = torch.baddbmm(
+            baddbmm_input,
+            query,
+            key.transpose(-1, -2),
+            beta=beta,
+            alpha=self.scale,
+        )
+        del baddbmm_input
+
+        if self.upcast_softmax:
+            attention_scores = attention_scores.float()
+
+        attention_probs = attention_scores.softmax(dim=-1)
+        del attention_scores
+
+        attention_probs = attention_probs.to(dtype)
+
+        return attention_probs
+
+    def prepare_attention_mask(
+        self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
+    ) -> torch.Tensor:
+        r"""
+        Prepare the attention mask for the attention computation.
+
+        Args:
+            attention_mask (`torch.Tensor`):
+                The attention mask to prepare.
+            target_length (`int`):
+                The target length of the attention mask. This is the length of the attention mask after padding.
+            batch_size (`int`):
+                The batch size, which is used to repeat the attention mask.
+            out_dim (`int`, *optional*, defaults to `3`):
+                The output dimension of the attention mask. Can be either `3` or `4`.
+
+        Returns:
+            `torch.Tensor`: The prepared attention mask.
+        """
+        head_size = self.heads
+        if attention_mask is None:
+            return attention_mask
+
+        current_length: int = attention_mask.shape[-1]
+        if current_length != target_length:
+            if attention_mask.device.type == "mps":
+                # HACK: MPS: Does not support padding by greater than dimension of input tensor.
+                # Instead, we can manually construct the padding tensor.
+                padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
+                padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
+                attention_mask = torch.cat([attention_mask, padding], dim=2)
+            else:
+                # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
+                #       we want to instead pad by (0, remaining_length), where remaining_length is:
+                #       remaining_length: int = target_length - current_length
+                # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
+                attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+
+        if out_dim == 3:
+            if attention_mask.shape[0] < batch_size * head_size:
+                attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
+        elif out_dim == 4:
+            attention_mask = attention_mask.unsqueeze(1)
+            attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
+
+        return attention_mask
+
+    def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+        r"""
+        Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
+        `Attention` class.
+
+        Args:
+            encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
+
+        Returns:
+            `torch.Tensor`: The normalized encoder hidden states.
+        """
+        assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
+
+        if isinstance(self.norm_cross, nn.LayerNorm):
+            encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+        elif isinstance(self.norm_cross, nn.GroupNorm):
+            # Group norm norms along the channels dimension and expects
+            # input to be in the shape of (N, C, *). In this case, we want
+            # to norm along the hidden dimension, so we need to move
+            # (batch_size, sequence_length, hidden_size) ->
+            # (batch_size, hidden_size, sequence_length)
+            encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+            encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+            encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+        else:
+            assert False
+
+        return encoder_hidden_states
+
+    @torch.no_grad()
+    def fuse_projections(self, fuse=True):
+        is_cross_attention = self.cross_attention_dim != self.query_dim
+        device = self.to_q.weight.data.device
+        dtype = self.to_q.weight.data.dtype
+
+        if not is_cross_attention:
+            # fetch weight matrices.
+            concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+            in_features = concatenated_weights.shape[1]
+            out_features = concatenated_weights.shape[0]
+
+            # create a new single projection layer and copy over the weights.
+            self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
+            self.to_qkv.weight.copy_(concatenated_weights)
+
+        else:
+            concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+            in_features = concatenated_weights.shape[1]
+            out_features = concatenated_weights.shape[0]
+
+            self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
+            self.to_kv.weight.copy_(concatenated_weights)
+
+        self.fused_projections = fuse
+
+
+class AttnProcessor:
+    r"""
+    Default processor for performing attention-related computations.
+    """
+
+    def __call__(
+        self,
+        attn: Attention,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        temb: Optional[torch.FloatTensor] = None,
+        scale: float = 1.0,
+    ) -> torch.Tensor:
+        residual = hidden_states
+
+        args = () if USE_PEFT_BACKEND else (scale,)
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states, *args)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states, *args)
+        value = attn.to_v(encoder_hidden_states, *args)
+
+        query = attn.head_to_batch_dim(query)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        attention_probs = attn.get_attention_scores(query, key, attention_mask)
+        hidden_states = torch.bmm(attention_probs, value)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states, *args)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+
+class CustomDiffusionAttnProcessor(nn.Module):
+    r"""
+    Processor for implementing attention for the Custom Diffusion method.
+
+    Args:
+        train_kv (`bool`, defaults to `True`):
+            Whether to newly train the key and value matrices corresponding to the text features.
+        train_q_out (`bool`, defaults to `True`):
+            Whether to newly train query matrices corresponding to the latent image features.
+        hidden_size (`int`, *optional*, defaults to `None`):
+            The hidden size of the attention layer.
+        cross_attention_dim (`int`, *optional*, defaults to `None`):
+            The number of channels in the `encoder_hidden_states`.
+        out_bias (`bool`, defaults to `True`):
+            Whether to include the bias parameter in `train_q_out`.
+        dropout (`float`, *optional*, defaults to 0.0):
+            The dropout probability to use.
+    """
+
+    def __init__(
+        self,
+        train_kv: bool = True,
+        train_q_out: bool = True,
+        hidden_size: Optional[int] = None,
+        cross_attention_dim: Optional[int] = None,
+        out_bias: bool = True,
+        dropout: float = 0.0,
+    ):
+        super().__init__()
+        self.train_kv = train_kv
+        self.train_q_out = train_q_out
+
+        self.hidden_size = hidden_size
+        self.cross_attention_dim = cross_attention_dim
+
+        # `_custom_diffusion` id for easy serialization and loading.
+        if self.train_kv:
+            self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+            self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+        if self.train_q_out:
+            self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
+            self.to_out_custom_diffusion = nn.ModuleList([])
+            self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
+            self.to_out_custom_diffusion.append(nn.Dropout(dropout))
+
+    def __call__(
+        self,
+        attn: Attention,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+    ) -> torch.Tensor:
+        batch_size, sequence_length, _ = hidden_states.shape
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+        if self.train_q_out:
+            query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
+        else:
+            query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
+
+        if encoder_hidden_states is None:
+            crossattn = False
+            encoder_hidden_states = hidden_states
+        else:
+            crossattn = True
+            if attn.norm_cross:
+                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        if self.train_kv:
+            key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
+            value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
+            key = key.to(attn.to_q.weight.dtype)
+            value = value.to(attn.to_q.weight.dtype)
+        else:
+            key = attn.to_k(encoder_hidden_states)
+            value = attn.to_v(encoder_hidden_states)
+
+        if crossattn:
+            detach = torch.ones_like(key)
+            detach[:, :1, :] = detach[:, :1, :] * 0.0
+            key = detach * key + (1 - detach) * key.detach()
+            value = detach * value + (1 - detach) * value.detach()
+
+        query = attn.head_to_batch_dim(query)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        attention_probs = attn.get_attention_scores(query, key, attention_mask)
+        hidden_states = torch.bmm(attention_probs, value)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        if self.train_q_out:
+            # linear proj
+            hidden_states = self.to_out_custom_diffusion[0](hidden_states)
+            # dropout
+            hidden_states = self.to_out_custom_diffusion[1](hidden_states)
+        else:
+            # linear proj
+            hidden_states = attn.to_out[0](hidden_states)
+            # dropout
+            hidden_states = attn.to_out[1](hidden_states)
+
+        return hidden_states
+
+
+class AttnAddedKVProcessor:
+    r"""
+    Processor for performing attention-related computations with extra learnable key and value matrices for the text
+    encoder.
+    """
+
+    def __call__(
+        self,
+        attn: Attention,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        scale: float = 1.0,
+    ) -> torch.Tensor:
+        residual = hidden_states
+
+        args = () if USE_PEFT_BACKEND else (scale,)
+
+        hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+        batch_size, sequence_length, _ = hidden_states.shape
+
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states, *args)
+        query = attn.head_to_batch_dim(query)
+
+        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
+        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
+        encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
+        encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
+
+        if not attn.only_cross_attention:
+            key = attn.to_k(hidden_states, *args)
+            value = attn.to_v(hidden_states, *args)
+            key = attn.head_to_batch_dim(key)
+            value = attn.head_to_batch_dim(value)
+            key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+            value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+        else:
+            key = encoder_hidden_states_key_proj
+            value = encoder_hidden_states_value_proj
+
+        attention_probs = attn.get_attention_scores(query, key, attention_mask)
+        hidden_states = torch.bmm(attention_probs, value)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states, *args)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+        hidden_states = hidden_states + residual
+
+        return hidden_states
+
+
+class AttnAddedKVProcessor2_0:
+    r"""
+    Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
+    learnable key and value matrices for the text encoder.
+    """
+
+    def __init__(self):
+        if not hasattr(F, "scaled_dot_product_attention"):
+            raise ImportError(
+                "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+            )
+
+    def __call__(
+        self,
+        attn: Attention,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        scale: float = 1.0,
+    ) -> torch.Tensor:
+        residual = hidden_states
+
+        args = () if USE_PEFT_BACKEND else (scale,)
+
+        hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+        batch_size, sequence_length, _ = hidden_states.shape
+
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states, *args)
+        query = attn.head_to_batch_dim(query, out_dim=4)
+
+        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+        encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
+        encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
+
+        if not attn.only_cross_attention:
+            key = attn.to_k(hidden_states, *args)
+            value = attn.to_v(hidden_states, *args)
+            key = attn.head_to_batch_dim(key, out_dim=4)
+            value = attn.head_to_batch_dim(value, out_dim=4)
+            key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+            value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+        else:
+            key = encoder_hidden_states_key_proj
+            value = encoder_hidden_states_value_proj
+
+        # the output of sdp = (batch, num_heads, seq_len, head_dim)
+        # TODO: add support for attn.scale when we move to Torch 2.1
+        hidden_states = F.scaled_dot_product_attention(
+            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+        )
+        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states, *args)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+        hidden_states = hidden_states + residual
+
+        return hidden_states
+
+
+class XFormersAttnAddedKVProcessor:
+    r"""
+    Processor for implementing memory efficient attention using xFormers.
+
+    Args:
+        attention_op (`Callable`, *optional*, defaults to `None`):
+            The base
+            [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
+            use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
+            operator.
+    """
+
+    def __init__(self, attention_op: Optional[Callable] = None):
+        self.attention_op = attention_op
+
+    def __call__(
+        self,
+        attn: Attention,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+    ) -> torch.Tensor:
+        residual = hidden_states
+        hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+        batch_size, sequence_length, _ = hidden_states.shape
+
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+        query = attn.head_to_batch_dim(query)
+
+        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+        encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
+        encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
+
+        if not attn.only_cross_attention:
+            key = attn.to_k(hidden_states)
+            value = attn.to_v(hidden_states)
+            key = attn.head_to_batch_dim(key)
+            value = attn.head_to_batch_dim(value)
+            key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+            value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+        else:
+            key = encoder_hidden_states_key_proj
+            value = encoder_hidden_states_value_proj
+
+        hidden_states = xformers.ops.memory_efficient_attention(
+            query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+        )
+        hidden_states = hidden_states.to(query.dtype)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+        hidden_states = hidden_states + residual
+
+        return hidden_states
+
+
+class XFormersAttnProcessor:
+    r"""
+    Processor for implementing memory efficient attention using xFormers.
+
+    Args:
+        attention_op (`Callable`, *optional*, defaults to `None`):
+            The base
+            [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
+            use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
+            operator.
+    """
+
+    def __init__(self, attention_op: Optional[Callable] = None):
+        self.attention_op = attention_op
+
+    def __call__(
+        self,
+        attn: Attention,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        temb: Optional[torch.FloatTensor] = None,
+        scale: float = 1.0,
+    ) -> torch.FloatTensor:
+        residual = hidden_states
+
+        args = () if USE_PEFT_BACKEND else (scale,)
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, key_tokens, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+
+        attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
+        if attention_mask is not None:
+            # expand our mask's singleton query_tokens dimension:
+            #   [batch*heads,            1, key_tokens] ->
+            #   [batch*heads, query_tokens, key_tokens]
+            # so that it can be added as a bias onto the attention scores that xformers computes:
+            #   [batch*heads, query_tokens, key_tokens]
+            # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
+            _, query_tokens, _ = hidden_states.shape
+            attention_mask = attention_mask.expand(-1, query_tokens, -1)
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states, *args)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states, *args)
+        value = attn.to_v(encoder_hidden_states, *args)
+
+        query = attn.head_to_batch_dim(query).contiguous()
+        key = attn.head_to_batch_dim(key).contiguous()
+        value = attn.head_to_batch_dim(value).contiguous()
+
+        hidden_states = xformers.ops.memory_efficient_attention(
+            query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+        )
+        hidden_states = hidden_states.to(query.dtype)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states, *args)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+
+class AttnProcessor2_0:
+    r"""
+    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+    """
+
+    def __init__(self):
+        if not hasattr(F, "scaled_dot_product_attention"):
+            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+    def __call__(
+        self,
+        attn: Attention,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        temb: Optional[torch.FloatTensor] = None,
+        scale: float = 1.0,
+        **kwargs,
+    ) -> torch.FloatTensor:
+        residual = hidden_states
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+
+        if attention_mask is not None:
+            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+            # scaled_dot_product_attention expects attention_mask shape to be
+            # (batch, heads, source_length, target_length)
+            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        args = () if USE_PEFT_BACKEND else (scale,)
+        query = attn.to_q(hidden_states, *args)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states, *args)
+        value = attn.to_v(encoder_hidden_states, *args)
+
+        inner_dim = key.shape[-1]
+        head_dim = inner_dim // attn.heads
+
+        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        # the output of sdp = (batch, num_heads, seq_len, head_dim)
+        # TODO: add support for attn.scale when we move to Torch 2.1
+        hidden_states = F.scaled_dot_product_attention(
+            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+        )
+
+        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+        hidden_states = hidden_states.to(query.dtype)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states, *args)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+
+class FusedAttnProcessor2_0:
+    r"""
+    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+    It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,
+    key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+    <Tip warning={true}>
+
+    This API is currently 🧪 experimental in nature and can change in future.
+
+    </Tip>
+    """
+
+    def __init__(self):
+        if not hasattr(F, "scaled_dot_product_attention"):
+            raise ImportError(
+                "FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
+            )
+
+    def __call__(
+        self,
+        attn: Attention,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        temb: Optional[torch.FloatTensor] = None,
+        scale: float = 1.0,
+    ) -> torch.FloatTensor:
+        residual = hidden_states
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+
+        if attention_mask is not None:
+            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+            # scaled_dot_product_attention expects attention_mask shape to be
+            # (batch, heads, source_length, target_length)
+            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        args = () if USE_PEFT_BACKEND else (scale,)
+        if encoder_hidden_states is None:
+            qkv = attn.to_qkv(hidden_states, *args)
+            split_size = qkv.shape[-1] // 3
+            query, key, value = torch.split(qkv, split_size, dim=-1)
+        else:
+            if attn.norm_cross:
+                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+            query = attn.to_q(hidden_states, *args)
+
+            kv = attn.to_kv(encoder_hidden_states, *args)
+            split_size = kv.shape[-1] // 2
+            key, value = torch.split(kv, split_size, dim=-1)
+
+        inner_dim = key.shape[-1]
+        head_dim = inner_dim // attn.heads
+
+        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        # the output of sdp = (batch, num_heads, seq_len, head_dim)
+        # TODO: add support for attn.scale when we move to Torch 2.1
+        hidden_states = F.scaled_dot_product_attention(
+            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+        )
+
+        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+        hidden_states = hidden_states.to(query.dtype)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states, *args)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+
+class CustomDiffusionXFormersAttnProcessor(nn.Module):
+    r"""
+    Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
+
+    Args:
+    train_kv (`bool`, defaults to `True`):
+        Whether to newly train the key and value matrices corresponding to the text features.
+    train_q_out (`bool`, defaults to `True`):
+        Whether to newly train query matrices corresponding to the latent image features.
+    hidden_size (`int`, *optional*, defaults to `None`):
+        The hidden size of the attention layer.
+    cross_attention_dim (`int`, *optional*, defaults to `None`):
+        The number of channels in the `encoder_hidden_states`.
+    out_bias (`bool`, defaults to `True`):
+        Whether to include the bias parameter in `train_q_out`.
+    dropout (`float`, *optional*, defaults to 0.0):
+        The dropout probability to use.
+    attention_op (`Callable`, *optional*, defaults to `None`):
+        The base
+        [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
+        as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
+    """
+
+    def __init__(
+        self,
+        train_kv: bool = True,
+        train_q_out: bool = False,
+        hidden_size: Optional[int] = None,
+        cross_attention_dim: Optional[int] = None,
+        out_bias: bool = True,
+        dropout: float = 0.0,
+        attention_op: Optional[Callable] = None,
+    ):
+        super().__init__()
+        self.train_kv = train_kv
+        self.train_q_out = train_q_out
+
+        self.hidden_size = hidden_size
+        self.cross_attention_dim = cross_attention_dim
+        self.attention_op = attention_op
+
+        # `_custom_diffusion` id for easy serialization and loading.
+        if self.train_kv:
+            self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+            self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+        if self.train_q_out:
+            self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
+            self.to_out_custom_diffusion = nn.ModuleList([])
+            self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
+            self.to_out_custom_diffusion.append(nn.Dropout(dropout))
+
+    def __call__(
+        self,
+        attn: Attention,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+    ) -> torch.FloatTensor:
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+        if self.train_q_out:
+            query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
+        else:
+            query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
+
+        if encoder_hidden_states is None:
+            crossattn = False
+            encoder_hidden_states = hidden_states
+        else:
+            crossattn = True
+            if attn.norm_cross:
+                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        if self.train_kv:
+            key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
+            value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
+            key = key.to(attn.to_q.weight.dtype)
+            value = value.to(attn.to_q.weight.dtype)
+        else:
+            key = attn.to_k(encoder_hidden_states)
+            value = attn.to_v(encoder_hidden_states)
+
+        if crossattn:
+            detach = torch.ones_like(key)
+            detach[:, :1, :] = detach[:, :1, :] * 0.0
+            key = detach * key + (1 - detach) * key.detach()
+            value = detach * value + (1 - detach) * value.detach()
+
+        query = attn.head_to_batch_dim(query).contiguous()
+        key = attn.head_to_batch_dim(key).contiguous()
+        value = attn.head_to_batch_dim(value).contiguous()
+
+        hidden_states = xformers.ops.memory_efficient_attention(
+            query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+        )
+        hidden_states = hidden_states.to(query.dtype)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        if self.train_q_out:
+            # linear proj
+            hidden_states = self.to_out_custom_diffusion[0](hidden_states)
+            # dropout
+            hidden_states = self.to_out_custom_diffusion[1](hidden_states)
+        else:
+            # linear proj
+            hidden_states = attn.to_out[0](hidden_states)
+            # dropout
+            hidden_states = attn.to_out[1](hidden_states)
+
+        return hidden_states
+
+
+class CustomDiffusionAttnProcessor2_0(nn.Module):
+    r"""
+    Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled
+    dot-product attention.
+
+    Args:
+        train_kv (`bool`, defaults to `True`):
+            Whether to newly train the key and value matrices corresponding to the text features.
+        train_q_out (`bool`, defaults to `True`):
+            Whether to newly train query matrices corresponding to the latent image features.
+        hidden_size (`int`, *optional*, defaults to `None`):
+            The hidden size of the attention layer.
+        cross_attention_dim (`int`, *optional*, defaults to `None`):
+            The number of channels in the `encoder_hidden_states`.
+        out_bias (`bool`, defaults to `True`):
+            Whether to include the bias parameter in `train_q_out`.
+        dropout (`float`, *optional*, defaults to 0.0):
+            The dropout probability to use.
+    """
+
+    def __init__(
+        self,
+        train_kv: bool = True,
+        train_q_out: bool = True,
+        hidden_size: Optional[int] = None,
+        cross_attention_dim: Optional[int] = None,
+        out_bias: bool = True,
+        dropout: float = 0.0,
+    ):
+        super().__init__()
+        self.train_kv = train_kv
+        self.train_q_out = train_q_out
+
+        self.hidden_size = hidden_size
+        self.cross_attention_dim = cross_attention_dim
+
+        # `_custom_diffusion` id for easy serialization and loading.
+        if self.train_kv:
+            self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+            self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+        if self.train_q_out:
+            self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
+            self.to_out_custom_diffusion = nn.ModuleList([])
+            self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
+            self.to_out_custom_diffusion.append(nn.Dropout(dropout))
+
+    def __call__(
+        self,
+        attn: Attention,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+    ) -> torch.FloatTensor:
+        batch_size, sequence_length, _ = hidden_states.shape
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+        if self.train_q_out:
+            query = self.to_q_custom_diffusion(hidden_states)
+        else:
+            query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            crossattn = False
+            encoder_hidden_states = hidden_states
+        else:
+            crossattn = True
+            if attn.norm_cross:
+                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        if self.train_kv:
+            key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
+            value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
+            key = key.to(attn.to_q.weight.dtype)
+            value = value.to(attn.to_q.weight.dtype)
+
+        else:
+            key = attn.to_k(encoder_hidden_states)
+            value = attn.to_v(encoder_hidden_states)
+
+        if crossattn:
+            detach = torch.ones_like(key)
+            detach[:, :1, :] = detach[:, :1, :] * 0.0
+            key = detach * key + (1 - detach) * key.detach()
+            value = detach * value + (1 - detach) * value.detach()
+
+        inner_dim = hidden_states.shape[-1]
+
+        head_dim = inner_dim // attn.heads
+        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        # the output of sdp = (batch, num_heads, seq_len, head_dim)
+        # TODO: add support for attn.scale when we move to Torch 2.1
+        hidden_states = F.scaled_dot_product_attention(
+            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+        )
+
+        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+        hidden_states = hidden_states.to(query.dtype)
+
+        if self.train_q_out:
+            # linear proj
+            hidden_states = self.to_out_custom_diffusion[0](hidden_states)
+            # dropout
+            hidden_states = self.to_out_custom_diffusion[1](hidden_states)
+        else:
+            # linear proj
+            hidden_states = attn.to_out[0](hidden_states)
+            # dropout
+            hidden_states = attn.to_out[1](hidden_states)
+
+        return hidden_states
+
+
+class SlicedAttnProcessor:
+    r"""
+    Processor for implementing sliced attention.
+
+    Args:
+        slice_size (`int`, *optional*):
+            The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
+            `attention_head_dim` must be a multiple of the `slice_size`.
+    """
+
+    def __init__(self, slice_size: int):
+        self.slice_size = slice_size
+
+    def __call__(
+        self,
+        attn: Attention,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+    ) -> torch.FloatTensor:
+        residual = hidden_states
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+        dim = query.shape[-1]
+        query = attn.head_to_batch_dim(query)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        batch_size_attention, query_tokens, _ = query.shape
+        hidden_states = torch.zeros(
+            (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
+        )
+
+        for i in range(batch_size_attention // self.slice_size):
+            start_idx = i * self.slice_size
+            end_idx = (i + 1) * self.slice_size
+
+            query_slice = query[start_idx:end_idx]
+            key_slice = key[start_idx:end_idx]
+            attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
+
+            attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
+
+            attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+            hidden_states[start_idx:end_idx] = attn_slice
+
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+
+class SlicedAttnAddedKVProcessor:
+    r"""
+    Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
+
+    Args:
+        slice_size (`int`, *optional*):
+            The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
+            `attention_head_dim` must be a multiple of the `slice_size`.
+    """
+
+    def __init__(self, slice_size):
+        self.slice_size = slice_size
+
+    def __call__(
+        self,
+        attn: "Attention",
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        temb: Optional[torch.FloatTensor] = None,
+    ) -> torch.FloatTensor:
+        residual = hidden_states
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+
+        batch_size, sequence_length, _ = hidden_states.shape
+
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+        dim = query.shape[-1]
+        query = attn.head_to_batch_dim(query)
+
+        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+        encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
+        encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
+
+        if not attn.only_cross_attention:
+            key = attn.to_k(hidden_states)
+            value = attn.to_v(hidden_states)
+            key = attn.head_to_batch_dim(key)
+            value = attn.head_to_batch_dim(value)
+            key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+            value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+        else:
+            key = encoder_hidden_states_key_proj
+            value = encoder_hidden_states_value_proj
+
+        batch_size_attention, query_tokens, _ = query.shape
+        hidden_states = torch.zeros(
+            (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
+        )
+
+        for i in range(batch_size_attention // self.slice_size):
+            start_idx = i * self.slice_size
+            end_idx = (i + 1) * self.slice_size
+
+            query_slice = query[start_idx:end_idx]
+            key_slice = key[start_idx:end_idx]
+            attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
+
+            attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
+
+            attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+            hidden_states[start_idx:end_idx] = attn_slice
+
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+        hidden_states = hidden_states + residual
+
+        return hidden_states
+
+
+class SpatialNorm(nn.Module):
+    """
+    Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
+
+    Args:
+        f_channels (`int`):
+            The number of channels for input to group normalization layer, and output of the spatial norm layer.
+        zq_channels (`int`):
+            The number of channels for the quantized vector as described in the paper.
+    """
+
+    def __init__(
+        self,
+        f_channels: int,
+        zq_channels: int,
+    ):
+        super().__init__()
+        self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
+        self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+        self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+
+    def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
+        f_size = f.shape[-2:]
+        zq = F.interpolate(zq, size=f_size, mode="nearest")
+        norm_f = self.norm_layer(f)
+        new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
+        return new_f
+
+
+## Deprecated
+class LoRAAttnProcessor(nn.Module):
+    r"""
+    Processor for implementing the LoRA attention mechanism.
+
+    Args:
+        hidden_size (`int`, *optional*):
+            The hidden size of the attention layer.
+        cross_attention_dim (`int`, *optional*):
+            The number of channels in the `encoder_hidden_states`.
+        rank (`int`, defaults to 4):
+            The dimension of the LoRA update matrices.
+        network_alpha (`int`, *optional*):
+            Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
+        kwargs (`dict`):
+            Additional keyword arguments to pass to the `LoRALinearLayer` layers.
+    """
+
+    def __init__(
+        self,
+        hidden_size: int,
+        cross_attention_dim: Optional[int] = None,
+        rank: int = 4,
+        network_alpha: Optional[int] = None,
+        **kwargs,
+    ):
+        super().__init__()
+
+        self.hidden_size = hidden_size
+        self.cross_attention_dim = cross_attention_dim
+        self.rank = rank
+
+        q_rank = kwargs.pop("q_rank", None)
+        q_hidden_size = kwargs.pop("q_hidden_size", None)
+        q_rank = q_rank if q_rank is not None else rank
+        q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
+
+        v_rank = kwargs.pop("v_rank", None)
+        v_hidden_size = kwargs.pop("v_hidden_size", None)
+        v_rank = v_rank if v_rank is not None else rank
+        v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
+
+        out_rank = kwargs.pop("out_rank", None)
+        out_hidden_size = kwargs.pop("out_hidden_size", None)
+        out_rank = out_rank if out_rank is not None else rank
+        out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
+
+        self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
+        self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+        self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
+        self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
+
+    def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+        self_cls_name = self.__class__.__name__
+        deprecate(
+            self_cls_name,
+            "0.26.0",
+            (
+                f"Make sure use {self_cls_name[4:]} instead by setting"
+                "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
+                " `LoraLoaderMixin.load_lora_weights`"
+            ),
+        )
+        attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
+        attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
+        attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
+        attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
+
+        attn._modules.pop("processor")
+        attn.processor = AttnProcessor()
+        return attn.processor(attn, hidden_states, *args, **kwargs)
+
+
+class LoRAAttnProcessor2_0(nn.Module):
+    r"""
+    Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
+    attention.
+
+    Args:
+        hidden_size (`int`):
+            The hidden size of the attention layer.
+        cross_attention_dim (`int`, *optional*):
+            The number of channels in the `encoder_hidden_states`.
+        rank (`int`, defaults to 4):
+            The dimension of the LoRA update matrices.
+        network_alpha (`int`, *optional*):
+            Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
+        kwargs (`dict`):
+            Additional keyword arguments to pass to the `LoRALinearLayer` layers.
+    """
+
+    def __init__(
+        self,
+        hidden_size: int,
+        cross_attention_dim: Optional[int] = None,
+        rank: int = 4,
+        network_alpha: Optional[int] = None,
+        **kwargs,
+    ):
+        super().__init__()
+        if not hasattr(F, "scaled_dot_product_attention"):
+            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+        self.hidden_size = hidden_size
+        self.cross_attention_dim = cross_attention_dim
+        self.rank = rank
+
+        q_rank = kwargs.pop("q_rank", None)
+        q_hidden_size = kwargs.pop("q_hidden_size", None)
+        q_rank = q_rank if q_rank is not None else rank
+        q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
+
+        v_rank = kwargs.pop("v_rank", None)
+        v_hidden_size = kwargs.pop("v_hidden_size", None)
+        v_rank = v_rank if v_rank is not None else rank
+        v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
+
+        out_rank = kwargs.pop("out_rank", None)
+        out_hidden_size = kwargs.pop("out_hidden_size", None)
+        out_rank = out_rank if out_rank is not None else rank
+        out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
+
+        self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
+        self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+        self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
+        self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
+
+    def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+        self_cls_name = self.__class__.__name__
+        deprecate(
+            self_cls_name,
+            "0.26.0",
+            (
+                f"Make sure use {self_cls_name[4:]} instead by setting"
+                "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
+                " `LoraLoaderMixin.load_lora_weights`"
+            ),
+        )
+        attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
+        attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
+        attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
+        attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
+
+        attn._modules.pop("processor")
+        attn.processor = AttnProcessor2_0()
+        return attn.processor(attn, hidden_states, *args, **kwargs)
+
+
+class LoRAXFormersAttnProcessor(nn.Module):
+    r"""
+    Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
+
+    Args:
+        hidden_size (`int`, *optional*):
+            The hidden size of the attention layer.
+        cross_attention_dim (`int`, *optional*):
+            The number of channels in the `encoder_hidden_states`.
+        rank (`int`, defaults to 4):
+            The dimension of the LoRA update matrices.
+        attention_op (`Callable`, *optional*, defaults to `None`):
+            The base
+            [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
+            use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
+            operator.
+        network_alpha (`int`, *optional*):
+            Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
+        kwargs (`dict`):
+            Additional keyword arguments to pass to the `LoRALinearLayer` layers.
+    """
+
+    def __init__(
+        self,
+        hidden_size: int,
+        cross_attention_dim: int,
+        rank: int = 4,
+        attention_op: Optional[Callable] = None,
+        network_alpha: Optional[int] = None,
+        **kwargs,
+    ):
+        super().__init__()
+
+        self.hidden_size = hidden_size
+        self.cross_attention_dim = cross_attention_dim
+        self.rank = rank
+        self.attention_op = attention_op
+
+        q_rank = kwargs.pop("q_rank", None)
+        q_hidden_size = kwargs.pop("q_hidden_size", None)
+        q_rank = q_rank if q_rank is not None else rank
+        q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
+
+        v_rank = kwargs.pop("v_rank", None)
+        v_hidden_size = kwargs.pop("v_hidden_size", None)
+        v_rank = v_rank if v_rank is not None else rank
+        v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
+
+        out_rank = kwargs.pop("out_rank", None)
+        out_hidden_size = kwargs.pop("out_hidden_size", None)
+        out_rank = out_rank if out_rank is not None else rank
+        out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
+
+        self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
+        self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+        self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
+        self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
+
+    def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+        self_cls_name = self.__class__.__name__
+        deprecate(
+            self_cls_name,
+            "0.26.0",
+            (
+                f"Make sure use {self_cls_name[4:]} instead by setting"
+                "LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
+                " `LoraLoaderMixin.load_lora_weights`"
+            ),
+        )
+        attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
+        attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
+        attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
+        attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
+
+        attn._modules.pop("processor")
+        attn.processor = XFormersAttnProcessor()
+        return attn.processor(attn, hidden_states, *args, **kwargs)
+
+
+class LoRAAttnAddedKVProcessor(nn.Module):
+    r"""
+    Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
+    encoder.
+
+    Args:
+        hidden_size (`int`, *optional*):
+            The hidden size of the attention layer.
+        cross_attention_dim (`int`, *optional*, defaults to `None`):
+            The number of channels in the `encoder_hidden_states`.
+        rank (`int`, defaults to 4):
+            The dimension of the LoRA update matrices.
+        network_alpha (`int`, *optional*):
+            Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
+        kwargs (`dict`):
+            Additional keyword arguments to pass to the `LoRALinearLayer` layers.
+    """
+
+    def __init__(
+        self,
+        hidden_size: int,
+        cross_attention_dim: Optional[int] = None,
+        rank: int = 4,
+        network_alpha: Optional[int] = None,
+    ):
+        super().__init__()
+
+        self.hidden_size = hidden_size
+        self.cross_attention_dim = cross_attention_dim
+        self.rank = rank
+
+        self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+        self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+        self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+        self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+        self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+        self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+
+    def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+        self_cls_name = self.__class__.__name__
+        deprecate(
+            self_cls_name,
+            "0.26.0",
+            (
+                f"Make sure use {self_cls_name[4:]} instead by setting"
+                "LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
+                " `LoraLoaderMixin.load_lora_weights`"
+            ),
+        )
+        attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
+        attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
+        attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
+        attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
+
+        attn._modules.pop("processor")
+        attn.processor = AttnAddedKVProcessor()
+        return attn.processor(attn, hidden_states, *args, **kwargs)
+
+
+class IPAdapterAttnProcessor(nn.Module):
+    r"""
+    Attention processor for IP-Adapater.
+
+    Args:
+        hidden_size (`int`):
+            The hidden size of the attention layer.
+        cross_attention_dim (`int`):
+            The number of channels in the `encoder_hidden_states`.
+        num_tokens (`int`, defaults to 4):
+            The context length of the image features.
+        scale (`float`, defaults to 1.0):
+            the weight scale of image prompt.
+    """
+
+    def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0):
+        super().__init__()
+
+        self.hidden_size = hidden_size
+        self.cross_attention_dim = cross_attention_dim
+        self.num_tokens = num_tokens
+        self.scale = scale
+
+        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+
+    def __call__(
+        self,
+        attn,
+        hidden_states,
+        encoder_hidden_states=None,
+        attention_mask=None,
+        temb=None,
+        scale=1.0,
+    ):
+        if scale != 1.0:
+            logger.warning("`scale` of IPAttnProcessor should be set with `set_ip_adapter_scale`.")
+        residual = hidden_states
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        # split hidden states
+        end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+        encoder_hidden_states, ip_hidden_states = (
+            encoder_hidden_states[:, :end_pos, :],
+            encoder_hidden_states[:, end_pos:, :],
+        )
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        query = attn.head_to_batch_dim(query)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        attention_probs = attn.get_attention_scores(query, key, attention_mask)
+        hidden_states = torch.bmm(attention_probs, value)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # for ip-adapter
+        ip_key = self.to_k_ip(ip_hidden_states)
+        ip_value = self.to_v_ip(ip_hidden_states)
+
+        ip_key = attn.head_to_batch_dim(ip_key)
+        ip_value = attn.head_to_batch_dim(ip_value)
+
+        ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+        ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+        ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
+
+        hidden_states = hidden_states + self.scale * ip_hidden_states
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+class VPTemporalAdapterAttnProcessor2_0(torch.nn.Module):
+    r"""
+    Attention processor for IP-Adapter for PyTorch 2.0.
+
+    Args:
+        hidden_size (`int`):
+            The hidden size of the attention layer.
+        cross_attention_dim (`int`):
+            The number of channels in the `encoder_hidden_states`.
+        num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
+            The context length of the image features.
+        scale (`float` or `List[float]`, defaults to 1.0):
+            the weight scale of image prompt.
+    """
+
+    """
+    Support frame-wise VP-Adapter
+    encoder_hidden_states : I(num of ip_adapters), B, N * T(num of time condition), C
+    ip_adapter_masks(bool): (I, B, N * T, C) == encoder_hidden_states.shape
+    
+    """
+
+    def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
+        super().__init__()
+
+        if not hasattr(F, "scaled_dot_product_attention"):
+            raise ImportError(
+                f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+            )
+
+        self.hidden_size = hidden_size
+        self.cross_attention_dim = cross_attention_dim
+
+        if not isinstance(num_tokens, (tuple, list)):
+            num_tokens = [num_tokens]
+        self.num_tokens = num_tokens
+
+        if not isinstance(scale, list):
+            scale = [scale] * len(num_tokens)
+        if len(scale) != len(num_tokens):
+            raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
+        self.scale = scale
+
+        self.to_k_ip = nn.ModuleList(
+            [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+        )
+        self.to_v_ip = nn.ModuleList(
+            [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+        )
+
+    def __call__(
+        self,
+        attn: Attention,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        temb: Optional[torch.FloatTensor] = None,
+        scale: float = 1.0,
+        ip_adapter_masks: Optional[torch.FloatTensor] = None,
+        time_conditions: Optional[list] = None,
+        audio_length_in_s: Optional[int] = None,
+    ):
+        residual = hidden_states
+
+        # separate ip_hidden_states from encoder_hidden_states
+        if encoder_hidden_states is not None:
+            if isinstance(encoder_hidden_states, tuple):
+                encoder_hidden_states, ip_hidden_states = encoder_hidden_states
+            else:
+                deprecation_message = (
+                    "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
+                    " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
+                )
+                deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
+                end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
+                encoder_hidden_states, ip_hidden_states = (
+                    encoder_hidden_states[:, :end_pos, :],
+                    [encoder_hidden_states[:, end_pos:, :]],
+                )
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+
+        if attention_mask is not None:
+            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+            # scaled_dot_product_attention expects attention_mask shape to be
+            # (batch, heads, source_length, target_length)
+            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        inner_dim = key.shape[-1]
+        head_dim = inner_dim // attn.heads
+
+        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        # the output of sdp = (batch, num_heads, seq_len, head_dim)
+        # TODO: add support for attn.scale when we move to Torch 2.1
+        hidden_states = F.scaled_dot_product_attention(
+            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+        )
+
+        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+        hidden_states = hidden_states.to(query.dtype)
+
+        if ip_adapter_masks is not None:
+            if not isinstance(ip_adapter_masks, List):
+                # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
+                ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
+            if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
+                raise ValueError(
+                    f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
+                    f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
+                    f"({len(ip_hidden_states)})"
+                )
+            else:
+                for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
+                    if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
+                        raise ValueError(
+                            "Each element of the ip_adapter_masks array should be a tensor with shape "
+                            "[1, num_images_for_ip_adapter, height, width]."
+                            " Please use `IPAdapterMaskProcessor` to preprocess your mask"
+                        )
+                    if mask.shape[1] != ip_state.shape[1]:
+                        raise ValueError(
+                            f"Number of masks ({mask.shape[1]}) does not match "
+                            f"number of ip images ({ip_state.shape[1]}) at index {index}"
+                        )
+                    if isinstance(scale, list) and not len(scale) == mask.shape[1]:
+                        raise ValueError(
+                            f"Number of masks ({mask.shape[1]}) does not match "
+                            f"number of scales ({len(scale)}) at index {index}"
+                        )
+        else:
+            ip_adapter_masks = [None] * len(self.scale)
+        # for ip-adapter
+        for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
+            ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
+        ):
+            skip = False
+            if isinstance(scale, list):
+                if all(s == 0 for s in scale):
+                    skip = True
+            elif scale == 0:
+                skip = True
+            if not skip:
+                    time_condition_masks = None
+                    for time_condition in time_conditions:
+                        # hard code
+                        time_condition_mask = torch.zeros((
+                            batch_size, 
+                            int(math.sqrt(hidden_states.shape[1]) // 2),
+                            int(2 * math.sqrt(hidden_states.shape[1])),
+                        )).bool().to(device=hidden_states.device)
+                        mel_latent_length = time_condition_mask.shape[-1]
+                        time_start, time_end = \
+                            int(time_condition[0] // audio_length_in_s * mel_latent_length),\
+                            int(time_condition[1] // audio_length_in_s * mel_latent_length)
+
+                        time_condition_mask[:, :, time_start:time_end] = True
+                        time_condition_mask = time_condition_mask.flatten(-2).unsqueeze(-1).repeat(1, 1, 4)
+                        if time_condition_masks is None:
+                            time_condition_masks = time_condition_mask
+                        else:
+                            time_condition_masks = torch.cat([time_condition_masks, time_condition_mask], dim=-1)
+
+                    current_ip_hidden_states = rearrange(current_ip_hidden_states, 'L B N C -> B (L N) C')
+                    ip_key = to_k_ip(current_ip_hidden_states)
+                    ip_value = to_v_ip(current_ip_hidden_states)
+
+                    ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+                    ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+                    time_condition_masks = time_condition_masks.unsqueeze(1).repeat(1, attn.heads, 1, 1)
+
+                    # the output of sdp = (batch, num_heads, seq_len, head_dim)
+                    # TODO: add support for attn.scale when we move to Torch 2.1
+                    current_ip_hidden_states = F.scaled_dot_product_attention(
+                        query, ip_key, ip_value, attn_mask=time_condition_masks, dropout_p=0.0, is_causal=False
+                    )
+
+                    current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
+                        batch_size, -1, attn.heads * head_dim
+                    )
+                    current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
+
+                    hidden_states = hidden_states + scale * current_ip_hidden_states
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+class IPAdapterAttnProcessor2_0(torch.nn.Module):
+    r"""
+    Attention processor for IP-Adapter for PyTorch 2.0.
+
+    Args:
+        hidden_size (`int`):
+            The hidden size of the attention layer.
+        cross_attention_dim (`int`):
+            The number of channels in the `encoder_hidden_states`.
+        num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
+            The context length of the image features.
+        scale (`float` or `List[float]`, defaults to 1.0):
+            the weight scale of image prompt.
+    """
+
+    def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
+        super().__init__()
+
+        if not hasattr(F, "scaled_dot_product_attention"):
+            raise ImportError(
+                f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+            )
+
+        self.hidden_size = hidden_size
+        self.cross_attention_dim = cross_attention_dim
+
+        if not isinstance(num_tokens, (tuple, list)):
+            num_tokens = [num_tokens]
+        self.num_tokens = num_tokens
+
+        if not isinstance(scale, list):
+            scale = [scale] * len(num_tokens)
+        if len(scale) != len(num_tokens):
+            raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
+        self.scale = scale
+        self.to_k_ip = nn.ModuleList(
+            [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+        )
+        self.to_v_ip = nn.ModuleList(
+            [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+        )
+
+    def __call__(
+        self,
+        attn: Attention,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        temb: Optional[torch.FloatTensor] = None,
+        scale: float = 1.0,
+        ip_adapter_masks: Optional[torch.FloatTensor] = None,
+    ):
+        residual = hidden_states
+
+        # separate ip_hidden_states from encoder_hidden_states
+        if encoder_hidden_states is not None:
+            if isinstance(encoder_hidden_states, tuple):
+                encoder_hidden_states, ip_hidden_states = encoder_hidden_states
+            else:
+                deprecation_message = (
+                    "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
+                    " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
+                )
+                deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
+                end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
+                encoder_hidden_states, ip_hidden_states = (
+                    encoder_hidden_states[:, :end_pos, :],
+                    [encoder_hidden_states[:, end_pos:, :]],
+                )
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+
+        if attention_mask is not None:
+            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+            # scaled_dot_product_attention expects attention_mask shape to be
+            # (batch, heads, source_length, target_length)
+            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        inner_dim = key.shape[-1]
+        head_dim = inner_dim // attn.heads
+
+        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        # the output of sdp = (batch, num_heads, seq_len, head_dim)
+        # TODO: add support for attn.scale when we move to Torch 2.1
+        hidden_states = F.scaled_dot_product_attention(
+            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+        )
+
+        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+        hidden_states = hidden_states.to(query.dtype)
+
+        if ip_adapter_masks is not None:
+            if not isinstance(ip_adapter_masks, List):
+                # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
+                ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
+            if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
+                raise ValueError(
+                    f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
+                    f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
+                    f"({len(ip_hidden_states)})"
+                )
+            else:
+                for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
+                    if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
+                        raise ValueError(
+                            "Each element of the ip_adapter_masks array should be a tensor with shape "
+                            "[1, num_images_for_ip_adapter, height, width]."
+                            " Please use `IPAdapterMaskProcessor` to preprocess your mask"
+                        )
+                    if mask.shape[1] != ip_state.shape[1]:
+                        raise ValueError(
+                            f"Number of masks ({mask.shape[1]}) does not match "
+                            f"number of ip images ({ip_state.shape[1]}) at index {index}"
+                        )
+                    if isinstance(scale, list) and not len(scale) == mask.shape[1]:
+                        raise ValueError(
+                            f"Number of masks ({mask.shape[1]}) does not match "
+                            f"number of scales ({len(scale)}) at index {index}"
+                        )
+        else:
+            ip_adapter_masks = [None] * len(self.scale)
+
+        # for ip-adapter
+        for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
+            ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
+        ):
+            skip = False
+            if isinstance(scale, list):
+                if all(s == 0 for s in scale):
+                    skip = True
+            elif scale == 0:
+                skip = True
+            if not skip:
+                    ip_key = to_k_ip(current_ip_hidden_states)
+                    ip_value = to_v_ip(current_ip_hidden_states)
+
+                    ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+                    ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+                    # the output of sdp = (batch, num_heads, seq_len, head_dim)
+                    # TODO: add support for attn.scale when we move to Torch 2.1
+                    current_ip_hidden_states = F.scaled_dot_product_attention(
+                        query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+                    )
+
+                    current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
+                        batch_size, -1, attn.heads * head_dim
+                    )
+                    current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
+
+                    hidden_states = hidden_states + scale * current_ip_hidden_states
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+LORA_ATTENTION_PROCESSORS = (
+    LoRAAttnProcessor,
+    LoRAAttnProcessor2_0,
+    LoRAXFormersAttnProcessor,
+    LoRAAttnAddedKVProcessor,
+)
+
+ADDED_KV_ATTENTION_PROCESSORS = (
+    AttnAddedKVProcessor,
+    SlicedAttnAddedKVProcessor,
+    AttnAddedKVProcessor2_0,
+    XFormersAttnAddedKVProcessor,
+    LoRAAttnAddedKVProcessor,
+)
+
+CROSS_ATTENTION_PROCESSORS = (
+    AttnProcessor,
+    AttnProcessor2_0,
+    XFormersAttnProcessor,
+    SlicedAttnProcessor,
+    LoRAAttnProcessor,
+    LoRAAttnProcessor2_0,
+    LoRAXFormersAttnProcessor,
+    IPAdapterAttnProcessor,
+    IPAdapterAttnProcessor2_0,
+)
+
+AttentionProcessor = Union[
+    AttnProcessor,
+    AttnProcessor2_0,
+    FusedAttnProcessor2_0,
+    XFormersAttnProcessor,
+    SlicedAttnProcessor,
+    AttnAddedKVProcessor,
+    SlicedAttnAddedKVProcessor,
+    AttnAddedKVProcessor2_0,
+    XFormersAttnAddedKVProcessor,
+    CustomDiffusionAttnProcessor,
+    CustomDiffusionXFormersAttnProcessor,
+    CustomDiffusionAttnProcessor2_0,
+    # deprecated
+    LoRAAttnProcessor,
+    LoRAAttnProcessor2_0,
+    LoRAXFormersAttnProcessor,
+    LoRAAttnAddedKVProcessor,
+]
\ No newline at end of file
diff --git a/foleycrafter/models/auffusion/dual_transformer_2d.py b/foleycrafter/models/auffusion/dual_transformer_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3f27b61e001347f0093c039ad10ae79975b7691
--- /dev/null
+++ b/foleycrafter/models/auffusion/dual_transformer_2d.py
@@ -0,0 +1,156 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+from torch import nn
+
+from foleycrafter.models.auffusion.transformer_2d \
+    import Transformer2DModel, Transformer2DModelOutput
+
+
+class DualTransformer2DModel(nn.Module):
+    """
+    Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
+
+    Parameters:
+        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+        in_channels (`int`, *optional*):
+            Pass if the input is continuous. The number of channels in the input and output.
+        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+        dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
+        cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
+        sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
+            Note that this is fixed at training time as it is used for learning a number of position embeddings. See
+            `ImagePositionalEmbeddings`.
+        num_vector_embeds (`int`, *optional*):
+            Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
+            Includes the class for the masked latent pixel.
+        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+        num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
+            The number of diffusion steps used during training. Note that this is fixed at training time as it is used
+            to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
+            up to but not more than steps than `num_embeds_ada_norm`.
+        attention_bias (`bool`, *optional*):
+            Configure if the TransformerBlocks' attention should contain a bias parameter.
+    """
+
+    def __init__(
+        self,
+        num_attention_heads: int = 16,
+        attention_head_dim: int = 88,
+        in_channels: Optional[int] = None,
+        num_layers: int = 1,
+        dropout: float = 0.0,
+        norm_num_groups: int = 32,
+        cross_attention_dim: Optional[int] = None,
+        attention_bias: bool = False,
+        sample_size: Optional[int] = None,
+        num_vector_embeds: Optional[int] = None,
+        activation_fn: str = "geglu",
+        num_embeds_ada_norm: Optional[int] = None,
+    ):
+        super().__init__()
+        self.transformers = nn.ModuleList(
+            [
+                Transformer2DModel(
+                    num_attention_heads=num_attention_heads,
+                    attention_head_dim=attention_head_dim,
+                    in_channels=in_channels,
+                    num_layers=num_layers,
+                    dropout=dropout,
+                    norm_num_groups=norm_num_groups,
+                    cross_attention_dim=cross_attention_dim,
+                    attention_bias=attention_bias,
+                    sample_size=sample_size,
+                    num_vector_embeds=num_vector_embeds,
+                    activation_fn=activation_fn,
+                    num_embeds_ada_norm=num_embeds_ada_norm,
+                )
+                for _ in range(2)
+            ]
+        )
+
+        # Variables that can be set by a pipeline:
+
+        # The ratio of transformer1 to transformer2's output states to be combined during inference
+        self.mix_ratio = 0.5
+
+        # The shape of `encoder_hidden_states` is expected to be
+        # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
+        self.condition_lengths = [77, 257]
+
+        # Which transformer to use to encode which condition.
+        # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
+        self.transformer_index_for_condition = [1, 0]
+
+    def forward(
+        self,
+        hidden_states,
+        encoder_hidden_states,
+        timestep=None,
+        attention_mask=None,
+        cross_attention_kwargs=None,
+        return_dict: bool = True,
+    ):
+        """
+        Args:
+            hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
+                When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
+                hidden_states.
+            encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
+                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+                self-attention.
+            timestep ( `torch.long`, *optional*):
+                Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
+            attention_mask (`torch.FloatTensor`, *optional*):
+                Optional attention mask to be applied in Attention.
+            cross_attention_kwargs (`dict`, *optional*):
+                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+                `self.processor` in
+                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+        Returns:
+            [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
+            [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
+            returning a tuple, the first element is the sample tensor.
+        """
+        input_states = hidden_states
+
+        encoded_states = []
+        tokens_start = 0
+        # attention_mask is not used yet
+        for i in range(2):
+            # for each of the two transformers, pass the corresponding condition tokens
+            condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
+            transformer_index = self.transformer_index_for_condition[i]
+            encoded_state = self.transformers[transformer_index](
+                input_states,
+                encoder_hidden_states=condition_state,
+                timestep=timestep,
+                cross_attention_kwargs=cross_attention_kwargs,
+                return_dict=False,
+            )[0]
+            encoded_states.append(encoded_state - input_states)
+            tokens_start += self.condition_lengths[i]
+
+        output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
+        output_states = output_states + input_states
+
+        if not return_dict:
+            return (output_states,)
+
+        return Transformer2DModelOutput(sample=output_states)
\ No newline at end of file
diff --git a/foleycrafter/models/auffusion/loaders/ip_adapter.py b/foleycrafter/models/auffusion/loaders/ip_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..faba325670450f3a3d2885ce32e74e3811ba8405
--- /dev/null
+++ b/foleycrafter/models/auffusion/loaders/ip_adapter.py
@@ -0,0 +1,520 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+from typing import Dict, List, Optional, Union
+
+import torch
+from huggingface_hub.utils import validate_hf_hub_args
+from safetensors import safe_open
+
+from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
+from diffusers.utils import (
+    _get_model_file,
+    is_accelerate_available,
+    is_torch_version,
+    is_transformers_available,
+    logging,
+)
+
+
+if is_transformers_available():
+    from transformers import (
+        CLIPImageProcessor,
+        CLIPVisionModelWithProjection,
+    )
+
+    from diffusers.models.attention_processor import (
+        IPAdapterAttnProcessor,
+    )
+
+from foleycrafter.models.auffusion.attention_processor import IPAdapterAttnProcessor2_0, VPTemporalAdapterAttnProcessor2_0
+
+logger = logging.get_logger(__name__)
+
+
+class IPAdapterMixin:
+    """Mixin for handling IP Adapters."""
+
+    @validate_hf_hub_args
+    def load_ip_adapter(
+        self,
+        pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
+        subfolder: Union[str, List[str]],
+        weight_name: Union[str, List[str]],
+        image_encoder_folder: Optional[str] = "image_encoder",
+        **kwargs,
+    ):
+        """
+        Parameters:
+            pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
+                Can be either:
+
+                    - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+                      the Hub.
+                    - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+                      with [`ModelMixin.save_pretrained`].
+                    - A [torch state
+                      dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+            subfolder (`str` or `List[str]`):
+                The subfolder location of a model file within a larger model repository on the Hub or locally.
+                If a list is passed, it should have the same length as `weight_name`.
+            weight_name (`str` or `List[str]`):
+                The name of the weight file to load. If a list is passed, it should have the same length as
+                `weight_name`.
+            image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
+                The subfolder location of the image encoder within a larger model repository on the Hub or locally.
+                Pass `None` to not load the image encoder. If the image encoder is located in a folder inside `subfolder`,
+                you only need to pass the name of the folder that contains image encoder weights, e.g. `image_encoder_folder="image_encoder"`.
+                If the image encoder is located in a folder other than `subfolder`, you should pass the path to the folder that contains image encoder weights,
+                for example, `image_encoder_folder="different_subfolder/image_encoder"`.
+            cache_dir (`Union[str, os.PathLike]`, *optional*):
+                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+                is not used.
+            force_download (`bool`, *optional*, defaults to `False`):
+                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+                cached versions if they exist.
+            resume_download (`bool`, *optional*, defaults to `False`):
+                Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+                incompletely downloaded files are deleted.
+            proxies (`Dict[str, str]`, *optional*):
+                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+            local_files_only (`bool`, *optional*, defaults to `False`):
+                Whether to only load local model weights and configuration files or not. If set to `True`, the model
+                won't be downloaded from the Hub.
+            token (`str` or *bool*, *optional*):
+                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+                `diffusers-cli login` (stored in `~/.huggingface`) is used.
+            revision (`str`, *optional*, defaults to `"main"`):
+                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+                allowed by Git.
+            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+                Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+                argument to `True` will raise an error.
+        """
+
+        # handle the list inputs for multiple IP Adapters
+        if not isinstance(weight_name, list):
+            weight_name = [weight_name]
+
+        if not isinstance(pretrained_model_name_or_path_or_dict, list):
+            pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
+        if len(pretrained_model_name_or_path_or_dict) == 1:
+            pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
+
+        if not isinstance(subfolder, list):
+            subfolder = [subfolder]
+        if len(subfolder) == 1:
+            subfolder = subfolder * len(weight_name)
+
+        if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
+            raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
+
+        if len(weight_name) != len(subfolder):
+            raise ValueError("`weight_name` and `subfolder` must have the same length.")
+
+        # Load the main state dict first.
+        cache_dir = kwargs.pop("cache_dir", None)
+        force_download = kwargs.pop("force_download", False)
+        resume_download = kwargs.pop("resume_download", False)
+        proxies = kwargs.pop("proxies", None)
+        local_files_only = kwargs.pop("local_files_only", None)
+        token = kwargs.pop("token", None)
+        revision = kwargs.pop("revision", None)
+        low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+
+        if low_cpu_mem_usage and not is_accelerate_available():
+            low_cpu_mem_usage = False
+            logger.warning(
+                "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+                " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+                " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+                " install accelerate\n```\n."
+            )
+
+        if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+            raise NotImplementedError(
+                "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+                " `low_cpu_mem_usage=False`."
+            )
+
+        user_agent = {
+            "file_type": "attn_procs_weights",
+            "framework": "pytorch",
+        }
+        state_dicts = []
+        for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
+            pretrained_model_name_or_path_or_dict, weight_name, subfolder
+        ):
+            if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+                model_file = _get_model_file(
+                    pretrained_model_name_or_path_or_dict,
+                    weights_name=weight_name,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    resume_download=resume_download,
+                    proxies=proxies,
+                    local_files_only=local_files_only,
+                    token=token,
+                    revision=revision,
+                    subfolder=subfolder,
+                    user_agent=user_agent,
+                )
+                if weight_name.endswith(".safetensors"):
+                    state_dict = {"image_proj": {}, "ip_adapter": {}}
+                    with safe_open(model_file, framework="pt", device="cpu") as f:
+                        for key in f.keys():
+                            if key.startswith("image_proj."):
+                                state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
+                            elif key.startswith("ip_adapter."):
+                                state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
+                else:
+                    state_dict = torch.load(model_file, map_location="cpu")
+            else:
+                state_dict = pretrained_model_name_or_path_or_dict
+
+            keys = list(state_dict.keys())
+            if keys != ["image_proj", "ip_adapter"]:
+                raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
+
+            state_dicts.append(state_dict)
+
+            # load CLIP image encoder here if it has not been registered to the pipeline yet
+            if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
+                if image_encoder_folder is not None:
+                    if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+                        logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
+                        if image_encoder_folder.count("/") == 0:
+                            image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
+                        else:
+                            image_encoder_subfolder = Path(image_encoder_folder).as_posix()
+
+                        image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+                            pretrained_model_name_or_path_or_dict,
+                            subfolder=image_encoder_subfolder,
+                            low_cpu_mem_usage=low_cpu_mem_usage,
+                        ).to(self.device, dtype=self.dtype)
+                        self.register_modules(image_encoder=image_encoder)
+                    else:
+                        raise ValueError(
+                            "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
+                        )
+                else:
+                    logger.warning(
+                        "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
+                        "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
+                    )
+
+            # create feature extractor if it has not been registered to the pipeline yet
+            if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
+                feature_extractor = CLIPImageProcessor()
+                self.register_modules(feature_extractor=feature_extractor)
+
+        # load ip-adapter into unet
+        unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
+        unet._load_ip_adapter_weights(state_dicts)
+
+    def set_ip_adapter_scale(self, scale):
+        """
+        Sets the conditioning scale between text and image.
+
+        Example:
+
+        ```py
+        pipeline.set_ip_adapter_scale(0.5)
+        ```
+        """
+        unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
+        for attn_processor in unet.attn_processors.values():
+            if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
+                if not isinstance(scale, list):
+                    scale = [scale] * len(attn_processor.scale)
+                if len(attn_processor.scale) != len(scale):
+                    raise ValueError(
+                        f"`scale` should be a list of same length as the number if ip-adapters "
+                        f"Expected {len(attn_processor.scale)} but got {len(scale)}."
+                    )
+                attn_processor.scale = scale
+
+    def unload_ip_adapter(self):
+        """
+        Unloads the IP Adapter weights
+
+        Examples:
+
+        ```python
+        >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
+        >>> pipeline.unload_ip_adapter()
+        >>> ...
+        ```
+        """
+        # remove CLIP image encoder
+        if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
+            self.image_encoder = None
+            self.register_to_config(image_encoder=[None, None])
+
+        # remove feature extractor only when safety_checker is None as safety_checker uses
+        # the feature_extractor later
+        if not hasattr(self, "safety_checker"):
+            if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
+                self.feature_extractor = None
+                self.register_to_config(feature_extractor=[None, None])
+
+        # remove hidden encoder
+        self.unet.encoder_hid_proj = None
+        self.config.encoder_hid_dim_type = None
+
+        # restore original Unet attention processors layers
+        self.unet.set_default_attn_processor()
+
+
+class VPAdapterMixin:
+    """Mixin for handling IP Adapters."""
+
+    @validate_hf_hub_args
+    def load_ip_adapter(
+        self,
+        pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
+        subfolder: Union[str, List[str]],
+        weight_name: Union[str, List[str]],
+        image_encoder_folder: Optional[str] = "image_encoder",
+        **kwargs,
+    ):
+        """
+        Parameters:
+            pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
+                Can be either:
+
+                    - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+                      the Hub.
+                    - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+                      with [`ModelMixin.save_pretrained`].
+                    - A [torch state
+                      dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+            subfolder (`str` or `List[str]`):
+                The subfolder location of a model file within a larger model repository on the Hub or locally.
+                If a list is passed, it should have the same length as `weight_name`.
+            weight_name (`str` or `List[str]`):
+                The name of the weight file to load. If a list is passed, it should have the same length as
+                `weight_name`.
+            image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
+                The subfolder location of the image encoder within a larger model repository on the Hub or locally.
+                Pass `None` to not load the image encoder. If the image encoder is located in a folder inside `subfolder`,
+                you only need to pass the name of the folder that contains image encoder weights, e.g. `image_encoder_folder="image_encoder"`.
+                If the image encoder is located in a folder other than `subfolder`, you should pass the path to the folder that contains image encoder weights,
+                for example, `image_encoder_folder="different_subfolder/image_encoder"`.
+            cache_dir (`Union[str, os.PathLike]`, *optional*):
+                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+                is not used.
+            force_download (`bool`, *optional*, defaults to `False`):
+                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+                cached versions if they exist.
+            resume_download (`bool`, *optional*, defaults to `False`):
+                Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+                incompletely downloaded files are deleted.
+            proxies (`Dict[str, str]`, *optional*):
+                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+            local_files_only (`bool`, *optional*, defaults to `False`):
+                Whether to only load local model weights and configuration files or not. If set to `True`, the model
+                won't be downloaded from the Hub.
+            token (`str` or *bool*, *optional*):
+                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+                `diffusers-cli login` (stored in `~/.huggingface`) is used.
+            revision (`str`, *optional*, defaults to `"main"`):
+                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+                allowed by Git.
+            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+                Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+                argument to `True` will raise an error.
+        """
+
+        # handle the list inputs for multiple IP Adapters
+        if not isinstance(weight_name, list):
+            weight_name = [weight_name]
+
+        if not isinstance(pretrained_model_name_or_path_or_dict, list):
+            pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
+        if len(pretrained_model_name_or_path_or_dict) == 1:
+            pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
+
+        if not isinstance(subfolder, list):
+            subfolder = [subfolder]
+        if len(subfolder) == 1:
+            subfolder = subfolder * len(weight_name)
+
+        if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
+            raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
+
+        if len(weight_name) != len(subfolder):
+            raise ValueError("`weight_name` and `subfolder` must have the same length.")
+
+        # Load the main state dict first.
+        cache_dir = kwargs.pop("cache_dir", None)
+        force_download = kwargs.pop("force_download", False)
+        resume_download = kwargs.pop("resume_download", False)
+        proxies = kwargs.pop("proxies", None)
+        local_files_only = kwargs.pop("local_files_only", None)
+        token = kwargs.pop("token", None)
+        revision = kwargs.pop("revision", None)
+        low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+
+        if low_cpu_mem_usage and not is_accelerate_available():
+            low_cpu_mem_usage = False
+            logger.warning(
+                "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+                " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+                " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+                " install accelerate\n```\n."
+            )
+
+        if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+            raise NotImplementedError(
+                "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+                " `low_cpu_mem_usage=False`."
+            )
+
+        user_agent = {
+            "file_type": "attn_procs_weights",
+            "framework": "pytorch",
+        }
+        state_dicts = []
+        for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
+            pretrained_model_name_or_path_or_dict, weight_name, subfolder
+        ):
+            if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+                model_file = _get_model_file(
+                    pretrained_model_name_or_path_or_dict,
+                    weights_name=weight_name,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    resume_download=resume_download,
+                    proxies=proxies,
+                    local_files_only=local_files_only,
+                    token=token,
+                    revision=revision,
+                    subfolder=subfolder,
+                    user_agent=user_agent,
+                )
+                if weight_name.endswith(".safetensors"):
+                    state_dict = {"image_proj": {}, "ip_adapter": {}}
+                    with safe_open(model_file, framework="pt", device="cpu") as f:
+                        for key in f.keys():
+                            if key.startswith("image_proj."):
+                                state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
+                            elif key.startswith("ip_adapter."):
+                                state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
+                else:
+                    state_dict = torch.load(model_file, map_location="cpu")
+            else:
+                state_dict = pretrained_model_name_or_path_or_dict
+
+            keys = list(state_dict.keys())
+            if keys != ["image_proj", "ip_adapter"]:
+                raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
+
+            state_dicts.append(state_dict)
+
+            # load CLIP image encoder here if it has not been registered to the pipeline yet
+            if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
+                if image_encoder_folder is not None:
+                    if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+                        logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
+                        if image_encoder_folder.count("/") == 0:
+                            image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
+                        else:
+                            image_encoder_subfolder = Path(image_encoder_folder).as_posix()
+
+                        image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+                            pretrained_model_name_or_path_or_dict,
+                            subfolder=image_encoder_subfolder,
+                            low_cpu_mem_usage=low_cpu_mem_usage,
+                        ).to(self.device, dtype=self.dtype)
+                        self.register_modules(image_encoder=image_encoder)
+                    else:
+                        raise ValueError(
+                            "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
+                        )
+                else:
+                    logger.warning(
+                        "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
+                        "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
+                    )
+
+            # create feature extractor if it has not been registered to the pipeline yet
+            if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
+                feature_extractor = CLIPImageProcessor()
+                self.register_modules(feature_extractor=feature_extractor)
+
+        # load ip-adapter into unet
+        unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
+        unet._load_ip_adapter_weights_VPAdapter(state_dicts)
+
+    def set_ip_adapter_scale(self, scale):
+        """
+        Sets the conditioning scale between text and image.
+
+        Example:
+
+        ```py
+        pipeline.set_ip_adapter_scale(0.5)
+        ```
+        """
+        unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
+        for attn_processor in unet.attn_processors.values():
+            if isinstance(attn_processor, (IPAdapterAttnProcessor, VPTemporalAdapterAttnProcessor2_0)):
+                if not isinstance(scale, list):
+                    scale = [scale] * len(attn_processor.scale)
+                if len(attn_processor.scale) != len(scale):
+                    raise ValueError(
+                        f"`scale` should be a list of same length as the number if ip-adapters "
+                        f"Expected {len(attn_processor.scale)} but got {len(scale)}."
+                    )
+                attn_processor.scale = scale
+
+    def unload_ip_adapter(self):
+        """
+        Unloads the IP Adapter weights
+
+        Examples:
+
+        ```python
+        >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
+        >>> pipeline.unload_ip_adapter()
+        >>> ...
+        ```
+        """
+        # remove CLIP image encoder
+        if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
+            self.image_encoder = None
+            self.register_to_config(image_encoder=[None, None])
+
+        # remove feature extractor only when safety_checker is None as safety_checker uses
+        # the feature_extractor later
+        if not hasattr(self, "safety_checker"):
+            if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
+                self.feature_extractor = None
+                self.register_to_config(feature_extractor=[None, None])
+
+        # remove hidden encoder
+        self.unet.encoder_hid_proj = None
+        self.config.encoder_hid_dim_type = None
+
+        # restore original Unet attention processors layers
+        self.unet.set_default_attn_processor()
diff --git a/foleycrafter/models/auffusion/loaders/unet.py b/foleycrafter/models/auffusion/loaders/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6ab346cb819ab59126ddffc18a548dae9242063
--- /dev/null
+++ b/foleycrafter/models/auffusion/loaders/unet.py
@@ -0,0 +1,1100 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import inspect
+import os
+from collections import defaultdict
+from contextlib import nullcontext
+from functools import partial
+from typing import Callable, Dict, List, Optional, Union, Tuple
+
+import safetensors
+import torch
+import torch.nn.functional as F
+from huggingface_hub.utils import validate_hf_hub_args
+from torch import nn
+
+from diffusers.models.embeddings import ImageProjection, MLPProjection, Resampler
+from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
+from diffusers.utils import (
+    USE_PEFT_BACKEND,
+    _get_model_file,
+    delete_adapter_layers,
+    is_accelerate_available,
+    logging,
+    is_torch_version,
+    set_adapter_layers,
+    set_weights_and_activate_adapters,
+)
+from diffusers.loaders.utils import AttnProcsLayers
+
+from foleycrafter.models.adapters.ip_adapter import VideoProjModel
+from foleycrafter.models.auffusion.attention_processor import IPAdapterAttnProcessor2_0, VPTemporalAdapterAttnProcessor2_0, AttnProcessor2_0
+
+
+if is_accelerate_available():
+    from accelerate import init_empty_weights
+    from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
+
+logger = logging.get_logger(__name__)
+
+class VPAdapterImageProjection(nn.Module):
+    def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
+        super().__init__()
+        self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
+
+    def forward(self, image_embeds: List[torch.FloatTensor]):
+        projected_image_embeds = []
+
+        # currently, we accept `image_embeds` as
+        #  1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
+        #  2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
+        if not isinstance(image_embeds, list):
+            deprecation_message = (
+                "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
+                " Please make sure to update your script to pass `image_embeds` as a list of tensors to supress this warning."
+            )
+            image_embeds = [image_embeds.unsqueeze(1)]
+
+        if len(image_embeds) != len(self.image_projection_layers):
+            raise ValueError(
+                f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
+            )
+
+        for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):
+            image_embed = image_embed.squeeze(1)
+            batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
+            image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
+            image_embed = image_projection_layer(image_embed)
+            image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
+
+            projected_image_embeds.append(image_embed)
+
+        return projected_image_embeds
+
+class MultiIPAdapterImageProjection(nn.Module):
+    def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
+        super().__init__()
+        self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
+
+    def forward(self, image_embeds: List[torch.FloatTensor]):
+        projected_image_embeds = []
+
+        # currently, we accept `image_embeds` as
+        #  1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
+        #  2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
+        if not isinstance(image_embeds, list):
+            deprecation_message = (
+                "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
+                " Please make sure to update your script to pass `image_embeds` as a list of tensors to supress this warning."
+            )
+            image_embeds = [image_embeds.unsqueeze(1)]
+
+        if len(image_embeds) != len(self.image_projection_layers):
+            raise ValueError(
+                f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
+            )
+
+        for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):
+            batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
+            image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
+            image_embed = image_projection_layer(image_embed)
+            image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
+
+            projected_image_embeds.append(image_embed)
+
+        return projected_image_embeds
+
+
+TEXT_ENCODER_NAME = "text_encoder"
+UNET_NAME = "unet"
+
+LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
+LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
+
+CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
+CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
+
+
+class UNet2DConditionLoadersMixin:
+    """
+    Load LoRA layers into a [`UNet2DCondtionModel`].
+    """
+
+    text_encoder_name = TEXT_ENCODER_NAME
+    unet_name = UNET_NAME
+
+    @validate_hf_hub_args
+    def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
+        r"""
+        Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
+        defined in
+        [`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)
+        and be a `torch.nn.Module` class.
+
+        Parameters:
+            pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+                Can be either:
+
+                    - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+                      the Hub.
+                    - A path to a directory (for example `./my_model_directory`) containing the model weights saved
+                      with [`ModelMixin.save_pretrained`].
+                    - A [torch state
+                      dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+            cache_dir (`Union[str, os.PathLike]`, *optional*):
+                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+                is not used.
+            force_download (`bool`, *optional*, defaults to `False`):
+                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+                cached versions if they exist.
+            resume_download (`bool`, *optional*, defaults to `False`):
+                Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+                incompletely downloaded files are deleted.
+            proxies (`Dict[str, str]`, *optional*):
+                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+            local_files_only (`bool`, *optional*, defaults to `False`):
+                Whether to only load local model weights and configuration files or not. If set to `True`, the model
+                won't be downloaded from the Hub.
+            token (`str` or *bool*, *optional*):
+                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+                `diffusers-cli login` (stored in `~/.huggingface`) is used.
+            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+                Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+                argument to `True` will raise an error.
+            revision (`str`, *optional*, defaults to `"main"`):
+                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+                allowed by Git.
+            subfolder (`str`, *optional*, defaults to `""`):
+                The subfolder location of a model file within a larger model repository on the Hub or locally.
+            mirror (`str`, *optional*):
+                Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
+                guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+                information.
+
+        Example:
+
+        ```py
+        from diffusers import AutoPipelineForText2Image
+        import torch
+
+        pipeline = AutoPipelineForText2Image.from_pretrained(
+            "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+        ).to("cuda")
+        pipeline.unet.load_attn_procs(
+            "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
+        )
+        ```
+        """
+        from diffusers.models.attention_processor import CustomDiffusionAttnProcessor
+        from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
+
+        cache_dir = kwargs.pop("cache_dir", None)
+        force_download = kwargs.pop("force_download", False)
+        resume_download = kwargs.pop("resume_download", False)
+        proxies = kwargs.pop("proxies", None)
+        local_files_only = kwargs.pop("local_files_only", None)
+        token = kwargs.pop("token", None)
+        revision = kwargs.pop("revision", None)
+        subfolder = kwargs.pop("subfolder", None)
+        weight_name = kwargs.pop("weight_name", None)
+        use_safetensors = kwargs.pop("use_safetensors", None)
+        low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+        # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
+        # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
+        network_alphas = kwargs.pop("network_alphas", None)
+
+        _pipeline = kwargs.pop("_pipeline", None)
+
+        is_network_alphas_none = network_alphas is None
+
+        allow_pickle = False
+
+        if use_safetensors is None:
+            use_safetensors = True
+            allow_pickle = True
+
+        user_agent = {
+            "file_type": "attn_procs_weights",
+            "framework": "pytorch",
+        }
+
+        if low_cpu_mem_usage and not is_accelerate_available():
+            low_cpu_mem_usage = False
+            logger.warning(
+                "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+                " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+                " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+                " install accelerate\n```\n."
+            )
+
+        model_file = None
+        if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+            # Let's first try to load .safetensors weights
+            if (use_safetensors and weight_name is None) or (
+                weight_name is not None and weight_name.endswith(".safetensors")
+            ):
+                try:
+                    model_file = _get_model_file(
+                        pretrained_model_name_or_path_or_dict,
+                        weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
+                        cache_dir=cache_dir,
+                        force_download=force_download,
+                        resume_download=resume_download,
+                        proxies=proxies,
+                        local_files_only=local_files_only,
+                        token=token,
+                        revision=revision,
+                        subfolder=subfolder,
+                        user_agent=user_agent,
+                    )
+                    state_dict = safetensors.torch.load_file(model_file, device="cpu")
+                except IOError as e:
+                    if not allow_pickle:
+                        raise e
+                    # try loading non-safetensors weights
+                    pass
+            if model_file is None:
+                model_file = _get_model_file(
+                    pretrained_model_name_or_path_or_dict,
+                    weights_name=weight_name or LORA_WEIGHT_NAME,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    resume_download=resume_download,
+                    proxies=proxies,
+                    local_files_only=local_files_only,
+                    token=token,
+                    revision=revision,
+                    subfolder=subfolder,
+                    user_agent=user_agent,
+                )
+                state_dict = torch.load(model_file, map_location="cpu")
+        else:
+            state_dict = pretrained_model_name_or_path_or_dict
+
+        # fill attn processors
+        lora_layers_list = []
+
+        is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND
+        is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
+
+        if is_lora:
+            # correct keys
+            state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas)
+
+            if network_alphas is not None:
+                network_alphas_keys = list(network_alphas.keys())
+                used_network_alphas_keys = set()
+
+            lora_grouped_dict = defaultdict(dict)
+            mapped_network_alphas = {}
+
+            all_keys = list(state_dict.keys())
+            for key in all_keys:
+                value = state_dict.pop(key)
+                attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
+                lora_grouped_dict[attn_processor_key][sub_key] = value
+
+                # Create another `mapped_network_alphas` dictionary so that we can properly map them.
+                if network_alphas is not None:
+                    for k in network_alphas_keys:
+                        if k.replace(".alpha", "") in key:
+                            mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)})
+                            used_network_alphas_keys.add(k)
+
+            if not is_network_alphas_none:
+                if len(set(network_alphas_keys) - used_network_alphas_keys) > 0:
+                    raise ValueError(
+                        f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
+                    )
+
+            if len(state_dict) > 0:
+                raise ValueError(
+                    f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
+                )
+
+            for key, value_dict in lora_grouped_dict.items():
+                attn_processor = self
+                for sub_key in key.split("."):
+                    attn_processor = getattr(attn_processor, sub_key)
+
+                # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
+                # or add_{k,v,q,out_proj}_proj_lora layers.
+                rank = value_dict["lora.down.weight"].shape[0]
+
+                if isinstance(attn_processor, LoRACompatibleConv):
+                    in_features = attn_processor.in_channels
+                    out_features = attn_processor.out_channels
+                    kernel_size = attn_processor.kernel_size
+
+                    ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
+                    with ctx():
+                        lora = LoRAConv2dLayer(
+                            in_features=in_features,
+                            out_features=out_features,
+                            rank=rank,
+                            kernel_size=kernel_size,
+                            stride=attn_processor.stride,
+                            padding=attn_processor.padding,
+                            network_alpha=mapped_network_alphas.get(key),
+                        )
+                elif isinstance(attn_processor, LoRACompatibleLinear):
+                    ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
+                    with ctx():
+                        lora = LoRALinearLayer(
+                            attn_processor.in_features,
+                            attn_processor.out_features,
+                            rank,
+                            mapped_network_alphas.get(key),
+                        )
+                else:
+                    raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
+
+                value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
+                lora_layers_list.append((attn_processor, lora))
+
+                if low_cpu_mem_usage:
+                    device = next(iter(value_dict.values())).device
+                    dtype = next(iter(value_dict.values())).dtype
+                    load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
+                else:
+                    lora.load_state_dict(value_dict)
+
+        elif is_custom_diffusion:
+            attn_processors = {}
+            custom_diffusion_grouped_dict = defaultdict(dict)
+            for key, value in state_dict.items():
+                if len(value) == 0:
+                    custom_diffusion_grouped_dict[key] = {}
+                else:
+                    if "to_out" in key:
+                        attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
+                    else:
+                        attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
+                    custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
+
+            for key, value_dict in custom_diffusion_grouped_dict.items():
+                if len(value_dict) == 0:
+                    attn_processors[key] = CustomDiffusionAttnProcessor(
+                        train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
+                    )
+                else:
+                    cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
+                    hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
+                    train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
+                    attn_processors[key] = CustomDiffusionAttnProcessor(
+                        train_kv=True,
+                        train_q_out=train_q_out,
+                        hidden_size=hidden_size,
+                        cross_attention_dim=cross_attention_dim,
+                    )
+                    attn_processors[key].load_state_dict(value_dict)
+        elif USE_PEFT_BACKEND:
+            # In that case we have nothing to do as loading the adapter weights is already handled above by `set_peft_model_state_dict`
+            # on the Unet
+            pass
+        else:
+            raise ValueError(
+                f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
+            )
+
+        # <Unsafe code
+        # We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
+        # Now we remove any existing hooks to
+        is_model_cpu_offload = False
+        is_sequential_cpu_offload = False
+
+        # For PEFT backend the Unet is already offloaded at this stage as it is handled inside `lora_lora_weights_into_unet`
+        if not USE_PEFT_BACKEND:
+            if _pipeline is not None:
+                for _, component in _pipeline.components.items():
+                    if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
+                        is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
+                        is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
+
+                        logger.info(
+                            "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
+                        )
+                        remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
+
+            # only custom diffusion needs to set attn processors
+            if is_custom_diffusion:
+                self.set_attn_processor(attn_processors)
+
+            # set lora layers
+            for target_module, lora_layer in lora_layers_list:
+                target_module.set_lora_layer(lora_layer)
+
+            self.to(dtype=self.dtype, device=self.device)
+
+            # Offload back.
+            if is_model_cpu_offload:
+                _pipeline.enable_model_cpu_offload()
+            elif is_sequential_cpu_offload:
+                _pipeline.enable_sequential_cpu_offload()
+            # Unsafe code />
+
+    def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
+        is_new_lora_format = all(
+            key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
+        )
+        if is_new_lora_format:
+            # Strip the `"unet"` prefix.
+            is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
+            if is_text_encoder_present:
+                warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
+                logger.warn(warn_message)
+            unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
+            state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
+
+        # change processor format to 'pure' LoRACompatibleLinear format
+        if any("processor" in k.split(".") for k in state_dict.keys()):
+
+            def format_to_lora_compatible(key):
+                if "processor" not in key.split("."):
+                    return key
+                return key.replace(".processor", "").replace("to_out_lora", "to_out.0.lora").replace("_lora", ".lora")
+
+            state_dict = {format_to_lora_compatible(k): v for k, v in state_dict.items()}
+
+            if network_alphas is not None:
+                network_alphas = {format_to_lora_compatible(k): v for k, v in network_alphas.items()}
+        return state_dict, network_alphas
+
+    def save_attn_procs(
+        self,
+        save_directory: Union[str, os.PathLike],
+        is_main_process: bool = True,
+        weight_name: str = None,
+        save_function: Callable = None,
+        safe_serialization: bool = True,
+        **kwargs,
+    ):
+        r"""
+        Save attention processor layers to a directory so that it can be reloaded with the
+        [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.
+
+        Arguments:
+            save_directory (`str` or `os.PathLike`):
+                Directory to save an attention processor to (will be created if it doesn't exist).
+            is_main_process (`bool`, *optional*, defaults to `True`):
+                Whether the process calling this is the main process or not. Useful during distributed training and you
+                need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+                process to avoid race conditions.
+            save_function (`Callable`):
+                The function to use to save the state dictionary. Useful during distributed training when you need to
+                replace `torch.save` with another method. Can be configured with the environment variable
+                `DIFFUSERS_SAVE_MODE`.
+            safe_serialization (`bool`, *optional*, defaults to `True`):
+                Whether to save the model using `safetensors` or with `pickle`.
+
+        Example:
+
+        ```py
+        import torch
+        from diffusers import DiffusionPipeline
+
+        pipeline = DiffusionPipeline.from_pretrained(
+            "CompVis/stable-diffusion-v1-4",
+            torch_dtype=torch.float16,
+        ).to("cuda")
+        pipeline.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
+        pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
+        ```
+        """
+        from diffusers.models.attention_processor import (
+            CustomDiffusionAttnProcessor,
+            CustomDiffusionAttnProcessor2_0,
+            CustomDiffusionXFormersAttnProcessor,
+        )
+
+        if os.path.isfile(save_directory):
+            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+            return
+
+        if save_function is None:
+            if safe_serialization:
+
+                def save_function(weights, filename):
+                    return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
+
+            else:
+                save_function = torch.save
+
+        os.makedirs(save_directory, exist_ok=True)
+
+        is_custom_diffusion = any(
+            isinstance(
+                x,
+                (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
+            )
+            for (_, x) in self.attn_processors.items()
+        )
+        if is_custom_diffusion:
+            model_to_save = AttnProcsLayers(
+                {
+                    y: x
+                    for (y, x) in self.attn_processors.items()
+                    if isinstance(
+                        x,
+                        (
+                            CustomDiffusionAttnProcessor,
+                            CustomDiffusionAttnProcessor2_0,
+                            CustomDiffusionXFormersAttnProcessor,
+                        ),
+                    )
+                }
+            )
+            state_dict = model_to_save.state_dict()
+            for name, attn in self.attn_processors.items():
+                if len(attn.state_dict()) == 0:
+                    state_dict[name] = {}
+        else:
+            model_to_save = AttnProcsLayers(self.attn_processors)
+            state_dict = model_to_save.state_dict()
+
+        if weight_name is None:
+            if safe_serialization:
+                weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
+            else:
+                weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
+
+        # Save the model
+        save_function(state_dict, os.path.join(save_directory, weight_name))
+        logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
+
+    def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
+        self.lora_scale = lora_scale
+        self._safe_fusing = safe_fusing
+        self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
+
+    def _fuse_lora_apply(self, module, adapter_names=None):
+        if not USE_PEFT_BACKEND:
+            if hasattr(module, "_fuse_lora"):
+                module._fuse_lora(self.lora_scale, self._safe_fusing)
+
+            if adapter_names is not None:
+                raise ValueError(
+                    "The `adapter_names` argument is not supported in your environment. Please switch"
+                    " to PEFT backend to use this argument by installing latest PEFT and transformers."
+                    " `pip install -U peft transformers`"
+                )
+        else:
+            from peft.tuners.tuners_utils import BaseTunerLayer
+
+            merge_kwargs = {"safe_merge": self._safe_fusing}
+
+            if isinstance(module, BaseTunerLayer):
+                if self.lora_scale != 1.0:
+                    module.scale_layer(self.lora_scale)
+
+                # For BC with prevous PEFT versions, we need to check the signature
+                # of the `merge` method to see if it supports the `adapter_names` argument.
+                supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
+                if "adapter_names" in supported_merge_kwargs:
+                    merge_kwargs["adapter_names"] = adapter_names
+                elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
+                    raise ValueError(
+                        "The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
+                        " to the latest version of PEFT. `pip install -U peft`"
+                    )
+
+                module.merge(**merge_kwargs)
+
+    def unfuse_lora(self):
+        self.apply(self._unfuse_lora_apply)
+
+    def _unfuse_lora_apply(self, module):
+        if not USE_PEFT_BACKEND:
+            if hasattr(module, "_unfuse_lora"):
+                module._unfuse_lora()
+        else:
+            from peft.tuners.tuners_utils import BaseTunerLayer
+
+            if isinstance(module, BaseTunerLayer):
+                module.unmerge()
+
+    def set_adapters(
+        self,
+        adapter_names: Union[List[str], str],
+        weights: Optional[Union[List[float], float]] = None,
+    ):
+        """
+        Set the currently active adapters for use in the UNet.
+
+        Args:
+            adapter_names (`List[str]` or `str`):
+                The names of the adapters to use.
+            adapter_weights (`Union[List[float], float]`, *optional*):
+                The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
+                adapters.
+
+        Example:
+
+        ```py
+        from diffusers import AutoPipelineForText2Image
+        import torch
+
+        pipeline = AutoPipelineForText2Image.from_pretrained(
+            "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+        ).to("cuda")
+        pipeline.load_lora_weights(
+            "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
+        )
+        pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+        pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
+        ```
+        """
+        if not USE_PEFT_BACKEND:
+            raise ValueError("PEFT backend is required for `set_adapters()`.")
+
+        adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
+
+        if weights is None:
+            weights = [1.0] * len(adapter_names)
+        elif isinstance(weights, float):
+            weights = [weights] * len(adapter_names)
+
+        if len(adapter_names) != len(weights):
+            raise ValueError(
+                f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
+            )
+
+        set_weights_and_activate_adapters(self, adapter_names, weights)
+
+    def disable_lora(self):
+        """
+        Disable the UNet's active LoRA layers.
+
+        Example:
+
+        ```py
+        from diffusers import AutoPipelineForText2Image
+        import torch
+
+        pipeline = AutoPipelineForText2Image.from_pretrained(
+            "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+        ).to("cuda")
+        pipeline.load_lora_weights(
+            "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
+        )
+        pipeline.disable_lora()
+        ```
+        """
+        if not USE_PEFT_BACKEND:
+            raise ValueError("PEFT backend is required for this method.")
+        set_adapter_layers(self, enabled=False)
+
+    def enable_lora(self):
+        """
+        Enable the UNet's active LoRA layers.
+
+        Example:
+
+        ```py
+        from diffusers import AutoPipelineForText2Image
+        import torch
+
+        pipeline = AutoPipelineForText2Image.from_pretrained(
+            "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+        ).to("cuda")
+        pipeline.load_lora_weights(
+            "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
+        )
+        pipeline.enable_lora()
+        ```
+        """
+        if not USE_PEFT_BACKEND:
+            raise ValueError("PEFT backend is required for this method.")
+        set_adapter_layers(self, enabled=True)
+
+    def delete_adapters(self, adapter_names: Union[List[str], str]):
+        """
+        Delete an adapter's LoRA layers from the UNet.
+
+        Args:
+            adapter_names (`Union[List[str], str]`):
+                The names (single string or list of strings) of the adapter to delete.
+
+        Example:
+
+        ```py
+        from diffusers import AutoPipelineForText2Image
+        import torch
+
+        pipeline = AutoPipelineForText2Image.from_pretrained(
+            "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+        ).to("cuda")
+        pipeline.load_lora_weights(
+            "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
+        )
+        pipeline.delete_adapters("cinematic")
+        ```
+        """
+        if not USE_PEFT_BACKEND:
+            raise ValueError("PEFT backend is required for this method.")
+
+        if isinstance(adapter_names, str):
+            adapter_names = [adapter_names]
+
+        for adapter_name in adapter_names:
+            delete_adapter_layers(self, adapter_name)
+
+            # Pop also the corresponding adapter from the config
+            if hasattr(self, "peft_config"):
+                self.peft_config.pop(adapter_name, None)
+
+    def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
+        if low_cpu_mem_usage:
+            if is_accelerate_available():
+                from accelerate import init_empty_weights
+
+            else:
+                low_cpu_mem_usage = False
+                logger.warning(
+                    "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+                    " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+                    " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+                    " install accelerate\n```\n."
+                )
+
+        if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+            raise NotImplementedError(
+                "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+                " `low_cpu_mem_usage=False`."
+            )
+
+        updated_state_dict = {}
+        image_projection = None
+        init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
+
+        if "proj.weight" in state_dict:
+            # IP-Adapter
+            num_image_text_embeds = 4
+            clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
+            cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
+
+            with init_context():
+                image_projection = ImageProjection(
+                    cross_attention_dim=cross_attention_dim,
+                    image_embed_dim=clip_embeddings_dim,
+                    num_image_text_embeds=num_image_text_embeds,
+                )
+
+            for key, value in state_dict.items():
+                diffusers_name = key.replace("proj", "image_embeds")
+                updated_state_dict[diffusers_name] = value
+
+        if not low_cpu_mem_usage:
+            image_projection.load_state_dict(updated_state_dict)
+        else:
+            load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
+
+        return image_projection
+
+    # def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, multi_frames_condition):
+    #     updated_state_dict = {}
+    #     image_projection = None
+
+    #     if "proj.weight" in state_dict:
+    #         # IP-Adapter
+    #         # NOTE: adapt for  multi-frame
+    #         num_image_text_embeds = 4
+    #         clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
+    #         cross_attention_dim = state_dict["proj.weight"].shape[0] // 4
+    #         # cross_attention_dim = state_dict["proj.weight"].shape[0]
+
+    #         if not multi_frames_condition:
+    #             image_projection = ImageProjection(
+    #                 cross_attention_dim=cross_attention_dim,
+    #                 image_embed_dim=clip_embeddings_dim,
+    #                 num_image_text_embeds=num_image_text_embeds,
+    #             )
+    #         else:
+    #             num_image_text_embeds = 50
+    #             cross_attention_dim = state_dict["proj.weight"].shape[0]
+    #             image_projection = VideoProjModel(
+    #                 cross_attention_dim=cross_attention_dim, 
+    #                 clip_embeddings_dim=clip_embeddings_dim, 
+    #                 clip_extra_context_tokens=1,
+    #                 video_frame=num_image_text_embeds,
+    #             )
+
+    #         for key, value in state_dict.items():
+    #             if not multi_frames_condition:
+    #                 diffusers_name = key.replace("proj", "image_embeds")
+    #             else:
+    #                 diffusers_name = key
+    #             updated_state_dict[diffusers_name] = value
+
+    #     elif "proj.3.weight" in state_dict:
+    #         # IP-Adapter Full
+    #         clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
+    #         cross_attention_dim = state_dict["proj.3.weight"].shape[0]
+
+    #         image_projection = MLPProjection(
+    #             cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
+    #         )
+
+    #         for key, value in state_dict.items():
+    #             diffusers_name = key.replace("proj.0", "ff.net.0.proj")
+    #             diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
+    #             diffusers_name = diffusers_name.replace("proj.3", "norm")
+    #             updated_state_dict[diffusers_name] = value
+
+    #     else:
+    #         # IP-Adapter Plus
+    #         num_image_text_embeds = state_dict["latents"].shape[1]
+    #         embed_dims = state_dict["proj_in.weight"].shape[1]
+    #         output_dims = state_dict["proj_out.weight"].shape[0]
+    #         hidden_dims = state_dict["latents"].shape[2]
+    #         heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
+
+    #         image_projection = Resampler(
+    #             embed_dims=embed_dims,
+    #             output_dims=output_dims,
+    #             hidden_dims=hidden_dims,
+    #             heads=heads,
+    #             num_queries=num_image_text_embeds,
+    #         )
+
+    #         for key, value in state_dict.items():
+    #             diffusers_name = key.replace("0.to", "2.to")
+    #             diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight")
+    #             diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias")
+    #             diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight")
+    #             diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight")
+
+    #             if "norm1" in diffusers_name:
+    #                 updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
+    #             elif "norm2" in diffusers_name:
+    #                 updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
+    #             elif "to_kv" in diffusers_name:
+    #                 v_chunk = value.chunk(2, dim=0)
+    #                 updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
+    #                 updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
+    #             elif "to_out" in diffusers_name:
+    #                 updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
+    #             else:
+    #                 updated_state_dict[diffusers_name] = value
+
+    #     image_projection.load_state_dict(updated_state_dict)
+    #     return image_projection
+
+    def _convert_ip_adapter_attn_to_diffusers_VPAdapter(self, state_dicts, low_cpu_mem_usage=False):
+        from diffusers.models.attention_processor import (
+            AttnProcessor,
+            IPAdapterAttnProcessor,
+        )
+
+        if low_cpu_mem_usage:
+            if is_accelerate_available():
+                from accelerate import init_empty_weights
+
+            else:
+                low_cpu_mem_usage = False
+                logger.warning(
+                    "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+                    " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+                    " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+                    " install accelerate\n```\n."
+                )
+
+        if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+            raise NotImplementedError(
+                "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+                " `low_cpu_mem_usage=False`."
+            )
+
+        # set ip-adapter cross-attention processors & load state_dict
+        attn_procs = {}
+        key_id = 1
+        init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
+        for name in self.attn_processors.keys():
+            cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
+            if name.startswith("mid_block"):
+                hidden_size = self.config.block_out_channels[-1]
+            elif name.startswith("up_blocks"):
+                block_id = int(name[len("up_blocks.")])
+                hidden_size = list(reversed(self.config.block_out_channels))[block_id]
+            elif name.startswith("down_blocks"):
+                block_id = int(name[len("down_blocks.")])
+                hidden_size = self.config.block_out_channels[block_id]
+
+            if cross_attention_dim is None or "motion_modules" in name or 'fuser' in name:
+                attn_processor_class = (
+                    AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
+                )
+                attn_procs[name] = attn_processor_class()
+            else:
+                attn_processor_class = (
+                    VPTemporalAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
+                )
+                num_image_text_embeds = []
+                for state_dict in state_dicts:
+                    if "proj.weight" in state_dict["image_proj"]:
+                        # IP-Adapter
+                        num_image_text_embeds += [4]
+                    elif "proj.3.weight" in state_dict["image_proj"]:
+                        # IP-Adapter Full Face
+                        num_image_text_embeds += [257]  # 256 CLIP tokens + 1 CLS token
+                    else:
+                        # IP-Adapter Plus
+                        num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
+
+                with init_context():
+                    attn_procs[name] = attn_processor_class(
+                        hidden_size=hidden_size,
+                        cross_attention_dim=cross_attention_dim,
+                        scale=1.0,
+                        num_tokens=num_image_text_embeds,
+                    )
+
+                value_dict = {}
+                for i, state_dict in enumerate(state_dicts):
+                    value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
+                    value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
+
+                if not low_cpu_mem_usage:
+                    attn_procs[name].load_state_dict(value_dict)
+                else:
+                    device = next(iter(value_dict.values())).device
+                    dtype = next(iter(value_dict.values())).dtype
+                    load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
+
+                key_id += 2
+
+        return attn_procs
+    
+    def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
+        from diffusers.models.attention_processor import (
+            AttnProcessor,
+            IPAdapterAttnProcessor,
+        )
+
+        if low_cpu_mem_usage:
+            if is_accelerate_available():
+                from accelerate import init_empty_weights
+
+            else:
+                low_cpu_mem_usage = False
+                logger.warning(
+                    "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+                    " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+                    " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+                    " install accelerate\n```\n."
+                )
+
+        if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+            raise NotImplementedError(
+                "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+                " `low_cpu_mem_usage=False`."
+            )
+
+        # set ip-adapter cross-attention processors & load state_dict
+        attn_procs = {}
+        key_id = 1
+        init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
+        for name in self.attn_processors.keys():
+            cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
+            if name.startswith("mid_block"):
+                hidden_size = self.config.block_out_channels[-1]
+            elif name.startswith("up_blocks"):
+                block_id = int(name[len("up_blocks.")])
+                hidden_size = list(reversed(self.config.block_out_channels))[block_id]
+            elif name.startswith("down_blocks"):
+                block_id = int(name[len("down_blocks.")])
+                hidden_size = self.config.block_out_channels[block_id]
+
+            if cross_attention_dim is None or "motion_modules" in name or 'fuser' in name:
+                attn_processor_class = (
+                    AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
+                )
+                attn_procs[name] = attn_processor_class()
+            else:
+                attn_processor_class = (
+                    IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
+                )
+                num_image_text_embeds = []
+                for state_dict in state_dicts:
+                    if "proj.weight" in state_dict["image_proj"]:
+                        # IP-Adapter
+                        num_image_text_embeds += [4]
+                    elif "proj.3.weight" in state_dict["image_proj"]:
+                        # IP-Adapter Full Face
+                        num_image_text_embeds += [257]  # 256 CLIP tokens + 1 CLS token
+                    else:
+                        # IP-Adapter Plus
+                        num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
+
+                with init_context():
+                    attn_procs[name] = attn_processor_class(
+                        hidden_size=hidden_size,
+                        cross_attention_dim=cross_attention_dim,
+                        scale=1.0,
+                        num_tokens=num_image_text_embeds,
+                    )
+
+                value_dict = {}
+                for i, state_dict in enumerate(state_dicts):
+                    value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
+                    value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
+
+                if not low_cpu_mem_usage:
+                    attn_procs[name].load_state_dict(value_dict)
+                else:
+                    device = next(iter(value_dict.values())).device
+                    dtype = next(iter(value_dict.values())).dtype
+                    load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
+
+                key_id += 2
+
+        return attn_procs
+
+    def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
+        attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
+        self.set_attn_processor(attn_procs)
+
+        # convert IP-Adapter Image Projection layers to diffusers
+        image_projection_layers = []
+        for state_dict in state_dicts:
+            image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
+                state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
+            )
+            image_projection_layers.append(image_projection_layer)
+
+        self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
+        self.config.encoder_hid_dim_type = "ip_image_proj"
+
+        self.to(dtype=self.dtype, device=self.device)
+
+    def _load_ip_adapter_weights_VPAdapter(self, state_dicts, low_cpu_mem_usage=False):
+        attn_procs = self._convert_ip_adapter_attn_to_diffusers_VPAdapter(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
+        self.set_attn_processor(attn_procs)
+
+        # convert IP-Adapter Image Projection layers to diffusers
+        image_projection_layers = []
+        for state_dict in state_dicts:
+            image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
+                state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
+            )
+            image_projection_layers.append(image_projection_layer)
+
+        self.encoder_hid_proj = VPAdapterImageProjection(image_projection_layers)
+        self.config.encoder_hid_dim_type = "ip_image_proj"
+
+        self.to(dtype=self.dtype, device=self.device)
\ No newline at end of file
diff --git a/foleycrafter/models/auffusion/resnet.py b/foleycrafter/models/auffusion/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..6434630129a0ec88eec27b22d3258c591574e39f
--- /dev/null
+++ b/foleycrafter/models/auffusion/resnet.py
@@ -0,0 +1,685 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diffusers.utils import USE_PEFT_BACKEND
+from diffusers.models.activations import get_activation
+from diffusers.models.downsampling import (  # noqa
+    Downsample1D,
+    Downsample2D,
+    FirDownsample2D,
+    KDownsample2D,
+    downsample_2d,
+)
+from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
+from diffusers.models.normalization import AdaGroupNorm
+from diffusers.models.upsampling import (  # noqa
+    FirUpsample2D,
+    KUpsample2D,
+    Upsample1D,
+    Upsample2D,
+    upfirdn2d_native,
+    upsample_2d,
+)
+from foleycrafter.models.auffusion.attention_processor import SpatialNorm
+
+
+class ResnetBlock2D(nn.Module):
+    r"""
+    A Resnet block.
+
+    Parameters:
+        in_channels (`int`): The number of channels in the input.
+        out_channels (`int`, *optional*, default to be `None`):
+            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
+        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
+        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
+        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
+        groups_out (`int`, *optional*, default to None):
+            The number of groups to use for the second normalization layer. if set to None, same as `groups`.
+        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
+        non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
+        time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
+            By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
+            "ada_group" for a stronger conditioning with scale and shift.
+        kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
+            [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
+        output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
+        use_in_shortcut (`bool`, *optional*, default to `True`):
+            If `True`, add a 1x1 nn.conv2d layer for skip-connection.
+        up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
+        down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
+        conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
+            `conv_shortcut` output.
+        conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
+            If None, same as `out_channels`.
+    """
+
+    def __init__(
+        self,
+        *,
+        in_channels: int,
+        out_channels: Optional[int] = None,
+        conv_shortcut: bool = False,
+        dropout: float = 0.0,
+        temb_channels: int = 512,
+        groups: int = 32,
+        groups_out: Optional[int] = None,
+        pre_norm: bool = True,
+        eps: float = 1e-6,
+        non_linearity: str = "swish",
+        skip_time_act: bool = False,
+        time_embedding_norm: str = "default",  # default, scale_shift, ada_group, spatial
+        kernel: Optional[torch.FloatTensor] = None,
+        output_scale_factor: float = 1.0,
+        use_in_shortcut: Optional[bool] = None,
+        up: bool = False,
+        down: bool = False,
+        conv_shortcut_bias: bool = True,
+        conv_2d_out_channels: Optional[int] = None,
+    ):
+        super().__init__()
+        self.pre_norm = pre_norm
+        self.pre_norm = True
+        self.in_channels = in_channels
+        out_channels = in_channels if out_channels is None else out_channels
+        self.out_channels = out_channels
+        self.use_conv_shortcut = conv_shortcut
+        self.up = up
+        self.down = down
+        self.output_scale_factor = output_scale_factor
+        self.time_embedding_norm = time_embedding_norm
+        self.skip_time_act = skip_time_act
+
+        linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
+        conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
+
+        if groups_out is None:
+            groups_out = groups
+
+        if self.time_embedding_norm == "ada_group":
+            self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
+        elif self.time_embedding_norm == "spatial":
+            self.norm1 = SpatialNorm(in_channels, temb_channels)
+        else:
+            self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+        self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+        if temb_channels is not None:
+            if self.time_embedding_norm == "default":
+                self.time_emb_proj = linear_cls(temb_channels, out_channels)
+            elif self.time_embedding_norm == "scale_shift":
+                self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
+            elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
+                self.time_emb_proj = None
+            else:
+                raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
+        else:
+            self.time_emb_proj = None
+
+        if self.time_embedding_norm == "ada_group":
+            self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
+        elif self.time_embedding_norm == "spatial":
+            self.norm2 = SpatialNorm(out_channels, temb_channels)
+        else:
+            self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+
+        self.dropout = torch.nn.Dropout(dropout)
+        conv_2d_out_channels = conv_2d_out_channels or out_channels
+        self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
+
+        self.nonlinearity = get_activation(non_linearity)
+
+        self.upsample = self.downsample = None
+        if self.up:
+            if kernel == "fir":
+                fir_kernel = (1, 3, 3, 1)
+                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
+            elif kernel == "sde_vp":
+                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
+            else:
+                self.upsample = Upsample2D(in_channels, use_conv=False)
+        elif self.down:
+            if kernel == "fir":
+                fir_kernel = (1, 3, 3, 1)
+                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
+            elif kernel == "sde_vp":
+                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
+            else:
+                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
+
+        self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
+
+        self.conv_shortcut = None
+        if self.use_in_shortcut:
+            self.conv_shortcut = conv_cls(
+                in_channels,
+                conv_2d_out_channels,
+                kernel_size=1,
+                stride=1,
+                padding=0,
+                bias=conv_shortcut_bias,
+            )
+
+    def forward(
+        self,
+        input_tensor: torch.FloatTensor,
+        temb: torch.FloatTensor,
+        scale: float = 1.0,
+    ) -> torch.FloatTensor:
+        hidden_states = input_tensor
+
+        if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
+            hidden_states = self.norm1(hidden_states, temb)
+        else:
+            hidden_states = self.norm1(hidden_states)
+
+        hidden_states = self.nonlinearity(hidden_states)
+
+        if self.upsample is not None:
+            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+            if hidden_states.shape[0] >= 64:
+                input_tensor = input_tensor.contiguous()
+                hidden_states = hidden_states.contiguous()
+            input_tensor = (
+                self.upsample(input_tensor, scale=scale)
+                if isinstance(self.upsample, Upsample2D)
+                else self.upsample(input_tensor)
+            )
+            hidden_states = (
+                self.upsample(hidden_states, scale=scale)
+                if isinstance(self.upsample, Upsample2D)
+                else self.upsample(hidden_states)
+            )
+        elif self.downsample is not None:
+            input_tensor = (
+                self.downsample(input_tensor, scale=scale)
+                if isinstance(self.downsample, Downsample2D)
+                else self.downsample(input_tensor)
+            )
+            hidden_states = (
+                self.downsample(hidden_states, scale=scale)
+                if isinstance(self.downsample, Downsample2D)
+                else self.downsample(hidden_states)
+            )
+
+        hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
+
+        if self.time_emb_proj is not None:
+            if not self.skip_time_act:
+                temb = self.nonlinearity(temb)
+            temb = (
+                self.time_emb_proj(temb, scale)[:, :, None, None]
+                if not USE_PEFT_BACKEND
+                # NOTE: Maybe we can use different prompt in different time
+                else self.time_emb_proj(temb)[:, :, None, None]
+            )
+
+        if temb is not None and self.time_embedding_norm == "default":
+            hidden_states = hidden_states + temb
+
+        if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
+            hidden_states = self.norm2(hidden_states, temb)
+        else:
+            hidden_states = self.norm2(hidden_states)
+
+        if temb is not None and self.time_embedding_norm == "scale_shift":
+            scale, shift = torch.chunk(temb, 2, dim=1)
+            hidden_states = hidden_states * (1 + scale) + shift
+
+        hidden_states = self.nonlinearity(hidden_states)
+
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
+
+        if self.conv_shortcut is not None:
+            input_tensor = (
+                self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
+            )
+
+        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
+
+        return output_tensor
+
+
+# unet_rl.py
+def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor:
+    if len(tensor.shape) == 2:
+        return tensor[:, :, None]
+    if len(tensor.shape) == 3:
+        return tensor[:, :, None, :]
+    elif len(tensor.shape) == 4:
+        return tensor[:, :, 0, :]
+    else:
+        raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
+
+
+class Conv1dBlock(nn.Module):
+    """
+    Conv1d --> GroupNorm --> Mish
+
+    Parameters:
+        inp_channels (`int`): Number of input channels.
+        out_channels (`int`): Number of output channels.
+        kernel_size (`int` or `tuple`): Size of the convolving kernel.
+        n_groups (`int`, default `8`): Number of groups to separate the channels into.
+        activation (`str`, defaults to `mish`): Name of the activation function.
+    """
+
+    def __init__(
+        self,
+        inp_channels: int,
+        out_channels: int,
+        kernel_size: Union[int, Tuple[int, int]],
+        n_groups: int = 8,
+        activation: str = "mish",
+    ):
+        super().__init__()
+
+        self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
+        self.group_norm = nn.GroupNorm(n_groups, out_channels)
+        self.mish = get_activation(activation)
+
+    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+        intermediate_repr = self.conv1d(inputs)
+        intermediate_repr = rearrange_dims(intermediate_repr)
+        intermediate_repr = self.group_norm(intermediate_repr)
+        intermediate_repr = rearrange_dims(intermediate_repr)
+        output = self.mish(intermediate_repr)
+        return output
+
+
+# unet_rl.py
+class ResidualTemporalBlock1D(nn.Module):
+    """
+    Residual 1D block with temporal convolutions.
+
+    Parameters:
+        inp_channels (`int`): Number of input channels.
+        out_channels (`int`): Number of output channels.
+        embed_dim (`int`): Embedding dimension.
+        kernel_size (`int` or `tuple`): Size of the convolving kernel.
+        activation (`str`, defaults `mish`): It is possible to choose the right activation function.
+    """
+
+    def __init__(
+        self,
+        inp_channels: int,
+        out_channels: int,
+        embed_dim: int,
+        kernel_size: Union[int, Tuple[int, int]] = 5,
+        activation: str = "mish",
+    ):
+        super().__init__()
+        self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
+        self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
+
+        self.time_emb_act = get_activation(activation)
+        self.time_emb = nn.Linear(embed_dim, out_channels)
+
+        self.residual_conv = (
+            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
+        )
+
+    def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            inputs : [ batch_size x inp_channels x horizon ]
+            t : [ batch_size x embed_dim ]
+
+        returns:
+            out : [ batch_size x out_channels x horizon ]
+        """
+        t = self.time_emb_act(t)
+        t = self.time_emb(t)
+        out = self.conv_in(inputs) + rearrange_dims(t)
+        out = self.conv_out(out)
+        return out + self.residual_conv(inputs)
+
+
+class TemporalConvLayer(nn.Module):
+    """
+    Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
+    https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
+
+    Parameters:
+        in_dim (`int`): Number of input channels.
+        out_dim (`int`): Number of output channels.
+        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
+    """
+
+    def __init__(
+        self,
+        in_dim: int,
+        out_dim: Optional[int] = None,
+        dropout: float = 0.0,
+        norm_num_groups: int = 32,
+    ):
+        super().__init__()
+        out_dim = out_dim or in_dim
+        self.in_dim = in_dim
+        self.out_dim = out_dim
+
+        # conv layers
+        self.conv1 = nn.Sequential(
+            nn.GroupNorm(norm_num_groups, in_dim),
+            nn.SiLU(),
+            nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
+        )
+        self.conv2 = nn.Sequential(
+            nn.GroupNorm(norm_num_groups, out_dim),
+            nn.SiLU(),
+            nn.Dropout(dropout),
+            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
+        )
+        self.conv3 = nn.Sequential(
+            nn.GroupNorm(norm_num_groups, out_dim),
+            nn.SiLU(),
+            nn.Dropout(dropout),
+            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
+        )
+        self.conv4 = nn.Sequential(
+            nn.GroupNorm(norm_num_groups, out_dim),
+            nn.SiLU(),
+            nn.Dropout(dropout),
+            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
+        )
+
+        # zero out the last layer params,so the conv block is identity
+        nn.init.zeros_(self.conv4[-1].weight)
+        nn.init.zeros_(self.conv4[-1].bias)
+
+    def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor:
+        hidden_states = (
+            hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
+        )
+
+        identity = hidden_states
+        hidden_states = self.conv1(hidden_states)
+        hidden_states = self.conv2(hidden_states)
+        hidden_states = self.conv3(hidden_states)
+        hidden_states = self.conv4(hidden_states)
+
+        hidden_states = identity + hidden_states
+
+        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
+            (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
+        )
+        return hidden_states
+
+
+class TemporalResnetBlock(nn.Module):
+    r"""
+    A Resnet block.
+
+    Parameters:
+        in_channels (`int`): The number of channels in the input.
+        out_channels (`int`, *optional*, default to be `None`):
+            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
+        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
+        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: Optional[int] = None,
+        temb_channels: int = 512,
+        eps: float = 1e-6,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        out_channels = in_channels if out_channels is None else out_channels
+        self.out_channels = out_channels
+
+        kernel_size = (3, 1, 1)
+        padding = [k // 2 for k in kernel_size]
+
+        self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=eps, affine=True)
+        self.conv1 = nn.Conv3d(
+            in_channels,
+            out_channels,
+            kernel_size=kernel_size,
+            stride=1,
+            padding=padding,
+        )
+
+        if temb_channels is not None:
+            self.time_emb_proj = nn.Linear(temb_channels, out_channels)
+        else:
+            self.time_emb_proj = None
+
+        self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=eps, affine=True)
+
+        self.dropout = torch.nn.Dropout(0.0)
+        self.conv2 = nn.Conv3d(
+            out_channels,
+            out_channels,
+            kernel_size=kernel_size,
+            stride=1,
+            padding=padding,
+        )
+
+        self.nonlinearity = get_activation("silu")
+
+        self.use_in_shortcut = self.in_channels != out_channels
+
+        self.conv_shortcut = None
+        if self.use_in_shortcut:
+            self.conv_shortcut = nn.Conv3d(
+                in_channels,
+                out_channels,
+                kernel_size=1,
+                stride=1,
+                padding=0,
+            )
+
+    def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
+        hidden_states = input_tensor
+
+        hidden_states = self.norm1(hidden_states)
+        hidden_states = self.nonlinearity(hidden_states)
+        hidden_states = self.conv1(hidden_states)
+
+        if self.time_emb_proj is not None:
+            temb = self.nonlinearity(temb)
+            temb = self.time_emb_proj(temb)[:, :, :, None, None]
+            temb = temb.permute(0, 2, 1, 3, 4)
+            hidden_states = hidden_states + temb
+
+        hidden_states = self.norm2(hidden_states)
+        hidden_states = self.nonlinearity(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.conv2(hidden_states)
+
+        if self.conv_shortcut is not None:
+            input_tensor = self.conv_shortcut(input_tensor)
+
+        output_tensor = input_tensor + hidden_states
+
+        return output_tensor
+
+
+# VideoResBlock
+class SpatioTemporalResBlock(nn.Module):
+    r"""
+    A SpatioTemporal Resnet block.
+
+    Parameters:
+        in_channels (`int`): The number of channels in the input.
+        out_channels (`int`, *optional*, default to be `None`):
+            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
+        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
+        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the spatial resenet.
+        temporal_eps (`float`, *optional*, defaults to `eps`): The epsilon to use for the temporal resnet.
+        merge_factor (`float`, *optional*, defaults to `0.5`): The merge factor to use for the temporal mixing.
+        merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
+            The merge strategy to use for the temporal mixing.
+        switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
+            If `True`, switch the spatial and temporal mixing.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: Optional[int] = None,
+        temb_channels: int = 512,
+        eps: float = 1e-6,
+        temporal_eps: Optional[float] = None,
+        merge_factor: float = 0.5,
+        merge_strategy="learned_with_images",
+        switch_spatial_to_temporal_mix: bool = False,
+    ):
+        super().__init__()
+
+        self.spatial_res_block = ResnetBlock2D(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            temb_channels=temb_channels,
+            eps=eps,
+        )
+
+        self.temporal_res_block = TemporalResnetBlock(
+            in_channels=out_channels if out_channels is not None else in_channels,
+            out_channels=out_channels if out_channels is not None else in_channels,
+            temb_channels=temb_channels,
+            eps=temporal_eps if temporal_eps is not None else eps,
+        )
+
+        self.time_mixer = AlphaBlender(
+            alpha=merge_factor,
+            merge_strategy=merge_strategy,
+            switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,
+        )
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        temb: Optional[torch.FloatTensor] = None,
+        image_only_indicator: Optional[torch.Tensor] = None,
+    ):
+        num_frames = image_only_indicator.shape[-1]
+        hidden_states = self.spatial_res_block(hidden_states, temb)
+
+        batch_frames, channels, height, width = hidden_states.shape
+        batch_size = batch_frames // num_frames
+
+        hidden_states_mix = (
+            hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
+        )
+        hidden_states = (
+            hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
+        )
+
+        if temb is not None:
+            temb = temb.reshape(batch_size, num_frames, -1)
+
+        hidden_states = self.temporal_res_block(hidden_states, temb)
+        hidden_states = self.time_mixer(
+            x_spatial=hidden_states_mix,
+            x_temporal=hidden_states,
+            image_only_indicator=image_only_indicator,
+        )
+
+        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
+        return hidden_states
+
+
+class AlphaBlender(nn.Module):
+    r"""
+    A module to blend spatial and temporal features.
+
+    Parameters:
+        alpha (`float`): The initial value of the blending factor.
+        merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
+            The merge strategy to use for the temporal mixing.
+        switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
+            If `True`, switch the spatial and temporal mixing.
+    """
+
+    strategies = ["learned", "fixed", "learned_with_images"]
+
+    def __init__(
+        self,
+        alpha: float,
+        merge_strategy: str = "learned_with_images",
+        switch_spatial_to_temporal_mix: bool = False,
+    ):
+        super().__init__()
+        self.merge_strategy = merge_strategy
+        self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix  # For TemporalVAE
+
+        if merge_strategy not in self.strategies:
+            raise ValueError(f"merge_strategy needs to be in {self.strategies}")
+
+        if self.merge_strategy == "fixed":
+            self.register_buffer("mix_factor", torch.Tensor([alpha]))
+        elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
+            self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
+        else:
+            raise ValueError(f"Unknown merge strategy {self.merge_strategy}")
+
+    def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
+        if self.merge_strategy == "fixed":
+            alpha = self.mix_factor
+
+        elif self.merge_strategy == "learned":
+            alpha = torch.sigmoid(self.mix_factor)
+
+        elif self.merge_strategy == "learned_with_images":
+            if image_only_indicator is None:
+                raise ValueError("Please provide image_only_indicator to use learned_with_images merge strategy")
+
+            alpha = torch.where(
+                image_only_indicator.bool(),
+                torch.ones(1, 1, device=image_only_indicator.device),
+                torch.sigmoid(self.mix_factor)[..., None],
+            )
+
+            # (batch, channel, frames, height, width)
+            if ndims == 5:
+                alpha = alpha[:, None, :, None, None]
+            # (batch*frames, height*width, channels)
+            elif ndims == 3:
+                alpha = alpha.reshape(-1)[:, None, None]
+            else:
+                raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5")
+
+        else:
+            raise NotImplementedError
+
+        return alpha
+
+    def forward(
+        self,
+        x_spatial: torch.Tensor,
+        x_temporal: torch.Tensor,
+        image_only_indicator: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
+        alpha = alpha.to(x_spatial.dtype)
+
+        if self.switch_spatial_to_temporal_mix:
+            alpha = 1.0 - alpha
+
+        x = alpha * x_spatial + (1.0 - alpha) * x_temporal
+        return x
\ No newline at end of file
diff --git a/foleycrafter/models/auffusion/transformer_2d.py b/foleycrafter/models/auffusion/transformer_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ed523786e81e266eaec914648a779464bc794e5
--- /dev/null
+++ b/foleycrafter/models/auffusion/transformer_2d.py
@@ -0,0 +1,460 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.embeddings import ImagePositionalEmbeddings
+from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
+from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection
+from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNormSingle
+
+from foleycrafter.models.auffusion.attention import BasicTransformerBlock
+
+@dataclass
+class Transformer2DModelOutput(BaseOutput):
+    """
+    The output of [`Transformer2DModel`].
+
+    Args:
+        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
+            The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
+            distributions for the unnoised latent pixels.
+    """
+
+    sample: torch.FloatTensor
+
+class Transformer2DModel(ModelMixin, ConfigMixin):
+    """
+    A 2D Transformer model for image-like data.
+
+    Parameters:
+        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+        in_channels (`int`, *optional*):
+            The number of channels in the input and output (specify if the input is **continuous**).
+        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+        cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+        sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
+            This is fixed during training since it is used to learn a number of position embeddings.
+        num_vector_embeds (`int`, *optional*):
+            The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
+            Includes the class for the masked latent pixel.
+        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+        num_embeds_ada_norm ( `int`, *optional*):
+            The number of diffusion steps used during training. Pass if at least one of the norm_layers is
+            `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
+            added to the hidden states.
+
+            During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
+        attention_bias (`bool`, *optional*):
+            Configure if the `TransformerBlocks` attention should contain a bias parameter.
+    """
+
+    _supports_gradient_checkpointing = True
+
+    @register_to_config
+    def __init__(
+        self,
+        num_attention_heads: int = 16,
+        attention_head_dim: int = 88,
+        in_channels: Optional[int] = None,
+        out_channels: Optional[int] = None,
+        num_layers: int = 1,
+        dropout: float = 0.0,
+        norm_num_groups: int = 32,
+        cross_attention_dim: Optional[int] = None,
+        attention_bias: bool = False,
+        sample_size: Optional[int] = None,
+        num_vector_embeds: Optional[int] = None,
+        patch_size: Optional[int] = None,
+        activation_fn: str = "geglu",
+        num_embeds_ada_norm: Optional[int] = None,
+        use_linear_projection: bool = False,
+        only_cross_attention: bool = False,
+        double_self_attention: bool = False,
+        upcast_attention: bool = False,
+        norm_type: str = "layer_norm",
+        norm_elementwise_affine: bool = True,
+        norm_eps: float = 1e-5,
+        attention_type: str = "default",
+        caption_channels: int = None,
+    ):
+        super().__init__()
+        self.use_linear_projection = use_linear_projection
+        self.num_attention_heads = num_attention_heads
+        self.attention_head_dim = attention_head_dim
+        inner_dim = num_attention_heads * attention_head_dim
+
+        conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
+        linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
+
+        # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
+        # Define whether input is continuous or discrete depending on configuration
+        self.is_input_continuous = (in_channels is not None) and (patch_size is None)
+        self.is_input_vectorized = num_vector_embeds is not None
+        self.is_input_patches = in_channels is not None and patch_size is not None
+
+        if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
+            deprecation_message = (
+                f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
+                " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
+                " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
+                " results in future versions. If you have downloaded this checkpoint from the HF中国镜像站 Hub, it"
+                " would be very nice if you could open a Pull request for the `transformer/config.json` file"
+            )
+            deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
+            norm_type = "ada_norm"
+
+        if self.is_input_continuous and self.is_input_vectorized:
+            raise ValueError(
+                f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
+                " sure that either `in_channels` or `num_vector_embeds` is None."
+            )
+        elif self.is_input_vectorized and self.is_input_patches:
+            raise ValueError(
+                f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
+                " sure that either `num_vector_embeds` or `num_patches` is None."
+            )
+        elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
+            raise ValueError(
+                f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
+                f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
+            )
+
+        # 2. Define input layers
+        if self.is_input_continuous:
+            self.in_channels = in_channels
+
+            self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+            if use_linear_projection:
+                self.proj_in = linear_cls(in_channels, inner_dim)
+            else:
+                self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+        elif self.is_input_vectorized:
+            assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
+            assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
+
+            self.height = sample_size
+            self.width = sample_size
+            self.num_vector_embeds = num_vector_embeds
+            self.num_latent_pixels = self.height * self.width
+
+            self.latent_image_embedding = ImagePositionalEmbeddings(
+                num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
+            )
+        elif self.is_input_patches:
+            assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
+
+            self.height = sample_size
+            self.width = sample_size
+
+            self.patch_size = patch_size
+            interpolation_scale = self.config.sample_size // 64  # => 64 (= 512 pixart) has interpolation scale 1
+            interpolation_scale = max(interpolation_scale, 1)
+            self.pos_embed = PatchEmbed(
+                height=sample_size,
+                width=sample_size,
+                patch_size=patch_size,
+                in_channels=in_channels,
+                embed_dim=inner_dim,
+                interpolation_scale=interpolation_scale,
+            )
+
+        # 3. Define transformers blocks
+        self.transformer_blocks = nn.ModuleList(
+            [
+                # NOTE: remember to change
+                BasicTransformerBlock(
+                    inner_dim,
+                    num_attention_heads,
+                    attention_head_dim,
+                    dropout=dropout,
+                    cross_attention_dim=cross_attention_dim,
+                    activation_fn=activation_fn,
+                    num_embeds_ada_norm=num_embeds_ada_norm,
+                    attention_bias=attention_bias,
+                    only_cross_attention=only_cross_attention,
+                    double_self_attention=double_self_attention,
+                    upcast_attention=upcast_attention,
+                    norm_type=norm_type,
+                    norm_elementwise_affine=norm_elementwise_affine,
+                    norm_eps=norm_eps,
+                    attention_type=attention_type,
+                )
+                for d in range(num_layers)
+            ]
+        )
+
+        # 4. Define output layers
+        self.out_channels = in_channels if out_channels is None else out_channels
+        if self.is_input_continuous:
+            # TODO: should use out_channels for continuous projections
+            if use_linear_projection:
+                self.proj_out = linear_cls(inner_dim, in_channels)
+            else:
+                self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+        elif self.is_input_vectorized:
+            self.norm_out = nn.LayerNorm(inner_dim)
+            self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
+        elif self.is_input_patches and norm_type != "ada_norm_single":
+            self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+            self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
+            self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+        elif self.is_input_patches and norm_type == "ada_norm_single":
+            self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+            self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
+            self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+
+        # 5. PixArt-Alpha blocks.
+        self.adaln_single = None
+        self.use_additional_conditions = False
+        if norm_type == "ada_norm_single":
+            self.use_additional_conditions = self.config.sample_size == 128
+            # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
+            # additional conditions until we find better name
+            self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
+
+        self.caption_projection = None
+        if caption_channels is not None:
+            self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
+
+        self.gradient_checkpointing = False
+
+    def _set_gradient_checkpointing(self, module, value=False):
+        if hasattr(module, "gradient_checkpointing"):
+            module.gradient_checkpointing = value
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        timestep: Optional[torch.LongTensor] = None,
+        added_cond_kwargs: Dict[str, torch.Tensor] = None,
+        class_labels: Optional[torch.LongTensor] = None,
+        cross_attention_kwargs: Dict[str, Any] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_attention_mask: Optional[torch.Tensor] = None,
+        return_dict: bool = True,
+    ):
+        """
+        The [`Transformer2DModel`] forward method.
+
+        Args:
+            hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
+                Input `hidden_states`.
+            encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
+                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+                self-attention.
+            timestep ( `torch.LongTensor`, *optional*):
+                Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
+            class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+                Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
+                `AdaLayerZeroNorm`.
+            cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
+                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+                `self.processor` in
+                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+            attention_mask ( `torch.Tensor`, *optional*):
+                An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+                is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+                negative values to the attention scores corresponding to "discard" tokens.
+            encoder_attention_mask ( `torch.Tensor`, *optional*):
+                Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
+
+                    * Mask `(batch, sequence_length)` True = keep, False = discard.
+                    * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
+
+                If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
+                above. This bias will be added to the cross-attention scores.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+                tuple.
+
+        Returns:
+            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+            `tuple` where the first element is the sample tensor.
+        """
+        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+        #   we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+        #   we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+        # expects mask of shape:
+        #   [batch, key_tokens]
+        # adds singleton query_tokens dimension:
+        #   [batch,                    1, key_tokens]
+        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+        if attention_mask is not None and attention_mask.ndim == 2:
+            # assume that mask is expressed as:
+            #   (1 = keep,      0 = discard)
+            # convert mask into a bias that can be added to attention scores:
+            #       (keep = +0,     discard = -10000.0)
+            attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+            attention_mask = attention_mask.unsqueeze(1)
+
+        # convert encoder_attention_mask to a bias the same way we do for attention_mask
+        if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+            encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+        # Retrieve lora scale.
+        lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+        # 1. Input 
+        if self.is_input_continuous:
+            batch, _, height, width = hidden_states.shape
+            inner_dim = hidden_states.shape[1]
+            residual = hidden_states
+
+            hidden_states = self.norm(hidden_states)
+            if not self.use_linear_projection:
+                hidden_states = (
+                    self.proj_in(hidden_states, scale=lora_scale)
+                    if not USE_PEFT_BACKEND
+                    else self.proj_in(hidden_states)
+                )
+                inner_dim = hidden_states.shape[1]
+                hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+            else:
+                inner_dim = hidden_states.shape[1]
+                hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+                hidden_states = (
+                    self.proj_in(hidden_states, scale=lora_scale)
+                    if not USE_PEFT_BACKEND
+                    else self.proj_in(hidden_states)
+                )
+
+        elif self.is_input_vectorized:
+            hidden_states = self.latent_image_embedding(hidden_states)
+        elif self.is_input_patches:
+            height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
+            self.height, self.width = height, width
+            hidden_states = self.pos_embed(hidden_states)
+
+            if self.adaln_single is not None:
+                if self.use_additional_conditions and added_cond_kwargs is None:
+                    raise ValueError(
+                        "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
+                    )
+                batch_size = hidden_states.shape[0]
+                timestep, embedded_timestep = self.adaln_single(
+                    timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+                )
+
+        if self.caption_projection is not None:
+            batch_size = hidden_states.shape[0]
+            encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+            encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
+        # 2. Blocks
+        for block in self.transformer_blocks:
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module, return_dict=None):
+                    def custom_forward(*inputs):
+                        if return_dict is not None:
+                            return module(*inputs, return_dict=return_dict)
+                        else:
+                            return module(*inputs)
+
+                    return custom_forward
+
+                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(block),
+                    hidden_states,
+                    attention_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    timestep,
+                    cross_attention_kwargs,
+                    class_labels,
+                    **ckpt_kwargs,
+                )
+            else:
+                hidden_states = block(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    encoder_hidden_states=encoder_hidden_states,
+                    encoder_attention_mask=encoder_attention_mask,
+                    timestep=timestep,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    class_labels=class_labels,
+                )
+
+        # 3. Output
+        if self.is_input_continuous:
+            if not self.use_linear_projection:
+                hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+                hidden_states = (
+                    self.proj_out(hidden_states, scale=lora_scale)
+                    if not USE_PEFT_BACKEND
+                    else self.proj_out(hidden_states)
+                )
+            else:
+                hidden_states = (
+                    self.proj_out(hidden_states, scale=lora_scale)
+                    if not USE_PEFT_BACKEND
+                    else self.proj_out(hidden_states)
+                )
+                hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+
+            output = hidden_states + residual
+        elif self.is_input_vectorized:
+            hidden_states = self.norm_out(hidden_states)
+            logits = self.out(hidden_states)
+            # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
+            logits = logits.permute(0, 2, 1)
+
+            # log(p(x_0))
+            output = F.log_softmax(logits.double(), dim=1).float()
+
+        if self.is_input_patches:
+            if self.config.norm_type != "ada_norm_single":
+                conditioning = self.transformer_blocks[0].norm1.emb(
+                    timestep, class_labels, hidden_dtype=hidden_states.dtype
+                )
+                shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
+                hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+                hidden_states = self.proj_out_2(hidden_states)
+            elif self.config.norm_type == "ada_norm_single":
+                shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
+                hidden_states = self.norm_out(hidden_states)
+                # Modulation
+                hidden_states = hidden_states * (1 + scale) + shift
+                hidden_states = self.proj_out(hidden_states)
+                hidden_states = hidden_states.squeeze(1)
+
+            # unpatchify
+            if self.adaln_single is None:
+                height = width = int(hidden_states.shape[1] ** 0.5)
+            hidden_states = hidden_states.reshape(
+                shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
+            )
+            hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+            output = hidden_states.reshape(
+                shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
+            )
+
+        if not return_dict:
+            return (output,)
+
+        return Transformer2DModelOutput(sample=output)
\ No newline at end of file
diff --git a/foleycrafter/models/auffusion/unet_2d_blocks.py b/foleycrafter/models/auffusion/unet_2d_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c186bd2113a36c2502f5059b08d16b67eb74817
--- /dev/null
+++ b/foleycrafter/models/auffusion/unet_2d_blocks.py
@@ -0,0 +1,3498 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import is_torch_version, logging
+from diffusers.utils.torch_utils import apply_freeu
+from diffusers.models.activations import get_activation
+from diffusers.models.normalization import AdaGroupNorm
+
+from foleycrafter.models.auffusion.resnet import \
+    Downsample2D, FirDownsample2D, FirUpsample2D, \
+    KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
+from foleycrafter.models.auffusion.transformer_2d import \
+    Transformer2DModel
+from foleycrafter.models.auffusion.dual_transformer_2d import \
+    DualTransformer2DModel
+from foleycrafter.models.auffusion.attention_processor import \
+    Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+def get_down_block(
+    down_block_type: str,
+    num_layers: int,
+    in_channels: int,
+    out_channels: int,
+    temb_channels: int,
+    add_downsample: bool,
+    resnet_eps: float,
+    resnet_act_fn: str,
+    transformer_layers_per_block: int = 1,
+    num_attention_heads: Optional[int] = None,
+    resnet_groups: Optional[int] = None,
+    cross_attention_dim: Optional[int] = None,
+    downsample_padding: Optional[int] = None,
+    dual_cross_attention: bool = False,
+    use_linear_projection: bool = False,
+    only_cross_attention: bool = False,
+    upcast_attention: bool = False,
+    resnet_time_scale_shift: str = "default",
+    attention_type: str = "default",
+    resnet_skip_time_act: bool = False,
+    resnet_out_scale_factor: float = 1.0,
+    cross_attention_norm: Optional[str] = None,
+    attention_head_dim: Optional[int] = None,
+    downsample_type: Optional[str] = None,
+    dropout: float = 0.0,
+):
+    # If attn head dim is not defined, we default it to the number of heads
+    if attention_head_dim is None:
+        logger.warn(
+            f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
+        )
+        attention_head_dim = num_attention_heads
+
+    down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+    if down_block_type == "DownBlock2D":
+        return DownBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            temb_channels=temb_channels,
+            dropout=dropout,
+            add_downsample=add_downsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            downsample_padding=downsample_padding,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+        )
+    elif down_block_type == "ResnetDownsampleBlock2D":
+        return ResnetDownsampleBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            temb_channels=temb_channels,
+            dropout=dropout,
+            add_downsample=add_downsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+            skip_time_act=resnet_skip_time_act,
+            output_scale_factor=resnet_out_scale_factor,
+        )
+    elif down_block_type == "AttnDownBlock2D":
+        if add_downsample is False:
+            downsample_type = None
+        else:
+            downsample_type = downsample_type or "conv"  # default to 'conv'
+        return AttnDownBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            temb_channels=temb_channels,
+            dropout=dropout,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            downsample_padding=downsample_padding,
+            attention_head_dim=attention_head_dim,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+            downsample_type=downsample_type,
+        )
+    elif down_block_type == "CrossAttnDownBlock2D":
+        if cross_attention_dim is None:
+            raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
+        return CrossAttnDownBlock2D(
+            num_layers=num_layers,
+            transformer_layers_per_block=transformer_layers_per_block,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            temb_channels=temb_channels,
+            dropout=dropout,
+            add_downsample=add_downsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            downsample_padding=downsample_padding,
+            cross_attention_dim=cross_attention_dim,
+            num_attention_heads=num_attention_heads,
+            dual_cross_attention=dual_cross_attention,
+            use_linear_projection=use_linear_projection,
+            only_cross_attention=only_cross_attention,
+            upcast_attention=upcast_attention,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+            attention_type=attention_type,
+        )
+    elif down_block_type == "SimpleCrossAttnDownBlock2D":
+        if cross_attention_dim is None:
+            raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
+        return SimpleCrossAttnDownBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            temb_channels=temb_channels,
+            dropout=dropout,
+            add_downsample=add_downsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            cross_attention_dim=cross_attention_dim,
+            attention_head_dim=attention_head_dim,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+            skip_time_act=resnet_skip_time_act,
+            output_scale_factor=resnet_out_scale_factor,
+            only_cross_attention=only_cross_attention,
+            cross_attention_norm=cross_attention_norm,
+        )
+    elif down_block_type == "SkipDownBlock2D":
+        return SkipDownBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            temb_channels=temb_channels,
+            dropout=dropout,
+            add_downsample=add_downsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            downsample_padding=downsample_padding,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+        )
+    elif down_block_type == "AttnSkipDownBlock2D":
+        return AttnSkipDownBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            temb_channels=temb_channels,
+            dropout=dropout,
+            add_downsample=add_downsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            attention_head_dim=attention_head_dim,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+        )
+    elif down_block_type == "DownEncoderBlock2D":
+        return DownEncoderBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            dropout=dropout,
+            add_downsample=add_downsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            downsample_padding=downsample_padding,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+        )
+    elif down_block_type == "AttnDownEncoderBlock2D":
+        return AttnDownEncoderBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            dropout=dropout,
+            add_downsample=add_downsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            downsample_padding=downsample_padding,
+            attention_head_dim=attention_head_dim,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+        )
+    elif down_block_type == "KDownBlock2D":
+        return KDownBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            temb_channels=temb_channels,
+            dropout=dropout,
+            add_downsample=add_downsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+        )
+    elif down_block_type == "KCrossAttnDownBlock2D":
+        return KCrossAttnDownBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            temb_channels=temb_channels,
+            dropout=dropout,
+            add_downsample=add_downsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            cross_attention_dim=cross_attention_dim,
+            attention_head_dim=attention_head_dim,
+            add_self_attention=True if not add_downsample else False,
+        )
+    raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(
+    up_block_type: str,
+    num_layers: int,
+    in_channels: int,
+    out_channels: int,
+    prev_output_channel: int,
+    temb_channels: int,
+    add_upsample: bool,
+    resnet_eps: float,
+    resnet_act_fn: str,
+    resolution_idx: Optional[int] = None,
+    transformer_layers_per_block: int = 1,
+    num_attention_heads: Optional[int] = None,
+    resnet_groups: Optional[int] = None,
+    cross_attention_dim: Optional[int] = None,
+    dual_cross_attention: bool = False,
+    use_linear_projection: bool = False,
+    only_cross_attention: bool = False,
+    upcast_attention: bool = False,
+    resnet_time_scale_shift: str = "default",
+    attention_type: str = "default",
+    resnet_skip_time_act: bool = False,
+    resnet_out_scale_factor: float = 1.0,
+    cross_attention_norm: Optional[str] = None,
+    attention_head_dim: Optional[int] = None,
+    upsample_type: Optional[str] = None,
+    dropout: float = 0.0,
+) -> nn.Module:
+    # If attn head dim is not defined, we default it to the number of heads
+    if attention_head_dim is None:
+        logger.warn(
+            f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
+        )
+        attention_head_dim = num_attention_heads
+
+    up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+    if up_block_type == "UpBlock2D":
+        return UpBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            prev_output_channel=prev_output_channel,
+            temb_channels=temb_channels,
+            resolution_idx=resolution_idx,
+            dropout=dropout,
+            add_upsample=add_upsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+        )
+    elif up_block_type == "ResnetUpsampleBlock2D":
+        return ResnetUpsampleBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            prev_output_channel=prev_output_channel,
+            temb_channels=temb_channels,
+            resolution_idx=resolution_idx,
+            dropout=dropout,
+            add_upsample=add_upsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+            skip_time_act=resnet_skip_time_act,
+            output_scale_factor=resnet_out_scale_factor,
+        )
+    elif up_block_type == "CrossAttnUpBlock2D":
+        if cross_attention_dim is None:
+            raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
+        return CrossAttnUpBlock2D(
+            num_layers=num_layers,
+            transformer_layers_per_block=transformer_layers_per_block,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            prev_output_channel=prev_output_channel,
+            temb_channels=temb_channels,
+            resolution_idx=resolution_idx,
+            dropout=dropout,
+            add_upsample=add_upsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            cross_attention_dim=cross_attention_dim,
+            num_attention_heads=num_attention_heads,
+            dual_cross_attention=dual_cross_attention,
+            use_linear_projection=use_linear_projection,
+            only_cross_attention=only_cross_attention,
+            upcast_attention=upcast_attention,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+            attention_type=attention_type,
+        )
+    elif up_block_type == "SimpleCrossAttnUpBlock2D":
+        if cross_attention_dim is None:
+            raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
+        return SimpleCrossAttnUpBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            prev_output_channel=prev_output_channel,
+            temb_channels=temb_channels,
+            resolution_idx=resolution_idx,
+            dropout=dropout,
+            add_upsample=add_upsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            cross_attention_dim=cross_attention_dim,
+            attention_head_dim=attention_head_dim,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+            skip_time_act=resnet_skip_time_act,
+            output_scale_factor=resnet_out_scale_factor,
+            only_cross_attention=only_cross_attention,
+            cross_attention_norm=cross_attention_norm,
+        )
+    elif up_block_type == "AttnUpBlock2D":
+        if add_upsample is False:
+            upsample_type = None
+        else:
+            upsample_type = upsample_type or "conv"  # default to 'conv'
+
+        return AttnUpBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            prev_output_channel=prev_output_channel,
+            temb_channels=temb_channels,
+            resolution_idx=resolution_idx,
+            dropout=dropout,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            attention_head_dim=attention_head_dim,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+            upsample_type=upsample_type,
+        )
+    elif up_block_type == "SkipUpBlock2D":
+        return SkipUpBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            prev_output_channel=prev_output_channel,
+            temb_channels=temb_channels,
+            resolution_idx=resolution_idx,
+            dropout=dropout,
+            add_upsample=add_upsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+        )
+    elif up_block_type == "AttnSkipUpBlock2D":
+        return AttnSkipUpBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            prev_output_channel=prev_output_channel,
+            temb_channels=temb_channels,
+            resolution_idx=resolution_idx,
+            dropout=dropout,
+            add_upsample=add_upsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            attention_head_dim=attention_head_dim,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+        )
+    elif up_block_type == "UpDecoderBlock2D":
+        return UpDecoderBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            resolution_idx=resolution_idx,
+            dropout=dropout,
+            add_upsample=add_upsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+            temb_channels=temb_channels,
+        )
+    elif up_block_type == "AttnUpDecoderBlock2D":
+        return AttnUpDecoderBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            resolution_idx=resolution_idx,
+            dropout=dropout,
+            add_upsample=add_upsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            attention_head_dim=attention_head_dim,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+            temb_channels=temb_channels,
+        )
+    elif up_block_type == "KUpBlock2D":
+        return KUpBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            temb_channels=temb_channels,
+            resolution_idx=resolution_idx,
+            dropout=dropout,
+            add_upsample=add_upsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+        )
+    elif up_block_type == "KCrossAttnUpBlock2D":
+        return KCrossAttnUpBlock2D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            temb_channels=temb_channels,
+            resolution_idx=resolution_idx,
+            dropout=dropout,
+            add_upsample=add_upsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            cross_attention_dim=cross_attention_dim,
+            attention_head_dim=attention_head_dim,
+        )
+
+    raise ValueError(f"{up_block_type} does not exist.")
+
+
+class AutoencoderTinyBlock(nn.Module):
+    """
+    Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
+    blocks.
+
+    Args:
+        in_channels (`int`): The number of input channels.
+        out_channels (`int`): The number of output channels.
+        act_fn (`str`):
+            ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
+
+    Returns:
+        `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
+        `out_channels`.
+    """
+
+    def __init__(self, in_channels: int, out_channels: int, act_fn: str):
+        super().__init__()
+        act_fn = get_activation(act_fn)
+        self.conv = nn.Sequential(
+            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
+            act_fn,
+            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
+            act_fn,
+            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
+        )
+        self.skip = (
+            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
+            if in_channels != out_channels
+            else nn.Identity()
+        )
+        self.fuse = nn.ReLU()
+
+    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+        return self.fuse(self.conv(x) + self.skip(x))
+
+
+class UNetMidBlock2D(nn.Module):
+    """
+    A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
+
+    Args:
+        in_channels (`int`): The number of input channels.
+        temb_channels (`int`): The number of temporal embedding channels.
+        dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
+        num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
+        resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
+        resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
+            The type of normalization to apply to the time embeddings. This can help to improve the performance of the
+            model on tasks with long-range temporal dependencies.
+        resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
+        resnet_groups (`int`, *optional*, defaults to 32):
+            The number of groups to use in the group normalization layers of the resnet blocks.
+        attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
+        resnet_pre_norm (`bool`, *optional*, defaults to `True`):
+            Whether to use pre-normalization for the resnet blocks.
+        add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
+        attention_head_dim (`int`, *optional*, defaults to 1):
+            Dimension of a single attention head. The number of attention heads is determined based on this value and
+            the number of input channels.
+        output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
+
+    Returns:
+        `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
+        in_channels, height, width)`.
+
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",  # default, spatial
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        attn_groups: Optional[int] = None,
+        resnet_pre_norm: bool = True,
+        add_attention: bool = True,
+        attention_head_dim: int = 1,
+        output_scale_factor: float = 1.0,
+    ):
+        super().__init__()
+        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+        self.add_attention = add_attention
+
+        if attn_groups is None:
+            attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
+
+        # there is always at least one resnet
+        resnets = [
+            ResnetBlock2D(
+                in_channels=in_channels,
+                out_channels=in_channels,
+                temb_channels=temb_channels,
+                eps=resnet_eps,
+                groups=resnet_groups,
+                dropout=dropout,
+                time_embedding_norm=resnet_time_scale_shift,
+                non_linearity=resnet_act_fn,
+                output_scale_factor=output_scale_factor,
+                pre_norm=resnet_pre_norm,
+            )
+        ]
+        attentions = []
+
+        if attention_head_dim is None:
+            logger.warn(
+                f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
+            )
+            attention_head_dim = in_channels
+
+        for _ in range(num_layers):
+            if self.add_attention:
+                attentions.append(
+                    Attention(
+                        in_channels,
+                        heads=in_channels // attention_head_dim,
+                        dim_head=attention_head_dim,
+                        rescale_output_factor=output_scale_factor,
+                        eps=resnet_eps,
+                        norm_num_groups=attn_groups,
+                        spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
+                        residual_connection=True,
+                        bias=True,
+                        upcast_softmax=True,
+                        _from_deprecated_attn_block=True,
+                    )
+                )
+            else:
+                attentions.append(None)
+
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=in_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+
+    def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+        hidden_states = self.resnets[0](hidden_states, temb)
+        for attn, resnet in zip(self.attentions, self.resnets[1:]):
+            if attn is not None:
+                hidden_states = attn(hidden_states, temb=temb)
+            hidden_states = resnet(hidden_states, temb)
+
+        return hidden_states
+
+
+class UNetMidBlock2DCrossAttn(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        num_attention_heads: int = 1,
+        output_scale_factor: float = 1.0,
+        cross_attention_dim: int = 1280,
+        dual_cross_attention: bool = False,
+        use_linear_projection: bool = False,
+        upcast_attention: bool = False,
+        attention_type: str = "default",
+    ):
+        super().__init__()
+
+        self.has_cross_attention = True
+        self.num_attention_heads = num_attention_heads
+        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+        # support for variable transformer layers per block
+        if isinstance(transformer_layers_per_block, int):
+            transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+        # there is always at least one resnet
+        resnets = [
+            ResnetBlock2D(
+                in_channels=in_channels,
+                out_channels=in_channels,
+                temb_channels=temb_channels,
+                eps=resnet_eps,
+                groups=resnet_groups,
+                dropout=dropout,
+                time_embedding_norm=resnet_time_scale_shift,
+                non_linearity=resnet_act_fn,
+                output_scale_factor=output_scale_factor,
+                pre_norm=resnet_pre_norm,
+            )
+        ]
+        attentions = []
+
+        for i in range(num_layers):
+            if not dual_cross_attention:
+                attentions.append(
+                    Transformer2DModel(
+                        num_attention_heads,
+                        in_channels // num_attention_heads,
+                        in_channels=in_channels,
+                        num_layers=transformer_layers_per_block[i],
+                        cross_attention_dim=cross_attention_dim,
+                        norm_num_groups=resnet_groups,
+                        use_linear_projection=use_linear_projection,
+                        upcast_attention=upcast_attention,
+                        attention_type=attention_type,
+                    )
+                )
+            else:
+                attentions.append(
+                    DualTransformer2DModel(
+                        num_attention_heads,
+                        in_channels // num_attention_heads,
+                        in_channels=in_channels,
+                        num_layers=1,
+                        cross_attention_dim=cross_attention_dim,
+                        norm_num_groups=resnet_groups,
+                    )
+                )
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=in_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        temb: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+    ) -> torch.FloatTensor:
+        lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+        hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
+        for attn, resnet in zip(self.attentions, self.resnets[1:]):
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module, return_dict=None):
+                    def custom_forward(*inputs):
+                        if return_dict is not None:
+                            return module(*inputs, return_dict=return_dict)
+                        else:
+                            return module(*inputs)
+
+                    return custom_forward
+
+                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    attention_mask=attention_mask,
+                    encoder_attention_mask=encoder_attention_mask,
+                    return_dict=False,
+                )[0]
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(resnet),
+                    hidden_states,
+                    temb,
+                    **ckpt_kwargs,
+                )
+            else:
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    attention_mask=attention_mask,
+                    encoder_attention_mask=encoder_attention_mask,
+                    return_dict=False,
+                )[0]
+                hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+
+        return hidden_states
+
+
+class UNetMidBlock2DSimpleCrossAttn(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        attention_head_dim: int = 1,
+        output_scale_factor: float = 1.0,
+        cross_attention_dim: int = 1280,
+        skip_time_act: bool = False,
+        only_cross_attention: bool = False,
+        cross_attention_norm: Optional[str] = None,
+    ):
+        super().__init__()
+
+        self.has_cross_attention = True
+
+        self.attention_head_dim = attention_head_dim
+        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+        self.num_heads = in_channels // self.attention_head_dim
+
+        # there is always at least one resnet
+        resnets = [
+            ResnetBlock2D(
+                in_channels=in_channels,
+                out_channels=in_channels,
+                temb_channels=temb_channels,
+                eps=resnet_eps,
+                groups=resnet_groups,
+                dropout=dropout,
+                time_embedding_norm=resnet_time_scale_shift,
+                non_linearity=resnet_act_fn,
+                output_scale_factor=output_scale_factor,
+                pre_norm=resnet_pre_norm,
+                skip_time_act=skip_time_act,
+            )
+        ]
+        attentions = []
+
+        for _ in range(num_layers):
+            processor = (
+                AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
+            )
+
+            attentions.append(
+                Attention(
+                    query_dim=in_channels,
+                    cross_attention_dim=in_channels,
+                    heads=self.num_heads,
+                    dim_head=self.attention_head_dim,
+                    added_kv_proj_dim=cross_attention_dim,
+                    norm_num_groups=resnet_groups,
+                    bias=True,
+                    upcast_softmax=True,
+                    only_cross_attention=only_cross_attention,
+                    cross_attention_norm=cross_attention_norm,
+                    processor=processor,
+                )
+            )
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=in_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                    skip_time_act=skip_time_act,
+                )
+            )
+
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        temb: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+    ) -> torch.FloatTensor:
+        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+        lora_scale = cross_attention_kwargs.get("scale", 1.0)
+
+        if attention_mask is None:
+            # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
+            mask = None if encoder_hidden_states is None else encoder_attention_mask
+        else:
+            # when attention_mask is defined: we don't even check for encoder_attention_mask.
+            # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
+            # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
+            #       then we can simplify this whole if/else block to:
+            #         mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
+            mask = attention_mask
+
+        hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
+        for attn, resnet in zip(self.attentions, self.resnets[1:]):
+            # attn
+            hidden_states = attn(
+                hidden_states,
+                encoder_hidden_states=encoder_hidden_states,
+                attention_mask=mask,
+                **cross_attention_kwargs,
+            )
+
+            # resnet
+            hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+
+        return hidden_states
+
+
+class AttnDownBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        attention_head_dim: int = 1,
+        output_scale_factor: float = 1.0,
+        downsample_padding: int = 1,
+        downsample_type: str = "conv",
+    ):
+        super().__init__()
+        resnets = []
+        attentions = []
+        self.downsample_type = downsample_type
+
+        if attention_head_dim is None:
+            logger.warn(
+                f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+            )
+            attention_head_dim = out_channels
+
+        for i in range(num_layers):
+            in_channels = in_channels if i == 0 else out_channels
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+            attentions.append(
+                Attention(
+                    out_channels,
+                    heads=out_channels // attention_head_dim,
+                    dim_head=attention_head_dim,
+                    rescale_output_factor=output_scale_factor,
+                    eps=resnet_eps,
+                    norm_num_groups=resnet_groups,
+                    residual_connection=True,
+                    bias=True,
+                    upcast_softmax=True,
+                    _from_deprecated_attn_block=True,
+                )
+            )
+
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+
+        if downsample_type == "conv":
+            self.downsamplers = nn.ModuleList(
+                [
+                    Downsample2D(
+                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+                    )
+                ]
+            )
+        elif downsample_type == "resnet":
+            self.downsamplers = nn.ModuleList(
+                [
+                    ResnetBlock2D(
+                        in_channels=out_channels,
+                        out_channels=out_channels,
+                        temb_channels=temb_channels,
+                        eps=resnet_eps,
+                        groups=resnet_groups,
+                        dropout=dropout,
+                        time_embedding_norm=resnet_time_scale_shift,
+                        non_linearity=resnet_act_fn,
+                        output_scale_factor=output_scale_factor,
+                        pre_norm=resnet_pre_norm,
+                        down=True,
+                    )
+                ]
+            )
+        else:
+            self.downsamplers = None
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        temb: Optional[torch.FloatTensor] = None,
+        upsample_size: Optional[int] = None,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+
+        lora_scale = cross_attention_kwargs.get("scale", 1.0)
+
+        output_states = ()
+
+        for resnet, attn in zip(self.resnets, self.attentions):
+            cross_attention_kwargs.update({"scale": lora_scale})
+            hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+            hidden_states = attn(hidden_states, **cross_attention_kwargs)
+            output_states = output_states + (hidden_states,)
+
+        if self.downsamplers is not None:
+            for downsampler in self.downsamplers:
+                if self.downsample_type == "resnet":
+                    hidden_states = downsampler(hidden_states, temb=temb, scale=lora_scale)
+                else:
+                    hidden_states = downsampler(hidden_states, scale=lora_scale)
+
+            output_states += (hidden_states,)
+
+        return hidden_states, output_states
+
+
+class CrossAttnDownBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        num_attention_heads: int = 1,
+        cross_attention_dim: int = 1280,
+        output_scale_factor: float = 1.0,
+        downsample_padding: int = 1,
+        add_downsample: bool = True,
+        dual_cross_attention: bool = False,
+        use_linear_projection: bool = False,
+        only_cross_attention: bool = False,
+        upcast_attention: bool = False,
+        attention_type: str = "default",
+    ):
+        super().__init__()
+        resnets = []
+        attentions = []
+
+        self.has_cross_attention = True
+        self.num_attention_heads = num_attention_heads
+        if isinstance(transformer_layers_per_block, int):
+            transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+        for i in range(num_layers):
+            in_channels = in_channels if i == 0 else out_channels
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+            if not dual_cross_attention:
+                # Transformer2DModelWithSwitcher
+                attentions.append(
+                    Transformer2DModel(
+                        num_attention_heads,
+                        out_channels // num_attention_heads,
+                        in_channels=out_channels,
+                        num_layers=transformer_layers_per_block[i],
+                        cross_attention_dim=cross_attention_dim,
+                        norm_num_groups=resnet_groups,
+                        use_linear_projection=use_linear_projection,
+                        only_cross_attention=only_cross_attention,
+                        upcast_attention=upcast_attention,
+                        attention_type=attention_type,
+                    )
+                )
+            else:
+                attentions.append(
+                    DualTransformer2DModel(
+                        num_attention_heads,
+                        out_channels // num_attention_heads,
+                        in_channels=out_channels,
+                        num_layers=1,
+                        cross_attention_dim=cross_attention_dim,
+                        norm_num_groups=resnet_groups,
+                    )
+                )
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_downsample:
+            self.downsamplers = nn.ModuleList(
+                [
+                    Downsample2D(
+                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+                    )
+                ]
+            )
+        else:
+            self.downsamplers = None
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        temb: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        additional_residuals: Optional[torch.FloatTensor] = None,
+    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+        output_states = ()
+
+        lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+        blocks = list(zip(self.resnets, self.attentions))
+
+        for i, (resnet, attn) in enumerate(blocks):
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module, return_dict=None):
+                    def custom_forward(*inputs):
+                        if return_dict is not None:
+                            return module(*inputs, return_dict=return_dict)
+                        else:
+                            return module(*inputs)
+
+                    return custom_forward
+
+                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(resnet),
+                    hidden_states,
+                    temb,
+                    **ckpt_kwargs,
+                )
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    attention_mask=attention_mask,
+                    encoder_attention_mask=encoder_attention_mask,
+                    return_dict=False,
+                )[0]
+            else:
+                hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    attention_mask=attention_mask,
+                    encoder_attention_mask=encoder_attention_mask,
+                    return_dict=False,
+                )[0]
+
+            # apply additional residuals to the output of the last pair of resnet and attention blocks
+            if i == len(blocks) - 1 and additional_residuals is not None:
+                hidden_states = hidden_states + additional_residuals
+
+            output_states = output_states + (hidden_states,)
+
+        if self.downsamplers is not None:
+            for downsampler in self.downsamplers:
+                hidden_states = downsampler(hidden_states, scale=lora_scale)
+
+            output_states = output_states + (hidden_states,)
+
+        return hidden_states, output_states
+
+
+class DownBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        output_scale_factor: float = 1.0,
+        add_downsample: bool = True,
+        downsample_padding: int = 1,
+    ):
+        super().__init__()
+        resnets = []
+
+        for i in range(num_layers):
+            in_channels = in_channels if i == 0 else out_channels
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_downsample:
+            self.downsamplers = nn.ModuleList(
+                [
+                    Downsample2D(
+                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+                    )
+                ]
+            )
+        else:
+            self.downsamplers = None
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
+    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+        output_states = ()
+
+        for resnet in self.resnets:
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs)
+
+                    return custom_forward
+
+                if is_torch_version(">=", "1.11.0"):
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+                    )
+                else:
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet), hidden_states, temb
+                    )
+            else:
+                hidden_states = resnet(hidden_states, temb, scale=scale)
+
+            output_states = output_states + (hidden_states,)
+
+        if self.downsamplers is not None:
+            for downsampler in self.downsamplers:
+                hidden_states = downsampler(hidden_states, scale=scale)
+
+            output_states = output_states + (hidden_states,)
+
+        return hidden_states, output_states
+
+
+class DownEncoderBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        output_scale_factor: float = 1.0,
+        add_downsample: bool = True,
+        downsample_padding: int = 1,
+    ):
+        super().__init__()
+        resnets = []
+
+        for i in range(num_layers):
+            in_channels = in_channels if i == 0 else out_channels
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    temb_channels=None,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_downsample:
+            self.downsamplers = nn.ModuleList(
+                [
+                    Downsample2D(
+                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+                    )
+                ]
+            )
+        else:
+            self.downsamplers = None
+
+    def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
+        for resnet in self.resnets:
+            hidden_states = resnet(hidden_states, temb=None, scale=scale)
+
+        if self.downsamplers is not None:
+            for downsampler in self.downsamplers:
+                hidden_states = downsampler(hidden_states, scale)
+
+        return hidden_states
+
+
+class AttnDownEncoderBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        attention_head_dim: int = 1,
+        output_scale_factor: float = 1.0,
+        add_downsample: bool = True,
+        downsample_padding: int = 1,
+    ):
+        super().__init__()
+        resnets = []
+        attentions = []
+
+        if attention_head_dim is None:
+            logger.warn(
+                f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+            )
+            attention_head_dim = out_channels
+
+        for i in range(num_layers):
+            in_channels = in_channels if i == 0 else out_channels
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    temb_channels=None,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+            attentions.append(
+                Attention(
+                    out_channels,
+                    heads=out_channels // attention_head_dim,
+                    dim_head=attention_head_dim,
+                    rescale_output_factor=output_scale_factor,
+                    eps=resnet_eps,
+                    norm_num_groups=resnet_groups,
+                    residual_connection=True,
+                    bias=True,
+                    upcast_softmax=True,
+                    _from_deprecated_attn_block=True,
+                )
+            )
+
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_downsample:
+            self.downsamplers = nn.ModuleList(
+                [
+                    Downsample2D(
+                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+                    )
+                ]
+            )
+        else:
+            self.downsamplers = None
+
+    def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
+        for resnet, attn in zip(self.resnets, self.attentions):
+            hidden_states = resnet(hidden_states, temb=None, scale=scale)
+            cross_attention_kwargs = {"scale": scale}
+            hidden_states = attn(hidden_states, **cross_attention_kwargs)
+
+        if self.downsamplers is not None:
+            for downsampler in self.downsamplers:
+                hidden_states = downsampler(hidden_states, scale)
+
+        return hidden_states
+
+
+class AttnSkipDownBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_pre_norm: bool = True,
+        attention_head_dim: int = 1,
+        output_scale_factor: float = np.sqrt(2.0),
+        add_downsample: bool = True,
+    ):
+        super().__init__()
+        self.attentions = nn.ModuleList([])
+        self.resnets = nn.ModuleList([])
+
+        if attention_head_dim is None:
+            logger.warn(
+                f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+            )
+            attention_head_dim = out_channels
+
+        for i in range(num_layers):
+            in_channels = in_channels if i == 0 else out_channels
+            self.resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=min(in_channels // 4, 32),
+                    groups_out=min(out_channels // 4, 32),
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+            self.attentions.append(
+                Attention(
+                    out_channels,
+                    heads=out_channels // attention_head_dim,
+                    dim_head=attention_head_dim,
+                    rescale_output_factor=output_scale_factor,
+                    eps=resnet_eps,
+                    norm_num_groups=32,
+                    residual_connection=True,
+                    bias=True,
+                    upcast_softmax=True,
+                    _from_deprecated_attn_block=True,
+                )
+            )
+
+        if add_downsample:
+            self.resnet_down = ResnetBlock2D(
+                in_channels=out_channels,
+                out_channels=out_channels,
+                temb_channels=temb_channels,
+                eps=resnet_eps,
+                groups=min(out_channels // 4, 32),
+                dropout=dropout,
+                time_embedding_norm=resnet_time_scale_shift,
+                non_linearity=resnet_act_fn,
+                output_scale_factor=output_scale_factor,
+                pre_norm=resnet_pre_norm,
+                use_in_shortcut=True,
+                down=True,
+                kernel="fir",
+            )
+            self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
+            self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+        else:
+            self.resnet_down = None
+            self.downsamplers = None
+            self.skip_conv = None
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        temb: Optional[torch.FloatTensor] = None,
+        skip_sample: Optional[torch.FloatTensor] = None,
+        scale: float = 1.0,
+    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
+        output_states = ()
+
+        for resnet, attn in zip(self.resnets, self.attentions):
+            hidden_states = resnet(hidden_states, temb, scale=scale)
+            cross_attention_kwargs = {"scale": scale}
+            hidden_states = attn(hidden_states, **cross_attention_kwargs)
+            output_states += (hidden_states,)
+
+        if self.downsamplers is not None:
+            hidden_states = self.resnet_down(hidden_states, temb, scale=scale)
+            for downsampler in self.downsamplers:
+                skip_sample = downsampler(skip_sample)
+
+            hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+            output_states += (hidden_states,)
+
+        return hidden_states, output_states, skip_sample
+
+
+class SkipDownBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_pre_norm: bool = True,
+        output_scale_factor: float = np.sqrt(2.0),
+        add_downsample: bool = True,
+        downsample_padding: int = 1,
+    ):
+        super().__init__()
+        self.resnets = nn.ModuleList([])
+
+        for i in range(num_layers):
+            in_channels = in_channels if i == 0 else out_channels
+            self.resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=min(in_channels // 4, 32),
+                    groups_out=min(out_channels // 4, 32),
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+
+        if add_downsample:
+            self.resnet_down = ResnetBlock2D(
+                in_channels=out_channels,
+                out_channels=out_channels,
+                temb_channels=temb_channels,
+                eps=resnet_eps,
+                groups=min(out_channels // 4, 32),
+                dropout=dropout,
+                time_embedding_norm=resnet_time_scale_shift,
+                non_linearity=resnet_act_fn,
+                output_scale_factor=output_scale_factor,
+                pre_norm=resnet_pre_norm,
+                use_in_shortcut=True,
+                down=True,
+                kernel="fir",
+            )
+            self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
+            self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+        else:
+            self.resnet_down = None
+            self.downsamplers = None
+            self.skip_conv = None
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        temb: Optional[torch.FloatTensor] = None,
+        skip_sample: Optional[torch.FloatTensor] = None,
+        scale: float = 1.0,
+    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
+        output_states = ()
+
+        for resnet in self.resnets:
+            hidden_states = resnet(hidden_states, temb, scale)
+            output_states += (hidden_states,)
+
+        if self.downsamplers is not None:
+            hidden_states = self.resnet_down(hidden_states, temb, scale)
+            for downsampler in self.downsamplers:
+                skip_sample = downsampler(skip_sample)
+
+            hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+            output_states += (hidden_states,)
+
+        return hidden_states, output_states, skip_sample
+
+
+class ResnetDownsampleBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        output_scale_factor: float = 1.0,
+        add_downsample: bool = True,
+        skip_time_act: bool = False,
+    ):
+        super().__init__()
+        resnets = []
+
+        for i in range(num_layers):
+            in_channels = in_channels if i == 0 else out_channels
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                    skip_time_act=skip_time_act,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_downsample:
+            self.downsamplers = nn.ModuleList(
+                [
+                    ResnetBlock2D(
+                        in_channels=out_channels,
+                        out_channels=out_channels,
+                        temb_channels=temb_channels,
+                        eps=resnet_eps,
+                        groups=resnet_groups,
+                        dropout=dropout,
+                        time_embedding_norm=resnet_time_scale_shift,
+                        non_linearity=resnet_act_fn,
+                        output_scale_factor=output_scale_factor,
+                        pre_norm=resnet_pre_norm,
+                        skip_time_act=skip_time_act,
+                        down=True,
+                    )
+                ]
+            )
+        else:
+            self.downsamplers = None
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
+    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+        output_states = ()
+
+        for resnet in self.resnets:
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs)
+
+                    return custom_forward
+
+                if is_torch_version(">=", "1.11.0"):
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+                    )
+                else:
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet), hidden_states, temb
+                    )
+            else:
+                hidden_states = resnet(hidden_states, temb, scale)
+
+            output_states = output_states + (hidden_states,)
+
+        if self.downsamplers is not None:
+            for downsampler in self.downsamplers:
+                hidden_states = downsampler(hidden_states, temb, scale)
+
+            output_states = output_states + (hidden_states,)
+
+        return hidden_states, output_states
+
+
+class SimpleCrossAttnDownBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        attention_head_dim: int = 1,
+        cross_attention_dim: int = 1280,
+        output_scale_factor: float = 1.0,
+        add_downsample: bool = True,
+        skip_time_act: bool = False,
+        only_cross_attention: bool = False,
+        cross_attention_norm: Optional[str] = None,
+    ):
+        super().__init__()
+
+        self.has_cross_attention = True
+
+        resnets = []
+        attentions = []
+
+        self.attention_head_dim = attention_head_dim
+        self.num_heads = out_channels // self.attention_head_dim
+
+        for i in range(num_layers):
+            in_channels = in_channels if i == 0 else out_channels
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                    skip_time_act=skip_time_act,
+                )
+            )
+
+            processor = (
+                AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
+            )
+
+            attentions.append(
+                Attention(
+                    query_dim=out_channels,
+                    cross_attention_dim=out_channels,
+                    heads=self.num_heads,
+                    dim_head=attention_head_dim,
+                    added_kv_proj_dim=cross_attention_dim,
+                    norm_num_groups=resnet_groups,
+                    bias=True,
+                    upcast_softmax=True,
+                    only_cross_attention=only_cross_attention,
+                    cross_attention_norm=cross_attention_norm,
+                    processor=processor,
+                )
+            )
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_downsample:
+            self.downsamplers = nn.ModuleList(
+                [
+                    ResnetBlock2D(
+                        in_channels=out_channels,
+                        out_channels=out_channels,
+                        temb_channels=temb_channels,
+                        eps=resnet_eps,
+                        groups=resnet_groups,
+                        dropout=dropout,
+                        time_embedding_norm=resnet_time_scale_shift,
+                        non_linearity=resnet_act_fn,
+                        output_scale_factor=output_scale_factor,
+                        pre_norm=resnet_pre_norm,
+                        skip_time_act=skip_time_act,
+                        down=True,
+                    )
+                ]
+            )
+        else:
+            self.downsamplers = None
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        temb: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+        output_states = ()
+        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+
+        lora_scale = cross_attention_kwargs.get("scale", 1.0)
+
+        if attention_mask is None:
+            # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
+            mask = None if encoder_hidden_states is None else encoder_attention_mask
+        else:
+            # when attention_mask is defined: we don't even check for encoder_attention_mask.
+            # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
+            # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
+            #       then we can simplify this whole if/else block to:
+            #         mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
+            mask = attention_mask
+
+        for resnet, attn in zip(self.resnets, self.attentions):
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module, return_dict=None):
+                    def custom_forward(*inputs):
+                        if return_dict is not None:
+                            return module(*inputs, return_dict=return_dict)
+                        else:
+                            return module(*inputs)
+
+                    return custom_forward
+
+                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    attention_mask=mask,
+                    **cross_attention_kwargs,
+                )
+            else:
+                hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    attention_mask=mask,
+                    **cross_attention_kwargs,
+                )
+
+            output_states = output_states + (hidden_states,)
+
+        if self.downsamplers is not None:
+            for downsampler in self.downsamplers:
+                hidden_states = downsampler(hidden_states, temb, scale=lora_scale)
+
+            output_states = output_states + (hidden_states,)
+
+        return hidden_states, output_states
+
+
+class KDownBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 4,
+        resnet_eps: float = 1e-5,
+        resnet_act_fn: str = "gelu",
+        resnet_group_size: int = 32,
+        add_downsample: bool = False,
+    ):
+        super().__init__()
+        resnets = []
+
+        for i in range(num_layers):
+            in_channels = in_channels if i == 0 else out_channels
+            groups = in_channels // resnet_group_size
+            groups_out = out_channels // resnet_group_size
+
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    dropout=dropout,
+                    temb_channels=temb_channels,
+                    groups=groups,
+                    groups_out=groups_out,
+                    eps=resnet_eps,
+                    non_linearity=resnet_act_fn,
+                    time_embedding_norm="ada_group",
+                    conv_shortcut_bias=False,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_downsample:
+            # YiYi's comments- might be able to use FirDownsample2D, look into details later
+            self.downsamplers = nn.ModuleList([KDownsample2D()])
+        else:
+            self.downsamplers = None
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
+    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+        output_states = ()
+
+        for resnet in self.resnets:
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs)
+
+                    return custom_forward
+
+                if is_torch_version(">=", "1.11.0"):
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+                    )
+                else:
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet), hidden_states, temb
+                    )
+            else:
+                hidden_states = resnet(hidden_states, temb, scale)
+
+            output_states += (hidden_states,)
+
+        if self.downsamplers is not None:
+            for downsampler in self.downsamplers:
+                hidden_states = downsampler(hidden_states)
+
+        return hidden_states, output_states
+
+
+class KCrossAttnDownBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        cross_attention_dim: int,
+        dropout: float = 0.0,
+        num_layers: int = 4,
+        resnet_group_size: int = 32,
+        add_downsample: bool = True,
+        attention_head_dim: int = 64,
+        add_self_attention: bool = False,
+        resnet_eps: float = 1e-5,
+        resnet_act_fn: str = "gelu",
+    ):
+        super().__init__()
+        resnets = []
+        attentions = []
+
+        self.has_cross_attention = True
+
+        for i in range(num_layers):
+            in_channels = in_channels if i == 0 else out_channels
+            groups = in_channels // resnet_group_size
+            groups_out = out_channels // resnet_group_size
+
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    dropout=dropout,
+                    temb_channels=temb_channels,
+                    groups=groups,
+                    groups_out=groups_out,
+                    eps=resnet_eps,
+                    non_linearity=resnet_act_fn,
+                    time_embedding_norm="ada_group",
+                    conv_shortcut_bias=False,
+                )
+            )
+            attentions.append(
+                KAttentionBlock(
+                    out_channels,
+                    out_channels // attention_head_dim,
+                    attention_head_dim,
+                    cross_attention_dim=cross_attention_dim,
+                    temb_channels=temb_channels,
+                    attention_bias=True,
+                    add_self_attention=add_self_attention,
+                    cross_attention_norm="layer_norm",
+                    group_size=resnet_group_size,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+        self.attentions = nn.ModuleList(attentions)
+
+        if add_downsample:
+            self.downsamplers = nn.ModuleList([KDownsample2D()])
+        else:
+            self.downsamplers = None
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        temb: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+        output_states = ()
+        lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+        for resnet, attn in zip(self.resnets, self.attentions):
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module, return_dict=None):
+                    def custom_forward(*inputs):
+                        if return_dict is not None:
+                            return module(*inputs, return_dict=return_dict)
+                        else:
+                            return module(*inputs)
+
+                    return custom_forward
+
+                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(resnet),
+                    hidden_states,
+                    temb,
+                    **ckpt_kwargs,
+                )
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    emb=temb,
+                    attention_mask=attention_mask,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    encoder_attention_mask=encoder_attention_mask,
+                )
+            else:
+                hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    emb=temb,
+                    attention_mask=attention_mask,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    encoder_attention_mask=encoder_attention_mask,
+                )
+
+            if self.downsamplers is None:
+                output_states += (None,)
+            else:
+                output_states += (hidden_states,)
+
+        if self.downsamplers is not None:
+            for downsampler in self.downsamplers:
+                hidden_states = downsampler(hidden_states)
+
+        return hidden_states, output_states
+
+
+class AttnUpBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        prev_output_channel: int,
+        out_channels: int,
+        temb_channels: int,
+        resolution_idx: int = None,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        attention_head_dim: int = 1,
+        output_scale_factor: float = 1.0,
+        upsample_type: str = "conv",
+    ):
+        super().__init__()
+        resnets = []
+        attentions = []
+
+        self.upsample_type = upsample_type
+
+        if attention_head_dim is None:
+            logger.warn(
+                f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+            )
+            attention_head_dim = out_channels
+
+        for i in range(num_layers):
+            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+            resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=resnet_in_channels + res_skip_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+            attentions.append(
+                Attention(
+                    out_channels,
+                    heads=out_channels // attention_head_dim,
+                    dim_head=attention_head_dim,
+                    rescale_output_factor=output_scale_factor,
+                    eps=resnet_eps,
+                    norm_num_groups=resnet_groups,
+                    residual_connection=True,
+                    bias=True,
+                    upcast_softmax=True,
+                    _from_deprecated_attn_block=True,
+                )
+            )
+
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+
+        if upsample_type == "conv":
+            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+        elif upsample_type == "resnet":
+            self.upsamplers = nn.ModuleList(
+                [
+                    ResnetBlock2D(
+                        in_channels=out_channels,
+                        out_channels=out_channels,
+                        temb_channels=temb_channels,
+                        eps=resnet_eps,
+                        groups=resnet_groups,
+                        dropout=dropout,
+                        time_embedding_norm=resnet_time_scale_shift,
+                        non_linearity=resnet_act_fn,
+                        output_scale_factor=output_scale_factor,
+                        pre_norm=resnet_pre_norm,
+                        up=True,
+                    )
+                ]
+            )
+        else:
+            self.upsamplers = None
+
+        self.resolution_idx = resolution_idx
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+        temb: Optional[torch.FloatTensor] = None,
+        upsample_size: Optional[int] = None,
+        scale: float = 1.0,
+    ) -> torch.FloatTensor:
+        for resnet, attn in zip(self.resnets, self.attentions):
+            # pop res hidden states
+            res_hidden_states = res_hidden_states_tuple[-1]
+            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+            hidden_states = resnet(hidden_states, temb, scale=scale)
+            cross_attention_kwargs = {"scale": scale}
+            hidden_states = attn(hidden_states, **cross_attention_kwargs)
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                if self.upsample_type == "resnet":
+                    hidden_states = upsampler(hidden_states, temb=temb, scale=scale)
+                else:
+                    hidden_states = upsampler(hidden_states, scale=scale)
+
+        return hidden_states
+
+
+class CrossAttnUpBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        prev_output_channel: int,
+        temb_channels: int,
+        resolution_idx: Optional[int] = None,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        num_attention_heads: int = 1,
+        cross_attention_dim: int = 1280,
+        output_scale_factor: float = 1.0,
+        add_upsample: bool = True,
+        dual_cross_attention: bool = False,
+        use_linear_projection: bool = False,
+        only_cross_attention: bool = False,
+        upcast_attention: bool = False,
+        attention_type: str = "default",
+    ):
+        super().__init__()
+        resnets = []
+        attentions = []
+
+        self.has_cross_attention = True
+        self.num_attention_heads = num_attention_heads
+
+        if isinstance(transformer_layers_per_block, int):
+            transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+        for i in range(num_layers):
+            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+            resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=resnet_in_channels + res_skip_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+            if not dual_cross_attention:
+                # Transformer2DModelWithSwitcher
+                attentions.append(
+                    Transformer2DModel(
+                        num_attention_heads,
+                        out_channels // num_attention_heads,
+                        in_channels=out_channels,
+                        num_layers=transformer_layers_per_block[i],
+                        cross_attention_dim=cross_attention_dim,
+                        norm_num_groups=resnet_groups,
+                        use_linear_projection=use_linear_projection,
+                        only_cross_attention=only_cross_attention,
+                        upcast_attention=upcast_attention,
+                        attention_type=attention_type,
+                    )
+                )
+            else:
+                attentions.append(
+                    DualTransformer2DModel(
+                        num_attention_heads,
+                        out_channels // num_attention_heads,
+                        in_channels=out_channels,
+                        num_layers=1,
+                        cross_attention_dim=cross_attention_dim,
+                        norm_num_groups=resnet_groups,
+                    )
+                )
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_upsample:
+            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+        else:
+            self.upsamplers = None
+
+        self.gradient_checkpointing = False
+        self.resolution_idx = resolution_idx
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+        temb: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        upsample_size: Optional[int] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+    ) -> torch.FloatTensor:
+        lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+        is_freeu_enabled = (
+            getattr(self, "s1", None)
+            and getattr(self, "s2", None)
+            and getattr(self, "b1", None)
+            and getattr(self, "b2", None)
+        )
+
+        for resnet, attn in zip(self.resnets, self.attentions):
+            # pop res hidden states
+            res_hidden_states = res_hidden_states_tuple[-1]
+            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+            # FreeU: Only operate on the first two stages
+            if is_freeu_enabled:
+                hidden_states, res_hidden_states = apply_freeu(
+                    self.resolution_idx,
+                    hidden_states,
+                    res_hidden_states,
+                    s1=self.s1,
+                    s2=self.s2,
+                    b1=self.b1,
+                    b2=self.b2,
+                )
+
+            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module, return_dict=None):
+                    def custom_forward(*inputs):
+                        if return_dict is not None:
+                            return module(*inputs, return_dict=return_dict)
+                        else:
+                            return module(*inputs)
+
+                    return custom_forward
+
+                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(resnet),
+                    hidden_states,
+                    temb,
+                    **ckpt_kwargs,
+                )
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    attention_mask=attention_mask,
+                    encoder_attention_mask=encoder_attention_mask,
+                    return_dict=False,
+                )[0]
+            else:
+                hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    attention_mask=attention_mask,
+                    encoder_attention_mask=encoder_attention_mask,
+                    return_dict=False,
+                )[0]
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
+
+        return hidden_states
+
+
+class UpBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        prev_output_channel: int,
+        out_channels: int,
+        temb_channels: int,
+        resolution_idx: Optional[int] = None,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        output_scale_factor: float = 1.0,
+        add_upsample: bool = True,
+    ):
+        super().__init__()
+        resnets = []
+
+        for i in range(num_layers):
+            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+            resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=resnet_in_channels + res_skip_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_upsample:
+            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+        else:
+            self.upsamplers = None
+
+        self.gradient_checkpointing = False
+        self.resolution_idx = resolution_idx
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+        temb: Optional[torch.FloatTensor] = None,
+        upsample_size: Optional[int] = None,
+        scale: float = 1.0,
+    ) -> torch.FloatTensor:
+        is_freeu_enabled = (
+            getattr(self, "s1", None)
+            and getattr(self, "s2", None)
+            and getattr(self, "b1", None)
+            and getattr(self, "b2", None)
+        )
+
+        for resnet in self.resnets:
+            # pop res hidden states
+            res_hidden_states = res_hidden_states_tuple[-1]
+            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+            # FreeU: Only operate on the first two stages
+            if is_freeu_enabled:
+                hidden_states, res_hidden_states = apply_freeu(
+                    self.resolution_idx,
+                    hidden_states,
+                    res_hidden_states,
+                    s1=self.s1,
+                    s2=self.s2,
+                    b1=self.b1,
+                    b2=self.b2,
+                )
+
+            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs)
+
+                    return custom_forward
+
+                if is_torch_version(">=", "1.11.0"):
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+                    )
+                else:
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet), hidden_states, temb
+                    )
+            else:
+                hidden_states = resnet(hidden_states, temb, scale=scale)
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
+
+        return hidden_states
+
+
+class UpDecoderBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        resolution_idx: Optional[int] = None,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",  # default, spatial
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        output_scale_factor: float = 1.0,
+        add_upsample: bool = True,
+        temb_channels: Optional[int] = None,
+    ):
+        super().__init__()
+        resnets = []
+
+        for i in range(num_layers):
+            input_channels = in_channels if i == 0 else out_channels
+
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=input_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_upsample:
+            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+        else:
+            self.upsamplers = None
+
+        self.resolution_idx = resolution_idx
+
+    def forward(
+        self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
+    ) -> torch.FloatTensor:
+        for resnet in self.resnets:
+            hidden_states = resnet(hidden_states, temb=temb, scale=scale)
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                hidden_states = upsampler(hidden_states)
+
+        return hidden_states
+
+
+class AttnUpDecoderBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        resolution_idx: Optional[int] = None,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        attention_head_dim: int = 1,
+        output_scale_factor: float = 1.0,
+        add_upsample: bool = True,
+        temb_channels: Optional[int] = None,
+    ):
+        super().__init__()
+        resnets = []
+        attentions = []
+
+        if attention_head_dim is None:
+            logger.warn(
+                f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
+            )
+            attention_head_dim = out_channels
+
+        for i in range(num_layers):
+            input_channels = in_channels if i == 0 else out_channels
+
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=input_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+            attentions.append(
+                Attention(
+                    out_channels,
+                    heads=out_channels // attention_head_dim,
+                    dim_head=attention_head_dim,
+                    rescale_output_factor=output_scale_factor,
+                    eps=resnet_eps,
+                    norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None,
+                    spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
+                    residual_connection=True,
+                    bias=True,
+                    upcast_softmax=True,
+                    _from_deprecated_attn_block=True,
+                )
+            )
+
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_upsample:
+            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+        else:
+            self.upsamplers = None
+
+        self.resolution_idx = resolution_idx
+
+    def forward(
+        self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
+    ) -> torch.FloatTensor:
+        for resnet, attn in zip(self.resnets, self.attentions):
+            hidden_states = resnet(hidden_states, temb=temb, scale=scale)
+            cross_attention_kwargs = {"scale": scale}
+            hidden_states = attn(hidden_states, temb=temb, **cross_attention_kwargs)
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                hidden_states = upsampler(hidden_states, scale=scale)
+
+        return hidden_states
+
+
+class AttnSkipUpBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        prev_output_channel: int,
+        out_channels: int,
+        temb_channels: int,
+        resolution_idx: Optional[int] = None,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_pre_norm: bool = True,
+        attention_head_dim: int = 1,
+        output_scale_factor: float = np.sqrt(2.0),
+        add_upsample: bool = True,
+    ):
+        super().__init__()
+        self.attentions = nn.ModuleList([])
+        self.resnets = nn.ModuleList([])
+
+        for i in range(num_layers):
+            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+            resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+            self.resnets.append(
+                ResnetBlock2D(
+                    in_channels=resnet_in_channels + res_skip_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=min(resnet_in_channels + res_skip_channels // 4, 32),
+                    groups_out=min(out_channels // 4, 32),
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+
+        if attention_head_dim is None:
+            logger.warn(
+                f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
+            )
+            attention_head_dim = out_channels
+
+        self.attentions.append(
+            Attention(
+                out_channels,
+                heads=out_channels // attention_head_dim,
+                dim_head=attention_head_dim,
+                rescale_output_factor=output_scale_factor,
+                eps=resnet_eps,
+                norm_num_groups=32,
+                residual_connection=True,
+                bias=True,
+                upcast_softmax=True,
+                _from_deprecated_attn_block=True,
+            )
+        )
+
+        self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+        if add_upsample:
+            self.resnet_up = ResnetBlock2D(
+                in_channels=out_channels,
+                out_channels=out_channels,
+                temb_channels=temb_channels,
+                eps=resnet_eps,
+                groups=min(out_channels // 4, 32),
+                groups_out=min(out_channels // 4, 32),
+                dropout=dropout,
+                time_embedding_norm=resnet_time_scale_shift,
+                non_linearity=resnet_act_fn,
+                output_scale_factor=output_scale_factor,
+                pre_norm=resnet_pre_norm,
+                use_in_shortcut=True,
+                up=True,
+                kernel="fir",
+            )
+            self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+            self.skip_norm = torch.nn.GroupNorm(
+                num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+            )
+            self.act = nn.SiLU()
+        else:
+            self.resnet_up = None
+            self.skip_conv = None
+            self.skip_norm = None
+            self.act = None
+
+        self.resolution_idx = resolution_idx
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+        temb: Optional[torch.FloatTensor] = None,
+        skip_sample=None,
+        scale: float = 1.0,
+    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+        for resnet in self.resnets:
+            # pop res hidden states
+            res_hidden_states = res_hidden_states_tuple[-1]
+            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+            hidden_states = resnet(hidden_states, temb, scale=scale)
+
+        cross_attention_kwargs = {"scale": scale}
+        hidden_states = self.attentions[0](hidden_states, **cross_attention_kwargs)
+
+        if skip_sample is not None:
+            skip_sample = self.upsampler(skip_sample)
+        else:
+            skip_sample = 0
+
+        if self.resnet_up is not None:
+            skip_sample_states = self.skip_norm(hidden_states)
+            skip_sample_states = self.act(skip_sample_states)
+            skip_sample_states = self.skip_conv(skip_sample_states)
+
+            skip_sample = skip_sample + skip_sample_states
+
+            hidden_states = self.resnet_up(hidden_states, temb, scale=scale)
+
+        return hidden_states, skip_sample
+
+
+class SkipUpBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        prev_output_channel: int,
+        out_channels: int,
+        temb_channels: int,
+        resolution_idx: Optional[int] = None,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_pre_norm: bool = True,
+        output_scale_factor: float = np.sqrt(2.0),
+        add_upsample: bool = True,
+        upsample_padding: int = 1,
+    ):
+        super().__init__()
+        self.resnets = nn.ModuleList([])
+
+        for i in range(num_layers):
+            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+            resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+            self.resnets.append(
+                ResnetBlock2D(
+                    in_channels=resnet_in_channels + res_skip_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
+                    groups_out=min(out_channels // 4, 32),
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+
+        self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+        if add_upsample:
+            self.resnet_up = ResnetBlock2D(
+                in_channels=out_channels,
+                out_channels=out_channels,
+                temb_channels=temb_channels,
+                eps=resnet_eps,
+                groups=min(out_channels // 4, 32),
+                groups_out=min(out_channels // 4, 32),
+                dropout=dropout,
+                time_embedding_norm=resnet_time_scale_shift,
+                non_linearity=resnet_act_fn,
+                output_scale_factor=output_scale_factor,
+                pre_norm=resnet_pre_norm,
+                use_in_shortcut=True,
+                up=True,
+                kernel="fir",
+            )
+            self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+            self.skip_norm = torch.nn.GroupNorm(
+                num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+            )
+            self.act = nn.SiLU()
+        else:
+            self.resnet_up = None
+            self.skip_conv = None
+            self.skip_norm = None
+            self.act = None
+
+        self.resolution_idx = resolution_idx
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+        temb: Optional[torch.FloatTensor] = None,
+        skip_sample=None,
+        scale: float = 1.0,
+    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+        for resnet in self.resnets:
+            # pop res hidden states
+            res_hidden_states = res_hidden_states_tuple[-1]
+            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+            hidden_states = resnet(hidden_states, temb, scale=scale)
+
+        if skip_sample is not None:
+            skip_sample = self.upsampler(skip_sample)
+        else:
+            skip_sample = 0
+
+        if self.resnet_up is not None:
+            skip_sample_states = self.skip_norm(hidden_states)
+            skip_sample_states = self.act(skip_sample_states)
+            skip_sample_states = self.skip_conv(skip_sample_states)
+
+            skip_sample = skip_sample + skip_sample_states
+
+            hidden_states = self.resnet_up(hidden_states, temb, scale=scale)
+
+        return hidden_states, skip_sample
+
+
+class ResnetUpsampleBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        prev_output_channel: int,
+        out_channels: int,
+        temb_channels: int,
+        resolution_idx: Optional[int] = None,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        output_scale_factor: float = 1.0,
+        add_upsample: bool = True,
+        skip_time_act: bool = False,
+    ):
+        super().__init__()
+        resnets = []
+
+        for i in range(num_layers):
+            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+            resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=resnet_in_channels + res_skip_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                    skip_time_act=skip_time_act,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_upsample:
+            self.upsamplers = nn.ModuleList(
+                [
+                    ResnetBlock2D(
+                        in_channels=out_channels,
+                        out_channels=out_channels,
+                        temb_channels=temb_channels,
+                        eps=resnet_eps,
+                        groups=resnet_groups,
+                        dropout=dropout,
+                        time_embedding_norm=resnet_time_scale_shift,
+                        non_linearity=resnet_act_fn,
+                        output_scale_factor=output_scale_factor,
+                        pre_norm=resnet_pre_norm,
+                        skip_time_act=skip_time_act,
+                        up=True,
+                    )
+                ]
+            )
+        else:
+            self.upsamplers = None
+
+        self.gradient_checkpointing = False
+        self.resolution_idx = resolution_idx
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+        temb: Optional[torch.FloatTensor] = None,
+        upsample_size: Optional[int] = None,
+        scale: float = 1.0,
+    ) -> torch.FloatTensor:
+        for resnet in self.resnets:
+            # pop res hidden states
+            res_hidden_states = res_hidden_states_tuple[-1]
+            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs)
+
+                    return custom_forward
+
+                if is_torch_version(">=", "1.11.0"):
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+                    )
+                else:
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet), hidden_states, temb
+                    )
+            else:
+                hidden_states = resnet(hidden_states, temb, scale=scale)
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                hidden_states = upsampler(hidden_states, temb, scale=scale)
+
+        return hidden_states
+
+
+class SimpleCrossAttnUpBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        prev_output_channel: int,
+        temb_channels: int,
+        resolution_idx: Optional[int] = None,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        attention_head_dim: int = 1,
+        cross_attention_dim: int = 1280,
+        output_scale_factor: float = 1.0,
+        add_upsample: bool = True,
+        skip_time_act: bool = False,
+        only_cross_attention: bool = False,
+        cross_attention_norm: Optional[str] = None,
+    ):
+        super().__init__()
+        resnets = []
+        attentions = []
+
+        self.has_cross_attention = True
+        self.attention_head_dim = attention_head_dim
+
+        self.num_heads = out_channels // self.attention_head_dim
+
+        for i in range(num_layers):
+            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+            resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=resnet_in_channels + res_skip_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                    skip_time_act=skip_time_act,
+                )
+            )
+
+            processor = (
+                AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
+            )
+
+            attentions.append(
+                Attention(
+                    query_dim=out_channels,
+                    cross_attention_dim=out_channels,
+                    heads=self.num_heads,
+                    dim_head=self.attention_head_dim,
+                    added_kv_proj_dim=cross_attention_dim,
+                    norm_num_groups=resnet_groups,
+                    bias=True,
+                    upcast_softmax=True,
+                    only_cross_attention=only_cross_attention,
+                    cross_attention_norm=cross_attention_norm,
+                    processor=processor,
+                )
+            )
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_upsample:
+            self.upsamplers = nn.ModuleList(
+                [
+                    ResnetBlock2D(
+                        in_channels=out_channels,
+                        out_channels=out_channels,
+                        temb_channels=temb_channels,
+                        eps=resnet_eps,
+                        groups=resnet_groups,
+                        dropout=dropout,
+                        time_embedding_norm=resnet_time_scale_shift,
+                        non_linearity=resnet_act_fn,
+                        output_scale_factor=output_scale_factor,
+                        pre_norm=resnet_pre_norm,
+                        skip_time_act=skip_time_act,
+                        up=True,
+                    )
+                ]
+            )
+        else:
+            self.upsamplers = None
+
+        self.gradient_checkpointing = False
+        self.resolution_idx = resolution_idx
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+        temb: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        upsample_size: Optional[int] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+    ) -> torch.FloatTensor:
+        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+
+        lora_scale = cross_attention_kwargs.get("scale", 1.0)
+        if attention_mask is None:
+            # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
+            mask = None if encoder_hidden_states is None else encoder_attention_mask
+        else:
+            # when attention_mask is defined: we don't even check for encoder_attention_mask.
+            # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
+            # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
+            #       then we can simplify this whole if/else block to:
+            #         mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
+            mask = attention_mask
+
+        for resnet, attn in zip(self.resnets, self.attentions):
+            # resnet
+            # pop res hidden states
+            res_hidden_states = res_hidden_states_tuple[-1]
+            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module, return_dict=None):
+                    def custom_forward(*inputs):
+                        if return_dict is not None:
+                            return module(*inputs, return_dict=return_dict)
+                        else:
+                            return module(*inputs)
+
+                    return custom_forward
+
+                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    attention_mask=mask,
+                    **cross_attention_kwargs,
+                )
+            else:
+                hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    attention_mask=mask,
+                    **cross_attention_kwargs,
+                )
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                hidden_states = upsampler(hidden_states, temb, scale=lora_scale)
+
+        return hidden_states
+
+
+class KUpBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        resolution_idx: int,
+        dropout: float = 0.0,
+        num_layers: int = 5,
+        resnet_eps: float = 1e-5,
+        resnet_act_fn: str = "gelu",
+        resnet_group_size: Optional[int] = 32,
+        add_upsample: bool = True,
+    ):
+        super().__init__()
+        resnets = []
+        k_in_channels = 2 * out_channels
+        k_out_channels = in_channels
+        num_layers = num_layers - 1
+
+        for i in range(num_layers):
+            in_channels = k_in_channels if i == 0 else out_channels
+            groups = in_channels // resnet_group_size
+            groups_out = out_channels // resnet_group_size
+
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=k_out_channels if (i == num_layers - 1) else out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=groups,
+                    groups_out=groups_out,
+                    dropout=dropout,
+                    non_linearity=resnet_act_fn,
+                    time_embedding_norm="ada_group",
+                    conv_shortcut_bias=False,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_upsample:
+            self.upsamplers = nn.ModuleList([KUpsample2D()])
+        else:
+            self.upsamplers = None
+
+        self.gradient_checkpointing = False
+        self.resolution_idx = resolution_idx
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+        temb: Optional[torch.FloatTensor] = None,
+        upsample_size: Optional[int] = None,
+        scale: float = 1.0,
+    ) -> torch.FloatTensor:
+        res_hidden_states_tuple = res_hidden_states_tuple[-1]
+        if res_hidden_states_tuple is not None:
+            hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
+
+        for resnet in self.resnets:
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs)
+
+                    return custom_forward
+
+                if is_torch_version(">=", "1.11.0"):
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+                    )
+                else:
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet), hidden_states, temb
+                    )
+            else:
+                hidden_states = resnet(hidden_states, temb, scale=scale)
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                hidden_states = upsampler(hidden_states)
+
+        return hidden_states
+
+
+class KCrossAttnUpBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        resolution_idx: int,
+        dropout: float = 0.0,
+        num_layers: int = 4,
+        resnet_eps: float = 1e-5,
+        resnet_act_fn: str = "gelu",
+        resnet_group_size: int = 32,
+        attention_head_dim: int = 1,  # attention dim_head
+        cross_attention_dim: int = 768,
+        add_upsample: bool = True,
+        upcast_attention: bool = False,
+    ):
+        super().__init__()
+        resnets = []
+        attentions = []
+
+        is_first_block = in_channels == out_channels == temb_channels
+        is_middle_block = in_channels != out_channels
+        add_self_attention = True if is_first_block else False
+
+        self.has_cross_attention = True
+        self.attention_head_dim = attention_head_dim
+
+        # in_channels, and out_channels for the block (k-unet)
+        k_in_channels = out_channels if is_first_block else 2 * out_channels
+        k_out_channels = in_channels
+
+        num_layers = num_layers - 1
+
+        for i in range(num_layers):
+            in_channels = k_in_channels if i == 0 else out_channels
+            groups = in_channels // resnet_group_size
+            groups_out = out_channels // resnet_group_size
+
+            if is_middle_block and (i == num_layers - 1):
+                conv_2d_out_channels = k_out_channels
+            else:
+                conv_2d_out_channels = None
+
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    conv_2d_out_channels=conv_2d_out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=groups,
+                    groups_out=groups_out,
+                    dropout=dropout,
+                    non_linearity=resnet_act_fn,
+                    time_embedding_norm="ada_group",
+                    conv_shortcut_bias=False,
+                )
+            )
+            attentions.append(
+                KAttentionBlock(
+                    k_out_channels if (i == num_layers - 1) else out_channels,
+                    k_out_channels // attention_head_dim
+                    if (i == num_layers - 1)
+                    else out_channels // attention_head_dim,
+                    attention_head_dim,
+                    cross_attention_dim=cross_attention_dim,
+                    temb_channels=temb_channels,
+                    attention_bias=True,
+                    add_self_attention=add_self_attention,
+                    cross_attention_norm="layer_norm",
+                    upcast_attention=upcast_attention,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+        self.attentions = nn.ModuleList(attentions)
+
+        if add_upsample:
+            self.upsamplers = nn.ModuleList([KUpsample2D()])
+        else:
+            self.upsamplers = None
+
+        self.gradient_checkpointing = False
+        self.resolution_idx = resolution_idx
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+        temb: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        upsample_size: Optional[int] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+    ) -> torch.FloatTensor:
+        res_hidden_states_tuple = res_hidden_states_tuple[-1]
+        if res_hidden_states_tuple is not None:
+            hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
+
+        lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+        for resnet, attn in zip(self.resnets, self.attentions):
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module, return_dict=None):
+                    def custom_forward(*inputs):
+                        if return_dict is not None:
+                            return module(*inputs, return_dict=return_dict)
+                        else:
+                            return module(*inputs)
+
+                    return custom_forward
+
+                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(resnet),
+                    hidden_states,
+                    temb,
+                    **ckpt_kwargs,
+                )
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    emb=temb,
+                    attention_mask=attention_mask,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    encoder_attention_mask=encoder_attention_mask,
+                )
+            else:
+                hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    emb=temb,
+                    attention_mask=attention_mask,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    encoder_attention_mask=encoder_attention_mask,
+                )
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                hidden_states = upsampler(hidden_states)
+
+        return hidden_states
+
+
+# can potentially later be renamed to `No-feed-forward` attention
+class KAttentionBlock(nn.Module):
+    r"""
+    A basic Transformer block.
+
+    Parameters:
+        dim (`int`): The number of channels in the input and output.
+        num_attention_heads (`int`): The number of heads to use for multi-head attention.
+        attention_head_dim (`int`): The number of channels in each head.
+        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+        attention_bias (`bool`, *optional*, defaults to `False`):
+            Configure if the attention layers should contain a bias parameter.
+        upcast_attention (`bool`, *optional*, defaults to `False`):
+            Set to `True` to upcast the attention computation to `float32`.
+        temb_channels (`int`, *optional*, defaults to 768):
+            The number of channels in the token embedding.
+        add_self_attention (`bool`, *optional*, defaults to `False`):
+            Set to `True` to add self-attention to the block.
+        cross_attention_norm (`str`, *optional*, defaults to `None`):
+            The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
+        group_size (`int`, *optional*, defaults to 32):
+            The number of groups to separate the channels into for group normalization.
+    """
+
+    def __init__(
+        self,
+        dim: int,
+        num_attention_heads: int,
+        attention_head_dim: int,
+        dropout: float = 0.0,
+        cross_attention_dim: Optional[int] = None,
+        attention_bias: bool = False,
+        upcast_attention: bool = False,
+        temb_channels: int = 768,  # for ada_group_norm
+        add_self_attention: bool = False,
+        cross_attention_norm: Optional[str] = None,
+        group_size: int = 32,
+    ):
+        super().__init__()
+        self.add_self_attention = add_self_attention
+
+        # 1. Self-Attn
+        if add_self_attention:
+            self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
+            self.attn1 = Attention(
+                query_dim=dim,
+                heads=num_attention_heads,
+                dim_head=attention_head_dim,
+                dropout=dropout,
+                bias=attention_bias,
+                cross_attention_dim=None,
+                cross_attention_norm=None,
+            )
+
+        # 2. Cross-Attn
+        self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
+        self.attn2 = Attention(
+            query_dim=dim,
+            cross_attention_dim=cross_attention_dim,
+            heads=num_attention_heads,
+            dim_head=attention_head_dim,
+            dropout=dropout,
+            bias=attention_bias,
+            upcast_attention=upcast_attention,
+            cross_attention_norm=cross_attention_norm,
+        )
+
+    def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:
+        return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)
+
+    def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:
+        return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        # TODO: mark emb as non-optional (self.norm2 requires it).
+        #       requires assessing impact of change to positional param interface.
+        emb: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+    ) -> torch.FloatTensor:
+        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+
+        # 1. Self-Attention
+        if self.add_self_attention:
+            norm_hidden_states = self.norm1(hidden_states, emb)
+
+            height, weight = norm_hidden_states.shape[2:]
+            norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)
+
+            attn_output = self.attn1(
+                norm_hidden_states,
+                encoder_hidden_states=None,
+                attention_mask=attention_mask,
+                **cross_attention_kwargs,
+            )
+            attn_output = self._to_4d(attn_output, height, weight)
+
+            hidden_states = attn_output + hidden_states
+
+        # 2. Cross-Attention/None
+        norm_hidden_states = self.norm2(hidden_states, emb)
+
+        height, weight = norm_hidden_states.shape[2:]
+        norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)
+        attn_output = self.attn2(
+            norm_hidden_states,
+            encoder_hidden_states=encoder_hidden_states,
+            attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask,
+            **cross_attention_kwargs,
+        )
+        attn_output = self._to_4d(attn_output, height, weight)
+
+        hidden_states = attn_output + hidden_states
+
+        return hidden_states
\ No newline at end of file
diff --git a/foleycrafter/models/auffusion_unet.py b/foleycrafter/models/auffusion_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..508b89dacd0ce137a8f1767397d07925b0daab01
--- /dev/null
+++ b/foleycrafter/models/auffusion_unet.py
@@ -0,0 +1,1260 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils.import_utils import is_xformers_available, is_torch_version
+from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.models.activations import get_activation
+# from diffusers import StableDiffusionGLIGENPipeline
+from diffusers.models.attention_processor import (
+    ADDED_KV_ATTENTION_PROCESSORS,
+    CROSS_ATTENTION_PROCESSORS,
+    Attention,
+    AttentionProcessor,
+    AttnAddedKVProcessor,
+    AttnProcessor,
+    XFormersAttnProcessor,
+)
+from diffusers.models.embeddings import (
+    GaussianFourierProjection,
+    ImageHintTimeEmbedding,
+    ImageProjection,
+    ImageTimeEmbedding,
+    PositionNet,
+    TextImageProjection,
+    TextImageTimeEmbedding,
+    TextTimeEmbedding,
+    TimestepEmbedding,
+    Timesteps,
+)
+from diffusers.models.modeling_utils import ModelMixin
+
+from foleycrafter.models.auffusion.unet_2d_blocks import (
+    UNetMidBlock2D,
+    UNetMidBlock2DCrossAttn,
+    UNetMidBlock2DSimpleCrossAttn,
+    get_down_block,
+    get_up_block,
+)
+
+from foleycrafter.models.auffusion.attention_processor\
+    import AttnProcessor2_0
+from foleycrafter.models.adapters.ip_adapter import TimeProjModel
+from foleycrafter.models.auffusion.loaders.unet import UNet2DConditionLoadersMixin
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+@dataclass
+class UNet2DConditionOutput(BaseOutput):
+    """
+    The output of [`UNet2DConditionModel`].
+
+    Args:
+        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+            The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+    """
+
+    sample: torch.FloatTensor = None
+
+
+class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+    r"""
+    A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
+    shaped output.
+
+    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+    for all models (such as downloading or saving).
+
+    Parameters:
+        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+            Height and width of input/output sample.
+        in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
+        out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+            Whether to flip the sin to cos in the time embedding.
+        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+        down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+            The tuple of downsample blocks to use.
+        mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+            Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
+            `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+            The tuple of upsample blocks to use.
+        only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
+            Whether to include self-attention in the basic transformer blocks, see
+            [`~models.attention.BasicTransformerBlock`].
+        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+            The tuple of output channels for each block.
+        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+            If `None`, normalization and activation layers is skipped in post-processing.
+        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+            The dimension of the cross attention features.
+        transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
+            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+       reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
+            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
+            blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
+            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+        encoder_hid_dim (`int`, *optional*, defaults to None):
+            If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+            dimension to `cross_attention_dim`.
+        encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+            If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+            embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+        num_attention_heads (`int`, *optional*):
+            The number of attention heads. If not defined, defaults to `attention_head_dim`
+        resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+            for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
+        class_embed_type (`str`, *optional*, defaults to `None`):
+            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
+            `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+        addition_embed_type (`str`, *optional*, defaults to `None`):
+            Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+            "text". "text" will use the `TextTimeEmbedding` layer.
+        addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
+            Dimension for the timestep embeddings.
+        num_class_embeds (`int`, *optional*, defaults to `None`):
+            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+            class conditioning with `class_embed_type` equal to `None`.
+        time_embedding_type (`str`, *optional*, defaults to `positional`):
+            The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
+        time_embedding_dim (`int`, *optional*, defaults to `None`):
+            An optional override for the dimension of the projected time embedding.
+        time_embedding_act_fn (`str`, *optional*, defaults to `None`):
+            Optional activation function to use only once on the time embeddings before they are passed to the rest of
+            the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
+        timestep_post_act (`str`, *optional*, defaults to `None`):
+            The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
+        time_cond_proj_dim (`int`, *optional*, defaults to `None`):
+            The dimension of `cond_proj` layer in the timestep embedding.
+        conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
+        *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
+        *optional*): The dimension of the `class_labels` input when
+            `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
+        class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
+            embeddings with the class embeddings.
+        mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
+            Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
+            `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
+            `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
+            otherwise.
+    """
+
+    _supports_gradient_checkpointing = True
+
+    @register_to_config
+    def __init__(
+        self,
+        sample_size: Optional[int] = None,
+        in_channels: int = 4,
+        out_channels: int = 4,
+        center_input_sample: bool = False,
+        flip_sin_to_cos: bool = True,
+        freq_shift: int = 0,
+        down_block_types: Tuple[str] = (
+            "CrossAttnDownBlock2D",
+            "CrossAttnDownBlock2D",
+            "CrossAttnDownBlock2D",
+            "DownBlock2D",
+        ),
+        mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+        up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+        only_cross_attention: Union[bool, Tuple[bool]] = False,
+        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+        layers_per_block: Union[int, Tuple[int]] = 2,
+        downsample_padding: int = 1,
+        mid_block_scale_factor: float = 1,
+        dropout: float = 0.0,
+        act_fn: str = "silu",
+        norm_num_groups: Optional[int] = 32,
+        norm_eps: float = 1e-5,
+        cross_attention_dim: Union[int, Tuple[int]] = 1280,
+        transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+        reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
+        encoder_hid_dim: Optional[int] = None,
+        encoder_hid_dim_type: Optional[str] = None,
+        attention_head_dim: Union[int, Tuple[int]] = 8,
+        num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+        dual_cross_attention: bool = False,
+        use_linear_projection: bool = False,
+        class_embed_type: Optional[str] = None,
+        addition_embed_type: Optional[str] = None,
+        addition_time_embed_dim: Optional[int] = None,
+        num_class_embeds: Optional[int] = None,
+        upcast_attention: bool = False,
+        resnet_time_scale_shift: str = "default",
+        resnet_skip_time_act: bool = False,
+        resnet_out_scale_factor: int = 1.0,
+        time_embedding_type: str = "positional",
+        time_embedding_dim: Optional[int] = None,
+        time_embedding_act_fn: Optional[str] = None,
+        timestep_post_act: Optional[str] = None,
+        time_cond_proj_dim: Optional[int] = None,
+        conv_in_kernel: int = 3,
+        conv_out_kernel: int = 3,
+        projection_class_embeddings_input_dim: Optional[int] = None,
+        attention_type: str = "default",
+        class_embeddings_concat: bool = False,
+        mid_block_only_cross_attention: Optional[bool] = None,
+        cross_attention_norm: Optional[str] = None,
+        addition_embed_type_num_heads=64,
+
+        # param for joint
+        video_feature_dim: tuple=(320, 640, 1280, 1280),
+        video_cross_attn_dim: int=1024,
+        video_frame_nums: int=16,
+    ):
+        super().__init__()
+
+        self.sample_size = sample_size
+
+        if num_attention_heads is not None:
+            raise ValueError(
+                "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+            )
+
+        # If `num_attention_heads` is not defined (which is the case for most models)
+        # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+        # The reason for this behavior is to correct for incorrectly named variables that were introduced
+        # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+        # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+        # which is why we correct for the naming here.
+        num_attention_heads = num_attention_heads or attention_head_dim
+
+        # Check inputs
+        if len(down_block_types) != len(up_block_types):
+            raise ValueError(
+                f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+            )
+
+        if len(block_out_channels) != len(down_block_types):
+            raise ValueError(
+                f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+            )
+
+        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+            raise ValueError(
+                f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+            )
+
+        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+            raise ValueError(
+                f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+            )
+
+        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
+            raise ValueError(
+                f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+            )
+
+        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+            raise ValueError(
+                f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+            )
+
+        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+            raise ValueError(
+                f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+            )
+        if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
+            for layer_number_per_block in transformer_layers_per_block:
+                if isinstance(layer_number_per_block, list):
+                    raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
+
+        # input
+        conv_in_padding = (conv_in_kernel - 1) // 2
+        self.conv_in = nn.Conv2d(
+            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+        )
+
+        # time
+        if time_embedding_type == "fourier":
+            time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+            if time_embed_dim % 2 != 0:
+                raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
+            self.time_proj = GaussianFourierProjection(
+                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+            )
+            timestep_input_dim = time_embed_dim
+        elif time_embedding_type == "positional":
+            time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
+
+            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+            timestep_input_dim = block_out_channels[0]
+        else:
+            raise ValueError(
+                f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+            )
+
+        self.time_embedding = TimestepEmbedding(
+            timestep_input_dim,
+            time_embed_dim,
+            act_fn=act_fn,
+            post_act_fn=timestep_post_act,
+            cond_proj_dim=time_cond_proj_dim,
+        )
+
+        if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+            encoder_hid_dim_type = "text_proj"
+            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+            logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+        if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+            raise ValueError(
+                f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+            )
+
+        if encoder_hid_dim_type == "text_proj":
+            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+        elif encoder_hid_dim_type == "text_image_proj":
+            # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+            # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
+            self.encoder_hid_proj = TextImageProjection(
+                text_embed_dim=encoder_hid_dim,
+                image_embed_dim=cross_attention_dim,
+                cross_attention_dim=cross_attention_dim,
+            )
+        elif encoder_hid_dim_type == "image_proj":
+            # Kandinsky 2.2
+            self.encoder_hid_proj = ImageProjection(
+                image_embed_dim=encoder_hid_dim,
+                cross_attention_dim=cross_attention_dim,
+            )
+        elif encoder_hid_dim_type is not None:
+            raise ValueError(
+                f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+            )
+        else:
+            self.encoder_hid_proj = None
+
+        # class embedding
+        if class_embed_type is None and num_class_embeds is not None:
+            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+        elif class_embed_type == "timestep":
+            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
+        elif class_embed_type == "identity":
+            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+        elif class_embed_type == "projection":
+            if projection_class_embeddings_input_dim is None:
+                raise ValueError(
+                    "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+                )
+            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+            # 2. it projects from an arbitrary input dimension.
+            #
+            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+        elif class_embed_type == "simple_projection":
+            if projection_class_embeddings_input_dim is None:
+                raise ValueError(
+                    "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
+                )
+            self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
+        else:
+            self.class_embedding = None
+
+        if addition_embed_type == "text":
+            if encoder_hid_dim is not None:
+                text_time_embedding_from_dim = encoder_hid_dim
+            else:
+                text_time_embedding_from_dim = cross_attention_dim
+
+            self.add_embedding = TextTimeEmbedding(
+                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+            )
+        elif addition_embed_type == "text_image":
+            # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+            # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
+            self.add_embedding = TextImageTimeEmbedding(
+                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+            )
+        elif addition_embed_type == "text_time":
+            self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+            self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+        elif addition_embed_type == "image":
+            # Kandinsky 2.2
+            self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+        elif addition_embed_type == "image_hint":
+            # Kandinsky 2.2 ControlNet
+            self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+        elif addition_embed_type is not None:
+            raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+        if time_embedding_act_fn is None:
+            self.time_embed_act = None
+        else:
+            self.time_embed_act = get_activation(time_embedding_act_fn)
+
+        self.down_blocks = nn.ModuleList([])
+        self.up_blocks = nn.ModuleList([])
+
+        if isinstance(only_cross_attention, bool):
+            if mid_block_only_cross_attention is None:
+                mid_block_only_cross_attention = only_cross_attention
+
+            only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+        if mid_block_only_cross_attention is None:
+            mid_block_only_cross_attention = False
+
+        if isinstance(num_attention_heads, int):
+            num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+        if isinstance(attention_head_dim, int):
+            attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+        if isinstance(cross_attention_dim, int):
+            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+        if isinstance(layers_per_block, int):
+            layers_per_block = [layers_per_block] * len(down_block_types)
+
+        if isinstance(transformer_layers_per_block, int):
+            transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+        if class_embeddings_concat:
+            # The time embeddings are concatenated with the class embeddings. The dimension of the
+            # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
+            # regular time embeddings
+            blocks_time_embed_dim = time_embed_dim * 2
+        else:
+            blocks_time_embed_dim = time_embed_dim
+
+        # down
+        output_channel = block_out_channels[0]
+        for i, down_block_type in enumerate(down_block_types):
+            input_channel = output_channel
+            output_channel = block_out_channels[i]
+            is_final_block = i == len(block_out_channels) - 1
+
+            down_block = get_down_block(
+                down_block_type,
+                num_layers=layers_per_block[i],
+                transformer_layers_per_block=transformer_layers_per_block[i],
+                in_channels=input_channel,
+                out_channels=output_channel,
+                temb_channels=blocks_time_embed_dim,
+                add_downsample=not is_final_block,
+                resnet_eps=norm_eps,
+                resnet_act_fn=act_fn,
+                resnet_groups=norm_num_groups,
+                cross_attention_dim=cross_attention_dim[i],
+                num_attention_heads=num_attention_heads[i],
+                downsample_padding=downsample_padding,
+                dual_cross_attention=dual_cross_attention,
+                use_linear_projection=use_linear_projection,
+                only_cross_attention=only_cross_attention[i],
+                upcast_attention=upcast_attention,
+                resnet_time_scale_shift=resnet_time_scale_shift,
+                attention_type=attention_type,
+                resnet_skip_time_act=resnet_skip_time_act,
+                resnet_out_scale_factor=resnet_out_scale_factor,
+                cross_attention_norm=cross_attention_norm,
+                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+                dropout=dropout,
+            )
+            self.down_blocks.append(down_block)
+
+        # mid
+        if mid_block_type == "UNetMidBlock2DCrossAttn":
+            self.mid_block = UNetMidBlock2DCrossAttn(
+                transformer_layers_per_block=transformer_layers_per_block[-1],
+                in_channels=block_out_channels[-1],
+                temb_channels=blocks_time_embed_dim,
+                dropout=dropout,
+                resnet_eps=norm_eps,
+                resnet_act_fn=act_fn,
+                output_scale_factor=mid_block_scale_factor,
+                resnet_time_scale_shift=resnet_time_scale_shift,
+                cross_attention_dim=cross_attention_dim[-1],
+                num_attention_heads=num_attention_heads[-1],
+                resnet_groups=norm_num_groups,
+                dual_cross_attention=dual_cross_attention,
+                use_linear_projection=use_linear_projection,
+                upcast_attention=upcast_attention,
+                attention_type=attention_type,
+            )
+        elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
+            self.mid_block = UNetMidBlock2DSimpleCrossAttn(
+                in_channels=block_out_channels[-1],
+                temb_channels=blocks_time_embed_dim,
+                dropout=dropout,
+                resnet_eps=norm_eps,
+                resnet_act_fn=act_fn,
+                output_scale_factor=mid_block_scale_factor,
+                cross_attention_dim=cross_attention_dim[-1],
+                attention_head_dim=attention_head_dim[-1],
+                resnet_groups=norm_num_groups,
+                resnet_time_scale_shift=resnet_time_scale_shift,
+                skip_time_act=resnet_skip_time_act,
+                only_cross_attention=mid_block_only_cross_attention,
+                cross_attention_norm=cross_attention_norm,
+            )
+        elif mid_block_type == "UNetMidBlock2D":
+            self.mid_block = UNetMidBlock2D(
+                in_channels=block_out_channels[-1],
+                temb_channels=blocks_time_embed_dim,
+                dropout=dropout,
+                num_layers=0,
+                resnet_eps=norm_eps,
+                resnet_act_fn=act_fn,
+                output_scale_factor=mid_block_scale_factor,
+                resnet_groups=norm_num_groups,
+                resnet_time_scale_shift=resnet_time_scale_shift,
+                add_attention=False,
+            )
+        elif mid_block_type is None:
+            self.mid_block = None
+        else:
+            raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+        # count how many layers upsample the images
+        self.num_upsamplers = 0
+
+        # up
+        reversed_block_out_channels = list(reversed(block_out_channels))
+        reversed_num_attention_heads = list(reversed(num_attention_heads))
+        reversed_layers_per_block = list(reversed(layers_per_block))
+        reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+        reversed_transformer_layers_per_block = (
+            list(reversed(transformer_layers_per_block))
+            if reverse_transformer_layers_per_block is None
+            else reverse_transformer_layers_per_block
+        )
+        only_cross_attention = list(reversed(only_cross_attention))
+
+        output_channel = reversed_block_out_channels[0]
+        for i, up_block_type in enumerate(up_block_types):
+            is_final_block = i == len(block_out_channels) - 1
+
+            prev_output_channel = output_channel
+            output_channel = reversed_block_out_channels[i]
+            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+            # add upsample block for all BUT final layer
+            if not is_final_block:
+                add_upsample = True
+                self.num_upsamplers += 1
+            else:
+                add_upsample = False
+
+            up_block = get_up_block(
+                up_block_type,
+                num_layers=reversed_layers_per_block[i] + 1,
+                transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+                in_channels=input_channel,
+                out_channels=output_channel,
+                prev_output_channel=prev_output_channel,
+                temb_channels=blocks_time_embed_dim,
+                add_upsample=add_upsample,
+                resnet_eps=norm_eps,
+                resnet_act_fn=act_fn,
+                resolution_idx=i,
+                resnet_groups=norm_num_groups,
+                cross_attention_dim=reversed_cross_attention_dim[i],
+                num_attention_heads=reversed_num_attention_heads[i],
+                dual_cross_attention=dual_cross_attention,
+                use_linear_projection=use_linear_projection,
+                only_cross_attention=only_cross_attention[i],
+                upcast_attention=upcast_attention,
+                resnet_time_scale_shift=resnet_time_scale_shift,
+                attention_type=attention_type,
+                resnet_skip_time_act=resnet_skip_time_act,
+                resnet_out_scale_factor=resnet_out_scale_factor,
+                cross_attention_norm=cross_attention_norm,
+                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+                dropout=dropout,
+            )
+            self.up_blocks.append(up_block)
+            prev_output_channel = output_channel
+
+        # out
+        if norm_num_groups is not None:
+            self.conv_norm_out = nn.GroupNorm(
+                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+            )
+
+            self.conv_act = get_activation(act_fn)
+
+        else:
+            self.conv_norm_out = None
+            self.conv_act = None
+
+        conv_out_padding = (conv_out_kernel - 1) // 2
+        self.conv_out = nn.Conv2d(
+            block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+        )
+
+        if attention_type in ["gated", "gated-text-image"]:
+            positive_len = 768
+            if isinstance(cross_attention_dim, int):
+                positive_len = cross_attention_dim
+            elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
+                positive_len = cross_attention_dim[0]
+
+            feature_type = "text-only" if attention_type == "gated" else "text-image"
+            self.position_net = TimeProjModel(
+                positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
+            ) 
+    
+        # additional settings
+        self.video_feature_dim    = video_feature_dim 
+        self.cross_attention_dim  = cross_attention_dim
+        self.video_cross_attn_dim = video_cross_attn_dim
+        self.video_frame_nums     = video_frame_nums
+
+        self.multi_frames_condition = False
+
+    def load_attention(self):
+        attn_dict = {}
+        for name in self.attn_processors.keys():
+            # if self-attention, save feature 
+            if name.endswith("attn1.processor"):
+                if is_xformers_available():
+                    attn_dict[name] = XFormersAttnProcessor()
+                else:
+                    attn_dict[name] = AttnProcessor() 
+            else:
+                attn_dict[name] = AttnProcessor2_0()
+        self.set_attn_processor(attn_dict)
+
+    def get_writer_feature(self):
+        return self.attn_feature_writer.get_cross_attention_feature()
+    
+    def clear_writer_feature(self):
+        self.attn_feature_writer.clear_cross_attention_feature()
+
+    def disable_feature_adapters(self):
+        raise NotImplementedError
+    
+    def set_reader_feature(self, features:list):
+        return self.attn_feature_reader.set_cross_attention_feature(features)
+
+    @property
+    def attn_processors(self) -> Dict[str, AttentionProcessor]:
+        r"""
+        Returns:
+            `dict` of attention processors: A dictionary containing all attention processors used in the model with
+            indexed by its weight name.
+        """
+        # set recursively
+        processors = {}
+
+        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+            if hasattr(module, "get_processor"):
+                processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+            for sub_name, child in module.named_children():
+                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+            return processors
+
+        for name, module in self.named_children():
+            fn_recursive_add_processors(name, module, processors)
+
+        return processors
+
+    def set_attn_processor(
+        self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
+    ):
+        r"""
+        Sets the attention processor to use to compute attention.
+
+        Parameters:
+            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+                The instantiated processor class or a dictionary of processor classes that will be set as the processor
+                for **all** `Attention` layers.
+
+                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+                processor. This is strongly recommended when setting trainable attention processors.
+
+        """
+        count = len(self.attn_processors.keys())
+
+        if isinstance(processor, dict) and len(processor) != count:
+            raise ValueError(
+                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+            )
+
+        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+            if hasattr(module, "set_processor"):
+                if not isinstance(processor, dict):
+                    module.set_processor(processor, _remove_lora=_remove_lora)
+                else:
+                    module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
+
+            for sub_name, child in module.named_children():
+                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+        for name, module in self.named_children():
+            fn_recursive_attn_processor(name, module, processor)
+
+    def set_default_attn_processor(self):
+        """
+        Disables custom attention processors and sets the default attention implementation.
+        """
+        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+            processor = AttnAddedKVProcessor()
+        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+            processor = AttnProcessor()
+        else:
+            raise ValueError(
+                f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+            )
+
+        self.set_attn_processor(processor, _remove_lora=True)
+
+    def set_attention_slice(self, slice_size):
+        r"""
+        Enable sliced attention computation.
+
+        When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+        several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+        Args:
+            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+                When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+                `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+                must be a multiple of `slice_size`.
+        """
+        sliceable_head_dims = []
+
+        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+            if hasattr(module, "set_attention_slice"):
+                sliceable_head_dims.append(module.sliceable_head_dim)
+
+            for child in module.children():
+                fn_recursive_retrieve_sliceable_dims(child)
+
+        # retrieve number of attention layers
+        for module in self.children():
+            fn_recursive_retrieve_sliceable_dims(module)
+
+        num_sliceable_layers = len(sliceable_head_dims)
+
+        if slice_size == "auto":
+            # half the attention head size is usually a good trade-off between
+            # speed and memory
+            slice_size = [dim // 2 for dim in sliceable_head_dims]
+        elif slice_size == "max":
+            # make smallest slice possible
+            slice_size = num_sliceable_layers * [1]
+
+        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+        if len(slice_size) != len(sliceable_head_dims):
+            raise ValueError(
+                f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+                f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+            )
+
+        for i in range(len(slice_size)):
+            size = slice_size[i]
+            dim = sliceable_head_dims[i]
+            if size is not None and size > dim:
+                raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+        # Recursively walk through all the children.
+        # Any children which exposes the set_attention_slice method
+        # gets the message
+        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+            if hasattr(module, "set_attention_slice"):
+                module.set_attention_slice(slice_size.pop())
+
+            for child in module.children():
+                fn_recursive_set_attention_slice(child, slice_size)
+
+        reversed_slice_size = list(reversed(slice_size))
+        for module in self.children():
+            fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+    def _set_gradient_checkpointing(self, module, value=False):
+        if hasattr(module, "gradient_checkpointing"):
+            module.gradient_checkpointing = value
+
+    def enable_freeu(self, s1, s2, b1, b2):
+        r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+
+        The suffixes after the scaling factors represent the stage blocks where they are being applied.
+
+        Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
+        are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+        Args:
+            s1 (`float`):
+                Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+                mitigate the "oversmoothing effect" in the enhanced denoising process.
+            s2 (`float`):
+                Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+                mitigate the "oversmoothing effect" in the enhanced denoising process.
+            b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+            b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+        """
+        for i, upsample_block in enumerate(self.up_blocks):
+            setattr(upsample_block, "s1", s1)
+            setattr(upsample_block, "s2", s2)
+            setattr(upsample_block, "b1", b1)
+            setattr(upsample_block, "b2", b2)
+
+    def disable_freeu(self):
+        """Disables the FreeU mechanism."""
+        freeu_keys = {"s1", "s2", "b1", "b2"}
+        for i, upsample_block in enumerate(self.up_blocks):
+            for k in freeu_keys:
+                if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
+                    setattr(upsample_block, k, None)
+
+    def fuse_qkv_projections(self):
+        """
+        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+        key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+        <Tip warning={true}>
+
+        This API is 🧪 experimental.
+
+        </Tip>
+        """
+        self.original_attn_processors = None
+
+        for _, attn_processor in self.attn_processors.items():
+            if "Added" in str(attn_processor.__class__.__name__):
+                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+        self.original_attn_processors = self.attn_processors
+
+        for module in self.modules():
+            if isinstance(module, Attention):
+                module.fuse_projections(fuse=True)
+
+    def unfuse_qkv_projections(self):
+        """Disables the fused QKV projection if enabled.
+
+        <Tip warning={true}>
+
+        This API is 🧪 experimental.
+
+        </Tip>
+
+        """
+        if self.original_attn_processors is not None:
+            self.set_attn_processor(self.original_attn_processors)
+
+    def forward(
+        self,
+        sample: torch.FloatTensor,
+        timestep: Union[torch.Tensor, float, int],
+        encoder_hidden_states: torch.Tensor,
+        class_labels: Optional[torch.Tensor] = None,
+        timestep_cond: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+        mid_block_additional_residual: Optional[torch.Tensor] = None,
+        down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+        encoder_attention_mask: Optional[torch.Tensor] = None,
+        return_dict: bool = True,
+    ) -> Union[UNet2DConditionOutput, Tuple]:
+        # import ipdb; ipdb.set_trace()
+        r"""
+        The [`UNet2DConditionModel`] forward method.
+
+        Args:
+            sample (`torch.FloatTensor`):
+                The noisy input tensor with the following shape `(batch, channel, height, width)`.
+            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+            encoder_hidden_states (`torch.FloatTensor`):
+                The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+            class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+                Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+            timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
+                Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
+                through the `self.time_embedding` layer to obtain the timestep embeddings.
+            attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+                An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+                is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+                negative values to the attention scores corresponding to "discard" tokens.
+            cross_attention_kwargs (`dict`, *optional*):
+                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+                `self.processor` in
+                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+            added_cond_kwargs: (`dict`, *optional*):
+                A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
+                are passed along to the UNet blocks.
+            down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
+                A tuple of tensors that if specified are added to the residuals of down unet blocks.
+            mid_block_additional_residual: (`torch.Tensor`, *optional*):
+                A tensor that if specified is added to the residual of the middle unet block.
+            encoder_attention_mask (`torch.Tensor`):
+                A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+                `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+                which adds large negative values to the attention scores corresponding to "discard" tokens.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+                tuple.
+            cross_attention_kwargs (`dict`, *optional*):
+                A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
+            added_cond_kwargs: (`dict`, *optional*):
+                A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
+                are passed along to the UNet blocks.
+            down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
+                additional residuals to be added to UNet long skip connections from down blocks to up blocks for
+                example from ControlNet side model(s)
+            mid_block_additional_residual (`torch.Tensor`, *optional*):
+                additional residual to be added to UNet mid block output, for example from ControlNet side model
+            down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
+                additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
+
+        Returns:
+            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+                If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
+                a `tuple` is returned where the first element is the sample tensor.
+        """
+        # By default samples have to be AT least a multiple of the overall upsampling factor.
+        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+        # However, the upsampling interpolation output size can be forced to fit any upsampling size
+        # on the fly if necessary.
+        default_overall_up_factor = 2**self.num_upsamplers
+
+        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+        forward_upsample_size = False
+        upsample_size = None
+
+        for dim in sample.shape[-2:]:
+            if dim % default_overall_up_factor != 0:
+                # Forward upsample size to force interpolation output size.
+                forward_upsample_size = True
+                break
+
+        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+        # expects mask of shape:
+        #   [batch, key_tokens]
+        # adds singleton query_tokens dimension:
+        #   [batch,                    1, key_tokens]
+        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+        if attention_mask is not None:
+            # assume that mask is expressed as:
+            #   (1 = keep,      0 = discard)
+            # convert mask into a bias that can be added to attention scores:
+            #       (keep = +0,     discard = -10000.0)
+            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+            attention_mask = attention_mask.unsqueeze(1)
+
+        # convert encoder_attention_mask to a bias the same way we do for attention_mask
+        if encoder_attention_mask is not None:
+            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+        # 0. center input if necessary
+        if self.config.center_input_sample:
+            sample = 2 * sample - 1.0
+
+        # 1. time
+        timesteps = timestep
+        if not torch.is_tensor(timesteps):
+            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+            # This would be a good case for the `match` statement (Python 3.10+)
+            is_mps = sample.device.type == "mps"
+            if isinstance(timestep, float):
+                dtype = torch.float32 if is_mps else torch.float64
+            else:
+                dtype = torch.int32 if is_mps else torch.int64
+            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+        elif len(timesteps.shape) == 0:
+            timesteps = timesteps[None].to(sample.device)
+
+        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+        timesteps = timesteps.expand(sample.shape[0])
+
+        t_emb = self.time_proj(timesteps)
+
+        # `Timesteps` does not contain any weights and will always return f32 tensors
+        # but time_embedding might actually be running in fp16. so we need to cast here.
+        # there might be better ways to encapsulate this.
+        t_emb = t_emb.to(dtype=sample.dtype)
+
+        emb = self.time_embedding(t_emb, timestep_cond)
+        aug_emb = None
+
+        if self.class_embedding is not None:
+            if class_labels is None:
+                raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+            if self.config.class_embed_type == "timestep":
+                class_labels = self.time_proj(class_labels)
+
+                # `Timesteps` does not contain any weights and will always return f32 tensors
+                # there might be better ways to encapsulate this.
+                class_labels = class_labels.to(dtype=sample.dtype)
+
+            class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
+
+            if self.config.class_embeddings_concat:
+                emb = torch.cat([emb, class_emb], dim=-1)
+            else:
+                emb = emb + class_emb
+
+        if self.config.addition_embed_type == "text":
+            aug_emb = self.add_embedding(encoder_hidden_states)
+        elif self.config.addition_embed_type == "text_image":
+            # Kandinsky 2.1 - style
+            if "image_embeds" not in added_cond_kwargs:
+                raise ValueError(
+                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+                )
+
+            image_embs = added_cond_kwargs.get("image_embeds")
+            text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+            aug_emb = self.add_embedding(text_embs, image_embs)
+        elif self.config.addition_embed_type == "text_time":
+            # SDXL - style
+            if "text_embeds" not in added_cond_kwargs:
+                raise ValueError(
+                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+                )
+            text_embeds = added_cond_kwargs.get("text_embeds")
+            if "time_ids" not in added_cond_kwargs:
+                raise ValueError(
+                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+                )
+            time_ids = added_cond_kwargs.get("time_ids")
+            time_embeds = self.add_time_proj(time_ids.flatten())
+            time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+            add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+            add_embeds = add_embeds.to(emb.dtype)
+            aug_emb = self.add_embedding(add_embeds)
+        elif self.config.addition_embed_type == "image":
+            # Kandinsky 2.2 - style
+            if "image_embeds" not in added_cond_kwargs:
+                raise ValueError(
+                    f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+                )
+            image_embs = added_cond_kwargs.get("image_embeds")
+            aug_emb = self.add_embedding(image_embs)
+        elif self.config.addition_embed_type == "image_hint":
+            # Kandinsky 2.2 - style
+            if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
+                raise ValueError(
+                    f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+                )
+            image_embs = added_cond_kwargs.get("image_embeds")
+            hint = added_cond_kwargs.get("hint")
+            aug_emb, hint = self.add_embedding(image_embs, hint)
+            sample = torch.cat([sample, hint], dim=1)
+
+        emb = emb + aug_emb if aug_emb is not None else emb
+
+        if self.time_embed_act is not None:
+            emb = self.time_embed_act(emb)
+
+        if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+            # Kadinsky 2.1 - style
+            if "image_embeds" not in added_cond_kwargs:
+                raise ValueError(
+                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
+                )
+
+            image_embeds = added_cond_kwargs.get("image_embeds")
+            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+            # Kandinsky 2.2 - style
+            if "image_embeds" not in added_cond_kwargs:
+                raise ValueError(
+                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
+                )
+            image_embeds = added_cond_kwargs.get("image_embeds")
+            encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
+            if "image_embeds" not in added_cond_kwargs:
+                raise ValueError(
+                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
+                )
+            image_embeds = added_cond_kwargs.get("image_embeds")
+            image_embeds = self.encoder_hid_proj(image_embeds)
+            if isinstance(image_embeds, list):
+                image_embeds = [image_embed.to(encoder_hidden_states.dtype) for image_embed in image_embeds]
+            else:
+                image_embeds = image_embeds.to(encoder_hidden_states.dtype)
+            encoder_hidden_states = (encoder_hidden_states, image_embeds) 
+            # encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
+        # import ipdb; ipdb.set_trace()
+        # 2. pre-process
+        sample = self.conv_in(sample)
+
+        # 2.5 GLIGEN position net
+        if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
+            cross_attention_kwargs = cross_attention_kwargs.copy()
+            gligen_args = cross_attention_kwargs.pop("gligen")
+            cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
+
+        # 3. down
+        lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+        if USE_PEFT_BACKEND:
+            # weight the lora layers by setting `lora_scale` for each PEFT layer
+            scale_lora_layers(self, lora_scale)
+
+        is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+        # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
+        is_adapter = down_intrablock_additional_residuals is not None
+        # maintain backward compatibility for legacy usage, where
+        #       T2I-Adapter and ControlNet both use down_block_additional_residuals arg
+        #       but can only use one or the other
+        if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
+            deprecate(
+                "T2I should not use down_block_additional_residuals",
+                "1.3.0",
+                "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
+                       and will be removed in diffusers 1.3.0.  `down_block_additional_residuals` should only be used \
+                       for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
+                standard_warn=False,
+            )
+            down_intrablock_additional_residuals = down_block_additional_residuals
+            is_adapter = True
+        # import ipdb; ipdb.set_trace()
+        down_block_res_samples = (sample,)
+        for downsample_block in self.down_blocks:
+            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+                # For t2i-adapter CrossAttnDownBlock2D
+                additional_residuals = {}
+                if is_adapter and len(down_intrablock_additional_residuals) > 0:
+                    additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
+
+                sample, res_samples = downsample_block(
+                    hidden_states=sample,
+                    temb=emb,
+                    encoder_hidden_states=encoder_hidden_states,
+                    attention_mask=attention_mask,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    encoder_attention_mask=encoder_attention_mask,
+                    **additional_residuals,
+                )
+                # import ipdb; ipdb.set_trace()
+            else:
+                sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
+                if is_adapter and len(down_intrablock_additional_residuals) > 0:
+                    sample += down_intrablock_additional_residuals.pop(0)
+
+            down_block_res_samples += res_samples
+
+        if is_controlnet:
+            new_down_block_res_samples = ()
+
+            for down_block_res_sample, down_block_additional_residual in zip(
+                down_block_res_samples, down_block_additional_residuals
+            ):
+                down_block_res_sample = down_block_res_sample + down_block_additional_residual
+                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+            down_block_res_samples = new_down_block_res_samples
+        # 4. mid
+        if self.mid_block is not None:
+            if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+                sample = self.mid_block(
+                    sample,
+                    emb,
+                    encoder_hidden_states=encoder_hidden_states,
+                    attention_mask=attention_mask,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    encoder_attention_mask=encoder_attention_mask,
+                )
+            else:
+                sample = self.mid_block(sample, emb)
+
+            # To support T2I-Adapter-XL
+            if (
+                is_adapter
+                and len(down_intrablock_additional_residuals) > 0
+                and sample.shape == down_intrablock_additional_residuals[0].shape
+            ):
+                sample += down_intrablock_additional_residuals.pop(0)
+
+        if is_controlnet:
+            sample = sample + mid_block_additional_residual
+        # import ipdb; ipdb.set_trace()
+        # 5. up
+        for i, upsample_block in enumerate(self.up_blocks):
+            is_final_block = i == len(self.up_blocks) - 1
+
+            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+            # if we have not reached the final block and need to forward the
+            # upsample size, we do it here
+            if not is_final_block and forward_upsample_size:
+                upsample_size = down_block_res_samples[-1].shape[2:]
+
+            if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+                sample = upsample_block(
+                    hidden_states=sample,
+                    temb=emb,
+                    res_hidden_states_tuple=res_samples,
+                    encoder_hidden_states=encoder_hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    upsample_size=upsample_size,
+                    attention_mask=attention_mask,
+                    encoder_attention_mask=encoder_attention_mask,
+                )
+            else:
+                sample = upsample_block(
+                    hidden_states=sample,
+                    temb=emb,
+                    res_hidden_states_tuple=res_samples,
+                    upsample_size=upsample_size,
+                    scale=lora_scale,
+                )
+        # import ipdb; ipdb.set_trace()
+        # 6. post-process
+        if self.conv_norm_out:
+            sample = self.conv_norm_out(sample)
+            sample = self.conv_act(sample)
+        sample = self.conv_out(sample)
+
+        if USE_PEFT_BACKEND:
+            # remove `lora_scale` from each PEFT layer
+            unscale_lora_layers(self, lora_scale)
+
+        if not return_dict:
+            return (sample,)
+        # import ipdb; ipdb.set_trace()
+        return UNet2DConditionOutput(sample=sample)
\ No newline at end of file
diff --git a/foleycrafter/models/specvqgan/data/greatesthit.py b/foleycrafter/models/specvqgan/data/greatesthit.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c4ac159e0d21de91d0752557b4b03a905855dba
--- /dev/null
+++ b/foleycrafter/models/specvqgan/data/greatesthit.py
@@ -0,0 +1,993 @@
+from matplotlib import collections
+import json
+import os
+import copy
+import matplotlib.pyplot as plt
+import torch
+from torchvision import transforms
+import numpy as np
+from tqdm import tqdm
+from random import sample
+import torchaudio
+import logging
+import collections
+from glob import glob
+import sys
+import albumentations
+import soundfile
+
+sys.path.insert(0, '.')  # nopep8
+from train import instantiate_from_config
+from foleycrafter.models.specvqgan.data.transforms import *
+
+torchaudio.set_audio_backend("sox_io")
+logger = logging.getLogger(f'main.{__name__}')
+
+SR = 22050
+FPS = 15
+MAX_SAMPLE_ITER = 10
+
+def non_negative(x): return int(np.round(max(0, x), 0))
+
+def rms(x): return np.sqrt(np.mean(x**2))
+
+def get_GH_data_identifier(video_name, start_idx, split='_'):
+    if isinstance(start_idx, str):
+        return video_name + split + start_idx
+    elif isinstance(start_idx, int):
+        return video_name + split + str(start_idx)
+    else:
+        raise NotImplementedError
+
+
+class Crop(object):
+
+    def __init__(self, cropped_shape=None, random_crop=False):
+        self.cropped_shape = cropped_shape
+        if cropped_shape is not None:
+            mel_num, spec_len = cropped_shape
+            if random_crop:
+                self.cropper = albumentations.RandomCrop
+            else:
+                self.cropper = albumentations.CenterCrop
+            self.preprocessor = albumentations.Compose([self.cropper(mel_num, spec_len)])
+        else:
+            self.preprocessor = lambda **kwargs: kwargs
+
+    def __call__(self, item):
+        item['image'] = self.preprocessor(image=item['image'])['image']
+        if 'cond_image' in item.keys():
+            item['cond_image'] = self.preprocessor(image=item['cond_image'])['image']
+        return item
+
+class CropImage(Crop):
+    def __init__(self, *crop_args):
+        super().__init__(*crop_args)
+
+class CropFeats(Crop):
+    def __init__(self, *crop_args):
+        super().__init__(*crop_args)
+
+    def __call__(self, item):
+        item['feature'] = self.preprocessor(image=item['feature'])['image']
+        return item
+
+class CropCoords(Crop):
+    def __init__(self, *crop_args):
+        super().__init__(*crop_args)
+
+    def __call__(self, item):
+        item['coord'] = self.preprocessor(image=item['coord'])['image']
+        return item
+
+class ResampleFrames(object):
+    def __init__(self, feat_sample_size, times_to_repeat_after_resample=None):
+        self.feat_sample_size = feat_sample_size
+        self.times_to_repeat_after_resample = times_to_repeat_after_resample
+
+    def __call__(self, item):
+        feat_len = item['feature'].shape[0]
+
+        ## resample
+        assert feat_len >= self.feat_sample_size
+        # evenly spaced points (abcdefghkl -> aoooofoooo)
+        idx = np.linspace(0, feat_len, self.feat_sample_size, dtype=np.int, endpoint=False)
+        # xoooo xoooo -> ooxoo ooxoo
+        shift = feat_len // (self.feat_sample_size + 1)
+        idx = idx + shift
+
+        ## repeat after resampling (abc -> aaaabbbbcccc)
+        if self.times_to_repeat_after_resample is not None and self.times_to_repeat_after_resample > 1:
+            idx = np.repeat(idx, self.times_to_repeat_after_resample)
+
+        item['feature'] = item['feature'][idx, :]
+        return item
+
+
+class GreatestHitSpecs(torch.utils.data.Dataset):
+
+    def __init__(self, split, spec_dir_path, spec_len, random_crop, mel_num, 
+                spec_crop_len, L=2.0, rand_shift=False, spec_transforms=None, splits_path='./data', 
+                meta_path='./data/info_r2plus1d_dim1024_15fps.json'):
+        super().__init__()
+        self.split = split
+        self.specs_dir = spec_dir_path
+        self.spec_transforms = spec_transforms
+        self.splits_path = splits_path
+        self.meta_path = meta_path
+        self.spec_len = spec_len
+        self.rand_shift = rand_shift
+        self.L = L
+        self.spec_take_first = int(math.ceil(860 * (L / 10.) / 32) * 32)
+        self.spec_take_first = 860 if self.spec_take_first > 860 else self.spec_take_first
+
+        greatesthit_meta = json.load(open(self.meta_path, 'r'))
+        unique_classes = sorted(list(set(ht for ht in greatesthit_meta['hit_type'])))
+        self.label2target = {label: target for target, label in enumerate(unique_classes)}
+        self.target2label = {target: label for label, target in self.label2target.items()}
+        self.video_idx2label = {
+            get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]): 
+            greatesthit_meta['hit_type'][i] for i in range(len(greatesthit_meta['video_name']))
+        }
+        self.available_video_hit = list(self.video_idx2label.keys())
+        self.video_idx2path = {
+            vh: os.path.join(self.specs_dir, 
+                vh.replace('_', '_denoised_') + '_' + self.video_idx2label[vh].replace(' ', '_') +'_mel.npy')
+            for vh in self.available_video_hit
+        }
+        self.video_idx2idx = {
+            get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]):
+            i for i in range(len(greatesthit_meta['video_name']))
+        }
+
+        split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json')
+        if not os.path.exists(split_clip_ids_path):
+            raise NotImplementedError()
+        clip_video_hit = json.load(open(split_clip_ids_path, 'r'))
+        self.dataset = clip_video_hit
+        spec_crop_len = self.spec_take_first if self.spec_take_first <= spec_crop_len else spec_crop_len
+        self.spec_transforms = transforms.Compose([
+            CropImage([mel_num, spec_crop_len], random_crop),
+            # transforms.RandomApply([FrequencyMasking(freq_mask_param=20)], p=0),
+            # transforms.RandomApply([TimeMasking(time_mask_param=int(32 * self.L))], p=0)
+        ])
+
+        self.video2indexes = {}
+        for video_idx in self.dataset:
+            video, start_idx = video_idx.split('_')
+            if video not in self.video2indexes.keys():
+                self.video2indexes[video] = []
+            self.video2indexes[video].append(start_idx)
+        for video in self.video2indexes.keys():
+            if len(self.video2indexes[video]) == 1: # given video contains only one hit
+                self.dataset.remove(
+                    get_GH_data_identifier(video, self.video2indexes[video][0])
+                )
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def __getitem__(self, idx):
+        item = {}
+
+        video_idx = self.dataset[idx]
+        spec_path = self.video_idx2path[video_idx]
+        spec = np.load(spec_path) # (80, 860)
+
+        if self.rand_shift:
+            shift = random.uniform(0, 0.5)
+            spec_shift = int(shift * spec.shape[1] // 10)
+            # Since only the first second is used
+            spec = np.roll(spec, -spec_shift, 1)
+
+        # concat spec outside dataload
+        item['image'] = 2 * spec - 1 # (80, 860)
+        item['image'] = item['image'][:, :self.spec_take_first]
+        item['file_path'] = spec_path
+
+        item['label'] = self.video_idx2label[video_idx]
+        item['target'] = self.label2target[item['label']]
+
+        if self.spec_transforms is not None:
+            item = self.spec_transforms(item)
+
+        return item
+
+
+class GreatestHitSpecsTrain(GreatestHitSpecs):
+    def __init__(self, specs_dataset_cfg):
+        super().__init__('train', **specs_dataset_cfg)
+
+class GreatestHitSpecsValidation(GreatestHitSpecs):
+    def __init__(self, specs_dataset_cfg):
+        super().__init__('val', **specs_dataset_cfg)
+
+class GreatestHitSpecsTest(GreatestHitSpecs):
+    def __init__(self, specs_dataset_cfg):
+        super().__init__('test', **specs_dataset_cfg)
+
+
+
+class GreatestHitWave(torch.utils.data.Dataset):
+
+    def __init__(self, split, wav_dir, random_crop, mel_num, spec_crop_len, spec_len,
+                L=2.0, splits_path='./data', rand_shift=True,
+                data_path='data/greatesthit/greatesthit-process-resized'):
+        super().__init__()
+        self.split = split
+        self.wav_dir = wav_dir
+        self.splits_path = splits_path
+        self.data_path = data_path
+        self.L = L
+        self.rand_shift = rand_shift
+
+        split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json')
+        if not os.path.exists(split_clip_ids_path):
+            raise NotImplementedError()
+        clip_video_hit = json.load(open(split_clip_ids_path, 'r'))
+
+        video_name = list(set([vidx.split('_')[0] for vidx in clip_video_hit]))
+
+        self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) // 2 for v in video_name}
+        self.left_over = int(FPS * L + 1)
+        self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_denoised_resampled.wav') for v in video_name}
+        self.dataset = clip_video_hit
+
+        self.video2indexes = {}
+        for video_idx in self.dataset:
+            video, start_idx = video_idx.split('_')
+            if video not in self.video2indexes.keys():
+                self.video2indexes[video] = []
+            self.video2indexes[video].append(start_idx)
+        for video in self.video2indexes.keys():
+            if len(self.video2indexes[video]) == 1: # given video contains only one hit
+                self.dataset.remove(
+                    get_GH_data_identifier(video, self.video2indexes[video][0])
+                )
+
+        self.wav_transforms = transforms.Compose([
+            MakeMono(),
+            Padding(target_len=int(SR * self.L)),
+        ])
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def __getitem__(self, idx):
+        item = {}
+        video_idx = self.dataset[idx]
+        video, start_idx = video_idx.split('_')
+        start_idx = int(start_idx)
+        if self.rand_shift:
+            shift = int(random.uniform(-0.5, 0.5) * SR)
+            start_idx = non_negative(start_idx + shift)
+
+        wave_path = self.video_audio_path[video]
+        wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_idx)
+        assert sr == SR
+        wav = self.wav_transforms(wav)
+
+        item['image'] = wav # (44100,)
+        # item['wav'] = wav
+        item['file_path_wav_'] = wave_path
+
+        item['label'] = 'None'
+        item['target'] = 'None'
+
+        return item
+
+
+class GreatestHitWaveTrain(GreatestHitWave):
+    def __init__(self, specs_dataset_cfg):
+        super().__init__('train', **specs_dataset_cfg)
+
+class GreatestHitWaveValidation(GreatestHitWave):
+    def __init__(self, specs_dataset_cfg):
+        super().__init__('val', **specs_dataset_cfg)
+
+class GreatestHitWaveTest(GreatestHitWave):
+    def __init__(self, specs_dataset_cfg):
+        super().__init__('test', **specs_dataset_cfg)
+
+
+class CondGreatestHitSpecsCondOnImage(torch.utils.data.Dataset):
+
+    def __init__(self, split, specs_dir, spec_len, feat_len, feat_depth, feat_crop_len, random_crop, mel_num, spec_crop_len,
+                vqgan_L=10.0, L=1.0, rand_shift=False, spec_transforms=None, frame_transforms=None, splits_path='./data', 
+                meta_path='./data/info_r2plus1d_dim1024_15fps.json', frame_path='data/greatesthit/greatesthit_processed',
+                p_outside_cond=0., p_audio_aug=0.5):
+        super().__init__()
+        self.split = split
+        self.specs_dir = specs_dir
+        self.spec_transforms = spec_transforms
+        self.frame_transforms = frame_transforms
+        self.splits_path = splits_path
+        self.meta_path = meta_path
+        self.frame_path = frame_path
+        self.feat_len = feat_len
+        self.feat_depth = feat_depth
+        self.feat_crop_len = feat_crop_len
+        self.spec_len = spec_len
+        self.rand_shift = rand_shift
+        self.L = L
+        self.spec_take_first = int(math.ceil(860 * (vqgan_L / 10.) / 32) * 32)
+        self.spec_take_first = 860 if self.spec_take_first > 860 else self.spec_take_first
+        self.p_outside_cond = torch.tensor(p_outside_cond)
+
+        greatesthit_meta = json.load(open(self.meta_path, 'r'))
+        unique_classes = sorted(list(set(ht for ht in greatesthit_meta['hit_type'])))
+        self.label2target = {label: target for target, label in enumerate(unique_classes)}
+        self.target2label = {target: label for label, target in self.label2target.items()}
+        self.video_idx2label = {
+            get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]): 
+            greatesthit_meta['hit_type'][i] for i in range(len(greatesthit_meta['video_name']))
+        }
+        self.available_video_hit = list(self.video_idx2label.keys())
+        self.video_idx2path = {
+            vh: os.path.join(self.specs_dir, 
+                vh.replace('_', '_denoised_') + '_' + self.video_idx2label[vh].replace(' ', '_') +'_mel.npy')
+            for vh in self.available_video_hit
+        }
+        for value in self.video_idx2path.values():
+            assert os.path.exists(value)
+        self.video_idx2idx = {
+            get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]):
+            i for i in range(len(greatesthit_meta['video_name']))
+        }
+
+        split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json')
+        if not os.path.exists(split_clip_ids_path):
+            self.make_split_files()
+        clip_video_hit = json.load(open(split_clip_ids_path, 'r'))
+        self.dataset = clip_video_hit
+        spec_crop_len = self.spec_take_first if self.spec_take_first <= spec_crop_len else spec_crop_len
+        self.spec_transforms = transforms.Compose([
+            CropImage([mel_num, spec_crop_len], random_crop),
+            # transforms.RandomApply([FrequencyMasking(freq_mask_param=20)], p=p_audio_aug),
+            # transforms.RandomApply([TimeMasking(time_mask_param=int(32 * self.L))], p=p_audio_aug)
+        ])
+        if self.frame_transforms == None:
+            self.frame_transforms = transforms.Compose([
+                Resize3D(128),
+                RandomResizedCrop3D(112, scale=(0.5, 1.0)),
+                RandomHorizontalFlip3D(),
+                ColorJitter3D(brightness=0.1, saturation=0.1),
+                ToTensor3D(),
+                Normalize3D(mean=[0.485, 0.456, 0.406],
+                            std=[0.229, 0.224, 0.225]),
+            ])
+
+        self.video2indexes = {}
+        for video_idx in self.dataset:
+            video, start_idx = video_idx.split('_')
+            if video not in self.video2indexes.keys():
+                self.video2indexes[video] = []
+            self.video2indexes[video].append(start_idx)
+        for video in self.video2indexes.keys():
+            if len(self.video2indexes[video]) == 1: # given video contains only one hit
+                self.dataset.remove(
+                    get_GH_data_identifier(video, self.video2indexes[video][0])
+                )
+
+        clip_classes = [self.label2target[self.video_idx2label[vh]] for vh in clip_video_hit]
+        class2count = collections.Counter(clip_classes)
+        self.class_counts = torch.tensor([class2count[cls] for cls in range(len(class2count))])
+        if self.L != 1.0:
+            print(split, L)
+            self.validate_data()
+        self.video2indexes = {}
+        for video_idx in self.dataset:
+            video, start_idx = video_idx.split('_')
+            if video not in self.video2indexes.keys():
+                self.video2indexes[video] = []
+            self.video2indexes[video].append(start_idx)
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def __getitem__(self, idx):
+        item = {}
+
+        try:
+            video_idx = self.dataset[idx]
+            spec_path = self.video_idx2path[video_idx]
+            spec = np.load(spec_path) # (80, 860)
+
+            video, start_idx = video_idx.split('_')
+            frame_path = os.path.join(self.frame_path, video, 'frames')
+            start_frame_idx = non_negative(FPS * int(start_idx)/SR)
+            end_frame_idx = non_negative(start_frame_idx + FPS * self.L)
+
+            if self.rand_shift:
+                shift = random.uniform(0, 0.5)
+                spec_shift = int(shift * spec.shape[1] // 10)
+                # Since only the first second is used
+                spec = np.roll(spec, -spec_shift, 1)
+                start_frame_idx += int(FPS * shift)
+                end_frame_idx += int(FPS * shift)
+
+            frames = [Image.open(os.path.join(
+                frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in 
+                range(start_frame_idx, end_frame_idx)]
+
+            # Sample condition
+            if torch.all(torch.bernoulli(self.p_outside_cond) == 1.):
+                # Sample condition from outside video
+                all_idx = set(list(range(len(self.dataset))))
+                all_idx.remove(idx)
+                cond_video_idx = self.dataset[sample(all_idx, k=1)[0]]
+                cond_video, cond_start_idx = cond_video_idx.split('_')
+            else:
+                cond_video = video
+                video_hits_idx = copy.copy(self.video2indexes[video])
+                video_hits_idx.remove(start_idx)
+                cond_start_idx = sample(video_hits_idx, k=1)[0]
+                cond_video_idx = get_GH_data_identifier(cond_video, cond_start_idx)
+
+            cond_spec_path = self.video_idx2path[cond_video_idx]
+            cond_spec = np.load(cond_spec_path) # (80, 860)
+
+            cond_video, cond_start_idx = cond_video_idx.split('_')
+            cond_frame_path = os.path.join(self.frame_path, cond_video, 'frames')
+            cond_start_frame_idx = non_negative(FPS * int(cond_start_idx)/SR)
+            cond_end_frame_idx = non_negative(cond_start_frame_idx + FPS * self.L)
+
+            if self.rand_shift:
+                cond_shift = random.uniform(0, 0.5)
+                cond_spec_shift = int(cond_shift * cond_spec.shape[1] // 10)
+                # Since only the first second is used
+                cond_spec = np.roll(cond_spec, -cond_spec_shift, 1)
+                cond_start_frame_idx += int(FPS * cond_shift)
+                cond_end_frame_idx += int(FPS * cond_shift)
+
+            cond_frames = [Image.open(os.path.join(
+                cond_frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in 
+                range(cond_start_frame_idx, cond_end_frame_idx)]
+
+            # concat spec outside dataload
+            item['image'] = 2 * spec - 1 # (80, 860)
+            item['cond_image'] = 2 * cond_spec - 1 # (80, 860)
+            item['image'] = item['image'][:, :self.spec_take_first]
+            item['cond_image'] = item['cond_image'][:, :self.spec_take_first]
+            item['file_path_specs_'] = spec_path
+            item['file_path_cond_specs_'] = cond_spec_path
+
+            if self.frame_transforms is not None:
+                cond_frames = self.frame_transforms(cond_frames)
+                frames = self.frame_transforms(frames)
+
+            item['feature'] = np.stack(cond_frames + frames, axis=0) # (30 * L, 112, 112, 3)
+            item['file_path_feats_'] = (frame_path, start_frame_idx)
+            item['file_path_cond_feats_'] = (cond_frame_path, cond_start_frame_idx)
+
+            item['label'] = self.video_idx2label[video_idx]
+            item['target'] = self.label2target[item['label']]
+
+            if self.spec_transforms is not None:
+                item = self.spec_transforms(item)
+        except Exception:
+            print(sys.exc_info()[2])
+            print('!!!!!!!!!!!!!!!!!!!!', video_idx, cond_video_idx)
+            print('!!!!!!!!!!!!!!!!!!!!', end_frame_idx, cond_end_frame_idx)
+            exit(1)
+
+        return item
+
+
+    def validate_data(self):
+        original_len = len(self.dataset)
+        valid_dataset = []
+        for video_idx in tqdm(self.dataset):
+            video, start_idx = video_idx.split('_')
+            frame_path = os.path.join(self.frame_path, video, 'frames')
+            start_frame_idx = non_negative(FPS * int(start_idx)/SR)
+            end_frame_idx = non_negative(start_frame_idx + FPS * (self.L + 0.6))
+            if os.path.exists(os.path.join(frame_path, f'frame{end_frame_idx:0>6d}.jpg')):
+                valid_dataset.append(video_idx)
+            else:
+                self.video2indexes[video].remove(start_idx)
+        for video_idx in valid_dataset:
+            video, start_idx = video_idx.split('_')
+            if len(self.video2indexes[video]) == 1:
+                valid_dataset.remove(video_idx)
+        if original_len != len(valid_dataset):
+            print(f'Validated dataset with enough frames: {len(valid_dataset)}')
+        self.dataset = valid_dataset
+        split_clip_ids_path = os.path.join(self.splits_path, f'greatesthit_{self.split}_{self.L:.2f}.json')
+        if not os.path.exists(split_clip_ids_path):
+            with open(split_clip_ids_path, 'w') as f:
+                json.dump(valid_dataset, f)
+
+
+    def make_split_files(self, ratio=[0.85, 0.1, 0.05]):
+        random.seed(1337)
+        print(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.')
+        # The downloaded videos (some went missing on YouTube and no longer available)
+        available_mel_paths = set(glob(os.path.join(self.specs_dir, '*_mel.npy')))
+        self.available_video_hit = [vh for vh in self.available_video_hit if self.video_idx2path[vh] in available_mel_paths]
+
+        all_video = list(self.video2indexes.keys())
+
+        print(f'The number of clips available after download: {len(self.available_video_hit)}')
+        print(f'The number of videos available after download: {len(all_video)}')
+
+        available_idx = list(range(len(all_video)))
+        random.shuffle(available_idx)
+        assert sum(ratio) == 1.
+        cut_train = int(ratio[0] * len(all_video))
+        cut_test = cut_train + int(ratio[1] * len(all_video))
+
+        train_idx = available_idx[:cut_train]
+        test_idx = available_idx[cut_train:cut_test]
+        valid_idx = available_idx[cut_test:]
+
+        train_video = [all_video[i] for i in train_idx]
+        test_video = [all_video[i] for i in test_idx]
+        valid_video = [all_video[i] for i in valid_idx]
+
+        train_video_hit = []
+        for v in train_video:
+            train_video_hit += [get_GH_data_identifier(v, hit_idx) for hit_idx in self.video2indexes[v]]
+        test_video_hit = []
+        for v in test_video:
+            test_video_hit += [get_GH_data_identifier(v, hit_idx) for hit_idx in self.video2indexes[v]]
+        valid_video_hit = []
+        for v in valid_video:
+            valid_video_hit += [get_GH_data_identifier(v, hit_idx) for hit_idx in self.video2indexes[v]]
+
+        # mix train and valid for better validation loss
+        mixed = train_video_hit + valid_video_hit
+        random.shuffle(mixed)
+        split = int(len(mixed) * ratio[0] / (ratio[0] + ratio[2]))
+        train_video_hit = mixed[:split]
+        valid_video_hit = mixed[split:]
+
+        with open(os.path.join(self.splits_path, 'greatesthit_train.json'), 'w') as train_file,\
+             open(os.path.join(self.splits_path, 'greatesthit_test.json'), 'w') as test_file,\
+             open(os.path.join(self.splits_path, 'greatesthit_valid.json'), 'w') as valid_file:
+            json.dump(train_video_hit, train_file)
+            json.dump(test_video_hit, test_file)
+            json.dump(valid_video_hit, valid_file)
+
+        print(f'Put {len(train_idx)} clips to the train set and saved it to ./data/greatesthit_train.json')
+        print(f'Put {len(test_idx)} clips to the test set and saved it to ./data/greatesthit_test.json')
+        print(f'Put {len(valid_idx)} clips to the valid set and saved it to ./data/greatesthit_valid.json')
+
+
+class CondGreatestHitSpecsCondOnImageTrain(CondGreatestHitSpecsCondOnImage):
+    def __init__(self, dataset_cfg):
+        train_transforms = transforms.Compose([
+            Resize3D(256),
+            RandomResizedCrop3D(224, scale=(0.5, 1.0)),
+            RandomHorizontalFlip3D(),
+            ColorJitter3D(brightness=0.1, saturation=0.1),
+            ToTensor3D(),
+            Normalize3D(mean=[0.485, 0.456, 0.406],
+                        std=[0.229, 0.224, 0.225]),
+        ])
+        super().__init__('train', frame_transforms=train_transforms, **dataset_cfg)
+
+class CondGreatestHitSpecsCondOnImageValidation(CondGreatestHitSpecsCondOnImage):
+    def __init__(self, dataset_cfg):
+        valid_transforms = transforms.Compose([
+            Resize3D(256),
+            CenterCrop3D(224),
+            ToTensor3D(),
+            Normalize3D(mean=[0.485, 0.456, 0.406],
+                        std=[0.229, 0.224, 0.225]),
+        ])
+        super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg)
+
+class CondGreatestHitSpecsCondOnImageTest(CondGreatestHitSpecsCondOnImage):
+    def __init__(self, dataset_cfg):
+        test_transforms = transforms.Compose([
+            Resize3D(256),
+            CenterCrop3D(224),
+            ToTensor3D(),
+            Normalize3D(mean=[0.485, 0.456, 0.406],
+                        std=[0.229, 0.224, 0.225]),
+        ])
+        super().__init__('test', frame_transforms=test_transforms, **dataset_cfg)
+
+
+class CondGreatestHitWaveCondOnImage(torch.utils.data.Dataset):
+
+    def __init__(self, split, wav_dir, spec_len, random_crop, mel_num, spec_crop_len,
+                L=2.0, frame_transforms=None, splits_path='./data',
+                data_path='data/greatesthit/greatesthit-process-resized',
+                p_outside_cond=0., p_audio_aug=0.5, rand_shift=True):
+        super().__init__()
+        self.split = split
+        self.wav_dir = wav_dir
+        self.frame_transforms = frame_transforms
+        self.splits_path = splits_path
+        self.data_path = data_path
+        self.spec_len = spec_len
+        self.L = L
+        self.rand_shift = rand_shift
+        self.p_outside_cond = torch.tensor(p_outside_cond)
+
+        split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json')
+        if not os.path.exists(split_clip_ids_path):
+            raise NotImplementedError()
+        clip_video_hit = json.load(open(split_clip_ids_path, 'r'))
+
+        video_name = list(set([vidx.split('_')[0] for vidx in clip_video_hit]))
+
+        self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames')))//2 for v in video_name}
+        self.left_over = int(FPS * L + 1)
+        self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_denoised_resampled.wav') for v in video_name}
+        self.dataset = clip_video_hit
+
+        self.video2indexes = {}
+        for video_idx in self.dataset:
+            video, start_idx = video_idx.split('_')
+            if video not in self.video2indexes.keys():
+                self.video2indexes[video] = []
+            self.video2indexes[video].append(start_idx)
+        for video in self.video2indexes.keys():
+            if len(self.video2indexes[video]) == 1: # given video contains only one hit
+                self.dataset.remove(
+                    get_GH_data_identifier(video, self.video2indexes[video][0])
+                )
+
+        self.wav_transforms = transforms.Compose([
+            MakeMono(),
+            Padding(target_len=int(SR * self.L)),
+        ])
+        if self.frame_transforms == None:
+            self.frame_transforms = transforms.Compose([
+                Resize3D(256),
+                RandomResizedCrop3D(224, scale=(0.5, 1.0)),
+                RandomHorizontalFlip3D(),
+                ColorJitter3D(brightness=0.1, saturation=0.1),
+                ToTensor3D(),
+                Normalize3D(mean=[0.485, 0.456, 0.406],
+                            std=[0.229, 0.224, 0.225]),
+            ])
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def __getitem__(self, idx):
+        item = {}
+        video_idx = self.dataset[idx]
+        video, start_idx = video_idx.split('_')
+        start_idx = int(start_idx)
+        frame_path = os.path.join(self.data_path, video, 'frames')
+        start_frame_idx = non_negative(FPS * int(start_idx)/SR)
+        if self.rand_shift:
+            shift = random.uniform(-0.5, 0.5)
+            start_frame_idx = non_negative(start_frame_idx + int(FPS * shift))
+            start_idx = non_negative(start_idx + int(SR * shift))
+        if start_frame_idx > self.video_frame_cnt[video] - self.left_over:
+            start_frame_idx = self.video_frame_cnt[video] - self.left_over
+            start_idx = non_negative(SR * (start_frame_idx / FPS))
+
+        end_frame_idx = non_negative(start_frame_idx + FPS * self.L)
+
+        # target
+        wave_path = self.video_audio_path[video]
+        frames = [Image.open(os.path.join(
+            frame_path, f'frame{i+1:0>6d}')).convert('RGB') for i in
+            range(start_frame_idx, end_frame_idx)]
+        wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_idx)
+        assert sr == SR
+        wav = self.wav_transforms(wav)
+
+        # cond
+        if torch.all(torch.bernoulli(self.p_outside_cond) == 1.):
+            all_idx = set(list(range(len(self.dataset))))
+            all_idx.remove(idx)
+            cond_video_idx = self.dataset[sample(all_idx, k=1)[0]]
+            cond_video, cond_start_idx = cond_video_idx.split('_')
+        else:
+            cond_video = video
+            video_hits_idx = copy.copy(self.video2indexes[video])
+            if str(start_idx) in video_hits_idx:
+                video_hits_idx.remove(str(start_idx))
+            cond_start_idx = sample(video_hits_idx, k=1)[0]
+            cond_video_idx = get_GH_data_identifier(cond_video, cond_start_idx)
+
+        cond_video, cond_start_idx = cond_video_idx.split('_')
+        cond_start_idx = int(cond_start_idx)
+        cond_frame_path = os.path.join(self.data_path, cond_video, 'frames')
+        cond_start_frame_idx = non_negative(FPS * int(cond_start_idx)/SR)
+        cond_wave_path = self.video_audio_path[cond_video]
+
+        if self.rand_shift:
+            cond_shift = random.uniform(-0.5, 0.5)
+            cond_start_frame_idx = non_negative(cond_start_frame_idx + int(FPS * cond_shift))
+            cond_start_idx = non_negative(cond_start_idx + int(shift * SR))
+        if cond_start_frame_idx > self.video_frame_cnt[cond_video] - self.left_over:
+            cond_start_frame_idx = self.video_frame_cnt[cond_video] - self.left_over
+            cond_start_idx = non_negative(SR * (cond_start_frame_idx / FPS))
+        cond_end_frame_idx = non_negative(cond_start_frame_idx + FPS * self.L)
+
+        cond_frames = [Image.open(os.path.join(
+                cond_frame_path, f'frame{i+1:0>6d}')).convert('RGB') for i in 
+                range(cond_start_frame_idx, cond_end_frame_idx)]
+        cond_wav, _ = soundfile.read(cond_wave_path, frames=int(SR * self.L), start=cond_start_idx)
+        cond_wav = self.wav_transforms(cond_wav)
+
+        item['image'] = wav # (44100,)
+        item['cond_image'] = cond_wav # (44100,)
+        item['file_path_wav_'] = wave_path
+        item['file_path_cond_wav_'] = cond_wave_path
+
+        if self.frame_transforms is not None:
+            cond_frames = self.frame_transforms(cond_frames)
+            frames = self.frame_transforms(frames)
+
+        item['feature'] = np.stack(cond_frames + frames, axis=0) # (30 * L, 112, 112, 3)
+        item['file_path_feats_'] = (frame_path, start_idx)
+        item['file_path_cond_feats_'] = (cond_frame_path, cond_start_idx)
+
+        item['label'] = 'None'
+        item['target'] = 'None'
+
+        return item
+
+    def validate_data(self):
+        raise NotImplementedError()
+
+    def make_split_files(self, ratio=[0.85, 0.1, 0.05]):
+        random.seed(1337)
+        print(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.')
+
+        all_video = sorted(os.listdir(self.data_path))
+        print(f'The number of videos available after download: {len(all_video)}')
+
+        available_idx = list(range(len(all_video)))
+        random.shuffle(available_idx)
+        assert sum(ratio) == 1.
+        cut_train = int(ratio[0] * len(all_video))
+        cut_test = cut_train + int(ratio[1] * len(all_video))
+
+        train_idx = available_idx[:cut_train]
+        test_idx = available_idx[cut_train:cut_test]
+        valid_idx = available_idx[cut_test:]
+
+        train_video = [all_video[i] for i in train_idx]
+        test_video = [all_video[i] for i in test_idx]
+        valid_video = [all_video[i] for i in valid_idx]
+
+        with open(os.path.join(self.splits_path, 'greatesthit_video_train.json'), 'w') as train_file,\
+             open(os.path.join(self.splits_path, 'greatesthit_video_test.json'), 'w') as test_file,\
+             open(os.path.join(self.splits_path, 'greatesthit_video_valid.json'), 'w') as valid_file:
+            json.dump(train_video, train_file)
+            json.dump(test_video, test_file)
+            json.dump(valid_video, valid_file)
+
+        print(f'Put {len(train_idx)} videos to the train set and saved it to ./data/greatesthit_video_train.json')
+        print(f'Put {len(test_idx)} videos to the test set and saved it to ./data/greatesthit_video_test.json')
+        print(f'Put {len(valid_idx)} videos to the valid set and saved it to ./data/greatesthit_video_valid.json')
+
+
+class CondGreatestHitWaveCondOnImageTrain(CondGreatestHitWaveCondOnImage):
+    def __init__(self, dataset_cfg):
+        train_transforms = transforms.Compose([
+            Resize3D(128),
+            RandomResizedCrop3D(112, scale=(0.5, 1.0)),
+            RandomHorizontalFlip3D(),
+            ColorJitter3D(brightness=0.4, saturation=0.4, contrast=0.2, hue=0.1),
+            ToTensor3D(),
+            Normalize3D(mean=[0.485, 0.456, 0.406],
+                        std=[0.229, 0.224, 0.225]),
+        ])
+        super().__init__('train', frame_transforms=train_transforms, **dataset_cfg)
+
+class CondGreatestHitWaveCondOnImageValidation(CondGreatestHitWaveCondOnImage):
+    def __init__(self, dataset_cfg):
+        valid_transforms = transforms.Compose([
+            Resize3D(128),
+            CenterCrop3D(112),
+            ToTensor3D(),
+            Normalize3D(mean=[0.485, 0.456, 0.406],
+                        std=[0.229, 0.224, 0.225]),
+        ])
+        super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg)
+
+class CondGreatestHitWaveCondOnImageTest(CondGreatestHitWaveCondOnImage):
+    def __init__(self, dataset_cfg):
+        test_transforms = transforms.Compose([
+            Resize3D(128),
+            CenterCrop3D(112),
+            ToTensor3D(),
+            Normalize3D(mean=[0.485, 0.456, 0.406],
+                        std=[0.229, 0.224, 0.225]),
+        ])
+        super().__init__('test', frame_transforms=test_transforms, **dataset_cfg)
+
+
+
+class GreatestHitWaveCondOnImage(torch.utils.data.Dataset):
+
+    def __init__(self, split, wav_dir, spec_len, random_crop, mel_num, spec_crop_len,
+                L=2.0, frame_transforms=None, splits_path='./data',
+                data_path='data/greatesthit/greatesthit-process-resized',
+                p_outside_cond=0., p_audio_aug=0.5, rand_shift=True):
+        super().__init__()
+        self.split = split
+        self.wav_dir = wav_dir
+        self.frame_transforms = frame_transforms
+        self.splits_path = splits_path
+        self.data_path = data_path
+        self.spec_len = spec_len
+        self.L = L
+        self.rand_shift = rand_shift
+        self.p_outside_cond = torch.tensor(p_outside_cond)
+
+        split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json')
+        if not os.path.exists(split_clip_ids_path):
+            raise NotImplementedError()
+        clip_video_hit = json.load(open(split_clip_ids_path, 'r'))
+
+        video_name = list(set([vidx.split('_')[0] for vidx in clip_video_hit]))
+
+        self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames')))//2 for v in video_name}
+        self.left_over = int(FPS * L + 1)
+        self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_denoised_resampled.wav') for v in video_name}
+        self.dataset = clip_video_hit
+
+        self.video2indexes = {}
+        for video_idx in self.dataset:
+            video, start_idx = video_idx.split('_')
+            if video not in self.video2indexes.keys():
+                self.video2indexes[video] = []
+            self.video2indexes[video].append(start_idx)
+        for video in self.video2indexes.keys():
+            if len(self.video2indexes[video]) == 1: # given video contains only one hit
+                self.dataset.remove(
+                    get_GH_data_identifier(video, self.video2indexes[video][0])
+                )
+
+        self.wav_transforms = transforms.Compose([
+            MakeMono(),
+            Padding(target_len=int(SR * self.L)),
+        ])
+        if self.frame_transforms == None:
+            self.frame_transforms = transforms.Compose([
+                Resize3D(256),
+                RandomResizedCrop3D(224, scale=(0.5, 1.0)),
+                RandomHorizontalFlip3D(),
+                ColorJitter3D(brightness=0.1, saturation=0.1),
+                ToTensor3D(),
+                Normalize3D(mean=[0.485, 0.456, 0.406],
+                            std=[0.229, 0.224, 0.225]),
+            ])
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def __getitem__(self, idx):
+        item = {}
+        video_idx = self.dataset[idx]
+        video, start_idx = video_idx.split('_')
+        start_idx = int(start_idx)
+        frame_path = os.path.join(self.data_path, video, 'frames')
+        start_frame_idx = non_negative(FPS * int(start_idx)/SR)
+        if self.rand_shift:
+            shift = random.uniform(-0.5, 0.5)
+            start_frame_idx = non_negative(start_frame_idx + int(FPS * shift))
+            start_idx = non_negative(start_idx + int(SR * shift))
+        if start_frame_idx > self.video_frame_cnt[video] - self.left_over:
+            start_frame_idx = self.video_frame_cnt[video] - self.left_over
+            start_idx = non_negative(SR * (start_frame_idx / FPS))
+
+        end_frame_idx = non_negative(start_frame_idx + FPS * self.L)
+
+        # target
+        wave_path = self.video_audio_path[video]
+        frames = [Image.open(os.path.join(
+            frame_path, f'frame{i+1:0>6d}')).convert('RGB') for i in
+            range(start_frame_idx, end_frame_idx)]
+        wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_idx)
+        assert sr == SR
+        wav = self.wav_transforms(wav)
+
+        item['image'] = wav # (44100,)
+        item['file_path_wav_'] = wave_path
+
+        if self.frame_transforms is not None:
+            frames = self.frame_transforms(frames)
+
+        item['feature'] = torch.stack(frames, dim=0) # (15 * L, 112, 112, 3)
+        item['file_path_feats_'] = (frame_path, start_idx)
+
+        item['label'] = 'None'
+        item['target'] = 'None'
+
+        return item
+
+    def validate_data(self):
+        raise NotImplementedError()
+
+    def make_split_files(self, ratio=[0.85, 0.1, 0.05]):
+        random.seed(1337)
+        print(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.')
+
+        all_video = sorted(os.listdir(self.data_path))
+        print(f'The number of videos available after download: {len(all_video)}')
+
+        available_idx = list(range(len(all_video)))
+        random.shuffle(available_idx)
+        assert sum(ratio) == 1.
+        cut_train = int(ratio[0] * len(all_video))
+        cut_test = cut_train + int(ratio[1] * len(all_video))
+
+        train_idx = available_idx[:cut_train]
+        test_idx = available_idx[cut_train:cut_test]
+        valid_idx = available_idx[cut_test:]
+
+        train_video = [all_video[i] for i in train_idx]
+        test_video = [all_video[i] for i in test_idx]
+        valid_video = [all_video[i] for i in valid_idx]
+
+        with open(os.path.join(self.splits_path, 'greatesthit_video_train.json'), 'w') as train_file,\
+             open(os.path.join(self.splits_path, 'greatesthit_video_test.json'), 'w') as test_file,\
+             open(os.path.join(self.splits_path, 'greatesthit_video_valid.json'), 'w') as valid_file:
+            json.dump(train_video, train_file)
+            json.dump(test_video, test_file)
+            json.dump(valid_video, valid_file)
+
+        print(f'Put {len(train_idx)} videos to the train set and saved it to ./data/greatesthit_video_train.json')
+        print(f'Put {len(test_idx)} videos to the test set and saved it to ./data/greatesthit_video_test.json')
+        print(f'Put {len(valid_idx)} videos to the valid set and saved it to ./data/greatesthit_video_valid.json')
+
+
+class GreatestHitWaveCondOnImageTrain(GreatestHitWaveCondOnImage):
+    def __init__(self, dataset_cfg):
+        train_transforms = transforms.Compose([
+            Resize3D(128),
+            RandomResizedCrop3D(112, scale=(0.5, 1.0)),
+            RandomHorizontalFlip3D(),
+            ColorJitter3D(brightness=0.4, saturation=0.4, contrast=0.2, hue=0.1),
+            ToTensor3D(),
+            Normalize3D(mean=[0.485, 0.456, 0.406],
+                        std=[0.229, 0.224, 0.225]),
+        ])
+        super().__init__('train', frame_transforms=train_transforms, **dataset_cfg)
+
+class GreatestHitWaveCondOnImageValidation(GreatestHitWaveCondOnImage):
+    def __init__(self, dataset_cfg):
+        valid_transforms = transforms.Compose([
+            Resize3D(128),
+            CenterCrop3D(112),
+            ToTensor3D(),
+            Normalize3D(mean=[0.485, 0.456, 0.406],
+                        std=[0.229, 0.224, 0.225]),
+        ])
+        super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg)
+
+class GreatestHitWaveCondOnImageTest(GreatestHitWaveCondOnImage):
+    def __init__(self, dataset_cfg):
+        test_transforms = transforms.Compose([
+            Resize3D(128),
+            CenterCrop3D(112),
+            ToTensor3D(),
+            Normalize3D(mean=[0.485, 0.456, 0.406],
+                        std=[0.229, 0.224, 0.225]),
+        ])
+        super().__init__('test', frame_transforms=test_transforms, **dataset_cfg)
+
+
+def draw_spec(spec, dest, cmap='magma'):
+    plt.imshow(spec, cmap=cmap, origin='lower')
+    plt.axis('off')
+    plt.savefig(dest, bbox_inches='tight', pad_inches=0., dpi=300)
+    plt.close()
+
+if __name__ == '__main__':
+    import sys
+
+    from omegaconf import OmegaConf
+
+    # cfg = OmegaConf.load('configs/greatesthit_transformer_with_vNet_randshift_2s_GH_vqgan_no_earlystop.yaml')
+    cfg = OmegaConf.load('configs/greatesthit_codebook.yaml')
+    data = instantiate_from_config(cfg.data)
+    data.prepare_data()
+    data.setup()
+    print(len(data.datasets['train']))
+    print(data.datasets['train'][24])
+
diff --git a/foleycrafter/models/specvqgan/data/impactset.py b/foleycrafter/models/specvqgan/data/impactset.py
new file mode 100644
index 0000000000000000000000000000000000000000..039dc764260c05ab816c2c79098eba9ef1ffd442
--- /dev/null
+++ b/foleycrafter/models/specvqgan/data/impactset.py
@@ -0,0 +1,778 @@
+import json
+import os
+import matplotlib.pyplot as plt
+import torch
+from torchvision import transforms
+import numpy as np
+from tqdm import tqdm
+from random import sample
+import torchaudio
+import logging
+from glob import glob
+import sys
+import soundfile
+import copy
+import csv
+import noisereduce as nr
+
+sys.path.insert(0, '.')  # nopep8
+from train import instantiate_from_config
+from foleycrafter.models.specvqgan.data.transforms import *
+
+torchaudio.set_audio_backend("sox_io")
+logger = logging.getLogger(f'main.{__name__}')
+
+SR = 22050
+FPS = 15
+MAX_SAMPLE_ITER = 10
+
+def non_negative(x): return int(np.round(max(0, x), 0))
+
+def rms(x): return np.sqrt(np.mean(x**2))
+
+def get_GH_data_identifier(video_name, start_idx, split='_'):
+    if isinstance(start_idx, str):
+        return video_name + split + start_idx
+    elif isinstance(start_idx, int):
+        return video_name + split + str(start_idx)
+    else:
+        raise NotImplementedError
+
+def draw_spec(spec, dest, cmap='magma'):
+    plt.imshow(spec, cmap=cmap, origin='lower')
+    plt.axis('off')
+    plt.savefig(dest, bbox_inches='tight', pad_inches=0., dpi=300)
+    plt.close()
+
+def convert_to_decibel(arr):
+    ref = 1
+    return 20 * np.log10(abs(arr + 1e-4) / ref)
+
+class ResampleFrames(object):
+    def __init__(self, feat_sample_size, times_to_repeat_after_resample=None):
+        self.feat_sample_size = feat_sample_size
+        self.times_to_repeat_after_resample = times_to_repeat_after_resample
+
+    def __call__(self, item):
+        feat_len = item['feature'].shape[0]
+
+        ## resample
+        assert feat_len >= self.feat_sample_size
+        # evenly spaced points (abcdefghkl -> aoooofoooo)
+        idx = np.linspace(0, feat_len, self.feat_sample_size, dtype=np.int, endpoint=False)
+        # xoooo xoooo -> ooxoo ooxoo
+        shift = feat_len // (self.feat_sample_size + 1)
+        idx = idx + shift
+
+        ## repeat after resampling (abc -> aaaabbbbcccc)
+        if self.times_to_repeat_after_resample is not None and self.times_to_repeat_after_resample > 1:
+            idx = np.repeat(idx, self.times_to_repeat_after_resample)
+
+        item['feature'] = item['feature'][idx, :]
+        return item
+
+
+class ImpactSetWave(torch.utils.data.Dataset):
+
+    def __init__(self, split, random_crop, mel_num, spec_crop_len,
+                L=2.0, denoise=False, splits_path='./data',
+                data_path='data/ImpactSet/impactset-proccess-resize'):
+        super().__init__()
+        self.split = split
+        self.splits_path = splits_path
+        self.data_path = data_path
+        self.L = L
+        self.denoise = denoise
+
+        video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json')
+        if not os.path.exists(video_name_split_path):
+            self.make_split_files()
+        video_name = json.load(open(video_name_split_path, 'r'))
+        self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name}
+        self.left_over = int(FPS * L + 1)
+        self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name}
+        self.dataset = video_name
+
+        self.wav_transforms = transforms.Compose([
+            MakeMono(),
+            Padding(target_len=int(SR * self.L)),
+        ])
+        
+        self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop)
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def __getitem__(self, idx):
+        item = {}
+        video = self.dataset[idx]
+
+        available_frame_idx = self.video_frame_cnt[video] - self.left_over
+        wav = None
+        spec = None
+        max_db = -np.inf
+        wave_path = ''
+        cur_wave_path = self.video_audio_path[video]
+        if self.denoise:
+            cur_wave_path = cur_wave_path.replace('.wav', '_denoised.wav')
+        for _ in range(10):
+            start_idx = torch.randint(0, available_frame_idx, (1,)).tolist()[0]
+            # target
+            start_t = (start_idx + 0.5) / FPS
+            start_audio_idx = non_negative(start_t * SR)
+
+            cur_wav, _ = soundfile.read(cur_wave_path, frames=int(SR * self.L), start=start_audio_idx)
+
+            decibel = convert_to_decibel(cur_wav)
+            if float(np.mean(decibel)) > max_db:
+                wav = cur_wav
+                wave_path = cur_wave_path
+                max_db = float(np.mean(decibel))
+            if max_db >= -40:
+                break
+
+        # print(max_db)
+        wav = self.wav_transforms(wav)
+        item['image'] = wav # (80, 173)
+        # item['wav'] = wav
+        item['file_path_wav_'] = wave_path
+
+        item['label'] = 'None'
+        item['target'] = 'None'
+
+        return item
+
+    def make_split_files(self):
+        raise NotImplementedError
+
+class ImpactSetWaveTrain(ImpactSetWave):
+    def __init__(self, specs_dataset_cfg):
+        super().__init__('train', **specs_dataset_cfg)
+
+class ImpactSetWaveValidation(ImpactSetWave):
+    def __init__(self, specs_dataset_cfg):
+        super().__init__('val', **specs_dataset_cfg)
+
+class ImpactSetWaveTest(ImpactSetWave):
+    def __init__(self, specs_dataset_cfg):
+        super().__init__('test', **specs_dataset_cfg)
+
+
+class ImpactSetSpec(torch.utils.data.Dataset):
+
+    def __init__(self, split, random_crop, mel_num, spec_crop_len,
+                L=2.0, denoise=False, splits_path='./data',
+                data_path='data/ImpactSet/impactset-proccess-resize'):
+        super().__init__()
+        self.split = split
+        self.splits_path = splits_path
+        self.data_path = data_path
+        self.L = L
+        self.denoise = denoise
+
+        video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json')
+        if not os.path.exists(video_name_split_path):
+            self.make_split_files()
+        video_name = json.load(open(video_name_split_path, 'r'))
+        self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name}
+        self.left_over = int(FPS * L + 1)
+        self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name}
+        self.dataset = video_name
+
+        self.wav_transforms = transforms.Compose([
+            MakeMono(),
+            SpectrogramTorchAudio(nfft=1024, hoplen=1024//4, spec_power=1),
+            MelScaleTorchAudio(sr=SR, stft=513, fmin=125, fmax=7600, nmels=80),
+            LowerThresh(1e-5),
+            Log10(),
+            Multiply(20),
+            Subtract(20),
+            Add(100),
+            Divide(100),
+            Clip(0, 1.0),
+            TrimSpec(173),
+        ])
+        
+        self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop)
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def __getitem__(self, idx):
+        item = {}
+        video = self.dataset[idx]
+
+        available_frame_idx = self.video_frame_cnt[video] - self.left_over
+        wav = None
+        spec = None
+        max_rms = -np.inf
+        wave_path = ''
+        cur_wave_path = self.video_audio_path[video]
+        if self.denoise:
+            cur_wave_path = cur_wave_path.replace('.wav', '_denoised.wav')
+        for _ in range(10):
+            start_idx = torch.randint(0, available_frame_idx, (1,)).tolist()[0]
+            # target
+            start_t = (start_idx + 0.5) / FPS
+            start_audio_idx = non_negative(start_t * SR)
+
+            cur_wav, _ = soundfile.read(cur_wave_path, frames=int(SR * self.L), start=start_audio_idx)
+
+            if self.wav_transforms is not None:
+                spec_tensor = self.wav_transforms(torch.tensor(cur_wav).float())
+                cur_spec = spec_tensor.numpy()
+            # zeros padding if not enough spec t steps
+            if cur_spec.shape[1] < 173:
+                pad = np.zeros((80, 173), dtype=cur_spec.dtype)
+                pad[:, :cur_spec.shape[1]] = cur_spec
+                cur_spec = pad
+            rms_val = rms(cur_spec)
+            if rms_val > max_rms:
+                wav = cur_wav
+                spec = cur_spec
+                wave_path = cur_wave_path
+                max_rms = rms_val
+            # print(rms_val)
+            if max_rms >= 0.1:
+                break
+
+        item['image'] = 2 * spec - 1 # (80, 173)
+        # item['wav'] = wav
+        item['file_path_wav_'] = wave_path
+
+        item['label'] = 'None'
+        item['target'] = 'None'
+
+        if self.spec_transforms is not None:
+            item = self.spec_transforms(item)
+        return item
+
+    def make_split_files(self):
+        raise NotImplementedError
+
+class ImpactSetSpecTrain(ImpactSetSpec):
+    def __init__(self, specs_dataset_cfg):
+        super().__init__('train', **specs_dataset_cfg)
+
+class ImpactSetSpecValidation(ImpactSetSpec):
+    def __init__(self, specs_dataset_cfg):
+        super().__init__('val', **specs_dataset_cfg)
+
+class ImpactSetSpecTest(ImpactSetSpec):
+    def __init__(self, specs_dataset_cfg):
+        super().__init__('test', **specs_dataset_cfg)
+
+
+
+class ImpactSetWaveTestTime(torch.utils.data.Dataset):
+
+    def __init__(self, split, random_crop, mel_num, spec_crop_len,
+                L=2.0, denoise=False, splits_path='./data',
+                data_path='data/ImpactSet/impactset-proccess-resize'):
+        super().__init__()
+        self.split = split
+        self.splits_path = splits_path
+        self.data_path = data_path
+        self.L = L
+        self.denoise = denoise
+
+        self.video_list = glob('data/ImpactSet/RawVideos/StockVideo_sound/*.wav') + [
+            'data/ImpactSet/RawVideos/YouTube-impact-ccl/1_ckbCU5aQs/1_ckbCU5aQs_0013_0016_resize.wav', 
+            'data/ImpactSet/RawVideos/YouTube-impact-ccl/GFmuVBiwz6k/GFmuVBiwz6k_0034_0054_resize.wav',
+            'data/ImpactSet/RawVideos/YouTube-impact-ccl/OsPcY316h1M/OsPcY316h1M_0000_0005_resize.wav',
+            'data/ImpactSet/RawVideos/YouTube-impact-ccl/SExIpBIBj_k/SExIpBIBj_k_0009_0019_resize.wav',
+            'data/ImpactSet/RawVideos/YouTube-impact-ccl/S6TkbV4B4QI/S6TkbV4B4QI_0028_0036_resize.wav',
+            'data/ImpactSet/RawVideos/YouTube-impact-ccl/2Ld24pPIn3k/2Ld24pPIn3k_0005_0011_resize.wav',
+            'data/ImpactSet/RawVideos/YouTube-impact-ccl/6d1YS7fdBK4/6d1YS7fdBK4_0007_0019_resize.wav',
+            'data/ImpactSet/RawVideos/YouTube-impact-ccl/JnBsmJgEkiw/JnBsmJgEkiw_0008_0016_resize.wav',
+            'data/ImpactSet/RawVideos/YouTube-impact-ccl/xcUyiXt0gjo/xcUyiXt0gjo_0015_0021_resize.wav',
+            'data/ImpactSet/RawVideos/YouTube-impact-ccl/4DRFJnZjpMM/4DRFJnZjpMM_0000_0010_resize.wav'
+        ] + glob('data/ImpactSet/RawVideos/self_recorded/*_resize.wav')
+
+        self.wav_transforms = transforms.Compose([
+            MakeMono(),
+            SpectrogramTorchAudio(nfft=1024, hoplen=1024//4, spec_power=1),
+            MelScaleTorchAudio(sr=SR, stft=513, fmin=125, fmax=7600, nmels=80),
+            LowerThresh(1e-5),
+            Log10(),
+            Multiply(20),
+            Subtract(20),
+            Add(100),
+            Divide(100),
+            Clip(0, 1.0),
+            TrimSpec(173),
+        ])
+        self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop)
+
+    def __len__(self):
+        return len(self.video_list)
+
+    def __getitem__(self, idx):
+        item = {}
+
+        wave_path = self.video_list[idx]
+
+        wav, _ = soundfile.read(wave_path)
+        start_idx = random.randint(0, min(4, wav.shape[0] - int(SR * self.L)))
+        wav = wav[start_idx:start_idx+int(SR * self.L)]
+
+        if self.denoise:
+            if len(wav.shape) == 1:
+                wav = wav[None, :]
+            wav = nr.reduce_noise(y=wav, sr=SR, n_fft=1024, hop_length=1024//4)
+            wav = wav.squeeze()
+        if self.wav_transforms is not None:
+            spec_tensor = self.wav_transforms(torch.tensor(wav).float())
+            spec = spec_tensor.numpy()
+        if spec.shape[1] < 173:
+            pad = np.zeros((80, 173), dtype=spec.dtype)
+            pad[:, :spec.shape[1]] = spec
+            spec = pad
+
+        item['image'] = 2 * spec - 1 # (80, 173)
+        # item['wav'] = wav
+        item['file_path_wav_'] = wave_path
+
+        item['label'] = 'None'
+        item['target'] = 'None'
+
+        if self.spec_transforms is not None:
+            item = self.spec_transforms(item)
+        return item
+
+    def make_split_files(self):
+        raise NotImplementedError
+
+class ImpactSetWaveTestTimeTrain(ImpactSetWaveTestTime):
+    def __init__(self, specs_dataset_cfg):
+        super().__init__('train', **specs_dataset_cfg)
+
+class ImpactSetWaveTestTimeValidation(ImpactSetWaveTestTime):
+    def __init__(self, specs_dataset_cfg):
+        super().__init__('val', **specs_dataset_cfg)
+
+class ImpactSetWaveTestTimeTest(ImpactSetWaveTestTime):
+    def __init__(self, specs_dataset_cfg):
+        super().__init__('test', **specs_dataset_cfg)
+
+
+class ImpactSetWaveWithSilent(torch.utils.data.Dataset):
+
+    def __init__(self, split, random_crop, mel_num, spec_crop_len,
+                L=2.0, denoise=False, splits_path='./data',
+                data_path='data/ImpactSet/impactset-proccess-resize'):
+        super().__init__()
+        self.split = split
+        self.splits_path = splits_path
+        self.data_path = data_path
+        self.L = L
+        self.denoise = denoise
+
+        video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json')
+        if not os.path.exists(video_name_split_path):
+            self.make_split_files()
+        video_name = json.load(open(video_name_split_path, 'r'))
+        self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name}
+        self.left_over = int(FPS * L + 1)
+        self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name}
+        self.dataset = video_name
+
+        self.wav_transforms = transforms.Compose([
+            MakeMono(),
+            Padding(target_len=int(SR * self.L)),
+        ])
+        
+        self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop)
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def __getitem__(self, idx):
+        item = {}
+        video = self.dataset[idx]
+
+        available_frame_idx = self.video_frame_cnt[video] - self.left_over
+        wave_path = self.video_audio_path[video]
+        if self.denoise:
+            wave_path = wave_path.replace('.wav', '_denoised.wav')
+        start_idx = torch.randint(0, available_frame_idx, (1,)).tolist()[0]
+        # target
+        start_t = (start_idx + 0.5) / FPS
+        start_audio_idx = non_negative(start_t * SR)
+
+        wav, _ = soundfile.read(wave_path, frames=int(SR * self.L), start=start_audio_idx)
+
+        wav = self.wav_transforms(wav)
+
+        item['image'] = wav # (44100,)
+        # item['wav'] = wav
+        item['file_path_wav_'] = wave_path
+
+        item['label'] = 'None'
+        item['target'] = 'None'
+        return item
+
+    def make_split_files(self):
+        raise NotImplementedError
+
+class ImpactSetWaveWithSilentTrain(ImpactSetWaveWithSilent):
+    def __init__(self, specs_dataset_cfg):
+        super().__init__('train', **specs_dataset_cfg)
+
+class ImpactSetWaveWithSilentValidation(ImpactSetWaveWithSilent):
+    def __init__(self, specs_dataset_cfg):
+        super().__init__('val', **specs_dataset_cfg)
+
+class ImpactSetWaveWithSilentTest(ImpactSetWaveWithSilent):
+    def __init__(self, specs_dataset_cfg):
+        super().__init__('test', **specs_dataset_cfg)
+
+
+class ImpactSetWaveCondOnImage(torch.utils.data.Dataset):
+
+    def __init__(self, split,
+                L=2.0, frame_transforms=None, denoise=False, splits_path='./data',
+                data_path='data/ImpactSet/impactset-proccess-resize',
+                p_outside_cond=0.):
+        super().__init__()
+        self.split = split
+        self.splits_path = splits_path
+        self.frame_transforms = frame_transforms
+        self.data_path = data_path
+        self.L = L
+        self.denoise = denoise
+        self.p_outside_cond = torch.tensor(p_outside_cond)
+
+        video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json')
+        if not os.path.exists(video_name_split_path):
+            self.make_split_files()
+        video_name = json.load(open(video_name_split_path, 'r'))
+        self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name}
+        self.left_over = int(FPS * L + 1)
+        for v, cnt in self.video_frame_cnt.items():
+            if cnt - (3*self.left_over) <= 0:
+                video_name.remove(v)
+        self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name}
+        self.dataset = video_name
+
+        video_timing_split_path = os.path.join(splits_path, f'countixAV_{split}_timing.json')
+        self.video_timing = json.load(open(video_timing_split_path, 'r'))
+        self.video_timing = {v: [int(float(t) * FPS) for t in ts] for v, ts in self.video_timing.items()}
+
+        if split != 'test':
+            video_class_path = os.path.join(splits_path, f'countixAV_{split}_class.json')
+            if not os.path.exists(video_class_path):
+                self.make_video_class()
+            self.video_class = json.load(open(video_class_path, 'r'))
+            self.class2video = {}
+            for v, c in self.video_class.items():
+                if c not in self.class2video.keys():
+                    self.class2video[c] = []
+                self.class2video[c].append(v)
+
+        self.wav_transforms = transforms.Compose([
+            MakeMono(),
+            Padding(target_len=int(SR * self.L)),
+        ])
+        if self.frame_transforms == None:
+            self.frame_transforms = transforms.Compose([
+                Resize3D(128),
+                RandomResizedCrop3D(112, scale=(0.5, 1.0)),
+                RandomHorizontalFlip3D(),
+                ColorJitter3D(brightness=0.1, saturation=0.1),
+                ToTensor3D(),
+                Normalize3D(mean=[0.485, 0.456, 0.406],
+                            std=[0.229, 0.224, 0.225]),
+            ])
+
+    def make_video_class(self):
+        meta_path = f'data/ImpactSet/data-info/CountixAV_{self.split}.csv'
+        video_class = {}
+        with open(meta_path, 'r') as f:
+            reader = csv.reader(f)
+            for i, row in enumerate(reader):
+                if i == 0:
+                    continue
+                vid, k_st, k_et = row[:3]
+                video_name = f'{vid}_{int(k_st):0>4d}_{int(k_et):0>4d}'
+                if video_name not in self.dataset:
+                    continue
+                video_class[video_name] = row[-1]
+        with open(os.path.join(self.splits_path, f'countixAV_{self.split}_class.json'), 'w') as f:
+            json.dump(video_class, f)
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def __getitem__(self, idx):
+        item = {}
+        video = self.dataset[idx]
+
+        available_frame_idx = self.video_frame_cnt[video] - self.left_over
+        rep_start_idx, rep_end_idx = self.video_timing[video]
+        rep_end_idx = min(available_frame_idx, rep_end_idx)
+        if available_frame_idx <= rep_start_idx + self.L * FPS:
+            idx_set = list(range(0, available_frame_idx))
+        else:
+            idx_set = list(range(rep_start_idx, rep_end_idx))
+        start_idx = sample(idx_set, k=1)[0]
+
+        wave_path = self.video_audio_path[video]
+        if self.denoise:
+            wave_path = wave_path.replace('.wav', '_denoised.wav')
+
+        # target
+        start_t = (start_idx + 0.5) / FPS
+        end_idx= non_negative(start_idx + FPS * self.L)
+        start_audio_idx = non_negative(start_t * SR)
+        wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_audio_idx)
+        assert sr == SR
+        wav = self.wav_transforms(wav)
+        frame_path = os.path.join(self.data_path, video, 'frames')
+        frames = [Image.open(os.path.join(
+            frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in
+            range(start_idx, end_idx)]
+
+        if torch.all(torch.bernoulli(self.p_outside_cond) == 1.) and self.split != 'test':
+            # outside from the same class
+            cur_class = self.video_class[video]
+            tmp_video = copy.copy(self.class2video[cur_class])
+            if len(tmp_video) > 1:
+                # if only 1 video in the class, use itself
+                tmp_video.remove(video)
+            cond_video = sample(tmp_video, k=1)[0]
+            cond_available_frame_idx = self.video_frame_cnt[cond_video] - self.left_over
+            cond_start_idx = torch.randint(0, cond_available_frame_idx, (1,)).tolist()[0]
+        else:
+            cond_video = video
+            idx_set = list(range(0, start_idx)) + list(range(end_idx, available_frame_idx))
+            cond_start_idx = random.sample(idx_set, k=1)[0]
+
+        cond_end_idx = non_negative(cond_start_idx + FPS * self.L)
+        cond_start_t = (cond_start_idx + 0.5) / FPS
+        cond_audio_idx = non_negative(cond_start_t * SR)
+        cond_frame_path = os.path.join(self.data_path, cond_video, 'frames')
+        cond_wave_path = self.video_audio_path[cond_video]
+
+        cond_frames = [Image.open(os.path.join(
+            cond_frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in
+            range(cond_start_idx, cond_end_idx)]
+        cond_wav, sr = soundfile.read(cond_wave_path, frames=int(SR * self.L), start=cond_audio_idx)
+        assert sr == SR
+        cond_wav = self.wav_transforms(cond_wav)
+
+        item['image'] = wav # (44100,)
+        item['cond_image'] = cond_wav # (44100,)
+        item['file_path_wav_'] = wave_path
+        item['file_path_cond_wav_'] = cond_wave_path
+
+        if self.frame_transforms is not None:
+            cond_frames = self.frame_transforms(cond_frames)
+            frames = self.frame_transforms(frames)
+
+        item['feature'] = np.stack(cond_frames + frames, axis=0) # (30 * L, 112, 112, 3)
+        item['file_path_feats_'] = (frame_path, start_idx)
+        item['file_path_cond_feats_'] = (cond_frame_path, cond_start_idx)
+
+        item['label'] = 'None'
+        item['target'] = 'None'
+
+        return item
+
+    def make_split_files(self):
+        raise NotImplementedError
+
+
+class ImpactSetWaveCondOnImageTrain(ImpactSetWaveCondOnImage):
+    def __init__(self, dataset_cfg):
+        train_transforms = transforms.Compose([
+            Resize3D(128),
+            RandomResizedCrop3D(112, scale=(0.5, 1.0)),
+            RandomHorizontalFlip3D(),
+            ColorJitter3D(brightness=0.4, saturation=0.4, contrast=0.2, hue=0.1),
+            ToTensor3D(),
+            Normalize3D(mean=[0.485, 0.456, 0.406],
+                        std=[0.229, 0.224, 0.225]),
+        ])
+        super().__init__('train', frame_transforms=train_transforms, **dataset_cfg)
+
+class ImpactSetWaveCondOnImageValidation(ImpactSetWaveCondOnImage):
+    def __init__(self, dataset_cfg):
+        valid_transforms = transforms.Compose([
+            Resize3D(128),
+            CenterCrop3D(112),
+            ToTensor3D(),
+            Normalize3D(mean=[0.485, 0.456, 0.406],
+                        std=[0.229, 0.224, 0.225]),
+        ])
+        super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg)
+
+class ImpactSetWaveCondOnImageTest(ImpactSetWaveCondOnImage):
+    def __init__(self, dataset_cfg):
+        test_transforms = transforms.Compose([
+            Resize3D(128),
+            CenterCrop3D(112),
+            ToTensor3D(),
+            Normalize3D(mean=[0.485, 0.456, 0.406],
+                        std=[0.229, 0.224, 0.225]),
+        ])
+        super().__init__('test', frame_transforms=test_transforms, **dataset_cfg)
+
+
+
+class ImpactSetCleanWaveCondOnImage(ImpactSetWaveCondOnImage):
+    def __init__(self, split, L=2, frame_transforms=None, denoise=False, splits_path='./data', data_path='data/ImpactSet/impactset-proccess-resize', p_outside_cond=0):
+        super().__init__(split, L, frame_transforms, denoise, splits_path, data_path, p_outside_cond)
+        pred_timing_path = f'data/countixAV_{split}_timing_processed_0.20.json'
+        assert os.path.exists(pred_timing_path)
+        self.pred_timing = json.load(open(pred_timing_path, 'r'))
+
+        self.dataset = []
+        for v, ts in self.pred_timing.items():
+            if v in self.video_audio_path.keys():
+                for t in ts:
+                    self.dataset.append([v, t])
+
+    def __getitem__(self, idx):
+        item = {}
+        video, start_t = self.dataset[idx]
+        available_frame_idx = self.video_frame_cnt[video] - self.left_over
+        available_timing = (available_frame_idx + 0.5) / FPS
+        start_t = float(start_t)
+        start_t = min(start_t, available_timing)
+
+        start_idx = non_negative(start_t * FPS - 0.5)
+
+        wave_path = self.video_audio_path[video]
+        if self.denoise:
+            wave_path = wave_path.replace('.wav', '_denoised.wav')
+
+        # target
+        end_idx= non_negative(start_idx + FPS * self.L)
+        start_audio_idx = non_negative(start_t * SR)
+        wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_audio_idx)
+        assert sr == SR
+        wav = self.wav_transforms(wav)
+        frame_path = os.path.join(self.data_path, video, 'frames')
+        frames = [Image.open(os.path.join(
+            frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in
+            range(start_idx, end_idx)]
+
+        if torch.all(torch.bernoulli(self.p_outside_cond) == 1.):
+            other_video = list(self.pred_timing.keys())
+            other_video.remove(video)
+            cond_video = sample(other_video, k=1)[0]
+            cond_available_frame_idx = self.video_frame_cnt[cond_video] - self.left_over
+            cond_available_timing = (cond_available_frame_idx + 0.5) / FPS
+        else:
+            cond_video = video
+            cond_available_timing = available_timing
+
+        cond_start_t = sample(self.pred_timing[cond_video], k=1)[0]
+        cond_start_t = float(cond_start_t)
+        cond_start_t = min(cond_start_t, cond_available_timing)
+        cond_start_idx = non_negative(cond_start_t * FPS - 0.5)
+        cond_end_idx = non_negative(cond_start_idx + FPS * self.L)
+        cond_audio_idx = non_negative(cond_start_t * SR)
+        cond_frame_path = os.path.join(self.data_path, cond_video, 'frames')
+        cond_wave_path = self.video_audio_path[cond_video]
+
+        cond_frames = [Image.open(os.path.join(
+            cond_frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in
+            range(cond_start_idx, cond_end_idx)]
+        cond_wav, sr = soundfile.read(cond_wave_path, frames=int(SR * self.L), start=cond_audio_idx)
+        assert sr == SR
+        cond_wav = self.wav_transforms(cond_wav)
+
+        item['image'] = wav # (44100,)
+        item['cond_image'] = cond_wav # (44100,)
+        item['file_path_wav_'] = wave_path
+        item['file_path_cond_wav_'] = cond_wave_path
+
+        if self.frame_transforms is not None:
+            cond_frames = self.frame_transforms(cond_frames)
+            frames = self.frame_transforms(frames)
+
+        item['feature'] = np.stack(cond_frames + frames, axis=0) # (30 * L, 112, 112, 3)
+        item['file_path_feats_'] = (frame_path, start_idx)
+        item['file_path_cond_feats_'] = (cond_frame_path, cond_start_idx)
+
+        item['label'] = 'None'
+        item['target'] = 'None'
+
+        return item
+
+
+class ImpactSetCleanWaveCondOnImageTrain(ImpactSetCleanWaveCondOnImage):
+    def __init__(self, dataset_cfg):
+        train_transforms = transforms.Compose([
+            Resize3D(128),
+            RandomResizedCrop3D(112, scale=(0.5, 1.0)),
+            RandomHorizontalFlip3D(),
+            ColorJitter3D(brightness=0.4, saturation=0.4, contrast=0.2, hue=0.1),
+            ToTensor3D(),
+            Normalize3D(mean=[0.485, 0.456, 0.406],
+                        std=[0.229, 0.224, 0.225]),
+        ])
+        super().__init__('train', frame_transforms=train_transforms, **dataset_cfg)
+
+class ImpactSetCleanWaveCondOnImageValidation(ImpactSetCleanWaveCondOnImage):
+    def __init__(self, dataset_cfg):
+        valid_transforms = transforms.Compose([
+            Resize3D(128),
+            CenterCrop3D(112),
+            ToTensor3D(),
+            Normalize3D(mean=[0.485, 0.456, 0.406],
+                        std=[0.229, 0.224, 0.225]),
+        ])
+        super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg)
+
+class ImpactSetCleanWaveCondOnImageTest(ImpactSetCleanWaveCondOnImage):
+    def __init__(self, dataset_cfg):
+        test_transforms = transforms.Compose([
+            Resize3D(128),
+            CenterCrop3D(112),
+            ToTensor3D(),
+            Normalize3D(mean=[0.485, 0.456, 0.406],
+                        std=[0.229, 0.224, 0.225]),
+        ])
+        super().__init__('test', frame_transforms=test_transforms, **dataset_cfg)
+
+
+if __name__ == '__main__':
+    import sys
+
+    from omegaconf import OmegaConf
+    cfg = OmegaConf.load('configs/countixAV_transformer_denoise_clean.yaml')
+    data = instantiate_from_config(cfg.data)
+    data.prepare_data()
+    data.setup()
+
+    print(data.datasets['train'])
+    print(len(data.datasets['train']))
+    # print(data.datasets['train'][24])
+    exit()
+
+    stats = []
+    torch.manual_seed(0)
+    np.random.seed(0)
+    random.seed = 0
+    for k in range(1):
+        x = np.arange(SR * 2)
+        for i in tqdm(range(len(data.datasets['train']))):
+            wav = data.datasets['train'][i]['wav']
+            spec = data.datasets['train'][i]['image']
+            spec = 0.5 * (spec + 1)
+            spec_rms = rms(spec)
+            stats.append(float(spec_rms))
+            # plt.plot(x, wav)
+            # plt.ylim(-1, 1)
+            # plt.savefig(f'tmp/th0.1_wav_e_{k}_{i}_{mean_val:.3f}_{spec_rms:.3f}.png')
+            # plt.close()
+            # plt.cla()
+            soundfile.write(f'tmp/wav_e_{k}_{i}_{spec_rms:.3f}.wav', wav, SR)
+            draw_spec(spec, f'tmp/wav_spec_e_{k}_{i}_{spec_rms:.3f}.png')
+            if i == 100:
+                break
+    # plt.hist(stats, bins=50)
+    # plt.savefig(f'tmp/rms_spec_stats.png')
diff --git a/foleycrafter/models/specvqgan/data/transforms.py b/foleycrafter/models/specvqgan/data/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2b5e022b1f4c3ae4bc62dc0e88240c919417f23
--- /dev/null
+++ b/foleycrafter/models/specvqgan/data/transforms.py
@@ -0,0 +1,685 @@
+import torch
+import torchaudio
+import torchaudio.functional
+from torchvision import transforms
+import torchvision.transforms.functional as F
+import torch.nn as nn
+from PIL import Image
+import numpy as np
+import math
+import random
+import soundfile
+import os
+import librosa
+import albumentations
+from torch_pitch_shift import *
+
+SR = 22050
+
+class ResizeShortSide(object):
+    def __init__(self, size):
+        super().__init__()
+        self.size = size
+
+    def __call__(self, x):
+        '''
+        x must be PIL.Image
+        '''
+        w, h = x.size
+        short_side = min(w, h)
+        w_target = int((w / short_side) * self.size)
+        h_target = int((h / short_side) * self.size)
+        return x.resize((w_target, h_target))
+
+
+class Crop(object):
+    def __init__(self, cropped_shape=None, random_crop=False):
+        self.cropped_shape = cropped_shape
+        if cropped_shape is not None:
+            mel_num, spec_len = cropped_shape
+            if random_crop:
+                self.cropper = albumentations.RandomCrop
+            else:
+                self.cropper = albumentations.CenterCrop
+            self.preprocessor = albumentations.Compose([self.cropper(mel_num, spec_len)])
+        else:
+            self.preprocessor = lambda **kwargs: kwargs
+
+    def __call__(self, item):
+        item['image'] = self.preprocessor(image=item['image'])['image']
+        if 'cond_image' in item.keys():
+            item['cond_image'] = self.preprocessor(image=item['cond_image'])['image']
+        return item
+
+class CropImage(Crop):
+    def __init__(self, *crop_args):
+        super().__init__(*crop_args)
+
+class CropFeats(Crop):
+    def __init__(self, *crop_args):
+        super().__init__(*crop_args)
+
+    def __call__(self, item):
+        item['feature'] = self.preprocessor(image=item['feature'])['image']
+        return item
+
+class CropCoords(Crop):
+    def __init__(self, *crop_args):
+        super().__init__(*crop_args)
+
+    def __call__(self, item):
+        item['coord'] = self.preprocessor(image=item['coord'])['image']
+        return item
+
+
+class RandomResizedCrop3D(nn.Module):
+    """Crop the given series of images to random size and aspect ratio.
+    The image can be a PIL Images or a Tensor, in which case it is expected
+    to have [N, ..., H, W] shape, where ... means an arbitrary number of leading dimensions
+
+    A crop of random size (default: of 0.08 to 1.0) of the original size and a random
+    aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
+    is finally resized to given size.
+    This is popularly used to train the Inception networks.
+
+    Args:
+      size (int or sequence): expected output size of each edge. If size is an
+        int instead of sequence like (h, w), a square output size ``(size, size)`` is
+        made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
+      scale (tuple of float): range of size of the origin size cropped
+      ratio (tuple of float): range of aspect ratio of the origin aspect ratio cropped.
+      interpolation (int): Desired interpolation enum defined by `filters`_.
+        Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR``
+        and ``PIL.Image.BICUBIC`` are supported.
+    """
+
+    def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=transforms.InterpolationMode.BILINEAR):
+        super().__init__()
+        if isinstance(size, tuple) and len(size) == 2:
+            self.size = size
+        else:
+            self.size = (size, size)
+
+        self.interpolation = interpolation
+        self.scale = scale
+        self.ratio = ratio
+
+    @staticmethod
+    def get_params(img, scale, ratio):
+        """Get parameters for ``crop`` for a random sized crop.
+
+        Args:
+          img (PIL Image or Tensor): Input image.
+          scale (list): range of scale of the origin size cropped
+          ratio (list): range of aspect ratio of the origin aspect ratio cropped
+
+        Returns:
+          tuple: params (i, j, h, w) to be passed to ``crop`` for a random
+            sized crop.
+        """
+        width, height = img.size
+        area = height * width
+
+        for _ in range(10):
+            target_area = area * \
+                torch.empty(1).uniform_(scale[0], scale[1]).item()
+            log_ratio = torch.log(torch.tensor(ratio))
+            aspect_ratio = torch.exp(
+                torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
+            ).item()
+
+            w = int(round(math.sqrt(target_area * aspect_ratio)))
+            h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+            if 0 < w <= width and 0 < h <= height:
+                i = torch.randint(0, height - h + 1, size=(1,)).item()
+                j = torch.randint(0, width - w + 1, size=(1,)).item()
+                return i, j, h, w
+
+        # Fallback to central crop
+        in_ratio = float(width) / float(height)
+        if in_ratio < min(ratio):
+            w = width
+            h = int(round(w / min(ratio)))
+        elif in_ratio > max(ratio):
+            h = height
+            w = int(round(h * max(ratio)))
+        else:  # whole image
+            w = width
+            h = height
+        i = (height - h) // 2
+        j = (width - w) // 2
+        return i, j, h, w
+
+    def forward(self, imgs):
+        """
+        Args:
+          img (PIL Image or Tensor): Image to be cropped and resized.
+
+        Returns:
+          PIL Image or Tensor: Randomly cropped and resized image.
+        """
+        i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio)
+        return [F.resized_crop(img, i, j, h, w, self.size, self.interpolation) for img in imgs]
+
+
+class Resize3D(object):
+    def __init__(self, size):
+        super().__init__()
+        self.size = size
+
+    def __call__(self, imgs):
+        '''
+        x must be PIL.Image
+        '''
+        return [x.resize((self.size, self.size)) for x in imgs]
+
+
+class RandomHorizontalFlip3D(object):
+    def __init__(self, p=0.5):
+        super().__init__()
+        self.p = p
+
+    def __call__(self, imgs):
+        '''
+        x must be PIL.Image
+        '''
+        if np.random.rand() < self.p:
+            return [x.transpose(Image.FLIP_LEFT_RIGHT) for x in imgs]
+        else:
+            return imgs
+
+
+class ColorJitter3D(torch.nn.Module):
+    """Randomly change the brightness, contrast and saturation of an image.
+
+    Args:
+    brightness (float or tuple of float (min, max)): How much to jitter brightness.
+        brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
+        or the given [min, max]. Should be non negative numbers.
+    contrast (float or tuple of float (min, max)): How much to jitter contrast.
+        contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
+        or the given [min, max]. Should be non negative numbers.
+    saturation (float or tuple of float (min, max)): How much to jitter saturation.
+        saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
+        or the given [min, max]. Should be non negative numbers.
+    hue (float or tuple of float (min, max)): How much to jitter hue.
+        hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
+        Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
+    """
+
+    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
+        super().__init__()
+        self.brightness = (1-brightness, 1+brightness)
+        self.contrast = (1-contrast, 1+contrast)
+        self.saturation = (1-saturation, 1+saturation)
+        self.hue = (0-hue, 0+hue)
+
+    @staticmethod
+    def get_params(brightness, contrast, saturation, hue):
+        """Get a randomized transform to be applied on image.
+
+        Arguments are same as that of __init__.
+
+        Returns:
+            Transform which randomly adjusts brightness, contrast and
+            saturation in a random order.
+        """
+        tfs = []
+
+        if brightness is not None:
+            brightness_factor = random.uniform(brightness[0], brightness[1])
+            tfs.append(transforms.Lambda(
+                lambda img: F.adjust_brightness(img, brightness_factor)))
+
+        if contrast is not None:
+            contrast_factor = random.uniform(contrast[0], contrast[1])
+            tfs.append(transforms.Lambda(
+                lambda img: F.adjust_contrast(img, contrast_factor)))
+
+        if saturation is not None:
+            saturation_factor = random.uniform(saturation[0], saturation[1])
+            tfs.append(transforms.Lambda(
+                lambda img: F.adjust_saturation(img, saturation_factor)))
+
+        if hue is not None:
+            hue_factor = random.uniform(hue[0], hue[1])
+            tfs.append(transforms.Lambda(
+                lambda img: F.adjust_hue(img, hue_factor)))
+
+        random.shuffle(tfs)
+        transform = transforms.Compose(tfs)
+
+        return transform
+
+    def forward(self, imgs):
+        """
+        Args:
+          img (PIL Image or Tensor): Input image.
+
+        Returns:
+          PIL Image or Tensor: Color jittered image.
+        """
+        transform = self.get_params(
+            self.brightness, self.contrast, self.saturation, self.hue)
+        return [transform(img) for img in imgs]
+
+
+class ToTensor3D(object):
+    def __init__(self):
+        super().__init__()
+
+    def __call__(self, imgs):
+        '''
+        x must be PIL.Image
+        '''
+        return [F.to_tensor(img) for img in imgs]
+
+
+class Normalize3D(object):
+    def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False):
+        super().__init__()
+        self.mean = mean
+        self.std = std
+        self.inplace = inplace
+
+    def __call__(self, imgs):
+        '''
+        x must be PIL.Image
+        '''
+        return [F.normalize(img, self.mean, self.std, self.inplace) for img in imgs]
+
+
+class CenterCrop3D(object):
+    def __init__(self, size):
+        super().__init__()
+        self.size = size
+
+    def __call__(self, imgs):
+        '''
+        x must be PIL.Image
+        '''
+        return [F.center_crop(img, self.size) for img in imgs]
+
+
+class FrequencyMasking(object):
+    def __init__(self, freq_mask_param: int, iid_masks: bool = False):
+        super().__init__()
+        self.masking = torchaudio.transforms.FrequencyMasking(freq_mask_param, iid_masks)
+
+    def __call__(self, item):
+        if 'cond_image' in item.keys():
+            batched_spec = torch.stack(
+                [torch.tensor(item['image']), torch.tensor(item['cond_image'])], dim=0
+            )[:, None] # (2, 1, H, W)
+            masked = self.masking(batched_spec).numpy()
+            item['image'] = masked[0, 0]
+            item['cond_image'] = masked[1, 0]
+        elif 'image' in item.keys():
+            inp = torch.tensor(item['image'])
+            item['image'] = self.masking(inp).numpy()
+        else:
+            raise NotImplementedError()
+        return item
+
+
+class TimeMasking(object):
+    def __init__(self, time_mask_param: int, iid_masks: bool = False):
+        super().__init__()
+        self.masking = torchaudio.transforms.TimeMasking(time_mask_param, iid_masks)
+
+    def __call__(self, item):
+        if 'cond_image' in item.keys():
+            batched_spec = torch.stack(
+                [torch.tensor(item['image']), torch.tensor(item['cond_image'])], dim=0
+            )[:, None] # (2, 1, H, W)
+            masked = self.masking(batched_spec).numpy()
+            item['image'] = masked[0, 0]
+            item['cond_image'] = masked[1, 0]
+        elif 'image' in item.keys():
+            inp = torch.tensor(item['image'])
+            item['image'] = self.masking(inp).numpy()
+        else:
+            raise NotImplementedError()
+        return item
+
+
+class PitchShift(nn.Module):
+
+    def __init__(self, up=12, down=-12, sample_rate=SR):
+        super().__init__()
+        self.range = (down, up)
+        self.sr = sample_rate
+
+    def forward(self, x):
+        assert len(x.shape) == 2
+        x = x[:, None, :]
+        ratio = float(random.randint(self.range[0], self.range[1]) / 12.)
+        shifted = pitch_shift(x, ratio, self.sr)
+        return shifted.squeeze()
+
+
+class MelSpectrogram(object):
+    def __init__(self, sr, nfft, fmin, fmax, nmels, hoplen, spec_power, inverse=False):
+        self.sr = sr
+        self.nfft = nfft
+        self.fmin = fmin
+        self.fmax = fmax
+        self.nmels = nmels
+        self.hoplen = hoplen
+        self.spec_power = spec_power
+        self.inverse = inverse
+
+        self.mel_basis = librosa.filters.mel(sr=sr, n_fft=nfft, fmin=fmin, fmax=fmax, n_mels=nmels)
+
+    def __call__(self, x):
+        x = x.numpy()
+        if self.inverse:
+            spec = librosa.feature.inverse.mel_to_stft(
+                x, sr=self.sr, n_fft=self.nfft, fmin=self.fmin, fmax=self.fmax, power=self.spec_power
+            )
+            wav = librosa.griffinlim(spec, hop_length=self.hoplen)
+            return torch.FloatTensor(wav)
+        else:
+            spec = np.abs(librosa.stft(x, n_fft=self.nfft, hop_length=self.hoplen)) ** self.spec_power
+            mel_spec = np.dot(self.mel_basis, spec)
+            return torch.FloatTensor(mel_spec)
+
+class SpectrogramTorchAudio(object):
+    def __init__(self, nfft, hoplen, spec_power, inverse=False):
+        self.nfft = nfft
+        self.hoplen = hoplen
+        self.spec_power = spec_power
+        self.inverse = inverse
+
+        self.spec_trans = torchaudio.transforms.Spectrogram(
+            n_fft=self.nfft,
+            hop_length=self.hoplen,
+            power=self.spec_power,
+        )
+        self.inv_spec_trans = torchaudio.transforms.GriffinLim(
+            n_fft=self.nfft,
+            hop_length=self.hoplen,
+            power=self.spec_power,
+        )
+
+    def __call__(self, x):
+        if self.inverse:
+            wav = self.inv_spec_trans(x)
+            return wav
+        else:
+            spec = torch.abs(self.spec_trans(x))
+            return spec
+
+
+class MelScaleTorchAudio(object):
+    def __init__(self, sr, stft, fmin, fmax, nmels, inverse=False):
+        self.sr = sr
+        self.stft = stft
+        self.fmin = fmin
+        self.fmax = fmax
+        self.nmels = nmels
+        self.inverse = inverse
+
+        self.mel_trans = torchaudio.transforms.MelScale(
+            n_mels=self.nmels,
+            sample_rate=self.sr,
+            f_min=self.fmin,
+            f_max=self.fmax,
+            n_stft=self.stft,
+            norm='slaney'
+        )
+        self.inv_mel_trans = torchaudio.transforms.InverseMelScale(
+            n_mels=self.nmels,
+            sample_rate=self.sr,
+            f_min=self.fmin,
+            f_max=self.fmax,
+            n_stft=self.stft,
+            norm='slaney'
+        )
+
+    def __call__(self, x):
+        if self.inverse:
+            spec = self.inv_mel_trans(x)
+            return spec
+        else:
+            mel_spec = self.mel_trans(x)
+            return mel_spec
+
+class Padding(object):
+    def __init__(self, target_len, inverse=False):
+        self.target_len=int(target_len)
+        self.inverse = inverse
+
+    def __call__(self, x):
+        if self.inverse:
+            return x
+        else:
+            x = x.squeeze()
+            if x.shape[0] < self.target_len:
+                pad = torch.zeros((self.target_len,), dtype=x.dtype, device=x.device)
+                pad[:x.shape[0]] = x
+                x = pad
+            elif x.shape[0] > self.target_len:
+                raise NotImplementedError()
+            return x
+
+class MakeMono(object):
+    def __init__(self, inverse=False):
+        self.inverse = inverse
+    
+    def __call__(self, x):
+        if self.inverse:
+            return x
+        else:
+            x = x.squeeze()
+            if len(x.shape) == 1:
+                return torch.FloatTensor(x)
+            elif len(x.shape) == 2:
+                target_dim = int(torch.argmin(torch.tensor(x.shape)))
+                return torch.mean(x, dim=target_dim)
+            else:
+                raise NotImplementedError
+
+class LowerThresh(object):
+    def __init__(self, min_val, inverse=False):
+        self.min_val = torch.tensor(min_val)
+        self.inverse = inverse
+
+    def __call__(self, x):
+        if self.inverse:
+            return x
+        else:
+            return torch.maximum(self.min_val, x)
+
+class Add(object):
+    def __init__(self, val, inverse=False):
+        self.inverse = inverse
+        self.val = val
+
+    def __call__(self, x):
+        if self.inverse:
+            return x - self.val
+        else:
+            return x + self.val
+
+class Subtract(Add):
+    def __init__(self, val, inverse=False):
+        self.inverse = inverse
+        self.val = val
+
+    def __call__(self, x):
+        if self.inverse:
+            return x + self.val
+        else:
+            return x - self.val
+
+class Multiply(object):
+    def __init__(self, val, inverse=False) -> None:
+        self.val = val
+        self.inverse = inverse
+
+    def __call__(self, x):
+        if self.inverse:
+            return x / self.val
+        else:
+            return x * self.val
+
+class Divide(Multiply):
+    def __init__(self, val, inverse=False):
+        self.inverse = inverse
+        self.val = val
+
+    def __call__(self, x):
+        if self.inverse:
+            return x * self.val
+        else:
+            return x / self.val
+
+
+class Log10(object):
+    def __init__(self, inverse=False):
+        self.inverse = inverse
+
+    def __call__(self, x):
+        if self.inverse:
+            return 10 ** x
+        else:
+            return torch.log10(x)
+
+class Clip(object):
+    def __init__(self, min_val, max_val, inverse=False):
+        self.min_val = min_val
+        self.max_val = max_val
+        self.inverse = inverse
+
+    def __call__(self, x):
+        if self.inverse:
+            return x
+        else:
+            return torch.clip(x, self.min_val, self.max_val)
+
+class TrimSpec(object):
+    def __init__(self, max_len, inverse=False):
+        self.max_len = max_len
+        self.inverse = inverse
+
+    def __call__(self, x):
+        if self.inverse:
+            return x
+        else:
+            return x[:, :self.max_len]
+
+class MaxNorm(object):
+    def __init__(self, inverse=False):
+        self.inverse = inverse
+        self.eps = 1e-10
+
+    def __call__(self, x):
+        if self.inverse:
+            return x
+        else:
+            return x / (x.max() + self.eps)
+
+
+class NormalizeAudio(object):
+    def __init__(self, inverse=False, desired_rms=0.1, eps=1e-4):
+        self.inverse = inverse
+        self.desired_rms = desired_rms
+        self.eps = torch.tensor(eps)
+
+    def __call__(self, x):
+        if self.inverse:
+            return x
+        else:
+            rms = torch.maximum(self.eps, torch.sqrt(torch.mean(x**2)))
+            x = x * (self.desired_rms / rms)
+            x[x > 1.] = 1.
+            x[x < -1.] = -1.
+            return x
+
+
+class RandomNormalizeAudio(object):
+    def __init__(self, inverse=False, rms_range=[0.05, 0.2], eps=1e-4):
+        self.inverse = inverse
+        self.rms_low, self.rms_high = rms_range
+        self.eps = torch.tensor(eps)
+
+    def __call__(self, x):
+        if self.inverse:
+            return x
+        else:
+            rms = torch.maximum(self.eps, torch.sqrt(torch.mean(x**2)))
+            desired_rms = (torch.rand(1) * (self.rms_high - self.rms_low)) + self.rms_low
+            x = x * (desired_rms / rms)
+            x[x > 1.] = 1.
+            x[x < -1.] = -1.
+            return x
+
+
+class MakeDouble(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x):
+        return x.to(torch.double)
+
+
+class MakeFloat(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x):
+        return x.to(torch.float)
+
+
+class Wave2Spectrogram(nn.Module):
+    def __init__(self, mel_num, spec_crop_len):
+        super().__init__()
+        self.trans = transforms.Compose([
+            LowerThresh(1e-5),
+            Log10(),
+            Multiply(20),
+            Subtract(20),
+            Add(100),
+            Divide(100),
+            Clip(0, 1.0),
+            TrimSpec(173),
+            transforms.CenterCrop((mel_num, spec_crop_len))
+        ])
+
+    def forward(self, x):
+        return self.trans(x)
+
+
+
+TRANSFORMS = transforms.Compose([
+    SpectrogramTorchAudio(nfft=1024, hoplen=1024//4, spec_power=1),
+    MelScaleTorchAudio(sr=22050, stft=513, fmin=125, fmax=7600, nmels=80),
+    LowerThresh(1e-5),
+    Log10(),
+    Multiply(20),
+    Subtract(20),
+    Add(100),
+    Divide(100),
+    Clip(0, 1.0),
+])
+
+def get_spectrogram_torch(audio_path, save_dir, length, save_results=True):
+    wav, _ = soundfile.read(audio_path)
+    wav = torch.FloatTensor(wav)
+    y = torch.zeros(length)
+    if wav.shape[0] < length:
+        y[:len(wav)] = wav
+    else:
+        y = wav[:length]
+    
+    mel_spec = TRANSFORMS(y).numpy()
+    y = y.numpy()
+    if save_results:
+        os.makedirs(save_dir, exist_ok=True)
+        audio_name = os.path.basename(audio_path).split('.')[0]
+        np.save(os.path.join(save_dir, audio_name + '_mel.npy'), mel_spec)
+        np.save(os.path.join(save_dir, audio_name + '_audio.npy'), y)
+    else:
+        return y, mel_spec
diff --git a/foleycrafter/models/specvqgan/data/utils.py b/foleycrafter/models/specvqgan/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e1f221f3415bf66a376e23aef7c9905181f6557
--- /dev/null
+++ b/foleycrafter/models/specvqgan/data/utils.py
@@ -0,0 +1,265 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import json
+from random import shuffle, choice, sample
+
+from moviepy.editor import VideoFileClip
+import librosa
+from scipy import signal
+from scipy.io import wavfile
+import torchaudio
+torchaudio.set_audio_backend("sox_io")
+
+INTERVAL = 1000
+
+# discard
+stft = torchaudio.transforms.MelSpectrogram(
+    sample_rate=16000, hop_length=161, n_mels=64).cuda()
+
+
+def log10(x): return torch.log(x)/torch.log(torch.tensor(10.))
+
+
+def norm_range(x, min_val, max_val):
+    return 2.*(x - min_val)/float(max_val - min_val) - 1.
+
+
+def normalize_spec(spec, spec_min, spec_max):
+    return norm_range(spec, spec_min, spec_max)
+
+
+def db_from_amp(x, cuda=False):
+    # rescale the audio
+    if cuda:
+        return 20. * log10(torch.max(torch.tensor(1e-5).to('cuda'), x.float()))
+    else:
+        return 20. * log10(torch.max(torch.tensor(1e-5), x.float()))
+
+
+def audio_stft(audio, stft=stft):
+    # We'll apply stft to the audio samples to convert it to a HxW matrix
+    N, C, A = audio.size()
+    audio = audio.view(N * C, A)
+    spec = stft(audio)
+    spec = spec.transpose(-1, -2)
+    spec = db_from_amp(spec, cuda=True)
+    spec = normalize_spec(spec, -100., 100.)
+    _, T, F = spec.size()
+    spec = spec.view(N, C, T, F)
+    return spec
+
+
+# discard
+# def get_spec(
+#     wavs,
+#     sample_rate=16000,
+#     use_volume_jittering=False,
+#     center=False,
+# ):
+#     # Volume  jittering - scale volume by factor in range (0.9, 1.1)
+#     if use_volume_jittering:
+#         wavs = [wav * np.random.uniform(0.9, 1.1) for wav in wavs]
+#     if center:
+#         wavs = [center_only(wav) for wav in wavs]
+
+#     # Convert to log filterbank
+#     specs = [logfbank(
+#         wav,
+#         sample_rate,
+#         winlen=0.009,
+#         winstep=0.005,  # if num_sec==1 else 0.01,
+#         nfilt=256,
+#         nfft=1024
+#     ).astype('float32').T for wav in wavs]
+
+#     # Convert to 32-bit float and expand dim
+#     specs = np.stack(specs, axis=0)
+#     specs = np.expand_dims(specs, 1)
+#     specs = torch.as_tensor(specs)  # Nx1xFxT
+
+#     return specs
+
+
+def center_only(audio, sr=16000, L=1.0):
+    # center_wav = np.arange(0, L, L/(0.5*sr)) ** 2
+    # center_wav = np.concatenate([center_wav, center_wav[::-1]])
+    # center_wav[L*sr//2:3*L*sr//4] = 1
+    # only take 0.3 sec audio
+    center_wav = np.zeros(int(L * sr))
+    center_wav[int(0.4*L*sr):int(0.7*L*sr)] = 1
+
+    return audio * center_wav
+
+def get_spec_librosa(
+    wavs,
+    sample_rate=16000,
+    use_volume_jittering=False,
+    center=False,
+):
+    # Volume  jittering - scale volume by factor in range (0.9, 1.1)
+    if use_volume_jittering:
+        wavs = [wav * np.random.uniform(0.9, 1.1) for wav in wavs]
+    if center:
+        wavs = [center_only(wav) for wav in wavs]
+
+    # Convert to log filterbank
+    specs = [librosa.feature.melspectrogram(
+        y=wav,
+        sr=sample_rate,
+        n_fft=400,
+        hop_length=126,
+        n_mels=128,
+    ).astype('float32') for wav in wavs]
+
+    # Convert to 32-bit float and expand dim
+    specs = [librosa.power_to_db(spec) for spec in specs]
+    specs = np.stack(specs, axis=0)
+    specs = np.expand_dims(specs, 1)
+    specs = torch.as_tensor(specs)  # Nx1xFxT
+
+    return specs
+
+
+def calcEuclideanDistance_Mat(X, Y):
+    """
+    Inputs:
+    - X: A numpy array of shape (N, F)
+    - Y: A numpy array of shape (M, F)
+
+    Returns:
+    A numpy array D of shape (N, M) where D[i, j] is the Euclidean distance
+    between X[i] and Y[j].
+    """
+    return ((torch.sum(X ** 2, axis=1, keepdims=True)) + (torch.sum(Y ** 2, axis=1, keepdims=True)).T - 2 * X @ Y.T) ** 0.5
+
+
+def calcEuclideanDistance(x1, x2):
+    return torch.sum((x1 - x2)**2, dim=1)**0.5
+
+
+def split_data(in_list, portion=(0.9, 0.95), is_shuffle=True):
+    if is_shuffle:
+        shuffle(in_list)
+    if type(in_list) == str:
+        with open(in_list) as l:
+            fw_list = json.load(l)
+    elif type(in_list) == list:
+        fw_list = in_list
+    else:
+        print(type(in_list))
+        raise TypeError('Invalid input list type')
+    c1, c2 = int(len(fw_list) * portion[0]), int(len(fw_list) * portion[1])
+    tr_list, va_list, te_list = fw_list[:c1], fw_list[c1:c2], fw_list[c2:]
+    print(
+        f'==> train set: {len(tr_list)}, validation set: {len(va_list)}, test set: {len(te_list)}')
+    return tr_list, va_list, te_list
+
+
+def load_one_clip(video_path):
+    v = VideoFileClip(video_path)
+    fps = int(v.fps)
+    frames = [f for f in v.iter_frames()][:-1]
+    frame_cnt = len(frames)
+    frame_length = 1000./fps
+    total_length = int(1000 * (frame_cnt / fps))
+
+    a = v.audio
+    sr = a.fps
+    a = np.array([fa for fa in a.iter_frames()])
+    a = librosa.resample(a, sr, 48000)
+    if len(a.shape) > 1:
+        a = np.mean(a, axis=1)
+
+    while True:
+        idx = np.random.choice(np.arange(frame_cnt - 1), 1)[0]
+        frame_clip = frames[idx]
+        start_time = int(idx * frame_length + 0.5 * frame_length - 500)
+        end_time = start_time + INTERVAL
+        if start_time < 0 or end_time > total_length:
+            continue
+        wave_clip = a[48 * start_time: 48 * end_time]
+        if wave_clip.shape[0] != 48000:
+            continue
+        break
+    return frame_clip, wave_clip
+
+
+def resize_frame(frame):
+    H, W = frame.size
+    short_edge = min(H, W)
+    scale = 256 / short_edge
+    H_tar, W_tar = int(np.round(H * scale)), int(np.round(W * scale))
+    return frame.resize((H_tar, W_tar))
+
+
+def get_spectrogram(wave, amp_jitter, amp_jitter_range, log_scale=True, sr=48000):
+    # random clip-level amplitude jittering
+    if amp_jitter:
+        amplified = wave * np.random.uniform(*amp_jitter_range)
+        if wave.dtype == np.int16:
+            amplified[amplified >= 32767] = 32767
+            amplified[amplified <= -32768] = -32768
+            wave = amplified.astype('int16')
+        elif wave.dtype == np.float32 or wave.dtype == np.float64:
+            amplified[amplified >= 1] = 1
+            amplified[amplified <= -1] = -1
+
+    # fr, ts, spectrogram = signal.spectrogram(wave[:48000], fs=sr, nperseg=480, noverlap=240, nfft=512)
+    # spectrogram = librosa.feature.melspectrogram(S=spectrogram, n_mels=257) # Try log-mel spectrogram?
+    spectrogram = librosa.feature.melspectrogram(
+        y=wave[:48000], sr=sr, hop_length=240, win_length=480, n_mels=257)
+    if log_scale:
+        spectrogram = librosa.power_to_db(spectrogram, ref=np.max)
+    assert spectrogram.shape[0] == 257
+
+    return spectrogram
+
+
+def cropAudio(audio, sr, f_idx, fps=10, length=1., left_shift=0):
+    time_per_frame = 1./fps
+    assert audio.shape[0] > sr * length
+    start_time = f_idx * time_per_frame - left_shift
+    start_time = 0 if start_time < 0 else start_time
+    start_idx = int(np.round(sr * start_time))
+    end_idx = int(np.round(start_idx + (sr * length)))
+    if end_idx > audio.shape[0]:
+        end_idx = audio.shape[0]
+        start_idx = int(end_idx - (sr * length))
+    try:
+        assert audio[start_idx:end_idx].shape[0] == sr * length
+    except:
+        print(audio.shape, start_idx, end_idx, end_idx - start_idx)
+        exit(1)
+    return audio[start_idx:end_idx]
+
+
+def pick_async_frame_idx(idx, total_frames, fps=10, gap=2.0, length=1.0, cnt=1):
+    assert idx < total_frames - fps * length
+    lower_bound = idx - int((length + gap) * fps)
+    upper_bound = idx + int((length + gap) * fps)
+    proposal = list(range(0, lower_bound)) + \
+        list(range(upper_bound, int(total_frames - fps * length)))
+    # assert len(proposal) >= cnt
+    avail_cnt = len(proposal)
+    try:
+        for i in range(cnt - avail_cnt):
+            proposal.append(proposal[i % avail_cnt])
+    except Exception as e:
+        print(idx, total_frames, proposal)
+        raise e
+    return sample(proposal, k=cnt)
+
+
+def adjust_learning_rate(optimizer, epoch, args):
+    """Decay the learning rate based on schedule"""
+    lr = args.lr
+    if args.cos:  # cosine lr schedule
+        lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epoch))
+    else:  # stepwise lr schedule
+        for milestone in args.schedule:
+            lr *= 0.1 if epoch >= milestone else 1.
+    for param_group in optimizer.param_groups:
+        param_group['lr'] = lr
diff --git a/foleycrafter/models/specvqgan/models/av_cond_transformer.py b/foleycrafter/models/specvqgan/models/av_cond_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..feb67b0a33456e4157822329a04d857dc61975e5
--- /dev/null
+++ b/foleycrafter/models/specvqgan/models/av_cond_transformer.py
@@ -0,0 +1,528 @@
+import sys
+
+import pytorch_lightning as pl
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision import transforms
+import torchaudio
+from omegaconf.listconfig import ListConfig
+
+sys.path.insert(0, '.')  # nopep8
+from foleycrafter.models.specvqgan.modules.transformer.mingpt import (GPTClass, GPTFeats, GPTFeatsClass)
+from foleycrafter.models.specvqgan.data.transforms import Wave2Spectrogram, PitchShift, NormalizeAudio
+from train import instantiate_from_config
+
+SR = 22050
+
+def disabled_train(self, mode=True):
+    """Overwrite model.train with this function to make sure train/eval mode
+    does not change anymore."""
+    return self
+
+
+class Net2NetTransformerAVCond(pl.LightningModule):
+    def __init__(self, transformer_config, first_stage_config,
+                 cond_stage_config, 
+                 drop_condition=False, drop_video=False, drop_cond_video=False,
+                 first_stage_permuter_config=None, cond_stage_permuter_config=None,
+                 ckpt_path=None, ignore_keys=[],
+                 first_stage_key="image",
+                 cond_first_stage_key="cond_image",
+                 cond_stage_key="depth",
+                 downsample_cond_size=-1,
+                 pkeep=1.0,
+                 clip=30,
+                 p_audio_aug=0.5,
+                 p_pitch_shift=0.,
+                 p_normalize=0.,
+                 mel_num=80, 
+                 spec_crop_len=160):
+
+        super().__init__()
+        self.init_first_stage_from_ckpt(first_stage_config)
+        self.init_cond_stage_from_ckpt(cond_stage_config)
+        if first_stage_permuter_config is None:
+            first_stage_permuter_config = {"target": "foleycrafter.models.specvqgan.modules.transformer.permuter.Identity"}
+        if cond_stage_permuter_config is None:
+            cond_stage_permuter_config = {"target": "foleycrafter.models.specvqgan.modules.transformer.permuter.Identity"}
+        self.first_stage_permuter = instantiate_from_config(config=first_stage_permuter_config)
+        self.cond_stage_permuter = instantiate_from_config(config=cond_stage_permuter_config)
+        self.transformer = instantiate_from_config(config=transformer_config)
+
+        self.wav_transforms = nn.Sequential(
+            transforms.RandomApply([NormalizeAudio()], p=p_normalize),
+            transforms.RandomApply([PitchShift()], p=p_pitch_shift),
+            torchaudio.transforms.Spectrogram(
+                n_fft=1024,
+                hop_length=1024//4,
+                power=1,
+            ),
+            # transforms.RandomApply([
+            #     torchaudio.transforms.FrequencyMasking(freq_mask_param=40, iid_masks=False)
+            # ], p=p_audio_aug),
+            # transforms.RandomApply([
+            #     torchaudio.transforms.TimeMasking(time_mask_param=int(32 * 2), iid_masks=False)
+            # ], p=p_audio_aug),
+            torchaudio.transforms.MelScale(
+                n_mels=80,
+                sample_rate=SR,
+                f_min=125,
+                f_max=7600,
+                n_stft=513,
+                norm='slaney'
+            ),
+            Wave2Spectrogram(mel_num, spec_crop_len),
+        )
+        ignore_keys = ['wav_transforms']
+
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+        self.drop_condition = drop_condition
+        self.drop_video = drop_video
+        self.drop_cond_video = drop_cond_video
+        print(f'>>> Feature setting: all cond: {self.drop_condition}, video: {self.drop_video}, cond video: {self.drop_cond_video}')
+        self.first_stage_key = first_stage_key
+        self.cond_first_stage_key = cond_first_stage_key
+        self.cond_stage_key = cond_stage_key
+        self.downsample_cond_size = downsample_cond_size
+        self.pkeep = pkeep
+        self.clip = clip
+        print('>>> model init done.')
+
+    def init_from_ckpt(self, path, ignore_keys=list()):
+        sd = torch.load(path, map_location="cpu")["state_dict"]
+        for k in sd.keys():
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    self.print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+        self.load_state_dict(sd, strict=False)
+        print(f"Restored from {path}")
+
+    def init_first_stage_from_ckpt(self, config):
+        model = instantiate_from_config(config)
+        model = model.eval()
+        model.train = disabled_train
+        self.first_stage_model = model
+
+    def init_cond_stage_from_ckpt(self, config):
+        model = instantiate_from_config(config)
+        model = model.eval()
+        model.train = disabled_train
+        self.cond_stage_model = model
+
+    def forward(self, x, c, xp):
+        # one step to produce the logits
+        _, z_indices = self.encode_to_z(x) # VQ-GAN encoding
+        _, zp_indices = self.encode_to_z(xp)
+        _, c_indices = self.encode_to_c(c) # Conv1-1 down dim + col-major permuter
+        z_indices = z_indices[:, :self.clip]
+        zp_indices = zp_indices[:, :self.clip]
+        if not self.drop_condition:
+            z_indices = torch.cat([zp_indices, z_indices], dim=1)
+
+        if self.training and self.pkeep < 1.0:
+            mask = torch.bernoulli(self.pkeep * torch.ones(z_indices.shape, device=z_indices.device))
+            mask = mask.round().to(dtype=torch.int64)
+            r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
+            a_indices = mask*z_indices+(1-mask)*r_indices
+        else:
+            a_indices = z_indices
+
+        # target includes all sequence elements (no need to handle first one
+        # differently because we are conditioning)
+        if self.drop_condition:
+            target = z_indices
+        else:
+            target = z_indices[:, self.clip:]
+
+        # in the case we do not want to encode condition anyhow (e.g. inputs are features)
+        if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)):
+            # make the prediction
+            logits, _, _ = self.transformer(z_indices[:, :-1], c)
+            # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
+            if isinstance(self.transformer, GPTFeatsClass):
+                cond_size = c['feature'].size(-1) + c['target'].size(-1)
+            else:
+                cond_size = c.size(-1)
+            if self.drop_condition:
+                logits = logits[:, cond_size-1:]
+            else:
+                logits = logits[:, cond_size-1:][:, self.clip:]
+        else:
+            cz_indices = torch.cat((c_indices, a_indices), dim=1)
+            # make the prediction
+            logits, _, _ = self.transformer(cz_indices[:, :-1])
+            # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
+            logits = logits[:, c_indices.shape[1]-1:]
+
+        return logits, target
+
+    def top_k_logits(self, logits, k):
+        v, ix = torch.topk(logits, k)
+        out = logits.clone()
+        out[out < v[..., [-1]]] = -float('Inf')
+        return out
+
+    @torch.no_grad()
+    def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
+               callback=lambda k: None):
+        x = x if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)) else torch.cat((c, x), dim=1)
+        block_size = self.transformer.get_block_size()
+        assert not self.transformer.training
+        if self.pkeep <= 0.0:
+            raise NotImplementedError('Implement for GPTFeatsCLass')
+            raise NotImplementedError('Implement for GPTFeats')
+            raise NotImplementedError('Implement for GPTClass')
+            raise NotImplementedError('also the model outputs attention')
+            # one pass suffices since input is pure noise anyway
+            assert len(x.shape)==2
+            # noise_shape = (x.shape[0], steps-1)
+            # noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
+            noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
+            x = torch.cat((x,noise),dim=1)
+            logits, _ = self.transformer(x)
+            # take all logits for now and scale by temp
+            logits = logits / temperature
+            # optionally crop probabilities to only the top k options
+            if top_k is not None:
+                logits = self.top_k_logits(logits, top_k)
+            # apply softmax to convert to probabilities
+            probs = F.softmax(logits, dim=-1)
+            # sample from the distribution or take the most likely
+            if sample:
+                shape = probs.shape
+                probs = probs.reshape(shape[0]*shape[1],shape[2])
+                ix = torch.multinomial(probs, num_samples=1)
+                probs = probs.reshape(shape[0],shape[1],shape[2])
+                ix = ix.reshape(shape[0],shape[1])
+            else:
+                _, ix = torch.topk(probs, k=1, dim=-1)
+            # cut off conditioning
+            x = ix[:, c.shape[1]-1:]
+        else:
+            for k in range(steps):
+                callback(k)
+                if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)):
+                    # if assert is removed, you need to make sure that the combined len is not longer block_s
+                    if isinstance(self.transformer, GPTFeatsClass):
+                        cond_size = c['feature'].size(-1) + c['target'].size(-1)
+                    else:
+                        cond_size = c.size(-1)
+                    assert x.size(1) + cond_size <= block_size
+
+                    x_cond = x
+                    c_cond = c
+                    logits, _, att = self.transformer(x_cond, c_cond)
+                else:
+                    assert x.size(1) <= block_size  # make sure model can see conditioning
+                    x_cond = x if x.size(1) <= block_size else x[:, -block_size:]  # crop context if needed
+                    logits, _, att = self.transformer(x_cond)
+                # pluck the logits at the final step and scale by temperature
+                logits = logits[:, -1, :] / temperature
+                # optionally crop probabilities to only the top k options
+                if top_k is not None:
+                    logits = self.top_k_logits(logits, top_k)
+                # apply softmax to convert to probabilities
+                probs = F.softmax(logits, dim=-1)
+                # sample from the distribution or take the most likely
+                if sample:
+                    ix = torch.multinomial(probs, num_samples=1)
+                else:
+                    _, ix = torch.topk(probs, k=1, dim=-1)
+                # append to the sequence and continue
+                x = torch.cat((x, ix), dim=1)
+            # cut off conditioning
+            x = x if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)) else x[:, c.shape[1]:]
+        return x, att.detach().cpu()
+
+    @torch.no_grad()
+    def encode_to_z(self, x):
+        quant_z, _, info = self.first_stage_model.encode(x)
+        indices = info[2].view(quant_z.shape[0], -1)
+        indices = self.first_stage_permuter(indices)
+        return quant_z, indices
+
+    @torch.no_grad()
+    def encode_to_c(self, c):
+        if self.downsample_cond_size > -1:
+            c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
+        quant_c, _, info = self.cond_stage_model.encode(c)
+        if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)):
+            # these are not indices but raw features or a class
+            indices = info[2]
+        else:
+            indices = info[2].view(quant_c.shape[0], -1)
+            indices = self.cond_stage_permuter(indices)
+        return quant_c, indices
+
+    @torch.no_grad()
+    def decode_to_img(self, index, zshape, stage='first'):
+        if stage == 'first':
+            index = self.first_stage_permuter(index, reverse=True)
+        elif stage == 'cond':
+            print('in cond stage in decode_to_img which is unexpected ')
+            index = self.cond_stage_permuter(index, reverse=True)
+        else:
+            raise NotImplementedError
+
+        bhwc = (zshape[0], zshape[2], zshape[3], zshape[1])
+        quant_z = self.first_stage_model.quantize.get_codebook_entry(index.reshape(-1), shape=bhwc)
+        x = self.first_stage_model.decode(quant_z)
+        return x
+
+    @torch.no_grad()
+    def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
+        log = dict()
+
+        N = 4
+        if lr_interface:
+            x, c, xp = self.get_xcxp(batch, N, diffuse=False, upsample_factor=8)
+        else:
+            x, c, xp = self.get_xcxp(batch, N)
+        x = x.to(device=self.device)
+        xp = xp.to(device=self.device)
+        # c = c.to(device=self.device)
+        if isinstance(c, dict):
+            c = {k: v.to(self.device) for k, v in c.items()}
+        else:
+            c = c.to(self.device)
+
+        quant_z, z_indices = self.encode_to_z(x)
+        quant_zp, zp_indices = self.encode_to_z(xp)
+        quant_c, c_indices = self.encode_to_c(c)  # output can be features or a single class or a featcls dict
+        z_indices_rec = z_indices.clone()
+        zp_indices_clip = zp_indices[:, :self.clip]
+        z_indices_clip = z_indices[:, :self.clip]
+
+        # create a "half"" sample
+        z_start_indices = z_indices_clip[:, :z_indices_clip.shape[1]//2]
+        if self.drop_condition:
+            steps = z_indices_clip.shape[1]-z_start_indices.shape[1]
+        else:
+            z_start_indices = torch.cat([zp_indices_clip, z_start_indices], dim=-1)
+            steps = 2*z_indices_clip.shape[1]-z_start_indices.shape[1]
+        index_sample, att_half = self.sample(z_start_indices, c_indices,
+                                   steps=steps,
+                                   temperature=temperature if temperature is not None else 1.0,
+                                   sample=True,
+                                   top_k=top_k if top_k is not None else 100,
+                                   callback=callback if callback is not None else lambda k: None)
+        if self.drop_condition:
+            z_indices_rec[:, :self.clip] = index_sample
+        else:
+            z_indices_rec[:, :self.clip] = index_sample[:, self.clip:]
+        x_sample = self.decode_to_img(z_indices_rec, quant_z.shape)
+
+        # sample
+        z_start_indices = z_indices_clip[:, :0]
+        if not self.drop_condition:
+            z_start_indices = torch.cat([zp_indices_clip, z_start_indices], dim=-1)
+        index_sample, att_nopix = self.sample(z_start_indices, c_indices,
+                                              steps=z_indices_clip.shape[1],
+                                              temperature=temperature if temperature is not None else 1.0,
+                                              sample=True,
+                                              top_k=top_k if top_k is not None else 100,
+                                              callback=callback if callback is not None else lambda k: None)
+        if self.drop_condition:
+            z_indices_rec[:, :self.clip] = index_sample
+        else:
+            z_indices_rec[:, :self.clip] = index_sample[:, self.clip:]
+        x_sample_nopix = self.decode_to_img(z_indices_rec, quant_z.shape)
+
+        # det sample
+        z_start_indices = z_indices_clip[:, :0]
+        if not self.drop_condition:
+            z_start_indices = torch.cat([zp_indices_clip, z_start_indices], dim=-1)
+        index_sample, att_det = self.sample(z_start_indices, c_indices,
+                                            steps=z_indices_clip.shape[1],
+                                            sample=False,
+                                            callback=callback if callback is not None else lambda k: None)
+        if self.drop_condition:
+            z_indices_rec[:, :self.clip] = index_sample
+        else:
+            z_indices_rec[:, :self.clip] = index_sample[:, self.clip:]
+        x_sample_det = self.decode_to_img(z_indices_rec, quant_z.shape)
+
+        # reconstruction
+        x_rec = self.decode_to_img(z_indices, quant_z.shape)
+
+        log["inputs"] = x
+        log["reconstructions"] = x_rec
+
+        if isinstance(self.cond_stage_key, str):
+            cond_is_not_image = self.cond_stage_key != "image"
+            cond_has_segmentation = self.cond_stage_key == "segmentation"
+        elif isinstance(self.cond_stage_key, ListConfig):
+            cond_is_not_image = 'image' not in self.cond_stage_key
+            cond_has_segmentation = 'segmentation' in self.cond_stage_key
+        else:
+            raise NotImplementedError
+
+        if cond_is_not_image:
+            cond_rec = self.cond_stage_model.decode(quant_c)
+            if cond_has_segmentation:
+                # get image from segmentation mask
+                num_classes = cond_rec.shape[1]
+
+                c = torch.argmax(c, dim=1, keepdim=True)
+                c = F.one_hot(c, num_classes=num_classes)
+                c = c.squeeze(1).permute(0, 3, 1, 2).float()
+                c = self.cond_stage_model.to_rgb(c)
+
+                cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
+                cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
+                cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
+                cond_rec = self.cond_stage_model.to_rgb(cond_rec)
+            log["conditioning_rec"] = cond_rec
+            log["conditioning"] = c
+
+        log["samples_half"] = x_sample
+        log["samples_nopix"] = x_sample_nopix
+        log["samples_det"] = x_sample_det
+        log["att_half"] = att_half
+        log["att_nopix"] = att_nopix
+        log["att_det"] = att_det
+        return log
+
+    def spec_transform(self, batch):
+        wav = batch[self.first_stage_key]
+        wav_cond = batch[self.cond_first_stage_key]
+        N = wav.shape[0]
+        wav_cat = torch.cat([wav, wav_cond], dim=0)
+        self.wav_transforms.to(wav_cat.device)
+        spec = self.wav_transforms(wav_cat.to(torch.float32))
+        batch[self.first_stage_key] = 2 * spec[:N] - 1
+        batch[self.cond_first_stage_key] = 2 * spec[N:] - 1
+        return batch
+
+    def get_input(self, key, batch):
+        if isinstance(key, str):
+            # if batch[key] is 1D; else the batch[key] is 2D
+            if key in ['feature', 'target']:
+                if self.drop_condition or self.drop_cond_video:
+                    cond_size = batch[key].shape[1] // 2
+                    batch[key] = batch[key][:, cond_size:]
+                x = self.cond_stage_model.get_input(
+                    batch, key, drop_cond=(self.drop_condition or self.drop_cond_video)
+                )
+            else:
+                x = batch[key]
+                if len(x.shape) == 3:
+                    x = x[..., None]
+                x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
+            if x.dtype == torch.double:
+                x = x.float()
+        elif isinstance(key, ListConfig):
+            x = self.cond_stage_model.get_input(batch, key)
+            for k, v in x.items():
+                if v.dtype == torch.double:
+                    x[k] = v.float()
+        return x
+
+    def get_xcxp(self, batch, N=None):
+        if len(batch[self.first_stage_key].shape) == 2:
+            batch = self.spec_transform(batch)
+        x = self.get_input(self.first_stage_key, batch)
+        c = self.get_input(self.cond_stage_key, batch)
+        xp = self.get_input(self.cond_first_stage_key, batch)
+        if N is not None:
+            x = x[:N]
+            xp = xp[:N]
+            if isinstance(self.cond_stage_key, ListConfig):
+                c = {k: v[:N] for k, v in c.items()}
+            else:
+                c = c[:N]
+        # Drop additional information during training
+        if self.drop_condition:
+            xp[:] = 0
+        if self.drop_video:
+            c[:] = 0
+        return x, c, xp
+
+    def shared_step(self, batch, batch_idx):
+        x, c, xp = self.get_xcxp(batch)
+        logits, target = self(x, c, xp)
+        loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
+        return loss
+
+    def training_step(self, batch, batch_idx):
+        loss = self.shared_step(batch, batch_idx)
+        self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+        return loss
+
+    def validation_step(self, batch, batch_idx):
+        loss = self.shared_step(batch, batch_idx)
+        self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+        return loss
+
+    def configure_optimizers(self):
+        """
+        Following minGPT:
+        This long function is unfortunately doing something very simple and is being very defensive:
+        We are separating out all parameters of the model into two buckets: those that will experience
+        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
+        We are then returning the PyTorch optimizer object.
+        """
+        # separate out all parameters to those that will and won't experience regularizing weight decay
+        decay = set()
+        no_decay = set()
+        whitelist_weight_modules = (torch.nn.Linear, )
+
+        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, torch.nn.Conv1d, torch.nn.LSTM, torch.nn.GRU)
+        for mn, m in self.transformer.named_modules():
+            for pn, p in m.named_parameters():
+                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
+
+                if pn.endswith('bias'):
+                    # all biases will not be decayed
+                    no_decay.add(fpn)
+                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
+                    # weights of whitelist modules will be weight decayed
+                    decay.add(fpn)
+                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
+                    # weights of blacklist modules will NOT be weight decayed
+                    no_decay.add(fpn)
+                elif ('weight' in pn or 'bias' in pn) and isinstance(m, (torch.nn.LSTM, torch.nn.GRU)):
+                    no_decay.add(fpn)
+
+        # special case the position embedding parameter in the root GPT module as not decayed
+        no_decay.add('pos_emb')
+
+        # validate that we considered every parameter
+        param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
+        inter_params = decay & no_decay
+        union_params = decay | no_decay
+        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
+        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
+                                                    % (str(param_dict.keys() - union_params), )
+
+        # create the pytorch optimizer object
+        optim_groups = [
+            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
+            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
+        ]
+        optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
+        return optimizer
+
+
+if __name__ == '__main__':
+    from omegaconf import OmegaConf
+
+    cfg_image = OmegaConf.load('./configs/vggsound_transformer.yaml')
+    cfg_image.model.params.first_stage_config.params.ckpt_path = './logs/2021-05-19T22-16-54_vggsound_codebook/checkpoints/last.ckpt'
+
+    transformer_cfg = cfg_image.model.params.transformer_config
+    first_stage_cfg = cfg_image.model.params.first_stage_config
+    cond_stage_cfg = cfg_image.model.params.cond_stage_config
+    permuter_cfg = cfg_image.model.params.permuter_config
+    transformer = Net2NetTransformerAVCond(
+        transformer_cfg, first_stage_cfg, cond_stage_cfg, permuter_cfg
+    )
+
+    c = torch.rand(2, 2048, 212)
+    x = torch.rand(2, 1, 80, 848)
+
+    logits, target = transformer(x, c)
+    print(logits.shape, target.shape)
diff --git a/foleycrafter/models/specvqgan/models/cond_transformer.py b/foleycrafter/models/specvqgan/models/cond_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..62e5168e511df7940f0a0933bb4cd7d6cf6da873
--- /dev/null
+++ b/foleycrafter/models/specvqgan/models/cond_transformer.py
@@ -0,0 +1,455 @@
+import sys
+
+import pytorch_lightning as pl
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from omegaconf.listconfig import ListConfig
+from torchvision import transforms
+from foleycrafter.models.specvqgan.data.transforms import Wave2Spectrogram
+import torchaudio
+
+sys.path.insert(0, '.')  # nopep8
+from foleycrafter.models.specvqgan.modules.transformer.mingpt import (GPTClass, GPTFeats, GPTFeatsClass)
+from train import instantiate_from_config
+
+
+def disabled_train(self, mode=True):
+    """Overwrite model.train with this function to make sure train/eval mode
+    does not change anymore."""
+    return self
+
+
+class Net2NetTransformer(pl.LightningModule):
+    def __init__(self, transformer_config, first_stage_config,
+                 cond_stage_config,
+                 first_stage_permuter_config=None, cond_stage_permuter_config=None,
+                 ckpt_path=None, ignore_keys=[],
+                 first_stage_key="image",
+                 cond_stage_key="depth",
+                 downsample_cond_size=-1,
+                 pkeep=1.0,
+                 mel_num=80,
+                 spec_crop_len=160):
+
+        super().__init__()
+        self.init_first_stage_from_ckpt(first_stage_config)
+        self.init_cond_stage_from_ckpt(cond_stage_config)
+        if first_stage_permuter_config is None:
+            first_stage_permuter_config = {"target": "foleycrafter.models.specvqgan.modules.transformer.permuter.Identity"}
+        if cond_stage_permuter_config is None:
+            cond_stage_permuter_config = {"target": "foleycrafter.models.specvqgan.modules.transformer.permuter.Identity"}
+        self.first_stage_permuter = instantiate_from_config(config=first_stage_permuter_config)
+        self.cond_stage_permuter = instantiate_from_config(config=cond_stage_permuter_config)
+        self.transformer = instantiate_from_config(config=transformer_config)
+
+        self.wav_transforms = nn.Sequential(
+            torchaudio.transforms.Spectrogram(
+                n_fft=1024,
+                hop_length=1024//4,
+                power=1,
+            ),
+            torchaudio.transforms.MelScale(
+                n_mels=80,
+                sample_rate=22050,
+                f_min=125,
+                f_max=7600,
+                n_stft=513,
+                norm='slaney'
+            ),
+            Wave2Spectrogram(mel_num, spec_crop_len),
+        )
+        ignore_keys = ['wav_transforms']
+
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+        self.first_stage_key = first_stage_key
+        self.cond_stage_key = cond_stage_key
+        self.downsample_cond_size = downsample_cond_size
+        self.pkeep = pkeep
+        print('>>> model init done.')
+
+    def init_from_ckpt(self, path, ignore_keys=list()):
+        sd = torch.load(path, map_location="cpu")["state_dict"]
+        for k in sd.keys():
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    self.print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+        self.load_state_dict(sd, strict=False)
+        print(f"Restored from {path}")
+
+    def init_first_stage_from_ckpt(self, config):
+        model = instantiate_from_config(config)
+        model = model.eval()
+        model.train = disabled_train
+        self.first_stage_model = model
+
+    def init_cond_stage_from_ckpt(self, config):
+        model = instantiate_from_config(config)
+        model = model.eval()
+        model.train = disabled_train
+        self.cond_stage_model = model
+
+    def forward(self, x, c):
+        # one step to produce the logits
+        _, z_indices = self.encode_to_z(x)
+        _, c_indices = self.encode_to_c(c)
+
+        if self.training and self.pkeep < 1.0:
+            mask = torch.bernoulli(self.pkeep * torch.ones(z_indices.shape, device=z_indices.device))
+            mask = mask.round().to(dtype=torch.int64)
+            r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
+            a_indices = mask*z_indices+(1-mask)*r_indices
+        else:
+            a_indices = z_indices
+
+        # target includes all sequence elements (no need to handle first one
+        # differently because we are conditioning)
+        target = z_indices
+
+        # in the case we do not want to encode condition anyhow (e.g. inputs are features)
+        if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)):
+            # make the prediction
+            logits, _, _ = self.transformer(z_indices[:, :-1], c)
+            # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
+            if isinstance(self.transformer, GPTFeatsClass):
+                cond_size = c['feature'].size(-1) + c['target'].size(-1)
+            else:
+                cond_size = c.size(-1)
+            logits = logits[:, cond_size-1:]
+        else:
+            cz_indices = torch.cat((c_indices, a_indices), dim=1)
+            # make the prediction
+            logits, _, _ = self.transformer(cz_indices[:, :-1])
+            # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
+            logits = logits[:, c_indices.shape[1]-1:]
+
+        return logits, target
+
+    def top_k_logits(self, logits, k):
+        v, ix = torch.topk(logits, k)
+        out = logits.clone()
+        out[out < v[..., [-1]]] = -float('Inf')
+        return out
+
+    @torch.no_grad()
+    def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
+               callback=lambda k: None):
+        x = x if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)) else torch.cat((c, x), dim=1)
+        block_size = self.transformer.get_block_size()
+        assert not self.transformer.training
+        if self.pkeep <= 0.0:
+            raise NotImplementedError('Implement for GPTFeatsCLass')
+            raise NotImplementedError('Implement for GPTFeats')
+            raise NotImplementedError('Implement for GPTClass')
+            raise NotImplementedError('also the model outputs attention')
+            # one pass suffices since input is pure noise anyway
+            assert len(x.shape)==2
+            # noise_shape = (x.shape[0], steps-1)
+            # noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
+            noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
+            x = torch.cat((x,noise),dim=1)
+            logits, _ = self.transformer(x)
+            # take all logits for now and scale by temp
+            logits = logits / temperature
+            # optionally crop probabilities to only the top k options
+            if top_k is not None:
+                logits = self.top_k_logits(logits, top_k)
+            # apply softmax to convert to probabilities
+            probs = F.softmax(logits, dim=-1)
+            # sample from the distribution or take the most likely
+            if sample:
+                shape = probs.shape
+                probs = probs.reshape(shape[0]*shape[1],shape[2])
+                ix = torch.multinomial(probs, num_samples=1)
+                probs = probs.reshape(shape[0],shape[1],shape[2])
+                ix = ix.reshape(shape[0],shape[1])
+            else:
+                _, ix = torch.topk(probs, k=1, dim=-1)
+            # cut off conditioning
+            x = ix[:, c.shape[1]-1:]
+        else:
+            for k in range(steps):
+                callback(k)
+                if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)):
+                    # if assert is removed, you need to make sure that the combined len is not longer block_s
+                    if isinstance(self.transformer, GPTFeatsClass):
+                        cond_size = c['feature'].size(-1) + c['target'].size(-1)
+                    else:
+                        cond_size = c.size(-1)
+                    assert x.size(1) + cond_size <= block_size
+
+                    x_cond = x
+                    c_cond = c
+                    logits, _, att = self.transformer(x_cond, c_cond)
+                else:
+                    assert x.size(1) <= block_size  # make sure model can see conditioning
+                    x_cond = x if x.size(1) <= block_size else x[:, -block_size:]  # crop context if needed
+                    logits, _, att = self.transformer(x_cond)
+                # pluck the logits at the final step and scale by temperature
+                logits = logits[:, -1, :] / temperature
+                # optionally crop probabilities to only the top k options
+                if top_k is not None:
+                    logits = self.top_k_logits(logits, top_k)
+                # apply softmax to convert to probabilities
+                probs = F.softmax(logits, dim=-1)
+                # sample from the distribution or take the most likely
+                if sample:
+                    ix = torch.multinomial(probs, num_samples=1)
+                else:
+                    _, ix = torch.topk(probs, k=1, dim=-1)
+                # append to the sequence and continue
+                x = torch.cat((x, ix), dim=1)
+            # cut off conditioning
+            x = x if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)) else x[:, c.shape[1]:]
+        return x, att.detach().cpu()
+
+    @torch.no_grad()
+    def encode_to_z(self, x):
+        quant_z, _, info = self.first_stage_model.encode(x)
+        indices = info[2].view(quant_z.shape[0], -1)
+        indices = self.first_stage_permuter(indices)
+        return quant_z, indices
+
+    @torch.no_grad()
+    def encode_to_c(self, c):
+        if self.downsample_cond_size > -1:
+            c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
+        quant_c, _, info = self.cond_stage_model.encode(c)
+        if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)):
+            # these are not indices but raw features or a class
+            indices = info[2]
+        else:
+            indices = info[2].view(quant_c.shape[0], -1)
+            indices = self.cond_stage_permuter(indices)
+        return quant_c, indices
+
+    @torch.no_grad()
+    def decode_to_img(self, index, zshape, stage='first'):
+        if stage == 'first':
+            index = self.first_stage_permuter(index, reverse=True)
+        elif stage == 'cond':
+            print('in cond stage in decode_to_img which is unexpected ')
+            index = self.cond_stage_permuter(index, reverse=True)
+        else:
+            raise NotImplementedError
+
+        bhwc = (zshape[0], zshape[2], zshape[3], zshape[1])
+        quant_z = self.first_stage_model.quantize.get_codebook_entry(index.reshape(-1), shape=bhwc)
+        x = self.first_stage_model.decode(quant_z)
+        return x
+
+    @torch.no_grad()
+    def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
+        log = dict()
+
+        N = 4
+        if lr_interface:
+            x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
+        else:
+            x, c = self.get_xc(batch, N)
+        x = x.to(device=self.device)
+        # c = c.to(device=self.device)
+        if isinstance(c, dict):
+            c = {k: v.to(self.device) for k, v in c.items()}
+        else:
+            c = c.to(self.device)
+
+        quant_z, z_indices = self.encode_to_z(x)
+        quant_c, c_indices = self.encode_to_c(c)  # output can be features or a single class or a featcls dict
+
+        # create a "half"" sample
+        z_start_indices = z_indices[:, :z_indices.shape[1]//2]
+        index_sample, att_half = self.sample(z_start_indices, c_indices,
+                                   steps=z_indices.shape[1]-z_start_indices.shape[1],
+                                   temperature=temperature if temperature is not None else 1.0,
+                                   sample=True,
+                                   top_k=top_k if top_k is not None else 100,
+                                   callback=callback if callback is not None else lambda k: None)
+        x_sample = self.decode_to_img(index_sample, quant_z.shape)
+
+        # sample
+        z_start_indices = z_indices[:, :0]
+        index_sample, att_nopix = self.sample(z_start_indices, c_indices,
+                                              steps=z_indices.shape[1],
+                                              temperature=temperature if temperature is not None else 1.0,
+                                              sample=True,
+                                              top_k=top_k if top_k is not None else 100,
+                                              callback=callback if callback is not None else lambda k: None)
+        x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
+
+        # det sample
+        z_start_indices = z_indices[:, :0]
+        index_sample, att_det = self.sample(z_start_indices, c_indices,
+                                            steps=z_indices.shape[1],
+                                            sample=False,
+                                            callback=callback if callback is not None else lambda k: None)
+        x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
+
+        # reconstruction
+        x_rec = self.decode_to_img(z_indices, quant_z.shape)
+
+        log["inputs"] = x
+        log["reconstructions"] = x_rec
+
+        if isinstance(self.cond_stage_key, str):
+            cond_is_not_image = self.cond_stage_key != "image"
+            cond_has_segmentation = self.cond_stage_key == "segmentation"
+        elif isinstance(self.cond_stage_key, ListConfig):
+            cond_is_not_image = 'image' not in self.cond_stage_key
+            cond_has_segmentation = 'segmentation' in self.cond_stage_key
+        else:
+            raise NotImplementedError
+
+        if cond_is_not_image:
+            cond_rec = self.cond_stage_model.decode(quant_c)
+            if cond_has_segmentation:
+                # get image from segmentation mask
+                num_classes = cond_rec.shape[1]
+
+                c = torch.argmax(c, dim=1, keepdim=True)
+                c = F.one_hot(c, num_classes=num_classes)
+                c = c.squeeze(1).permute(0, 3, 1, 2).float()
+                c = self.cond_stage_model.to_rgb(c)
+
+                cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
+                cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
+                cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
+                cond_rec = self.cond_stage_model.to_rgb(cond_rec)
+            log["conditioning_rec"] = cond_rec
+            log["conditioning"] = c
+
+        log["samples_half"] = x_sample
+        log["samples_nopix"] = x_sample_nopix
+        log["samples_det"] = x_sample_det
+        log["att_half"] = att_half
+        log["att_nopix"] = att_nopix
+        log["att_det"] = att_det
+        return log
+
+    def spec_transform(self, batch):
+        wav = batch[self.first_stage_key]
+        N = wav.shape[0]
+        self.wav_transforms.to(wav.device)
+        spec = self.wav_transforms(wav.to(torch.float32))
+        batch[self.first_stage_key] = 2 * spec[:N] - 1
+        return batch
+    
+    def get_input(self, key, batch):
+        if isinstance(key, str):
+            # if batch[key] is 1D; else the batch[key] is 2D
+            if key in ['feature', 'target']:
+                x = self.cond_stage_model.get_input(batch, key)
+            else:
+                x = batch[key]
+                if len(x.shape) == 3:
+                    x = x[..., None]
+                x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
+            if x.dtype == torch.double:
+                x = x.float()
+        elif isinstance(key, ListConfig):
+            x = self.cond_stage_model.get_input(batch, key)
+            for k, v in x.items():
+                if v.dtype == torch.double:
+                    x[k] = v.float()
+        return x
+
+    def get_xc(self, batch, N=None):
+        if len(batch[self.first_stage_key].shape) == 2:
+            batch = self.spec_transform(batch)
+        x = self.get_input(self.first_stage_key, batch)
+        c = self.get_input(self.cond_stage_key, batch)
+        if N is not None:
+            x = x[:N]
+            if isinstance(self.cond_stage_key, ListConfig):
+                c = {k: v[:N] for k, v in c.items()}
+            else:
+                c = c[:N]
+        return x, c
+
+    def shared_step(self, batch, batch_idx):
+        x, c = self.get_xc(batch)
+        logits, target = self(x, c)
+        loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
+        return loss
+
+    def training_step(self, batch, batch_idx):
+        loss = self.shared_step(batch, batch_idx)
+        self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+        return loss
+
+    def validation_step(self, batch, batch_idx):
+        loss = self.shared_step(batch, batch_idx)
+        self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+        return loss
+
+    def configure_optimizers(self):
+        """
+        Following minGPT:
+        This long function is unfortunately doing something very simple and is being very defensive:
+        We are separating out all parameters of the model into two buckets: those that will experience
+        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
+        We are then returning the PyTorch optimizer object.
+        """
+        # separate out all parameters to those that will and won't experience regularizing weight decay
+        decay = set()
+        no_decay = set()
+        whitelist_weight_modules = (torch.nn.Linear, )
+
+        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, torch.nn.Conv1d, torch.nn.LSTM, torch.nn.GRU)
+        for mn, m in self.transformer.named_modules():
+            for pn, p in m.named_parameters():
+                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
+
+                if pn.endswith('bias'):
+                    # all biases will not be decayed
+                    no_decay.add(fpn)
+                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
+                    # weights of whitelist modules will be weight decayed
+                    decay.add(fpn)
+                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
+                    # weights of blacklist modules will NOT be weight decayed
+                    no_decay.add(fpn)
+                elif ('weight' in pn or 'bias' in pn) and isinstance(m, (torch.nn.LSTM, torch.nn.GRU)):
+                    no_decay.add(fpn)
+
+        # special case the position embedding parameter in the root GPT module as not decayed
+        no_decay.add('pos_emb')
+
+        # validate that we considered every parameter
+        param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
+        inter_params = decay & no_decay
+        union_params = decay | no_decay
+        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
+        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
+                                                    % (str(param_dict.keys() - union_params), )
+
+        # create the pytorch optimizer object
+        optim_groups = [
+            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
+            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
+        ]
+        optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
+        return optimizer
+
+
+if __name__ == '__main__':
+    from omegaconf import OmegaConf
+
+    cfg_image = OmegaConf.load('./configs/vggsound_transformer.yaml')
+    cfg_image.model.params.first_stage_config.params.ckpt_path = './logs/2021-05-19T22-16-54_vggsound_codebook/checkpoints/last.ckpt'
+
+    transformer_cfg = cfg_image.model.params.transformer_config
+    first_stage_cfg = cfg_image.model.params.first_stage_config
+    cond_stage_cfg = cfg_image.model.params.cond_stage_config
+    permuter_cfg = cfg_image.model.params.permuter_config
+    transformer = Net2NetTransformer(
+        transformer_cfg, first_stage_cfg, cond_stage_cfg, permuter_cfg
+    )
+
+    c = torch.rand(2, 2048, 212)
+    x = torch.rand(2, 1, 80, 160)
+
+    logits, target = transformer(x, c)
+    print(logits.shape, target.shape)
diff --git a/foleycrafter/models/specvqgan/models/vqgan.py b/foleycrafter/models/specvqgan/models/vqgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..58e7273b3153dc0f370a763de11165169cc2db91
--- /dev/null
+++ b/foleycrafter/models/specvqgan/models/vqgan.py
@@ -0,0 +1,397 @@
+import torch
+import torch.nn as nn
+import torchaudio
+from torchvision import transforms
+import torch.nn.functional as F
+import pytorch_lightning as pl
+
+import sys
+import math
+sys.path.insert(0, '.')  # nopep8
+from train import instantiate_from_config
+from foleycrafter.models.specvqgan.data.transforms import Wave2Spectrogram, NormalizeAudio
+
+from foleycrafter.models.specvqgan.modules.diffusionmodules.model import Encoder, Decoder, Encoder1d, Decoder1d
+from foleycrafter.models.specvqgan.modules.vqvae.quantize import VectorQuantizer, VectorQuantizer1d
+
+
+class VQModel(pl.LightningModule):
+    def __init__(self,
+                 ddconfig,
+                 lossconfig,
+                 n_embed,
+                 embed_dim,
+                 ckpt_path=None,
+                 ignore_keys=[],
+                 image_key="image",
+                 colorize_nlabels=None,
+                 monitor=None,
+                 L=10.,
+                 mel_num=80,
+                 spec_crop_len=160,
+                 normalize=False,
+                 freeze_encoder=False,
+                 ):
+        super().__init__()
+        self.image_key = image_key
+        # we need this one for compatibility in train.ImageLogger.log_img if statement
+        self.first_stage_key = image_key
+        self.encoder = Encoder(**ddconfig)
+        self.decoder = Decoder(**ddconfig)
+        self.loss = instantiate_from_config(lossconfig)
+        self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
+        self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
+        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+        
+        aug_list = [
+            torchaudio.transforms.Spectrogram(
+                n_fft=1024,
+                hop_length=1024//4,
+                power=1,
+            ),
+            torchaudio.transforms.MelScale(
+                n_mels=80,
+                sample_rate=22050,
+                f_min=125,
+                f_max=7600,
+                n_stft=513,
+                norm='slaney'
+            ),
+            Wave2Spectrogram(mel_num, spec_crop_len),
+        ]
+        if normalize:
+            aug_list = [transforms.RandomApply([NormalizeAudio()], p=1. if normalize else 0.)] + aug_list
+
+        if not freeze_encoder:
+            self.wav_transforms = nn.Sequential(*aug_list)
+        ignore_keys += ['first_stage_model.wav_transforms', 'wav_transforms']
+        
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+        if colorize_nlabels is not None:
+            assert type(colorize_nlabels)==int
+            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+        if monitor is not None:
+            self.monitor = monitor
+        self.used_codes = []
+        self.counts = [0 for _ in range(self.quantize.n_e)]
+
+        if freeze_encoder:
+            for param in self.encoder.parameters():
+                param.requires_grad = False
+            for param in self.quantize.parameters():
+                param.requires_grad = False
+            for param in self.quant_conv.parameters():
+                param.requires_grad = False
+
+    def init_from_ckpt(self, path, ignore_keys=list()):
+        sd = torch.load(path, map_location="cpu")["state_dict"]
+        keys = list(sd.keys())
+        for k in keys:
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+        self.load_state_dict(sd, strict=False)
+        print(f"Restored from {path}")
+
+    def encode(self, x):
+        h = self.encoder(x)  # 2d: (B, 256, 16, 16) <- (B, 3, 256, 256)
+        h = self.quant_conv(h)  # 2d: (B, 256, 16, 16)
+        quant, emb_loss, info = self.quantize(h)  # (B, 256, 16, 16), (), ((), (768, 1024), (768, 1))
+        if not self.training:
+            self.counts = [info[2].squeeze().tolist().count(i) + self.counts[i] for i in range(self.quantize.n_e)]
+        return quant, emb_loss, info
+
+    def decode(self, quant):
+        quant = self.post_quant_conv(quant)
+        dec = self.decoder(quant)
+        return dec
+
+    def decode_code(self, code_b):
+        quant_b = self.quantize.embed_code(code_b)
+        dec = self.decode(quant_b)
+        return dec
+
+    def forward(self, input):
+        quant, diff, _ = self.encode(input)
+        dec = self.decode(quant)
+        return dec, diff
+
+    def get_input(self, batch, k):
+        x = batch[k]
+        if len(x.shape) == 2:
+            x = self.spec_trans(x)
+        if len(x.shape) == 3:
+            x = x[..., None]
+        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
+        return x.float()
+
+    def spec_trans(self, wav):
+        self.wav_transforms.to(wav.device)
+        spec = self.wav_transforms(wav.to(torch.float32))
+        return 2 * spec - 1
+
+    def training_step(self, batch, batch_idx, optimizer_idx):
+        x = self.get_input(batch, self.image_key)
+        xrec, qloss = self(x)
+
+        if optimizer_idx == 0:
+            # autoencode
+            aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+                                            last_layer=self.get_last_layer(), split="train")
+
+            self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+            return aeloss
+
+        if optimizer_idx == 1:
+            # discriminator
+            discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+                                            last_layer=self.get_last_layer(), split="train")
+            self.log("train/disc_loss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+            return discloss
+
+    def validation_step(self, batch, batch_idx):
+        if batch_idx == 0 and self.global_step != 0 and sum(self.counts) > 0:
+            zero_hit_codes = len([1 for count in self.counts if count == 0])
+            used_codes = []
+            for c, count in enumerate(self.counts):
+                used_codes.extend([c] * count)
+            self.logger.experiment.add_histogram('val/code_hits', torch.tensor(used_codes), self.global_step)
+            self.logger.experiment.add_scalar('val/zero_hit_codes', zero_hit_codes, self.global_step)
+            self.counts = [0 for _ in range(self.quantize.n_e)]
+        x = self.get_input(batch, self.image_key)
+        xrec, qloss = self(x)
+        aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
+                                        last_layer=self.get_last_layer(), split="val")
+
+        discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
+                                            last_layer=self.get_last_layer(), split="val")
+        rec_loss = log_dict_ae['val/rec_loss']
+        self.log('val/rec_loss', rec_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
+        self.log('val/aeloss', aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
+        self.log_dict(log_dict_ae)
+        self.log_dict(log_dict_disc)
+        return self.log_dict
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        opt_ae = torch.optim.Adam(list(self.encoder.parameters()) +
+                                  list(self.decoder.parameters()) +
+                                  list(self.quantize.parameters()) +
+                                  list(self.quant_conv.parameters()) +
+                                  list(self.post_quant_conv.parameters()),
+                                  lr=lr, betas=(0.5, 0.9))
+        opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+                                    lr=lr, betas=(0.5, 0.9))
+        return [opt_ae, opt_disc], []
+
+    def get_last_layer(self):
+        return self.decoder.conv_out.weight
+
+    def log_images(self, batch, **kwargs):
+        log = dict()
+        x = self.get_input(batch, self.image_key)
+        x = x.to(self.device)
+        xrec, _ = self(x)
+        if x.shape[1] > 3:
+            # colorize with random projection
+            assert xrec.shape[1] > 3
+            x = self.to_rgb(x)
+            xrec = self.to_rgb(xrec)
+        log["inputs"] = x
+        log["reconstructions"] = xrec
+        return log
+
+    def to_rgb(self, x):
+        assert self.image_key == "segmentation"
+        if not hasattr(self, "colorize"):
+            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+        x = F.conv2d(x, weight=self.colorize)
+        x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+        return x
+
+
+class VQModel1d(VQModel):
+    def __init__(self, ddconfig, lossconfig, n_embed, embed_dim, ckpt_path=None, ignore_keys=[],
+                 image_key='feature', colorize_nlabels=None, monitor=None):
+        # ckpt_path is none to super because otherwise will try to load 1D checkpoint into 2D model
+        super().__init__(ddconfig, lossconfig, n_embed, embed_dim)
+        self.image_key = image_key
+        # we need this one for compatibility in train.ImageLogger.log_img if statement
+        self.first_stage_key = image_key
+        self.encoder = Encoder1d(**ddconfig)
+        self.decoder = Decoder1d(**ddconfig)
+        self.loss = instantiate_from_config(lossconfig)
+        self.quantize = VectorQuantizer1d(n_embed, embed_dim, beta=0.25)
+        self.quant_conv = torch.nn.Conv1d(ddconfig['z_channels'], embed_dim, 1)
+        self.post_quant_conv = torch.nn.Conv1d(embed_dim, ddconfig['z_channels'], 1)
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+        if colorize_nlabels is not None:
+            assert type(colorize_nlabels)==int
+            self.register_buffer('colorize', torch.randn(3, colorize_nlabels, 1, 1))
+        if monitor is not None:
+            self.monitor = monitor
+
+    def get_input(self, batch, k):
+        x = batch[k]
+        if self.image_key == 'feature':
+            x = x.permute(0, 2, 1)
+        elif self.image_key == 'image':
+            x = x.unsqueeze(1)
+        x = x.to(memory_format=torch.contiguous_format)
+        return x.float()
+
+    def forward(self, input):
+        if self.image_key == 'image':
+            input = input.squeeze(1)
+        quant, diff, _ = self.encode(input)
+        dec = self.decode(quant)
+        if self.image_key == 'image':
+            dec = dec.unsqueeze(1)
+        return dec, diff
+
+    def log_images(self, batch, **kwargs):
+        if self.image_key == 'image':
+            log = dict()
+            x = self.get_input(batch, self.image_key)
+            x = x.to(self.device)
+            xrec, _ = self(x)
+            if x.shape[1] > 3:
+                # colorize with random projection
+                assert xrec.shape[1] > 3
+                x = self.to_rgb(x)
+                xrec = self.to_rgb(xrec)
+            log['inputs'] = x
+            log['reconstructions'] = xrec
+            return log
+        else:
+            raise NotImplementedError('1d input should be treated differently')
+
+    def to_rgb(self, batch, **kwargs):
+        raise NotImplementedError('1d input should be treated differently')
+
+
+class VQSegmentationModel(VQModel):
+    def __init__(self, n_labels, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+                                  list(self.decoder.parameters())+
+                                  list(self.quantize.parameters())+
+                                  list(self.quant_conv.parameters())+
+                                  list(self.post_quant_conv.parameters()),
+                                  lr=lr, betas=(0.5, 0.9))
+        return opt_ae
+
+    def training_step(self, batch, batch_idx):
+        x = self.get_input(batch, self.image_key)
+        xrec, qloss = self(x)
+        aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
+        self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+        return aeloss
+
+    def validation_step(self, batch, batch_idx):
+        x = self.get_input(batch, self.image_key)
+        xrec, qloss = self(x)
+        aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
+        self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+        total_loss = log_dict_ae["val/total_loss"]
+        self.log("val/total_loss", total_loss,
+                 prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
+        return aeloss
+
+    @torch.no_grad()
+    def log_images(self, batch, **kwargs):
+        log = dict()
+        x = self.get_input(batch, self.image_key)
+        x = x.to(self.device)
+        xrec, _ = self(x)
+        if x.shape[1] > 3:
+            # colorize with random projection
+            assert xrec.shape[1] > 3
+            # convert logits to indices
+            xrec = torch.argmax(xrec, dim=1, keepdim=True)
+            xrec = F.one_hot(xrec, num_classes=x.shape[1])
+            xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
+            x = self.to_rgb(x)
+            xrec = self.to_rgb(xrec)
+        log["inputs"] = x
+        log["reconstructions"] = xrec
+        return log
+
+
+class VQNoDiscModel(VQModel):
+    def __init__(self,
+                 ddconfig,
+                 lossconfig,
+                 n_embed,
+                 embed_dim,
+                 ckpt_path=None,
+                 ignore_keys=[],
+                 image_key="image",
+                 colorize_nlabels=None
+                 ):
+        super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
+                         ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
+                         colorize_nlabels=colorize_nlabels)
+
+    def training_step(self, batch, batch_idx):
+        x = self.get_input(batch, self.image_key)
+        xrec, qloss = self(x)
+        # autoencode
+        aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
+        output = pl.TrainResult(minimize=aeloss)
+        output.log("train/aeloss", aeloss,
+                   prog_bar=True, logger=True, on_step=True, on_epoch=True)
+        output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+        return output
+
+    def validation_step(self, batch, batch_idx):
+        x = self.get_input(batch, self.image_key)
+        xrec, qloss = self(x)
+        aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
+        rec_loss = log_dict_ae["val/rec_loss"]
+        output = pl.EvalResult(checkpoint_on=rec_loss)
+        output.log("val/rec_loss", rec_loss,
+                   prog_bar=True, logger=True, on_step=True, on_epoch=True)
+        output.log("val/aeloss", aeloss,
+                   prog_bar=True, logger=True, on_step=True, on_epoch=True)
+        output.log_dict(log_dict_ae)
+
+        return output
+
+    def configure_optimizers(self):
+        optimizer = torch.optim.Adam(list(self.encoder.parameters()) +
+                                     list(self.decoder.parameters()) +
+                                     list(self.quantize.parameters()) +
+                                     list(self.quant_conv.parameters()) +
+                                     list(self.post_quant_conv.parameters()),
+                                     lr=self.learning_rate, betas=(0.5, 0.9))
+        return optimizer
+
+
+if __name__ == '__main__':
+    from omegaconf import OmegaConf
+    from train import instantiate_from_config
+
+    image_key = 'image'
+    cfg_audio = OmegaConf.load('./configs/vggsound_codebook.yaml')
+    model = VQModel(cfg_audio.model.params.ddconfig,
+                    cfg_audio.model.params.lossconfig,
+                    cfg_audio.model.params.n_embed,
+                    cfg_audio.model.params.embed_dim,
+                    image_key='image')
+    batch = {
+        'image': torch.rand((4, 80, 848)),
+        'file_path_': ['data/vggsound/mel123.npy', 'data/vggsound/mel123.npy', 'data/vggsound/mel123.npy'],
+        'class': [1, 1, 1],
+    }
+    xrec, qloss = model(model.get_input(batch, image_key))
+    print(xrec.shape, qloss.shape)
diff --git a/foleycrafter/models/specvqgan/modules/diffusionmodules/model.py b/foleycrafter/models/specvqgan/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a1ceb026e9be0cd864287800daff4df37f432c1
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/diffusionmodules/model.py
@@ -0,0 +1,999 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+    """
+    This matches the implementation in Denoising Diffusion Probabilistic Models:
+    From Fairseq.
+    Build sinusoidal embeddings.
+    This matches the implementation in tensor2tensor, but differs slightly
+    from the description in Section 3.5 of "Attention Is All You Need".
+    """
+    assert len(timesteps.shape) == 1
+
+    half_dim = embedding_dim // 2
+    emb = math.log(10000) / (half_dim - 1)
+    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+    emb = emb.to(device=timesteps.device)
+    emb = timesteps.float()[:, None] * emb[None, :]
+    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+    if embedding_dim % 2 == 1:  # zero pad
+        emb = torch.nn.functional.pad(emb, (0,1,0,0))
+    return emb
+
+
+def nonlinearity(x):
+    # swish
+    return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels):
+    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+class Upsample(nn.Module):
+    def __init__(self, in_channels, with_conv):
+        super().__init__()
+        self.with_conv = with_conv
+        if self.with_conv:
+            self.conv = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x):
+        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+        if self.with_conv:
+            x = self.conv(x)
+        return x
+
+class Upsample1d(Upsample):
+    def __init__(self, in_channels, with_conv):
+        super().__init__(in_channels, with_conv)
+        if self.with_conv:
+            self.conv = torch.nn.Conv1d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+class Downsample(nn.Module):
+    def __init__(self, in_channels, with_conv):
+        super().__init__()
+        self.with_conv = with_conv
+        if self.with_conv:
+            # no asymmetric padding in torch conv, must do it ourselves
+            self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+            self.pad = (0, 1, 0, 1)
+        else:
+            self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
+
+    def forward(self, x):
+        if self.with_conv:  # bp: check self.avgpool and self.pad
+            x = torch.nn.functional.pad(x, self.pad, mode="constant", value=0)
+            x = self.conv(x)
+        else:
+            x = self.avg_pool(x)
+        return x
+
+class Downsample1d(Downsample):
+
+    def __init__(self, in_channels, with_conv):
+        super().__init__(in_channels, with_conv)
+        if self.with_conv:
+            # no asymmetric padding in torch conv, must do it ourselves
+            # TODO: can we replace it just with conv2d with padding 1?
+            self.conv = torch.nn.Conv1d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+            self.pad = (1, 1)
+        else:
+            self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2)
+
+
+class ResnetBlock(nn.Module):
+    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+                 dropout, temb_channels=512):
+        super().__init__()
+        self.in_channels = in_channels
+        out_channels = in_channels if out_channels is None else out_channels
+        self.out_channels = out_channels
+        self.use_conv_shortcut = conv_shortcut
+
+        self.norm1 = Normalize(in_channels)
+        self.conv1 = torch.nn.Conv2d(in_channels,
+                                     out_channels,
+                                     kernel_size=3,
+                                     stride=1,
+                                     padding=1)
+        if temb_channels > 0:
+            self.temb_proj = torch.nn.Linear(temb_channels,
+                                             out_channels)
+        self.norm2 = Normalize(out_channels)
+        self.dropout = torch.nn.Dropout(dropout)
+        self.conv2 = torch.nn.Conv2d(out_channels,
+                                     out_channels,
+                                     kernel_size=3,
+                                     stride=1,
+                                     padding=1)
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                self.conv_shortcut = torch.nn.Conv2d(in_channels,
+                                                     out_channels,
+                                                     kernel_size=3,
+                                                     stride=1,
+                                                     padding=1)
+            else:
+                self.nin_shortcut = torch.nn.Conv2d(in_channels,
+                                                    out_channels,
+                                                    kernel_size=1,
+                                                    stride=1,
+                                                    padding=0)
+
+    def forward(self, x, temb):
+        h = x
+        h = self.norm1(h)
+        h = nonlinearity(h)
+        h = self.conv1(h)
+
+        if temb is not None:
+            h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+        h = self.norm2(h)
+        h = nonlinearity(h)
+        h = self.dropout(h)
+        h = self.conv2(h)
+
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                x = self.conv_shortcut(x)
+            else:
+                x = self.nin_shortcut(x)
+
+        return x+h
+
+class ResnetBlock1d(ResnetBlock):
+    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+                 dropout, temb_channels=512):
+        super().__init__(in_channels=in_channels, out_channels=out_channels,
+                         conv_shortcut=conv_shortcut, dropout=dropout, temb_channels=temb_channels)
+        # redefining different elements (forward is goint to be the same as in RenetBlock)
+        if temb_channels > 0:
+            raise NotImplementedError('go to ResnetBlock and figure out how to deal with it in forward')
+            self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+
+        self.conv1 = torch.nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+        self.conv2 = torch.nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                self.conv_shortcut = torch.nn.Conv1d(in_channels, out_channels, kernel_size=3,
+                                                     stride=1, padding=1)
+            else:
+                self.nin_shortcut = torch.nn.Conv1d(in_channels, out_channels, kernel_size=1,
+                                                    stride=1, padding=0)
+
+
+class AttnBlock(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = Normalize(in_channels)
+        self.q = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.k = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.v = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.proj_out = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=1,
+                                        stride=1,
+                                        padding=0)
+
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        b,c,h,w = q.shape
+        q = q.reshape(b,c,h*w)
+        q = q.permute(0,2,1)   # b,hw,c
+        k = k.reshape(b,c,h*w) # b,c,hw
+        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+        w_ = w_ * (int(c)**(-0.5))
+        w_ = torch.nn.functional.softmax(w_, dim=2)
+
+        # attend to values
+        v = v.reshape(b,c,h*w)
+        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
+        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+        h_ = h_.reshape(b,c,h,w)
+
+        h_ = self.proj_out(h_)
+
+        return x+h_
+
+class AttnBlock1d(nn.Module):
+
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = Normalize(in_channels)
+        self.q = torch.nn.Conv1d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+        self.k = torch.nn.Conv1d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+        self.v = torch.nn.Conv1d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+        self.proj_out = torch.nn.Conv1d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        b, c, t = q.shape
+        q = q.permute(0, 2, 1)   # b,t,c
+        w_ = torch.bmm(q, k)     # b,t,t    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+        w_ = w_ * (int(c) ** (-0.5))
+        w_ = torch.nn.functional.softmax(w_, dim=2)
+
+        # attend to values
+        w_ = w_.permute(0, 2, 1)  # b,t,t (first t of k, second of q)
+        h_ = torch.bmm(v, w_)  # b,c,t (t of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+
+        h_ = self.proj_out(h_)
+
+        return x + h_
+
+
+class Model(nn.Module):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+                 resolution, use_timestep=True):
+        super().__init__()
+        self.ch = ch
+        self.temb_ch = self.ch*4
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+
+        self.use_timestep = use_timestep
+        if self.use_timestep:
+            # timestep embedding
+            self.temb = nn.Module()
+            self.temb.dense = nn.ModuleList([
+                torch.nn.Linear(self.ch,
+                                self.temb_ch),
+                torch.nn.Linear(self.temb_ch,
+                                self.temb_ch),
+            ])
+
+        # downsampling
+        self.conv_in = torch.nn.Conv2d(in_channels,
+                                       self.ch,
+                                       kernel_size=3,
+                                       stride=1,
+                                       padding=1)
+
+        curr_res = resolution
+        in_ch_mult = (1,)+tuple(ch_mult)
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = ch*in_ch_mult[i_level]
+            block_out = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(AttnBlock(block_in))
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions-1:
+                down.downsample = Downsample(block_in, resamp_with_conv)
+                curr_res = curr_res // 2
+            self.down.append(down)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+        self.mid.attn_1 = AttnBlock(block_in)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = ch*ch_mult[i_level]
+            skip_in = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks+1):
+                if i_block == self.num_res_blocks:
+                    skip_in = ch*in_ch_mult[i_level]
+                block.append(ResnetBlock(in_channels=block_in+skip_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(AttnBlock(block_in))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = Upsample(block_in, resamp_with_conv)
+                curr_res = curr_res * 2
+            self.up.insert(0, up) # prepend to get consistent order
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        out_ch,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+
+    def forward(self, x, t=None):
+        #assert x.shape[2] == x.shape[3] == self.resolution
+
+        if self.use_timestep:
+            # timestep embedding
+            assert t is not None
+            temb = get_timestep_embedding(t, self.ch)
+            temb = self.temb.dense[0](temb)
+            temb = nonlinearity(temb)
+            temb = self.temb.dense[1](temb)
+        else:
+            temb = None
+
+        # downsampling
+        hs = [self.conv_in(x)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                h = self.down[i_level].block[i_block](hs[-1], temb)
+                if len(self.down[i_level].attn) > 0:
+                    h = self.down[i_level].attn[i_block](h)
+                hs.append(h)
+            if i_level != self.num_resolutions-1:
+                hs.append(self.down[i_level].downsample(hs[-1]))
+
+        # middle
+        h = hs[-1]
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # upsampling
+        for i_level in reversed(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks+1):
+                h = self.up[i_level].block[i_block](
+                    torch.cat([h, hs.pop()], dim=1), temb)
+                if len(self.up[i_level].attn) > 0:
+                    h = self.up[i_level].attn[i_block](h)
+            if i_level != 0:
+                h = self.up[i_level].upsample(h)
+
+        # end
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+
+class Encoder(nn.Module):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+                 resolution, z_channels, double_z=True, **ignore_kwargs):
+        super().__init__()
+        self.ch = ch
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+
+        # downsampling
+        self.conv_in = torch.nn.Conv2d(in_channels,
+                                       self.ch,
+                                       kernel_size=3,
+                                       stride=1,
+                                       padding=1)
+
+        curr_res = resolution
+        in_ch_mult = (1,)+tuple(ch_mult)
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = ch*in_ch_mult[i_level]
+            block_out = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(AttnBlock(block_in))
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions-1:
+                down.downsample = Downsample(block_in, resamp_with_conv)
+                curr_res = curr_res // 2
+            self.down.append(down)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+        self.mid.attn_1 = AttnBlock(block_in)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        2*z_channels if double_z else z_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+
+    def forward(self, x):
+        #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
+
+        # timestep embedding
+        temb = None
+
+        # downsampling
+        hs = [self.conv_in(x)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                h = self.down[i_level].block[i_block](hs[-1], temb)
+                if len(self.down[i_level].attn) > 0:
+                    h = self.down[i_level].attn[i_block](h)
+                hs.append(h)
+            if i_level != self.num_resolutions-1:
+                hs.append(self.down[i_level].downsample(hs[-1]))
+
+        # middle
+        h = hs[-1]
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # end
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+class Encoder1d(Encoder):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+                 resolution, z_channels, double_z=True, **ignore_kwargs):
+        super().__init__(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
+                         attn_resolutions=attn_resolutions, dropout=dropout,
+                         resamp_with_conv=resamp_with_conv,
+                         in_channels=in_channels, resolution=resolution, z_channels=z_channels,
+                         double_z=double_z, **ignore_kwargs)
+        self.ch = ch
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+
+        # downsampling
+        self.conv_in = torch.nn.Conv1d(in_channels,
+                                       self.ch,
+                                       kernel_size=3,
+                                       stride=1,
+                                       padding=1)
+
+        curr_res = resolution
+        in_ch_mult = (1,)+tuple(ch_mult)
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = ch*in_ch_mult[i_level]
+            block_out = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(ResnetBlock1d(in_channels=block_in,
+                                           out_channels=block_out,
+                                           temb_channels=self.temb_ch,
+                                           dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(AttnBlock1d(block_in))
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions-1:
+                down.downsample = Downsample1d(block_in, resamp_with_conv)
+                curr_res = curr_res // 2
+            self.down.append(down)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock1d(in_channels=block_in,
+                                         out_channels=block_in,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout)
+        self.mid.attn_1 = AttnBlock1d(block_in)
+        self.mid.block_2 = ResnetBlock1d(in_channels=block_in,
+                                         out_channels=block_in,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout)
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv1d(block_in,
+                                        2*z_channels if double_z else z_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+
+class Decoder(nn.Module):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+                 resolution, z_channels, give_pre_end=False, **ignorekwargs):
+        super().__init__()
+        self.ch = ch
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+        self.give_pre_end = give_pre_end
+
+        # compute in_ch_mult, block_in and curr_res at lowest res
+        in_ch_mult = (1,)+tuple(ch_mult)
+        block_in = ch*ch_mult[self.num_resolutions-1]
+        curr_res = resolution // 2**(self.num_resolutions-1)
+        # self.z_shape = (1,z_channels,curr_res,curr_res)
+        # print("Working with z of shape {} = {} dimensions.".format(
+        #     self.z_shape, np.prod(self.z_shape)))
+
+        # z to block_in
+        self.conv_in = torch.nn.Conv2d(z_channels,
+                                       block_in,
+                                       kernel_size=3,
+                                       stride=1,
+                                       padding=1)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+        self.mid.attn_1 = AttnBlock(block_in)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks+1):
+                block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(AttnBlock(block_in))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = Upsample(block_in, resamp_with_conv)
+                curr_res = curr_res * 2
+            self.up.insert(0, up) # prepend to get consistent order
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        out_ch,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, z):
+        #assert z.shape[1:] == self.z_shape[1:]
+        self.last_z_shape = z.shape
+
+        # timestep embedding
+        temb = None
+
+        # z to block_in
+        h = self.conv_in(z)
+
+        # middle
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # upsampling
+        for i_level in reversed(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks+1):
+                h = self.up[i_level].block[i_block](h, temb)
+                if len(self.up[i_level].attn) > 0:
+                    h = self.up[i_level].attn[i_block](h)
+            if i_level != 0:
+                h = self.up[i_level].upsample(h)
+
+        # end
+        if self.give_pre_end:
+            return h
+
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+class Decoder1d(Decoder):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+                 resolution, z_channels, give_pre_end=False, **ignorekwargs):
+        super().__init__(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
+                         attn_resolutions=attn_resolutions, dropout=dropout,
+                         resamp_with_conv=resamp_with_conv,
+                         in_channels=in_channels, resolution=resolution, z_channels=z_channels,
+                         give_pre_end=give_pre_end, **ignorekwargs)
+        self.ch = ch
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+        self.give_pre_end = give_pre_end
+
+        # compute in_ch_mult, block_in and curr_res at lowest res
+        in_ch_mult = (1,) + tuple(ch_mult)
+        block_in = ch * ch_mult[self.num_resolutions-1]
+        curr_res = resolution // 2**(self.num_resolutions-1)
+        # self.z_shape = (1,z_channels,curr_res,curr_res)
+        # print("Working with z of shape {} = {} dimensions.".format(
+        #     self.z_shape, np.prod(self.z_shape)))
+
+        # z to block_in
+        self.conv_in = torch.nn.Conv1d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock1d(in_channels=block_in, out_channels=block_in,
+                                         temb_channels=self.temb_ch, dropout=dropout)
+        self.mid.attn_1 = AttnBlock1d(block_in)
+        self.mid.block_2 = ResnetBlock1d(in_channels=block_in, out_channels=block_in,
+                                         temb_channels=self.temb_ch, dropout=dropout)
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = ch * ch_mult[i_level]
+            for i_block in range(self.num_res_blocks+1):
+                block.append(ResnetBlock1d(in_channels=block_in, out_channels=block_out,
+                                           temb_channels=self.temb_ch, dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(AttnBlock1d(block_in))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = Upsample1d(block_in, resamp_with_conv)
+                curr_res = curr_res * 2
+            self.up.insert(0, up)  # prepend to get consistent order
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv1d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
+
+
+class VUNet(nn.Module):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True,
+                 in_channels, c_channels,
+                 resolution, z_channels, use_timestep=False, **ignore_kwargs):
+        super().__init__()
+        self.ch = ch
+        self.temb_ch = self.ch*4
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+
+        self.use_timestep = use_timestep
+        if self.use_timestep:
+            # timestep embedding
+            self.temb = nn.Module()
+            self.temb.dense = nn.ModuleList([
+                torch.nn.Linear(self.ch,
+                                self.temb_ch),
+                torch.nn.Linear(self.temb_ch,
+                                self.temb_ch),
+            ])
+
+        # downsampling
+        self.conv_in = torch.nn.Conv2d(c_channels,
+                                       self.ch,
+                                       kernel_size=3,
+                                       stride=1,
+                                       padding=1)
+
+        curr_res = resolution
+        in_ch_mult = (1,)+tuple(ch_mult)
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = ch*in_ch_mult[i_level]
+            block_out = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(AttnBlock(block_in))
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions-1:
+                down.downsample = Downsample(block_in, resamp_with_conv)
+                curr_res = curr_res // 2
+            self.down.append(down)
+
+        self.z_in = torch.nn.Conv2d(z_channels,
+                                    block_in,
+                                    kernel_size=1,
+                                    stride=1,
+                                    padding=0)
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+        self.mid.attn_1 = AttnBlock(block_in)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = ch*ch_mult[i_level]
+            skip_in = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks+1):
+                if i_block == self.num_res_blocks:
+                    skip_in = ch*in_ch_mult[i_level]
+                block.append(ResnetBlock(in_channels=block_in+skip_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(AttnBlock(block_in))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = Upsample(block_in, resamp_with_conv)
+                curr_res = curr_res * 2
+            self.up.insert(0, up) # prepend to get consistent order
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        out_ch,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+
+    def forward(self, x, z):
+        #assert x.shape[2] == x.shape[3] == self.resolution
+
+        if self.use_timestep:
+            # timestep embedding
+            assert t is not None
+            temb = get_timestep_embedding(t, self.ch)
+            temb = self.temb.dense[0](temb)
+            temb = nonlinearity(temb)
+            temb = self.temb.dense[1](temb)
+        else:
+            temb = None
+
+        # downsampling
+        hs = [self.conv_in(x)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                h = self.down[i_level].block[i_block](hs[-1], temb)
+                if len(self.down[i_level].attn) > 0:
+                    h = self.down[i_level].attn[i_block](h)
+                hs.append(h)
+            if i_level != self.num_resolutions-1:
+                hs.append(self.down[i_level].downsample(hs[-1]))
+
+        # middle
+        h = hs[-1]
+        z = self.z_in(z)
+        h = torch.cat((h,z),dim=1)
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # upsampling
+        for i_level in reversed(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks+1):
+                h = self.up[i_level].block[i_block](
+                    torch.cat([h, hs.pop()], dim=1), temb)
+                if len(self.up[i_level].attn) > 0:
+                    h = self.up[i_level].attn[i_block](h)
+            if i_level != 0:
+                h = self.up[i_level].upsample(h)
+
+        # end
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+
+class SimpleDecoder(nn.Module):
+    def __init__(self, in_channels, out_channels, *args, **kwargs):
+        super().__init__()
+        self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+                                     ResnetBlock(in_channels=in_channels,
+                                                 out_channels=2 * in_channels,
+                                                 temb_channels=0, dropout=0.0),
+                                     ResnetBlock(in_channels=2 * in_channels,
+                                                out_channels=4 * in_channels,
+                                                temb_channels=0, dropout=0.0),
+                                     ResnetBlock(in_channels=4 * in_channels,
+                                                out_channels=2 * in_channels,
+                                                temb_channels=0, dropout=0.0),
+                                     nn.Conv2d(2*in_channels, in_channels, 1),
+                                     Upsample(in_channels, with_conv=True)])
+        # end
+        self.norm_out = Normalize(in_channels)
+        self.conv_out = torch.nn.Conv2d(in_channels,
+                                        out_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x):
+        for i, layer in enumerate(self.model):
+            if i in [1,2,3]:
+                x = layer(x, None)
+            else:
+                x = layer(x)
+
+        h = self.norm_out(x)
+        h = nonlinearity(h)
+        x = self.conv_out(h)
+        return x
+
+
+class UpsampleDecoder(nn.Module):
+    def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+                 ch_mult=(2,2), dropout=0.0):
+        super().__init__()
+        # upsampling
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        block_in = in_channels
+        curr_res = resolution // 2 ** (self.num_resolutions - 1)
+        self.res_blocks = nn.ModuleList()
+        self.upsample_blocks = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            res_block = []
+            block_out = ch * ch_mult[i_level]
+            for i_block in range(self.num_res_blocks + 1):
+                res_block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+            self.res_blocks.append(nn.ModuleList(res_block))
+            if i_level != self.num_resolutions - 1:
+                self.upsample_blocks.append(Upsample(block_in, True))
+                curr_res = curr_res * 2
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        out_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x):
+        # upsampling
+        h = x
+        for k, i_level in enumerate(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks + 1):
+                h = self.res_blocks[i_level][i_block](h, None)
+            if i_level != self.num_resolutions - 1:
+                h = self.upsample_blocks[k](h)
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+
+if __name__ == '__main__':
+    ddconfig = {
+        'ch': 128,
+        'num_res_blocks': 2,
+        'dropout': 0.0,
+        'z_channels': 256,
+        'double_z': False,
+    }
+
+    # Audio example ##
+    ddconfig['in_channels'] = 1
+    ddconfig['resolution'] = 848
+    ddconfig['attn_resolutions'] = [53]
+    ddconfig['ch_mult'] = [1, 1, 2, 2, 4]
+    ddconfig['out_ch'] = 1
+    # input
+    inputs = torch.rand(4, 1, 80, 848)
+    print('Input:', inputs.shape)
+    # Encoder
+    encoder = Encoder(**ddconfig)
+    enc_outs = encoder(inputs)
+    print('Encoder out:', enc_outs.shape)
+    # Decoder
+    decoder = Decoder(**ddconfig)
+    quant_outs = torch.rand(4, 256, 5, 53)
+    dec_outs = decoder(quant_outs)
+    print('Decoder out:', dec_outs.shape)
diff --git a/foleycrafter/models/specvqgan/modules/discriminator/model.py b/foleycrafter/models/specvqgan/modules/discriminator/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5263368a5e74d9d07840399469ca12a54e7fecbc
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/discriminator/model.py
@@ -0,0 +1,295 @@
+import functools
+import torch.nn as nn
+
+
+class ActNorm(nn.Module):
+    def __init__(self, num_features, logdet=False, affine=True,
+                 allow_reverse_init=False):
+        assert affine
+        super().__init__()
+        self.logdet = logdet
+        self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
+        self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
+        self.allow_reverse_init = allow_reverse_init
+
+        self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
+
+    def initialize(self, input):
+        with torch.no_grad():
+            flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
+            mean = (
+                flatten.mean(1)
+                .unsqueeze(1)
+                .unsqueeze(2)
+                .unsqueeze(3)
+                .permute(1, 0, 2, 3)
+            )
+            std = (
+                flatten.std(1)
+                .unsqueeze(1)
+                .unsqueeze(2)
+                .unsqueeze(3)
+                .permute(1, 0, 2, 3)
+            )
+
+            self.loc.data.copy_(-mean)
+            self.scale.data.copy_(1 / (std + 1e-6))
+
+    def forward(self, input, reverse=False):
+        if reverse:
+            return self.reverse(input)
+        if len(input.shape) == 2:
+            input = input[:, :, None, None]
+            squeeze = True
+        else:
+            squeeze = False
+
+        _, _, height, width = input.shape
+
+        if self.training and self.initialized.item() == 0:
+            self.initialize(input)
+            self.initialized.fill_(1)
+
+        h = self.scale * (input + self.loc)
+
+        if squeeze:
+            h = h.squeeze(-1).squeeze(-1)
+
+        if self.logdet:
+            log_abs = torch.log(torch.abs(self.scale))
+            logdet = height * width * torch.sum(log_abs)
+            logdet = logdet * torch.ones(input.shape[0]).to(input)
+            return h, logdet
+
+        return h
+
+    def reverse(self, output):
+        if self.training and self.initialized.item() == 0:
+            if not self.allow_reverse_init:
+                raise RuntimeError(
+                    "Initializing ActNorm in reverse direction is "
+                    "disabled by default. Use allow_reverse_init=True to enable."
+                )
+            else:
+                self.initialize(output)
+                self.initialized.fill_(1)
+
+        if len(output.shape) == 2:
+            output = output[:, :, None, None]
+            squeeze = True
+        else:
+            squeeze = False
+
+        h = output / self.scale - self.loc
+
+        if squeeze:
+            h = h.squeeze(-1).squeeze(-1)
+        return h
+
+def weights_init(m):
+    classname = m.__class__.__name__
+    if classname.find('Conv') != -1:
+        nn.init.normal_(m.weight.data, 0.0, 0.02)
+    elif classname.find('BatchNorm') != -1:
+        nn.init.normal_(m.weight.data, 1.0, 0.02)
+        nn.init.constant_(m.bias.data, 0)
+
+
+class NLayerDiscriminator(nn.Module):
+    """Defines a PatchGAN discriminator as in Pix2Pix
+        --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
+    """
+    def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
+        """Construct a PatchGAN discriminator
+        Parameters:
+            input_nc (int)  -- the number of channels in input images
+            ndf (int)       -- the number of filters in the last conv layer
+            n_layers (int)  -- the number of conv layers in the discriminator
+            norm_layer      -- normalization layer
+        """
+        super(NLayerDiscriminator, self).__init__()
+        if not use_actnorm:
+            norm_layer = nn.BatchNorm2d
+        else:
+            norm_layer = ActNorm
+        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
+            use_bias = norm_layer.func != nn.BatchNorm2d
+        else:
+            use_bias = norm_layer != nn.BatchNorm2d
+
+        kw = 4
+        padw = 1
+        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
+        nf_mult = 1
+        nf_mult_prev = 1
+        for n in range(1, n_layers):  # gradually increase the number of filters
+            nf_mult_prev = nf_mult
+            nf_mult = min(2 ** n, 8)
+            sequence += [
+                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+                norm_layer(ndf * nf_mult),
+                nn.LeakyReLU(0.2, True)
+            ]
+
+        nf_mult_prev = nf_mult
+        nf_mult = min(2 ** n_layers, 8)
+        sequence += [
+            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+            norm_layer(ndf * nf_mult),
+            nn.LeakyReLU(0.2, True)
+        ]
+        # output 1 channel prediction map
+        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
+        self.main = nn.Sequential(*sequence)
+
+    def forward(self, input):
+        """Standard forward."""
+        return self.main(input)
+
+class NLayerDiscriminator1dFeats(NLayerDiscriminator):
+    """Defines a PatchGAN discriminator as in Pix2Pix
+        --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
+    """
+    def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
+        """Construct a PatchGAN discriminator
+        Parameters:
+            input_nc (int)  -- the number of channels in input feats
+            ndf (int)       -- the number of filters in the last conv layer
+            n_layers (int)  -- the number of conv layers in the discriminator
+            norm_layer      -- normalization layer
+        """
+        super().__init__(input_nc=input_nc, ndf=64, n_layers=n_layers, use_actnorm=use_actnorm)
+
+        if not use_actnorm:
+            norm_layer = nn.BatchNorm1d
+        else:
+            norm_layer = ActNorm
+        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm has affine parameters
+            use_bias = norm_layer.func != nn.BatchNorm1d
+        else:
+            use_bias = norm_layer != nn.BatchNorm1d
+
+        kw = 4
+        padw = 1
+        sequence = [nn.Conv1d(input_nc, input_nc//2, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
+        nf_mult = input_nc//2
+        nf_mult_prev = 1
+        for n in range(1, n_layers):  # gradually decrease the number of filters
+            nf_mult_prev = nf_mult
+            nf_mult = max(nf_mult_prev // (2 ** n), 8)
+            sequence += [
+                nn.Conv1d(nf_mult_prev, nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+                norm_layer(nf_mult),
+                nn.LeakyReLU(0.2, True)
+            ]
+
+        nf_mult_prev = nf_mult
+        nf_mult = max(nf_mult_prev // (2 ** n), 8)
+        sequence += [
+            nn.Conv1d(nf_mult_prev, nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+            norm_layer(nf_mult),
+            nn.LeakyReLU(0.2, True)
+        ]
+        nf_mult_prev = nf_mult
+        nf_mult = max(nf_mult_prev // (2 ** n), 8)
+        sequence += [
+            nn.Conv1d(nf_mult_prev, nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+            norm_layer(nf_mult),
+            nn.LeakyReLU(0.2, True)
+        ]
+        # output 1 channel prediction map
+        sequence += [nn.Conv1d(nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
+        self.main = nn.Sequential(*sequence)
+
+
+class NLayerDiscriminator1dSpecs(NLayerDiscriminator):
+    """Defines a PatchGAN discriminator as in Pix2Pix
+        --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
+    """
+    def __init__(self, input_nc=80, ndf=64, n_layers=3, use_actnorm=False):
+        """Construct a PatchGAN discriminator
+        Parameters:
+            input_nc (int)  -- the number of channels in input specs
+            ndf (int)       -- the number of filters in the last conv layer
+            n_layers (int)  -- the number of conv layers in the discriminator
+            norm_layer      -- normalization layer
+        """
+        super().__init__(input_nc=input_nc, ndf=64, n_layers=n_layers, use_actnorm=use_actnorm)
+
+        if not use_actnorm:
+            norm_layer = nn.BatchNorm1d
+        else:
+            norm_layer = ActNorm
+        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm has affine parameters
+            use_bias = norm_layer.func != nn.BatchNorm1d
+        else:
+            use_bias = norm_layer != nn.BatchNorm1d
+
+        kw = 4
+        padw = 1
+        sequence = [nn.Conv1d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
+        nf_mult = 1
+        nf_mult_prev = 1
+        for n in range(1, n_layers):  # gradually decrease the number of filters
+            nf_mult_prev = nf_mult
+            nf_mult = min(2 ** n, 8)
+            sequence += [
+                nn.Conv1d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+                norm_layer(ndf * nf_mult),
+                nn.LeakyReLU(0.2, True)
+            ]
+
+        nf_mult_prev = nf_mult
+        nf_mult = min(2 ** n_layers, 8)
+        sequence += [
+            nn.Conv1d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+            norm_layer(ndf * nf_mult),
+            nn.LeakyReLU(0.2, True)
+        ]
+        # output 1 channel prediction map
+        sequence += [nn.Conv1d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
+        self.main = nn.Sequential(*sequence)
+
+    def forward(self, input):
+        """Standard forward."""
+        # (B, C, L)
+        input = input.squeeze(1)
+        input = self.main(input)
+        return input
+
+
+if __name__ == '__main__':
+    import torch
+
+    ## FEATURES
+    disc_in_channels = 2048
+    disc_num_layers = 2
+    use_actnorm = False
+    disc_ndf = 64
+    discriminator = NLayerDiscriminator1dFeats(input_nc=disc_in_channels, n_layers=disc_num_layers,
+                                            use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init)
+    inputs = torch.rand((6, 2048, 212))
+    outputs = discriminator(inputs)
+    print(outputs.shape)
+
+    ## AUDIO
+    disc_in_channels = 1
+    disc_num_layers = 3
+    use_actnorm = False
+    disc_ndf = 64
+    discriminator = NLayerDiscriminator(input_nc=disc_in_channels, n_layers=disc_num_layers,
+                                        use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init)
+    inputs = torch.rand((6, 1, 80, 848))
+    outputs = discriminator(inputs)
+    print(outputs.shape)
+
+    ## IMAGE
+    disc_in_channels = 3
+    disc_num_layers = 3
+    use_actnorm = False
+    disc_ndf = 64
+    discriminator = NLayerDiscriminator(input_nc=disc_in_channels, n_layers=disc_num_layers,
+                                        use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init)
+    inputs = torch.rand((6, 3, 256, 256))
+    outputs = discriminator(inputs)
+    print(outputs.shape)
diff --git a/foleycrafter/models/specvqgan/modules/losses/__init__.py b/foleycrafter/models/specvqgan/modules/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..533c5aa92c87f32fd5676e02463c703b22130f73
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/__init__.py
@@ -0,0 +1,7 @@
+from foleycrafter.models.specvqgan.modules.losses.vqperceptual import DummyLoss
+
+# relative imports pain
+import os
+import sys
+path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'vggishish')
+sys.path.append(path)
diff --git a/foleycrafter/models/specvqgan/modules/losses/lpaps.py b/foleycrafter/models/specvqgan/modules/losses/lpaps.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e2a3f861f8ae1024da40c71f57a5ddd5098cfab
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/lpaps.py
@@ -0,0 +1,152 @@
+"""
+    Based on https://github.com/CompVis/taming-transformers/blob/52720829/taming/modules/losses/lpips.py
+    Adapted for spectrograms by Vladimir Iashin (v-iashin)
+"""
+from collections import namedtuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+import sys
+sys.path.insert(0, '.')  # nopep8
+from foleycrafter.models.specvqgan.modules.losses.vggishish.model import VGGishish
+from foleycrafter.models.specvqgan.util import get_ckpt_path
+
+
+class LPAPS(nn.Module):
+    # Learned perceptual metric
+    def __init__(self, use_dropout=True):
+        super().__init__()
+        self.scaling_layer = ScalingLayer()
+        self.chns = [64, 128, 256, 512, 512]  # vggish16 features
+        self.net = vggishish16(pretrained=True, requires_grad=False)
+        self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
+        self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
+        self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
+        self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
+        self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
+        self.load_from_pretrained()
+        for param in self.parameters():
+            param.requires_grad = False
+
+    def load_from_pretrained(self, name="vggishish_lpaps"):
+        ckpt = get_ckpt_path(name, "specvqgan/modules/autoencoder/lpaps")
+        self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
+        print("loaded pretrained LPAPS loss from {}".format(ckpt))
+
+    @classmethod
+    def from_pretrained(cls, name="vggishish_lpaps"):
+        if name != "vggishish_lpaps":
+            raise NotImplementedError
+        model = cls()
+        ckpt = get_ckpt_path(name)
+        model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
+        return model
+
+    def forward(self, input, target):
+        in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
+        outs0, outs1 = self.net(in0_input), self.net(in1_input)
+        feats0, feats1, diffs = {}, {}, {}
+        lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
+        for kk in range(len(self.chns)):
+            feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
+            diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
+
+        res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
+        val = res[0]
+        for l in range(1, len(self.chns)):
+            val += res[l]
+        return val
+
+class ScalingLayer(nn.Module):
+    def __init__(self):
+        super(ScalingLayer, self).__init__()
+        # we are gonna use get_ckpt_path to donwload the stats as well
+        stat_path = get_ckpt_path('vggishish_mean_std_melspec_10s_22050hz', 'specvqgan/modules/autoencoder/lpaps')
+        # if for images we normalize on the channel dim, in spectrogram we will norm on frequency dimension
+        means, stds = np.loadtxt(stat_path, dtype=np.float32).T
+        # the normalization in means and stds are given for [0, 1], but specvqgan expects [-1, 1]:
+        means = 2 * means - 1
+        stds = 2 * stds
+        # input is expected to be (B, 1, F, T)
+        self.register_buffer('shift', torch.from_numpy(means)[None, None, :, None])
+        self.register_buffer('scale', torch.from_numpy(stds)[None, None, :, None])
+
+    def forward(self, inp):
+        return (inp - self.shift) / self.scale
+
+
+class NetLinLayer(nn.Module):
+    """ A single linear layer which does a 1x1 conv """
+    def __init__(self, chn_in, chn_out=1, use_dropout=False):
+        super(NetLinLayer, self).__init__()
+        layers = [nn.Dropout(), ] if (use_dropout) else []
+        layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
+        self.model = nn.Sequential(*layers)
+
+class vggishish16(torch.nn.Module):
+    def __init__(self, requires_grad=False, pretrained=True):
+        super().__init__()
+        vgg_pretrained_features = self.vggishish16(pretrained=pretrained).features
+        self.slice1 = torch.nn.Sequential()
+        self.slice2 = torch.nn.Sequential()
+        self.slice3 = torch.nn.Sequential()
+        self.slice4 = torch.nn.Sequential()
+        self.slice5 = torch.nn.Sequential()
+        self.N_slices = 5
+        for x in range(4):
+            self.slice1.add_module(str(x), vgg_pretrained_features[x])
+        for x in range(4, 9):
+            self.slice2.add_module(str(x), vgg_pretrained_features[x])
+        for x in range(9, 16):
+            self.slice3.add_module(str(x), vgg_pretrained_features[x])
+        for x in range(16, 23):
+            self.slice4.add_module(str(x), vgg_pretrained_features[x])
+        for x in range(23, 30):
+            self.slice5.add_module(str(x), vgg_pretrained_features[x])
+        if not requires_grad:
+            for param in self.parameters():
+                param.requires_grad = False
+
+    def forward(self, X):
+        h = self.slice1(X)
+        h_relu1_2 = h
+        h = self.slice2(h)
+        h_relu2_2 = h
+        h = self.slice3(h)
+        h_relu3_3 = h
+        h = self.slice4(h)
+        h_relu4_3 = h
+        h = self.slice5(h)
+        h_relu5_3 = h
+        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
+        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
+        return out
+
+    def vggishish16(self, pretrained: bool = True) -> VGGishish:
+        # loading vggishish pretrained on vggsound
+        num_classes_vggsound = 309
+        conv_layers = [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 'MP', 512, 512, 512, 'MP', 512, 512, 512]
+        model = VGGishish(conv_layers, use_bn=False, num_classes=num_classes_vggsound)
+        if pretrained:
+            ckpt_path = get_ckpt_path('vggishish_lpaps', "specvqgan/modules/autoencoder/lpaps")
+            ckpt = torch.load(ckpt_path, map_location=torch.device("cpu"))
+            model.load_state_dict(ckpt, strict=False)
+        return model
+
+def normalize_tensor(x, eps=1e-10):
+    norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
+    return x / (norm_factor+eps)
+
+def spatial_average(x, keepdim=True):
+    return x.mean([2, 3], keepdim=keepdim)
+
+
+if __name__ == '__main__':
+    inputs = torch.rand((16, 1, 80, 848))
+    reconstructions = torch.rand((16, 1, 80, 848))
+    lpips = LPAPS().eval()
+    loss_p = lpips(inputs.contiguous(), reconstructions.contiguous())
+    # (16, 1, 1, 1)
+    print(loss_p.shape)
diff --git a/foleycrafter/models/specvqgan/modules/losses/vggishish/configs/melception.yaml b/foleycrafter/models/specvqgan/modules/losses/vggishish/configs/melception.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4c0316968a3e779804223d33e25f4574bea75392
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/vggishish/configs/melception.yaml
@@ -0,0 +1,24 @@
+seed: 1337
+log_code_state: True
+# patterns to ignore when backing up the code folder
+patterns_to_ignore: ['logs', '.git', '__pycache__', 'data', 'checkpoints', '*.pt']
+
+# data:
+mels_path: '/home/nvme/data/vggsound/features/melspec_10s_22050hz/'
+spec_shape: [80, 860]
+cropped_size: [80, 848]
+random_crop: False
+
+# train:
+device: 'cuda:0'
+batch_size: 8
+num_workers: 0
+optimizer: adam
+betas: [0.9, 0.999]
+momentum: 0.9
+learning_rate: 3e-4
+weight_decay: 0
+num_epochs: 100
+patience: 3
+logdir: './logs'
+cls_weights_in_loss: False
diff --git a/foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish.yaml b/foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f97359658fe257f995037e17b66244879a630498
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish.yaml
@@ -0,0 +1,34 @@
+seed: 1337
+log_code_state: True
+# patterns to ignore when backing up the code folder
+patterns_to_ignore: ['logs', '.git', '__pycache__']
+
+# data:
+mels_path: '/home/nvme/data/vggsound/features/melspec_10s_22050hz/'
+spec_shape: [80, 860]
+cropped_size: [80, 848]
+random_crop: False
+
+# model:
+# original vgg family except for MP is missing at the end
+# 'vggish': [64, 'MP', 128, 'MP', 256, 256, 'MP', 512, 512]
+# 'vgg11': [64, 'MP', 128, 'MP', 256, 256, 'MP', 512, 512, 'MP', 512, 512],
+# 'vgg13': [64, 64, 'MP', 128, 128, 'MP', 256, 256, 'MP', 512, 512, 'MP', 512, 512],
+# 'vgg16': [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 'MP', 512, 512, 512, 'MP', 512, 512, 512],
+# 'vgg19': [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 256, 'MP', 512, 512, 512, 512, 'MP', 512, 512, 512, 512],
+conv_layers: [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 'MP', 512, 512, 512, 'MP', 512, 512, 512]
+use_bn: False
+
+# train:
+device: 'cuda:0'
+batch_size: 32
+num_workers: 0
+optimizer: adam
+betas: [0.9, 0.999]
+momentum: 0.9
+learning_rate: 3e-4
+weight_decay: 0.0001
+num_epochs: 100
+patience: 3
+logdir: './logs'
+cls_weights_in_loss: False
diff --git a/foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish_gh.yaml b/foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish_gh.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..efa5f147cf88d1760f7004a7bea7f86902e7cc47
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish_gh.yaml
@@ -0,0 +1,25 @@
+seed: 1337
+log_code_state: True
+patterns_to_ignore: ['logs', '.git', '__pycache__']
+
+mels_path: '/home/duyxxd/SpecVQGAN/data/greatesthit/melspec_10s_22050hz'
+batch_size: 32
+num_workers: 8
+device: 'cuda:0'
+conv_layers: [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 'MP', 512, 512, 512, 'MP', 512, 512, 512]
+use_bn: False
+optimizer: adam
+learning_rate: 1e-4
+betas: [0.9, 0.999]
+cropped_size: [80, 160]
+momentum: 0.9
+weight_decay: 1e-4
+cls_weights_in_loss: False
+num_epochs: 100
+patience: 20
+logdir: '/home/duyxxd/SpecVQGAN/logs'
+exp_name: 'mix'
+action_only: False
+material_only: False
+
+load_model: /home/duyxxd/SpecVQGAN/logs/vggishish16.pt
\ No newline at end of file
diff --git a/foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish_gh_action.yaml b/foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish_gh_action.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bd7df483cf0ff1a0a62d0f84ee852511c94e73b9
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish_gh_action.yaml
@@ -0,0 +1,25 @@
+seed: 1337
+log_code_state: True
+patterns_to_ignore: ['logs', '.git', '__pycache__']
+
+mels_path: '/home/duyxxd/SpecVQGAN/data/greatesthit/melspec_10s_22050hz'
+batch_size: 32
+num_workers: 8
+device: 'cuda:0'
+conv_layers: [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 'MP', 512, 512, 512, 'MP', 512, 512, 512]
+use_bn: False
+optimizer: adam
+learning_rate: 1e-4
+betas: [0.9, 0.999]
+cropped_size: [80, 160]
+momentum: 0.9
+weight_decay: 1e-4
+cls_weights_in_loss: False
+num_epochs: 20
+patience: 20
+logdir: '/home/duyxxd/SpecVQGAN/logs'
+exp_name: 'action'
+action_only: True
+material_only: False
+
+load_model: /home/duyxxd/SpecVQGAN/logs/vggishish16.pt
\ No newline at end of file
diff --git a/foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish_gh_material.yaml b/foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish_gh_material.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..beba550c3f850279b42308a2613a8fae59de5377
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish_gh_material.yaml
@@ -0,0 +1,25 @@
+seed: 1337
+log_code_state: True
+patterns_to_ignore: ['logs', '.git', '__pycache__']
+
+mels_path: '/home/duyxxd/SpecVQGAN/data/greatesthit/melspec_10s_22050hz'
+batch_size: 32
+num_workers: 8
+device: 'cuda:0'
+conv_layers: [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 'MP', 512, 512, 512, 'MP', 512, 512, 512]
+use_bn: False
+optimizer: adam
+learning_rate: 1e-4
+betas: [0.9, 0.999]
+cropped_size: [80, 160]
+momentum: 0.9
+weight_decay: 1e-4
+cls_weights_in_loss: False
+num_epochs: 20
+patience: 20
+logdir: '/home/duyxxd/SpecVQGAN/logs'
+exp_name: 'material'
+action_only: False
+material_only: True
+
+load_model: /home/duyxxd/SpecVQGAN/logs/vggishish16.pt
\ No newline at end of file
diff --git a/foleycrafter/models/specvqgan/modules/losses/vggishish/dataset.py b/foleycrafter/models/specvqgan/modules/losses/vggishish/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b9603b9f4630079b0f0712c8ef78ef09044e325
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/vggishish/dataset.py
@@ -0,0 +1,295 @@
+import collections
+import csv
+import logging
+import os
+import random
+import math
+import json
+from glob import glob
+from pathlib import Path
+
+import numpy as np
+import torch
+import torchvision
+
+logger = logging.getLogger(f'main.{__name__}')
+
+
+class VGGSound(torch.utils.data.Dataset):
+
+    def __init__(self, split, specs_dir, transforms=None, splits_path='./data', meta_path='./data/vggsound.csv'):
+        super().__init__()
+        self.split = split
+        self.specs_dir = specs_dir
+        self.transforms = transforms
+        self.splits_path = splits_path
+        self.meta_path = meta_path
+
+        vggsound_meta = list(csv.reader(open(meta_path), quotechar='"'))
+        unique_classes = sorted(list(set(row[2] for row in vggsound_meta)))
+        self.label2target = {label: target for target, label in enumerate(unique_classes)}
+        self.target2label = {target: label for label, target in self.label2target.items()}
+        self.video2target = {row[0]: self.label2target[row[2]] for row in vggsound_meta}
+
+        split_clip_ids_path = os.path.join(splits_path, f'vggsound_{split}_partial.txt')
+        print('&&&&&&&&&&&&&&&&', split_clip_ids_path)
+        if not os.path.exists(split_clip_ids_path):
+            self.make_split_files()
+        clip_ids_with_timestamp = open(split_clip_ids_path).read().splitlines()
+        clip_paths = [os.path.join(specs_dir, v + '_mel.npy') for v in clip_ids_with_timestamp]
+        self.dataset = clip_paths
+        # self.dataset = clip_paths[:10000]  # overfit one batch
+
+        # 'zyTX_1BXKDE_16000_26000'[:11] -> 'zyTX_1BXKDE'
+        vid_classes = [self.video2target[Path(path).stem[:11]] for path in self.dataset]
+        class2count = collections.Counter(vid_classes)
+        self.class_counts = torch.tensor([class2count[cls] for cls in range(len(class2count))])
+        # self.sample_weights = [len(self.dataset) / class2count[self.video2target[Path(path).stem[:11]]] for path in self.dataset]
+
+    def __getitem__(self, idx):
+        item = {}
+
+        spec_path = self.dataset[idx]
+        # 'zyTX_1BXKDE_16000_26000' -> 'zyTX_1BXKDE'
+        video_name = Path(spec_path).stem[:11]
+
+        item['input'] = np.load(spec_path)
+        item['input_path'] = spec_path
+
+        # if self.split in ['train', 'valid']:
+        item['target'] = self.video2target[video_name]
+        item['label'] = self.target2label[item['target']]
+
+        if self.transforms is not None:
+            item = self.transforms(item)
+
+        return item
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def make_split_files(self):
+        random.seed(1337)
+        logger.info(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.')
+        # The downloaded videos (some went missing on YouTube and no longer available)
+        available_vid_paths = sorted(glob(os.path.join(self.specs_dir, '*_mel.npy')))
+        logger.info(f'The number of clips available after download: {len(available_vid_paths)}')
+
+        # original (full) train and test sets
+        vggsound_meta = list(csv.reader(open(self.meta_path), quotechar='"'))
+        train_vids = {row[0] for row in vggsound_meta if row[3] == 'train'}
+        test_vids = {row[0] for row in vggsound_meta if row[3] == 'test'}
+        logger.info(f'The number of videos in vggsound train set: {len(train_vids)}')
+        logger.info(f'The number of videos in vggsound test set: {len(test_vids)}')
+
+        # class counts in test set. We would like to have the same distribution in valid
+        unique_classes = sorted(list(set(row[2] for row in vggsound_meta)))
+        label2target = {label: target for target, label in enumerate(unique_classes)}
+        video2target = {row[0]: label2target[row[2]] for row in vggsound_meta}
+        test_vid_classes = [video2target[vid] for vid in test_vids]
+        test_target2count = collections.Counter(test_vid_classes)
+
+        # now given the counts from test set, sample the same count for validation and the rest leave in train
+        train_vids_wo_valid, valid_vids = set(), set()
+        for target, label in enumerate(label2target.keys()):
+            class_train_vids = [vid for vid in train_vids if video2target[vid] == target]
+            random.shuffle(class_train_vids)
+            count = test_target2count[target]
+            valid_vids.update(class_train_vids[:count])
+            train_vids_wo_valid.update(class_train_vids[count:])
+
+        # make file with a list of available test videos (each video should contain timestamps as well)
+        train_i = valid_i = test_i = 0
+        with open(os.path.join(self.splits_path, 'vggsound_train.txt'), 'w') as train_file, \
+             open(os.path.join(self.splits_path, 'vggsound_valid.txt'), 'w') as valid_file, \
+             open(os.path.join(self.splits_path, 'vggsound_test.txt'), 'w') as test_file:
+            for path in available_vid_paths:
+                path = path.replace('_mel.npy', '')
+                vid_name = Path(path).name
+                # 'zyTX_1BXKDE_16000_26000'[:11] -> 'zyTX_1BXKDE'
+                if vid_name[:11] in train_vids_wo_valid:
+                    train_file.write(vid_name + '\n')
+                    train_i += 1
+                elif vid_name[:11] in valid_vids:
+                    valid_file.write(vid_name + '\n')
+                    valid_i += 1
+                elif vid_name[:11] in test_vids:
+                    test_file.write(vid_name + '\n')
+                    test_i += 1
+                else:
+                    raise Exception(f'Clip {vid_name} is neither in train, valid nor test. Strange.')
+
+        logger.info(f'Put {train_i} clips to the train set and saved it to ./data/vggsound_train.txt')
+        logger.info(f'Put {valid_i} clips to the valid set and saved it to ./data/vggsound_valid.txt')
+        logger.info(f'Put {test_i} clips to the test set and saved it to ./data/vggsound_test.txt')
+
+
+def get_GH_data_identifier(video_name, start_idx, split='_'):
+    if isinstance(start_idx, str):
+        return video_name + split + start_idx
+    elif isinstance(start_idx, int):
+        return video_name + split + str(start_idx)
+    else:
+        raise NotImplementedError
+
+
+class GreatestHit(torch.utils.data.Dataset):
+
+    def __init__(self, split, spec_dir_path, spec_transform=None, L=2.0, action_only=False,
+                material_only=False, splits_path='/home/duyxxd/SpecVQGAN/data', 
+                meta_path='/home/duyxxd/SpecVQGAN/data/info_r2plus1d_dim1024_15fps.json'):
+        super().__init__()
+        self.split = split
+        self.specs_dir = spec_dir_path
+        self.splits_path = splits_path
+        self.meta_path = meta_path
+        self.spec_transform = spec_transform
+        self.L = L
+        self.spec_take_first = int(math.ceil(860 * (L / 10.) / 32) * 32)
+        self.spec_take_first = 860 if self.spec_take_first > 860 else self.spec_take_first
+        self.spec_take_first = 173
+
+        greatesthit_meta = json.load(open(self.meta_path, 'r'))
+        self.video_idx2label = {
+            get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]): 
+            greatesthit_meta['hit_type'][i] for i in range(len(greatesthit_meta['video_name']))
+        }
+        self.available_video_hit = list(self.video_idx2label.keys())
+        self.video_idx2path = {
+            vh: os.path.join(self.specs_dir, 
+                vh.replace('_', '_denoised_') + '_' + self.video_idx2label[vh].replace(' ', '_') +'_mel.npy')
+            for vh in self.available_video_hit
+        }
+        self.video_idx2idx = {
+            get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]):
+            i for i in range(len(greatesthit_meta['video_name']))
+        }
+
+        split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}_2.00_single_type_only.json')
+        if not os.path.exists(split_clip_ids_path):
+            raise NotImplementedError()
+        clip_video_hit = json.load(open(split_clip_ids_path, 'r'))
+        self.dataset = list(clip_video_hit.keys())
+        if action_only:
+            self.video_idx2label = {k: v.split(' ')[1] for k, v in clip_video_hit.items()}
+        elif material_only:
+            self.video_idx2label = {k: v.split(' ')[0] for k, v in clip_video_hit.items()}
+        else:
+            self.video_idx2label = clip_video_hit
+
+
+        self.video2indexes = {}
+        for video_idx in self.dataset:
+            video, start_idx = video_idx.split('_')
+            if video not in self.video2indexes.keys():
+                self.video2indexes[video] = []
+            self.video2indexes[video].append(start_idx)
+        for video in self.video2indexes.keys():
+            if len(self.video2indexes[video]) == 1: # given video contains only one hit
+                self.dataset.remove(
+                    get_GH_data_identifier(video, self.video2indexes[video][0])
+                )
+
+        vid_classes = list(self.video_idx2label.values())
+        unique_classes = sorted(list(set(vid_classes)))
+        self.label2target = {label: target for target, label in enumerate(unique_classes)}
+        if action_only:
+            label2target_fix = {'hit': 0, 'scratch': 1}
+        elif material_only:
+            label2target_fix = {'carpet': 0, 'ceramic': 1, 'cloth': 2, 'dirt': 3, 'drywall': 4, 'glass': 5, 'grass': 6, 'gravel': 7, 'leaf': 8, 'metal': 9, 'paper': 10, 'plastic': 11, 'plastic-bag': 12, 'rock': 13, 'tile': 14, 'water': 15, 'wood': 16}
+        else:
+            label2target_fix = {'carpet hit': 0, 'carpet scratch': 1, 'ceramic hit': 2, 'ceramic scratch': 3, 'cloth hit': 4, 'cloth scratch': 5, 'dirt hit': 6, 'dirt scratch': 7, 'drywall hit': 8, 'drywall scratch': 9, 'glass hit': 10, 'glass scratch': 11, 'grass hit': 12, 'grass scratch': 13, 'gravel hit': 14, 'gravel scratch': 15, 'leaf hit': 16, 'leaf scratch': 17, 'metal hit': 18, 'metal scratch': 19, 'paper hit': 20, 'paper scratch': 21, 'plastic hit': 22, 'plastic scratch': 23, 'plastic-bag hit': 24, 'plastic-bag scratch': 25, 'rock hit': 26, 'rock scratch': 27, 'tile hit': 28, 'tile scratch': 29, 'water hit': 30, 'water scratch': 31, 'wood hit': 32, 'wood scratch': 33}
+        for k in self.label2target.keys():
+            assert k in label2target_fix.keys()
+        self.label2target = label2target_fix
+        self.target2label = {target: label for label, target in self.label2target.items()}
+        class2count = collections.Counter(vid_classes)
+        self.class_counts = torch.tensor([class2count[cls] for cls in range(len(class2count))])
+        print(self.label2target)
+        print(len(vid_classes), len(class2count), class2count)
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def __getitem__(self, idx):
+        item = {}
+
+        video_idx = self.dataset[idx]
+        spec_path = self.video_idx2path[video_idx]
+        spec = np.load(spec_path) # (80, 860)
+
+        # concat spec outside dataload
+        item['input'] = 2 * spec - 1 # (80, 860)
+        item['input'] = item['input'][:, :self.spec_take_first] # (80, 173) (since 2sec audio can only generate 173)
+        item['file_path'] = spec_path
+
+        item['label'] = self.video_idx2label[video_idx]
+        item['target'] = self.label2target[item['label']]
+
+        if self.spec_transform is not None:
+            item = self.spec_transform(item)
+
+        return item
+
+
+
+class AMT_test(torch.utils.data.Dataset):
+
+    def __init__(self, spec_dir_path, spec_transform=None, action_only=False, material_only=False):
+        super().__init__()
+        self.specs_dir = spec_dir_path
+        self.spec_transform = spec_transform
+        self.spec_take_first = 173
+
+        self.dataset = sorted([os.path.join(self.specs_dir, f) for f in os.listdir(self.specs_dir)])
+        if action_only:
+            self.label2target = {'hit': 0, 'scratch': 1}
+        elif material_only:
+            self.label2target = {'carpet': 0, 'ceramic': 1, 'cloth': 2, 'dirt': 3, 'drywall': 4, 'glass': 5, 'grass': 6, 'gravel': 7, 'leaf': 8, 'metal': 9, 'paper': 10, 'plastic': 11, 'plastic-bag': 12, 'rock': 13, 'tile': 14, 'water': 15, 'wood': 16}
+        else:
+            self.label2target = {'carpet hit': 0, 'carpet scratch': 1, 'ceramic hit': 2, 'ceramic scratch': 3, 'cloth hit': 4, 'cloth scratch': 5, 'dirt hit': 6, 'dirt scratch': 7, 'drywall hit': 8, 'drywall scratch': 9, 'glass hit': 10, 'glass scratch': 11, 'grass hit': 12, 'grass scratch': 13, 'gravel hit': 14, 'gravel scratch': 15, 'leaf hit': 16, 'leaf scratch': 17, 'metal hit': 18, 'metal scratch': 19, 'paper hit': 20, 'paper scratch': 21, 'plastic hit': 22, 'plastic scratch': 23, 'plastic-bag hit': 24, 'plastic-bag scratch': 25, 'rock hit': 26, 'rock scratch': 27, 'tile hit': 28, 'tile scratch': 29, 'water hit': 30, 'water scratch': 31, 'wood hit': 32, 'wood scratch': 33}
+        self.target2label = {v: k for k, v in self.label2target.items()}
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def __getitem__(self, idx):
+        item = {}
+
+        spec_path = self.dataset[idx]
+        spec = np.load(spec_path) # (80, 860)
+
+        # concat spec outside dataload
+        item['input'] = 2 * spec - 1 # (80, 860)
+        item['input'] = item['input'][:, :self.spec_take_first] # (80, 173) (since 2sec audio can only generate 173)
+        item['file_path'] = spec_path
+
+        if self.spec_transform is not None:
+            item = self.spec_transform(item)
+
+        return item
+
+
+if __name__ == '__main__':
+    from transforms import Crop, StandardNormalizeAudio, ToTensor
+    specs_path = '/home/nvme/data/vggsound/features/melspec_10s_22050hz/'
+
+    transforms = torchvision.transforms.transforms.Compose([
+        StandardNormalizeAudio(specs_path),
+        ToTensor(),
+        Crop([80, 848]),
+    ])
+
+    datasets = {
+        'train': VGGSound('train', specs_path, transforms),
+        'valid': VGGSound('valid', specs_path, transforms),
+        'test': VGGSound('test', specs_path, transforms),
+    }
+
+    print(datasets['train'][0])
+    print(datasets['valid'][0])
+    print(datasets['test'][0])
+
+    print(datasets['train'].class_counts)
+    print(datasets['valid'].class_counts)
+    print(datasets['test'].class_counts)
diff --git a/foleycrafter/models/specvqgan/modules/losses/vggishish/logger.py b/foleycrafter/models/specvqgan/modules/losses/vggishish/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6205dec53e29b62e2901fd899fcf02ee0eb8807
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/vggishish/logger.py
@@ -0,0 +1,90 @@
+import logging
+import os
+import time
+from shutil import copytree, ignore_patterns
+
+import torch
+from omegaconf import OmegaConf
+from torch.utils.tensorboard import SummaryWriter, summary
+
+
+class LoggerWithTBoard(SummaryWriter):
+
+    def __init__(self, cfg):
+        # current time stamp and experiment log directory
+        self.start_time = time.strftime('%y-%m-%dT%H-%M-%S', time.localtime())
+        if cfg.exp_name is not None:
+            self.logdir = os.path.join(cfg.logdir, self.start_time + f'_{cfg.exp_name}')
+        else:
+            self.logdir = os.path.join(cfg.logdir, self.start_time)
+        # init tboard
+        super().__init__(self.logdir)
+        # backup the cfg
+        OmegaConf.save(cfg, os.path.join(self.log_dir, 'cfg.yaml'))
+        # backup the code state
+        if cfg.log_code_state:
+            dest_dir = os.path.join(self.logdir, 'code')
+            copytree(os.getcwd(), dest_dir, ignore=ignore_patterns(*cfg.patterns_to_ignore))
+
+        # init logger which handles printing and logging mostly same things to the log file
+        self.print_logger = logging.getLogger('main')
+        self.print_logger.setLevel(logging.INFO)
+        msgfmt = '[%(levelname)s] %(asctime)s - %(name)s \n    %(message)s'
+        datefmt = '%d %b %Y %H:%M:%S'
+        formatter = logging.Formatter(msgfmt, datefmt)
+        # stdout
+        sh = logging.StreamHandler()
+        sh.setLevel(logging.DEBUG)
+        sh.setFormatter(formatter)
+        self.print_logger.addHandler(sh)
+        # log file
+        fh = logging.FileHandler(os.path.join(self.log_dir, 'log.txt'))
+        fh.setLevel(logging.INFO)
+        fh.setFormatter(formatter)
+        self.print_logger.addHandler(fh)
+
+        self.print_logger.info(f'Saving logs and checkpoints @ {self.logdir}')
+
+    def log_param_num(self, model):
+        param_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
+        self.print_logger.info(f'The number of parameters: {param_num/1e+6:.3f} mil')
+        self.add_scalar('num_params', param_num, 0)
+        return param_num
+
+    def log_iter_loss(self, loss, iter, phase):
+        self.add_scalar(f'{phase}/loss_iter', loss, iter)
+
+    def log_epoch_loss(self, loss, epoch, phase):
+        self.add_scalar(f'{phase}/loss', loss, epoch)
+        self.print_logger.info(f'{phase} ({epoch}): loss {loss:.3f};')
+
+    def log_epoch_metrics(self, metrics_dict, epoch, phase):
+        for metric, val in metrics_dict.items():
+            self.add_scalar(f'{phase}/{metric}', val, epoch)
+        metrics_dict = {k: round(v, 4) for k, v in metrics_dict.items()}
+        self.print_logger.info(f'{phase} ({epoch}) metrics: {metrics_dict};')
+
+    def log_test_metrics(self, metrics_dict, hparams_dict, best_epoch):
+        allowed_types = (int, float, str, bool, torch.Tensor)
+        hparams_dict = {k: v for k, v in hparams_dict.items() if isinstance(v, allowed_types)}
+        metrics_dict = {f'test/{k}': round(v, 4) for k, v in metrics_dict.items()}
+        exp, ssi, sei = summary.hparams(hparams_dict, metrics_dict)
+        self.file_writer.add_summary(exp)
+        self.file_writer.add_summary(ssi)
+        self.file_writer.add_summary(sei)
+        for k, v in metrics_dict.items():
+            self.add_scalar(k, v, best_epoch)
+        self.print_logger.info(f'test ({best_epoch}) metrics: {metrics_dict};')
+
+    def log_best_model(self, model, loss, epoch, optimizer, metrics_dict):
+        model_name = model.__class__.__name__
+        self.best_model_path = os.path.join(self.logdir, f'{model_name}-{self.start_time}.pt')
+        checkpoint = {
+            'loss': loss,
+            'metrics': metrics_dict,
+            'epoch': epoch,
+            'optimizer': optimizer.state_dict(),
+            'model': model.state_dict(),
+        }
+        torch.save(checkpoint, self.best_model_path)
+        self.print_logger.info(f'Saved model in {self.best_model_path}')
diff --git a/foleycrafter/models/specvqgan/modules/losses/vggishish/loss.py b/foleycrafter/models/specvqgan/modules/losses/vggishish/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..bae76571909eec571aaf075d58e3dea8f6424546
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/vggishish/loss.py
@@ -0,0 +1,41 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+
+class WeightedCrossEntropy(nn.CrossEntropyLoss):
+
+    def __init__(self, weights, **pytorch_ce_loss_args) -> None:
+        super().__init__(reduction='none', **pytorch_ce_loss_args)
+        self.weights = weights
+
+    def __call__(self, outputs, targets, to_weight=True):
+        loss = super().__call__(outputs, targets)
+        if to_weight:
+            return (loss * self.weights[targets]).sum() / self.weights[targets].sum()
+        else:
+            return loss.mean()
+
+
+if __name__ == '__main__':
+    x = torch.randn(10, 5)
+    target = torch.randint(0, 5, (10,))
+    weights = torch.tensor([1., 2., 3., 4., 5.])
+
+    # criterion_weighted = nn.CrossEntropyLoss(weight=weights)
+    # loss_weighted = criterion_weighted(x, target)
+
+    # criterion_weighted_manual = nn.CrossEntropyLoss(reduction='none')
+    # loss_weighted_manual = criterion_weighted_manual(x, target)
+    # print(loss_weighted, loss_weighted_manual.mean())
+    # loss_weighted_manual = (loss_weighted_manual * weights[target]).sum() / weights[target].sum()
+    # print(loss_weighted, loss_weighted_manual)
+    # print(torch.allclose(loss_weighted, loss_weighted_manual))
+
+    pytorch_weighted = nn.CrossEntropyLoss(weight=weights)
+    pytorch_unweighted = nn.CrossEntropyLoss()
+    custom = WeightedCrossEntropy(weights)
+
+    assert torch.allclose(pytorch_weighted(x, target), custom(x, target, to_weight=True))
+    assert torch.allclose(pytorch_unweighted(x, target), custom(x, target, to_weight=False))
+    print(custom(x, target, to_weight=True), custom(x, target, to_weight=False))
diff --git a/foleycrafter/models/specvqgan/modules/losses/vggishish/metrics.py b/foleycrafter/models/specvqgan/modules/losses/vggishish/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..16905224c665491b9869d7641c1fe17689816a4b
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/vggishish/metrics.py
@@ -0,0 +1,69 @@
+import logging
+
+import numpy as np
+import scipy
+import torch
+from sklearn.metrics import average_precision_score, roc_auc_score
+
+logger = logging.getLogger(f'main.{__name__}')
+
+def metrics(targets, outputs, topk=(1, 5)):
+    """
+    Adapted from https://github.com/hche11/VGGSound/blob/master/utils.py
+
+    Calculate statistics including mAP, AUC, and d-prime.
+        Args:
+            output: 2d tensors, (dataset_size, classes_num) - before softmax
+            target: 1d tensors, (dataset_size, )
+            topk: tuple
+        Returns:
+            metric_dict: a dict of metrics
+    """
+    metrics_dict = dict()
+
+    num_cls = outputs.shape[-1]
+
+    # accuracy@k
+    _, preds = torch.topk(outputs, k=max(topk), dim=1)
+    correct_for_maxtopk = preds == targets.view(-1, 1).expand_as(preds)
+    for k in topk:
+        metrics_dict[f'accuracy_{k}'] = float(correct_for_maxtopk[:, :k].sum() / correct_for_maxtopk.shape[0])
+
+    # avg precision, average roc_auc, and dprime
+    targets = torch.nn.functional.one_hot(targets, num_classes=num_cls)
+
+    # ids of the predicted classes (same as softmax)
+    targets_pred = torch.softmax(outputs, dim=1)
+
+    targets = targets.numpy()
+    targets_pred = targets_pred.numpy()
+
+    # one-vs-rest
+    avg_p = [average_precision_score(targets[:, c], targets_pred[:, c], average=None) for c in range(num_cls)]
+    try:
+        roc_aucs = [roc_auc_score(targets[:, c], targets_pred[:, c], average=None) for c in range(num_cls)]
+    except ValueError:
+        logger.warning('Weird... Some classes never occured in targets. Do not trust the metrics.')
+        roc_aucs = np.array([0.5])
+        avg_p = np.array([0])
+
+    metrics_dict['mAP'] = np.mean(avg_p)
+    metrics_dict['mROCAUC'] = np.mean(roc_aucs)
+    # Percent point function (ppf) (inverse of cdf — percentiles).
+    metrics_dict['dprime'] = scipy.stats.norm().ppf(metrics_dict['mROCAUC']) * np.sqrt(2)
+
+    return metrics_dict
+
+
+if __name__ == '__main__':
+    targets = torch.tensor([3, 3, 1, 2, 1, 0])
+    outputs = torch.tensor([
+        [1.2, 1.3, 1.1, 1.5],
+        [1.3, 1.4, 1.0, 1.1],
+        [1.5, 1.1, 1.4, 1.3],
+        [1.0, 1.2, 1.4, 1.5],
+        [1.2, 1.3, 1.1, 1.1],
+        [1.2, 1.1, 1.1, 1.1],
+    ]).float()
+    metrics_dict = metrics(targets, outputs, topk=(1, 3))
+    print(metrics_dict)
diff --git a/foleycrafter/models/specvqgan/modules/losses/vggishish/model.py b/foleycrafter/models/specvqgan/modules/losses/vggishish/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5069bad0d9311e6e2c082a63eca165f7a908675
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/vggishish/model.py
@@ -0,0 +1,77 @@
+import torch
+import torch.nn as nn
+
+
+class VGGishish(nn.Module):
+
+    def __init__(self, conv_layers, use_bn, num_classes):
+        '''
+        Mostly from
+            https://pytorch.org/vision/0.8/_modules/torchvision/models/vgg.html
+        '''
+        super().__init__()
+        layers = []
+        in_channels = 1
+
+        # a list of channels with 'MP' (maxpool) from config
+        for v in conv_layers:
+            if v == 'MP':
+                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
+            else:
+                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, stride=1)
+                if use_bn:
+                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
+                else:
+                    layers += [conv2d, nn.ReLU(inplace=True)]
+                in_channels = v
+        self.features = nn.Sequential(*layers)
+
+        self.avgpool = nn.AdaptiveAvgPool2d((5, 10))
+
+        self.flatten = nn.Flatten()
+        self.classifier = nn.Sequential(
+            nn.Linear(512 * 5 * 10, 4096),
+            nn.ReLU(True),
+            nn.Linear(4096, 4096),
+            nn.ReLU(True),
+            nn.Linear(4096, num_classes)
+        )
+
+        # weight init
+        self.reset_parameters()
+
+    def forward(self, x):
+        # adding channel dim for conv2d (B, 1, F, T) <-
+        x = x.unsqueeze(1)
+        # backbone (B, 1, 5, 53) <- (B, 1, 80, 860)
+        x = self.features(x)
+        # adaptive avg pooling (B, 1, 5, 10) <- (B, 1, 5, 53) – if no MP is used as the end of VGG
+        x = self.avgpool(x)
+        # flatten
+        x = self.flatten(x)
+        # classify
+        x = self.classifier(x)
+        return x
+
+    def reset_parameters(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.Linear):
+                nn.init.normal_(m.weight, 0, 0.01)
+                nn.init.constant_(m.bias, 0)
+
+
+if __name__ == '__main__':
+    num_classes = 309
+    inputs = torch.rand(3, 80, 848)
+    conv_layers = [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 'MP', 512, 512, 512, 'MP', 512, 512, 512]
+    # conv_layers = [64, 'MP', 128, 'MP', 256, 256, 'MP', 512, 512, 'MP']
+    model = VGGishish(conv_layers, use_bn=False, num_classes=num_classes)
+    outputs = model(inputs)
+    print(outputs.shape)
diff --git a/foleycrafter/models/specvqgan/modules/losses/vggishish/predict.py b/foleycrafter/models/specvqgan/modules/losses/vggishish/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9d13f30153cd43a4a8bcfe2da4b9a53846bf1eb
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/vggishish/predict.py
@@ -0,0 +1,90 @@
+import os
+from torch.utils.data import DataLoader
+import torchvision
+from tqdm import tqdm
+from dataset import VGGSound
+import torch
+import torch.nn as nn
+from metrics import metrics
+from omegaconf import OmegaConf
+from model import VGGishish
+from transforms import Crop, StandardNormalizeAudio, ToTensor
+
+
+if __name__ == '__main__':
+    cfg_cli = OmegaConf.from_cli()
+    print(cfg_cli.config)
+    cfg_yml = OmegaConf.load(cfg_cli.config)
+    # the latter arguments are prioritized
+    cfg = OmegaConf.merge(cfg_yml, cfg_cli)
+    OmegaConf.set_readonly(cfg, True)
+    print(OmegaConf.to_yaml(cfg))
+
+    # logger = LoggerWithTBoard(cfg)
+    transforms = [
+        StandardNormalizeAudio(cfg.mels_path),
+        ToTensor(),
+    ]
+    if cfg.cropped_size not in [None, 'None', 'none']:
+        transforms.append(Crop(cfg.cropped_size))
+    transforms = torchvision.transforms.transforms.Compose(transforms)
+
+    datasets = {
+        'test': VGGSound('test', cfg.mels_path, transforms),
+    }
+
+    loaders = {
+        'test': DataLoader(datasets['test'], batch_size=cfg.batch_size,
+                           num_workers=cfg.num_workers, pin_memory=True)
+    }
+
+    device = torch.device(cfg.device if torch.cuda.is_available() else 'cpu')
+    model = VGGishish(cfg.conv_layers, cfg.use_bn, num_classes=len(datasets['test'].target2label))
+    model = model.to(device)
+
+    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate)
+    criterion = nn.CrossEntropyLoss()
+
+    # loading the best model
+    folder_name = os.path.split(cfg.config)[0].split('/')[-1]
+    print(folder_name)
+    ckpt = torch.load(f'./logs/{folder_name}/vggishish-{folder_name}.pt', map_location='cpu')
+    model.load_state_dict(ckpt['model'])
+    print((f'The model was trained for {ckpt["epoch"]} epochs. Loss: {ckpt["loss"]:.4f}'))
+
+    # Testing the model
+    model.eval()
+    running_loss = 0
+    preds_from_each_batch = []
+    targets_from_each_batch = []
+
+    for i, batch in enumerate(tqdm(loaders['test'])):
+        inputs = batch['input'].to(device)
+        targets = batch['target'].to(device)
+
+        # zero the parameter gradients
+        optimizer.zero_grad()
+
+        # forward + backward + optimize
+        with torch.set_grad_enabled(False):
+            outputs = model(inputs)
+            loss = criterion(outputs, targets)
+
+        # loss
+        running_loss += loss.item()
+
+        # for metrics calculation later on
+        preds_from_each_batch += [outputs.detach().cpu()]
+        targets_from_each_batch += [targets.cpu()]
+
+    # logging metrics
+    preds_from_each_batch = torch.cat(preds_from_each_batch)
+    targets_from_each_batch = torch.cat(targets_from_each_batch)
+    test_metrics_dict = metrics(targets_from_each_batch, preds_from_each_batch)
+    test_metrics_dict['avg_loss'] = running_loss / len(loaders['test'])
+    test_metrics_dict['param_num'] = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+    # TODO: I have no idea why tboard doesn't keep metrics (hparams) in a tensorboard when
+    # I run this experiment from cli: `python main.py config=./configs/vggish.yaml`
+    # while when I run it in vscode debugger the metrics are present in the tboard (weird)
+    print(test_metrics_dict)
diff --git a/foleycrafter/models/specvqgan/modules/losses/vggishish/predict_gh.py b/foleycrafter/models/specvqgan/modules/losses/vggishish/predict_gh.py
new file mode 100644
index 0000000000000000000000000000000000000000..c912d2f506febc0f67f1a7e7844d250f4743b6d8
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/vggishish/predict_gh.py
@@ -0,0 +1,66 @@
+import os
+import sys
+import json
+from torch.utils.data import DataLoader
+import torchvision
+from tqdm import tqdm
+from dataset import GreatestHit, AMT_test
+import torch
+import torch.nn as nn
+from metrics import metrics
+from omegaconf import OmegaConf
+from model import VGGishish
+from transforms import Crop, StandardNormalizeAudio, ToTensor
+
+
+if __name__ == '__main__':
+    cfg_cli = sys.argv[1]
+    target_path = sys.argv[2]
+    model_path = sys.argv[3]
+    cfg_yml = OmegaConf.load(cfg_cli)
+    # the latter arguments are prioritized
+    cfg = cfg_yml
+    OmegaConf.set_readonly(cfg, True)
+    # print(OmegaConf.to_yaml(cfg))
+
+    device = torch.device(cfg.device if torch.cuda.is_available() else 'cpu')
+    transforms = [
+        StandardNormalizeAudio(cfg.mels_path),
+    ]
+    if cfg.cropped_size not in [None, 'None', 'none']:
+        transforms.append(Crop(cfg.cropped_size))
+    transforms.append(ToTensor())
+    transforms = torchvision.transforms.transforms.Compose(transforms)
+
+    testset = AMT_test(target_path, transforms, action_only=cfg.action_only, material_only=cfg.material_only)
+    loader = DataLoader(testset, batch_size=cfg.batch_size,
+                        num_workers=cfg.num_workers, pin_memory=True)
+
+    model = VGGishish(cfg.conv_layers, cfg.use_bn, num_classes=len(testset.label2target))
+    ckpt = torch.load(model_path)['model']
+    model.load_state_dict(ckpt, strict=True)
+    model = model.to(device)
+
+    model.eval()
+
+    if cfg.cls_weights_in_loss:
+        weights = 1 / testset.class_counts
+    else:
+        weights = torch.ones(len(testset.label2target))
+
+    preds_from_each_batch = []
+    file_path_from_each_batch = []
+    for batch in tqdm(loader):
+        inputs = batch['input'].to(device)
+        file_path = batch['file_path']
+        with torch.set_grad_enabled(False):
+            outputs = model(inputs)
+        # for metrics calculation later on
+        preds_from_each_batch += [outputs.detach().cpu()]
+        file_path_from_each_batch += file_path
+    preds_from_each_batch = torch.cat(preds_from_each_batch)
+    _, preds = torch.topk(preds_from_each_batch, k=1)
+    pred_dict = {fp: int(p.item()) for fp, p in zip(file_path_from_each_batch, preds)}
+    mel_parent_dir = os.path.dirname(list(pred_dict.keys())[0])
+    pred_list = [pred_dict[os.path.join(mel_parent_dir, f'{i}.npy')] for i in range(len(pred_dict))]
+    json.dump(pred_list, open(target_path + f'_{cfg.exp_name}_preds.json', 'w'))
diff --git a/foleycrafter/models/specvqgan/modules/losses/vggishish/train_melception.py b/foleycrafter/models/specvqgan/modules/losses/vggishish/train_melception.py
new file mode 100644
index 0000000000000000000000000000000000000000..8adc5aa6e0e32a66cdbb7b449483a3b23d9b0ef9
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/vggishish/train_melception.py
@@ -0,0 +1,241 @@
+import random
+
+import numpy as np
+import torch
+import torchvision
+from omegaconf import OmegaConf
+from torch.utils.data.dataloader import DataLoader
+from torchvision.models.inception import BasicConv2d, Inception3
+from tqdm import tqdm
+
+from dataset import VGGSound
+from logger import LoggerWithTBoard
+from loss import WeightedCrossEntropy
+from metrics import metrics
+from transforms import Crop, StandardNormalizeAudio, ToTensor
+
+
+# TODO: refactor  ./evaluation/feature_extractors/melception.py to handle this class as well.
+# So far couldn't do it because of the difference in outputs
+class Melception(Inception3):
+
+    def __init__(self, num_classes, **kwargs):
+        # inception = Melception(num_classes=309)
+        super().__init__(num_classes=num_classes, **kwargs)
+        # the same as https://github.com/pytorch/vision/blob/5339e63148/torchvision/models/inception.py#L95
+        # but for 1-channel input instead of RGB.
+        self.Conv2d_1a_3x3 = BasicConv2d(1, 32, kernel_size=3, stride=2)
+        # also the 'hight' of the mel spec is 80 (vs 299 in RGB) we remove all max pool from Inception
+        self.maxpool1 = torch.nn.Identity()
+        self.maxpool2 = torch.nn.Identity()
+
+    def forward(self, x):
+        x = x.unsqueeze(1)
+        return super().forward(x)
+
+def train_inception_scorer(cfg):
+    logger = LoggerWithTBoard(cfg)
+
+    random.seed(cfg.seed)
+    np.random.seed(cfg.seed)
+    torch.manual_seed(cfg.seed)
+    torch.cuda.manual_seed_all(cfg.seed)
+    # makes iterations faster (in this case 30%) if your inputs are of a fixed size
+    # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3
+    torch.backends.cudnn.benchmark = True
+
+    meta_path = './data/vggsound.csv'
+    train_ids_path = './data/vggsound_train.txt'
+    cache_path = './data/'
+    splits_path = cache_path
+
+    transforms = [
+        StandardNormalizeAudio(cfg.mels_path, train_ids_path, cache_path),
+    ]
+    if cfg.cropped_size not in [None, 'None', 'none']:
+        logger.print_logger.info(f'Using cropping {cfg.cropped_size}')
+        transforms.append(Crop(cfg.cropped_size))
+    transforms.append(ToTensor())
+    transforms = torchvision.transforms.transforms.Compose(transforms)
+
+    datasets = {
+        'train': VGGSound('train', cfg.mels_path, transforms, splits_path, meta_path),
+        'valid': VGGSound('valid', cfg.mels_path, transforms, splits_path, meta_path),
+        'test': VGGSound('test', cfg.mels_path, transforms, splits_path, meta_path),
+    }
+
+    loaders = {
+        'train': DataLoader(datasets['train'], batch_size=cfg.batch_size, shuffle=True, drop_last=True,
+                            num_workers=cfg.num_workers, pin_memory=True),
+        'valid': DataLoader(datasets['valid'], batch_size=cfg.batch_size,
+                            num_workers=cfg.num_workers, pin_memory=True),
+        'test': DataLoader(datasets['test'], batch_size=cfg.batch_size,
+                           num_workers=cfg.num_workers, pin_memory=True),
+    }
+
+    device = torch.device(cfg.device if torch.cuda.is_available() else 'cpu')
+
+    model = Melception(num_classes=len(datasets['train'].target2label))
+    model = model.to(device)
+    param_num = logger.log_param_num(model)
+
+    if cfg.optimizer == 'adam':
+        optimizer = torch.optim.Adam(
+            model.parameters(), lr=cfg.learning_rate, betas=cfg.betas, weight_decay=cfg.weight_decay)
+    elif cfg.optimizer == 'sgd':
+        optimizer = torch.optim.SGD(
+            model.parameters(), lr=cfg.learning_rate, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
+    else:
+        raise NotImplementedError
+
+    if cfg.cls_weights_in_loss:
+        weights = 1 / datasets['train'].class_counts
+    else:
+        weights = torch.ones(len(datasets['train'].target2label))
+    criterion = WeightedCrossEntropy(weights.to(device))
+
+    # loop over the train and validation multiple times (typical PT boilerplate)
+    no_change_epochs = 0
+    best_valid_loss = float('inf')
+    early_stop_triggered = False
+
+    for epoch in range(cfg.num_epochs):
+
+        for phase in ['train', 'valid']:
+            if phase == 'train':
+                model.train()
+            else:
+                model.eval()
+
+            running_loss = 0
+            preds_from_each_batch = []
+            targets_from_each_batch = []
+
+            prog_bar = tqdm(loaders[phase], f'{phase} ({epoch})', ncols=0)
+            for i, batch in enumerate(prog_bar):
+                inputs = batch['input'].to(device)
+                targets = batch['target'].to(device)
+
+                # zero the parameter gradients
+                optimizer.zero_grad()
+
+                # forward + backward + optimize
+                with torch.set_grad_enabled(phase == 'train'):
+                    # inception v3
+                    if phase == 'train':
+                        outputs, aux_outputs = model(inputs)
+                        loss1 = criterion(outputs, targets)
+                        loss2 = criterion(aux_outputs, targets)
+                        loss = loss1 + 0.4*loss2
+                        loss = criterion(outputs, targets, to_weight=True)
+                    else:
+                        outputs = model(inputs)
+                        loss = criterion(outputs, targets, to_weight=False)
+
+                if phase == 'train':
+                    loss.backward()
+                    optimizer.step()
+
+                # loss
+                running_loss += loss.item()
+
+                # for metrics calculation later on
+                preds_from_each_batch += [outputs.detach().cpu()]
+                targets_from_each_batch += [targets.cpu()]
+
+                # iter logging
+                if i % 50 == 0:
+                    logger.log_iter_loss(loss.item(), epoch*len(loaders[phase])+i, phase)
+                    # tracks loss in the tqdm progress bar
+                    prog_bar.set_postfix(loss=loss.item())
+
+            # logging loss
+            epoch_loss = running_loss / len(loaders[phase])
+            logger.log_epoch_loss(epoch_loss, epoch, phase)
+
+            # logging metrics
+            preds_from_each_batch = torch.cat(preds_from_each_batch)
+            targets_from_each_batch = torch.cat(targets_from_each_batch)
+            metrics_dict = metrics(targets_from_each_batch, preds_from_each_batch)
+            logger.log_epoch_metrics(metrics_dict, epoch, phase)
+
+            # Early stopping
+            if phase == 'valid':
+                if epoch_loss < best_valid_loss:
+                    no_change_epochs = 0
+                    best_valid_loss = epoch_loss
+                    logger.log_best_model(model, epoch_loss, epoch, optimizer, metrics_dict)
+                else:
+                    no_change_epochs += 1
+                    logger.print_logger.info(
+                        f'Valid loss hasnt changed for {no_change_epochs} patience: {cfg.patience}'
+                    )
+                    if no_change_epochs >= cfg.patience:
+                        early_stop_triggered = True
+
+        if early_stop_triggered:
+            logger.print_logger.info(f'Training is early stopped @ {epoch}')
+            break
+
+    logger.print_logger.info('Finished Training')
+
+    # loading the best model
+    ckpt = torch.load(logger.best_model_path)
+    model.load_state_dict(ckpt['model'])
+    logger.print_logger.info(f'Loading the best model from {logger.best_model_path}')
+    logger.print_logger.info((f'The model was trained for {ckpt["epoch"]} epochs. Loss: {ckpt["loss"]:.4f}'))
+
+    # Testing the model
+    model.eval()
+    running_loss = 0
+    preds_from_each_batch = []
+    targets_from_each_batch = []
+
+    for i, batch in enumerate(loaders['test']):
+        inputs = batch['input'].to(device)
+        targets = batch['target'].to(device)
+
+        # zero the parameter gradients
+        optimizer.zero_grad()
+
+        # forward + backward + optimize
+        with torch.set_grad_enabled(False):
+            outputs = model(inputs)
+            loss = criterion(outputs, targets, to_weight=False)
+
+        # loss
+        running_loss += loss.item()
+
+        # for metrics calculation later on
+        preds_from_each_batch += [outputs.detach().cpu()]
+        targets_from_each_batch += [targets.cpu()]
+
+    # logging metrics
+    preds_from_each_batch = torch.cat(preds_from_each_batch)
+    targets_from_each_batch = torch.cat(targets_from_each_batch)
+    test_metrics_dict = metrics(targets_from_each_batch, preds_from_each_batch)
+    test_metrics_dict['avg_loss'] = running_loss / len(loaders['test'])
+    test_metrics_dict['param_num'] = param_num
+    # TODO: I have no idea why tboard doesn't keep metrics (hparams) when
+    # I run this experiment from cli: `python train_melception.py config=./configs/vggish.yaml`
+    # while when I run it in vscode debugger the metrics are logger (wtf)
+    logger.log_test_metrics(test_metrics_dict, dict(cfg), ckpt['epoch'])
+
+    logger.print_logger.info('Finished the experiment')
+
+
+if __name__ == '__main__':
+    # input = torch.rand(16, 1, 80, 848)
+    # output, aux = inception(input)
+    # print(output.shape, aux.shape)
+    # Expected input size: (3, 299, 299) in RGB -> (1, 80, 848) in Mel Spec
+    # train_inception_scorer()
+
+    cfg_cli = OmegaConf.from_cli()
+    cfg_yml = OmegaConf.load(cfg_cli.config)
+    # the latter arguments are prioritized
+    cfg = OmegaConf.merge(cfg_yml, cfg_cli)
+    OmegaConf.set_readonly(cfg, True)
+    print(OmegaConf.to_yaml(cfg))
+
+    train_inception_scorer(cfg)
diff --git a/foleycrafter/models/specvqgan/modules/losses/vggishish/train_vggishish.py b/foleycrafter/models/specvqgan/modules/losses/vggishish/train_vggishish.py
new file mode 100644
index 0000000000000000000000000000000000000000..205668224ec87a9ce571f6428531080231b1c16b
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/vggishish/train_vggishish.py
@@ -0,0 +1,199 @@
+from loss import WeightedCrossEntropy
+import random
+
+import numpy as np
+import torch
+import torchvision
+from omegaconf import OmegaConf
+from torch.utils.data.dataloader import DataLoader
+from tqdm import tqdm
+
+from dataset import VGGSound
+from transforms import Crop, StandardNormalizeAudio, ToTensor
+from logger import LoggerWithTBoard
+from metrics import metrics
+from model import VGGishish
+
+if __name__ == "__main__":
+    cfg_cli = OmegaConf.from_cli()
+    cfg_yml = OmegaConf.load(cfg_cli.config)
+    # the latter arguments are prioritized
+    cfg = OmegaConf.merge(cfg_yml, cfg_cli)
+    OmegaConf.set_readonly(cfg, True)
+    print(OmegaConf.to_yaml(cfg))
+
+    logger = LoggerWithTBoard(cfg)
+
+    random.seed(cfg.seed)
+    np.random.seed(cfg.seed)
+    torch.manual_seed(cfg.seed)
+    torch.cuda.manual_seed_all(cfg.seed)
+    # makes iterations faster (in this case 30%) if your inputs are of a fixed size
+    # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3
+    torch.backends.cudnn.benchmark = True
+
+    transforms = [
+        StandardNormalizeAudio(cfg.mels_path),
+    ]
+    if cfg.cropped_size not in [None, 'None', 'none']:
+        logger.print_logger.info(f'Using cropping {cfg.cropped_size}')
+        transforms.append(Crop(cfg.cropped_size))
+    transforms.append(ToTensor())
+    transforms = torchvision.transforms.transforms.Compose(transforms)
+
+    datasets = {
+        'train': VGGSound('train', cfg.mels_path, transforms),
+        'valid': VGGSound('valid', cfg.mels_path, transforms),
+        'test': VGGSound('test', cfg.mels_path, transforms),
+    }
+
+    loaders = {
+        'train': DataLoader(datasets['train'], batch_size=cfg.batch_size, shuffle=True, drop_last=True,
+                            num_workers=cfg.num_workers, pin_memory=True),
+        'valid': DataLoader(datasets['valid'], batch_size=cfg.batch_size,
+                            num_workers=cfg.num_workers, pin_memory=True),
+        'test': DataLoader(datasets['test'], batch_size=cfg.batch_size,
+                           num_workers=cfg.num_workers, pin_memory=True),
+    }
+
+    device = torch.device(cfg.device if torch.cuda.is_available() else 'cpu')
+
+    model = VGGishish(cfg.conv_layers, cfg.use_bn, num_classes=len(datasets['train'].target2label))
+    model = model.to(device)
+    param_num = logger.log_param_num(model)
+
+    if cfg.optimizer == 'adam':
+        optimizer = torch.optim.Adam(
+            model.parameters(), lr=cfg.learning_rate, betas=cfg.betas, weight_decay=cfg.weight_decay)
+    elif cfg.optimizer == 'sgd':
+        optimizer = torch.optim.SGD(
+            model.parameters(), lr=cfg.learning_rate, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
+    else:
+        raise NotImplementedError
+
+    if cfg.cls_weights_in_loss:
+        weights = 1 / datasets['train'].class_counts
+    else:
+        weights = torch.ones(len(datasets['train'].target2label))
+    criterion = WeightedCrossEntropy(weights.to(device))
+
+    # loop over the train and validation multiple times (typical PT boilerplate)
+    no_change_epochs = 0
+    best_valid_loss = float('inf')
+    early_stop_triggered = False
+
+    for epoch in range(cfg.num_epochs):
+
+        for phase in ['train', 'valid']:
+            if phase == 'train':
+                model.train()
+            else:
+                model.eval()
+
+            running_loss = 0
+            preds_from_each_batch = []
+            targets_from_each_batch = []
+
+            prog_bar = tqdm(loaders[phase], f'{phase} ({epoch})', ncols=0)
+            for i, batch in enumerate(prog_bar):
+                inputs = batch['input'].to(device)
+                targets = batch['target'].to(device)
+
+                # zero the parameter gradients
+                optimizer.zero_grad()
+
+                # forward + backward + optimize
+                with torch.set_grad_enabled(phase == 'train'):
+                    outputs = model(inputs)
+                    loss = criterion(outputs, targets, to_weight=phase == 'train')
+
+                if phase == 'train':
+                    loss.backward()
+                    optimizer.step()
+
+                # loss
+                running_loss += loss.item()
+
+                # for metrics calculation later on
+                preds_from_each_batch += [outputs.detach().cpu()]
+                targets_from_each_batch += [targets.cpu()]
+
+                # iter logging
+                if i % 50 == 0:
+                    logger.log_iter_loss(loss.item(), epoch*len(loaders[phase])+i, phase)
+                    # tracks loss in the tqdm progress bar
+                    prog_bar.set_postfix(loss=loss.item())
+
+            # logging loss
+            epoch_loss = running_loss / len(loaders[phase])
+            logger.log_epoch_loss(epoch_loss, epoch, phase)
+
+            # logging metrics
+            preds_from_each_batch = torch.cat(preds_from_each_batch)
+            targets_from_each_batch = torch.cat(targets_from_each_batch)
+            metrics_dict = metrics(targets_from_each_batch, preds_from_each_batch)
+            logger.log_epoch_metrics(metrics_dict, epoch, phase)
+
+            # Early stopping
+            if phase == 'valid':
+                if epoch_loss < best_valid_loss:
+                    no_change_epochs = 0
+                    best_valid_loss = epoch_loss
+                    logger.log_best_model(model, epoch_loss, epoch, optimizer, metrics_dict)
+                else:
+                    no_change_epochs += 1
+                    logger.print_logger.info(
+                        f'Valid loss hasnt changed for {no_change_epochs} patience: {cfg.patience}'
+                    )
+                    if no_change_epochs >= cfg.patience:
+                        early_stop_triggered = True
+
+        if early_stop_triggered:
+            logger.print_logger.info(f'Training is early stopped @ {epoch}')
+            break
+
+    logger.print_logger.info('Finished Training')
+
+    # loading the best model
+    ckpt = torch.load(logger.best_model_path)
+    model.load_state_dict(ckpt['model'])
+    logger.print_logger.info(f'Loading the best model from {logger.best_model_path}')
+    logger.print_logger.info((f'The model was trained for {ckpt["epoch"]} epochs. Loss: {ckpt["loss"]:.4f}'))
+
+    # Testing the model
+    model.eval()
+    running_loss = 0
+    preds_from_each_batch = []
+    targets_from_each_batch = []
+
+    for i, batch in enumerate(loaders['test']):
+        inputs = batch['input'].to(device)
+        targets = batch['target'].to(device)
+
+        # zero the parameter gradients
+        optimizer.zero_grad()
+
+        # forward + backward + optimize
+        with torch.set_grad_enabled(False):
+            outputs = model(inputs)
+            loss = criterion(outputs, targets, to_weight=False)
+
+        # loss
+        running_loss += loss.item()
+
+        # for metrics calculation later on
+        preds_from_each_batch += [outputs.detach().cpu()]
+        targets_from_each_batch += [targets.cpu()]
+
+    # logging metrics
+    preds_from_each_batch = torch.cat(preds_from_each_batch)
+    targets_from_each_batch = torch.cat(targets_from_each_batch)
+    test_metrics_dict = metrics(targets_from_each_batch, preds_from_each_batch)
+    test_metrics_dict['avg_loss'] = running_loss / len(loaders['test'])
+    test_metrics_dict['param_num'] = param_num
+    # TODO: I have no idea why tboard doesn't keep metrics (hparams) when
+    # I run this experiment from cli: `python train_vggishish.py config=./configs/vggish.yaml`
+    # while when I run it in vscode debugger the metrics are logger (wtf)
+    logger.log_test_metrics(test_metrics_dict, dict(cfg), ckpt['epoch'])
+
+    logger.print_logger.info('Finished the experiment')
diff --git a/foleycrafter/models/specvqgan/modules/losses/vggishish/train_vggishish_gh.py b/foleycrafter/models/specvqgan/modules/losses/vggishish/train_vggishish_gh.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b879131f3f32589c09eb07e818157da21797bb7
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/vggishish/train_vggishish_gh.py
@@ -0,0 +1,218 @@
+from loss import WeightedCrossEntropy
+import random
+import os
+import sys
+import json
+
+import numpy as np
+import torch
+import torchvision
+from omegaconf import OmegaConf
+from torch.utils.data.dataloader import DataLoader
+from tqdm import tqdm
+
+from dataset import GreatestHit, AMT_test
+from transforms import Crop, StandardNormalizeAudio, ToTensor
+from logger import LoggerWithTBoard
+from metrics import metrics
+from model import VGGishish
+
+
+if __name__ == "__main__":
+    cfg_cli = sys.argv[1]
+    cfg_yml = OmegaConf.load(cfg_cli)
+    # the latter arguments are prioritized
+    cfg = cfg_yml
+    OmegaConf.set_readonly(cfg, True)
+    print(OmegaConf.to_yaml(cfg))
+
+    logger = LoggerWithTBoard(cfg)
+
+    random.seed(cfg.seed)
+    np.random.seed(cfg.seed)
+    torch.manual_seed(cfg.seed)
+    torch.cuda.manual_seed_all(cfg.seed)
+    # makes iterations faster (in this case 30%) if your inputs are of a fixed size
+    # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3
+    torch.backends.cudnn.benchmark = True
+
+    transforms = [
+        StandardNormalizeAudio(cfg.mels_path),
+    ]
+    if cfg.cropped_size not in [None, 'None', 'none']:
+        logger.print_logger.info(f'Using cropping {cfg.cropped_size}')
+        transforms.append(Crop(cfg.cropped_size))
+    transforms.append(ToTensor())
+    transforms = torchvision.transforms.transforms.Compose(transforms)
+
+    datasets = {
+        'train': GreatestHit('train', cfg.mels_path, transforms, action_only=cfg.action_only, material_only=cfg.material_only),
+        'valid': GreatestHit('valid', cfg.mels_path, transforms, action_only=cfg.action_only, material_only=cfg.material_only),
+        'test': GreatestHit('test', cfg.mels_path, transforms, action_only=cfg.action_only, material_only=cfg.material_only),
+    }
+
+    loaders = {
+        'train': DataLoader(datasets['train'], batch_size=cfg.batch_size, shuffle=True, drop_last=True,
+                            num_workers=cfg.num_workers, pin_memory=True),
+        'valid': DataLoader(datasets['valid'], batch_size=cfg.batch_size,
+                            num_workers=cfg.num_workers, pin_memory=True),
+        'test': DataLoader(datasets['test'], batch_size=cfg.batch_size,
+                           num_workers=cfg.num_workers, pin_memory=True),
+    }
+
+    device = torch.device(cfg.device if torch.cuda.is_available() else 'cpu')
+
+    model = VGGishish(cfg.conv_layers, cfg.use_bn, num_classes=len(datasets['train'].label2target))
+    model = model.to(device)
+    if cfg.load_model is not None:
+        state_dict = torch.load(cfg.load_model, map_location=device)['model']
+        target_dict = {}
+        # ignore the last layer
+        for key, v in state_dict.items():
+            # ignore classifier
+            if 'classifier' not in key:
+                target_dict[key] = v
+        model.load_state_dict(target_dict, strict=False)
+    param_num = logger.log_param_num(model)
+
+    if cfg.optimizer == 'adam':
+        optimizer = torch.optim.Adam(
+            model.parameters(), lr=cfg.learning_rate, betas=cfg.betas, weight_decay=cfg.weight_decay)
+    elif cfg.optimizer == 'sgd':
+        optimizer = torch.optim.SGD(
+            model.parameters(), lr=cfg.learning_rate, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
+    else:
+        raise NotImplementedError
+
+    if cfg.cls_weights_in_loss:
+        weights = 1 / datasets['train'].class_counts
+    else:
+        weights = torch.ones(len(datasets['train'].label2target))
+    criterion = WeightedCrossEntropy(weights.to(device))
+
+    # loop over the train and validation multiple times (typical PT boilerplate)
+    no_change_epochs = 0
+    best_valid_loss = float('inf')
+    early_stop_triggered = False
+
+    for epoch in range(cfg.num_epochs):
+
+        for phase in ['train', 'valid']:
+            if phase == 'train':
+                model.train()
+            else:
+                model.eval()
+
+            running_loss = 0
+            preds_from_each_batch = []
+            targets_from_each_batch = []
+
+            prog_bar = tqdm(loaders[phase], f'{phase} ({epoch})', ncols=0)
+            for i, batch in enumerate(prog_bar):
+                inputs = batch['input'].to(device)
+                targets = batch['target'].to(device)
+
+                # zero the parameter gradients
+                optimizer.zero_grad()
+
+                # forward + backward + optimize
+                with torch.set_grad_enabled(phase == 'train'):
+                    outputs = model(inputs)
+                    loss = criterion(outputs, targets, to_weight=phase == 'train')
+
+                if phase == 'train':
+                    loss.backward()
+                    optimizer.step()
+
+                # loss
+                running_loss += loss.item()
+
+                # for metrics calculation later on
+                preds_from_each_batch += [outputs.detach().cpu()]
+                targets_from_each_batch += [targets.cpu()]
+
+                # iter logging
+                if i % 50 == 0:
+                    logger.log_iter_loss(loss.item(), epoch*len(loaders[phase])+i, phase)
+                    # tracks loss in the tqdm progress bar
+                    prog_bar.set_postfix(loss=loss.item())
+
+            # logging loss
+            epoch_loss = running_loss / len(loaders[phase])
+            logger.log_epoch_loss(epoch_loss, epoch, phase)
+
+            # logging metrics
+            preds_from_each_batch = torch.cat(preds_from_each_batch)
+            targets_from_each_batch = torch.cat(targets_from_each_batch)
+            if cfg.action_only:
+                metrics_dict = metrics(targets_from_each_batch, preds_from_each_batch, topk=(1,))
+            else:
+                metrics_dict = metrics(targets_from_each_batch, preds_from_each_batch, topk=(1, 5))
+            logger.log_epoch_metrics(metrics_dict, epoch, phase)
+
+            # Early stopping
+            if phase == 'valid':
+                if epoch_loss < best_valid_loss:
+                    no_change_epochs = 0
+                    best_valid_loss = epoch_loss
+                    logger.log_best_model(model, epoch_loss, epoch, optimizer, metrics_dict)
+                else:
+                    no_change_epochs += 1
+                    logger.print_logger.info(
+                        f'Valid loss hasnt changed for {no_change_epochs} patience: {cfg.patience}'
+                    )
+                    if no_change_epochs >= cfg.patience:
+                        early_stop_triggered = True
+
+        if early_stop_triggered:
+            logger.print_logger.info(f'Training is early stopped @ {epoch}')
+            break
+
+    logger.print_logger.info('Finished Training')
+
+    # loading the best model
+    ckpt = torch.load(logger.best_model_path)
+    model.load_state_dict(ckpt['model'])
+    logger.print_logger.info(f'Loading the best model from {logger.best_model_path}')
+    logger.print_logger.info((f'The model was trained for {ckpt["epoch"]} epochs. Loss: {ckpt["loss"]:.4f}'))
+
+    # Testing the model
+    model.eval()
+    running_loss = 0
+    preds_from_each_batch = []
+    targets_from_each_batch = []
+
+    for i, batch in enumerate(loaders['test']):
+        inputs = batch['input'].to(device)
+        targets = batch['target'].to(device)
+
+        # zero the parameter gradients
+        optimizer.zero_grad()
+
+        # forward + backward + optimize
+        with torch.set_grad_enabled(False):
+            outputs = model(inputs)
+            loss = criterion(outputs, targets, to_weight=False)
+
+        # loss
+        running_loss += loss.item()
+
+        # for metrics calculation later on
+        preds_from_each_batch += [outputs.detach().cpu()]
+        targets_from_each_batch += [targets.cpu()]
+
+    # logging metrics
+    preds_from_each_batch = torch.cat(preds_from_each_batch)
+    targets_from_each_batch = torch.cat(targets_from_each_batch)
+    if cfg.action_only:
+        test_metrics_dict = metrics(targets_from_each_batch, preds_from_each_batch, topk=(1,))
+    else:
+        test_metrics_dict = metrics(targets_from_each_batch, preds_from_each_batch, topk=(1, 5))
+    test_metrics_dict['avg_loss'] = running_loss / len(loaders['test'])
+    test_metrics_dict['param_num'] = param_num
+    # TODO: I have no idea why tboard doesn't keep metrics (hparams) when
+    # I run this experiment from cli: `python train_vggishish.py config=./configs/vggish.yaml`
+    # while when I run it in vscode debugger the metrics are logger (wtf)
+    logger.log_test_metrics(test_metrics_dict, dict(cfg), ckpt['epoch'])
+
+    logger.print_logger.info('Finished the experiment')
diff --git a/foleycrafter/models/specvqgan/modules/losses/vggishish/transforms.py b/foleycrafter/models/specvqgan/modules/losses/vggishish/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..551c4d95534a4c6f83484afcf06e1017baafc135
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/vggishish/transforms.py
@@ -0,0 +1,98 @@
+import logging
+import os
+from pathlib import Path
+
+import albumentations
+import numpy as np
+import torch
+from tqdm import tqdm
+
+logger = logging.getLogger(f'main.{__name__}')
+
+
+class StandardNormalizeAudio(object):
+    '''
+        Frequency-wise normalization
+    '''
+    def __init__(self, specs_dir, train_ids_path='./data/vggsound_train.txt', cache_path='./data/'):
+        self.specs_dir = specs_dir
+        self.train_ids_path = train_ids_path
+        # making the stats filename to match the specs dir name
+        self.cache_path = os.path.join(cache_path, f'train_means_stds_{Path(specs_dir).stem}.txt')
+        logger.info('Assuming that the input stats are calculated using preprocessed spectrograms (log)')
+        self.train_stats = self.calculate_or_load_stats()
+
+    def __call__(self, item):
+        # just to generalizat the input handling. Useful for FID, IS eval and training other staff
+        if isinstance(item, dict):
+            if 'input' in item:
+                input_key = 'input'
+            elif 'image' in item:
+                input_key = 'image'
+            else:
+                raise NotImplementedError
+            item[input_key] = (item[input_key] - self.train_stats['means']) / self.train_stats['stds']
+        elif isinstance(item, torch.Tensor):
+            # broadcasts np.ndarray (80, 1) to (1, 80, 1) because item is torch.Tensor (B, 80, T)
+            item = (item - self.train_stats['means']) / self.train_stats['stds']
+        else:
+            raise NotImplementedError
+        return item
+
+    def calculate_or_load_stats(self):
+        try:
+            # (F, 2)
+            train_stats = np.loadtxt(self.cache_path)
+            means, stds = train_stats.T
+            logger.info('Trying to load train stats for Standard Normalization of inputs')
+        except OSError:
+            logger.info('Could not find the precalculated stats for Standard Normalization. Calculating...')
+            train_vid_ids = open(self.train_ids_path)
+            specs_paths = [os.path.join(self.specs_dir, f'{i.rstrip()}_mel.npy') for i in train_vid_ids]
+            means = [None] * len(specs_paths)
+            stds = [None] * len(specs_paths)
+            for i, path in enumerate(tqdm(specs_paths)):
+                spec = np.load(path)
+                means[i] = spec.mean(axis=1)
+                stds[i] = spec.std(axis=1)
+            # (F) <- (num_files, F)
+            means = np.array(means).mean(axis=0)
+            stds = np.array(stds).mean(axis=0)
+            # saving in two columns
+            np.savetxt(self.cache_path, np.vstack([means, stds]).T, fmt='%0.8f')
+        means = means.reshape(-1, 1)
+        stds = stds.reshape(-1, 1)
+        return {'means': means, 'stds': stds}
+
+class ToTensor(object):
+
+    def __call__(self, item):
+        item['input'] = torch.from_numpy(item['input']).float()
+        if 'target' in item:
+            item['target'] = torch.tensor(item['target'])
+        return item
+
+class Crop(object):
+
+    def __init__(self, cropped_shape=None, random_crop=False):
+        self.cropped_shape = cropped_shape
+        if cropped_shape is not None:
+            mel_num, spec_len = cropped_shape
+            if random_crop:
+                self.cropper = albumentations.RandomCrop
+            else:
+                self.cropper = albumentations.CenterCrop
+            self.preprocessor = albumentations.Compose([self.cropper(mel_num, spec_len)])
+        else:
+            self.preprocessor = lambda **kwargs: kwargs
+
+    def __call__(self, item):
+        item['input'] = self.preprocessor(image=item['input'])['image']
+        return item
+
+
+if __name__ == '__main__':
+    cropper = Crop([80, 848])
+    item = {'input': torch.rand([80, 860])}
+    outputs = cropper(item)
+    print(outputs['input'].shape)
diff --git a/foleycrafter/models/specvqgan/modules/losses/vqperceptual.py b/foleycrafter/models/specvqgan/modules/losses/vqperceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..80e8d4b445a9c4c3b6513c088c875153e9553151
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/losses/vqperceptual.py
@@ -0,0 +1,209 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import sys
+
+sys.path.insert(0, '.')  # nopep8
+from foleycrafter.models.specvqgan.modules.discriminator.model import (NLayerDiscriminator, NLayerDiscriminator1dFeats,
+                                                   NLayerDiscriminator1dSpecs,
+                                                   weights_init)
+from foleycrafter.models.specvqgan.modules.losses.lpaps import LPAPS
+
+
+class DummyLoss(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+
+def adopt_weight(weight, global_step, threshold=0, value=0.):
+    if global_step < threshold:
+        weight = value
+    return weight
+
+
+def hinge_d_loss(logits_real, logits_fake):
+    loss_real = torch.mean(F.relu(1. - logits_real))
+    loss_fake = torch.mean(F.relu(1. + logits_fake))
+    d_loss = 0.5 * (loss_real + loss_fake)
+    return d_loss
+
+
+def vanilla_d_loss(logits_real, logits_fake):
+    d_loss = 0.5 * (
+        torch.mean(torch.nn.functional.softplus(-logits_real)) +
+        torch.mean(torch.nn.functional.softplus(logits_fake)))
+    return d_loss
+
+
+class VQLPAPSWithDiscriminator(nn.Module):
+    def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
+                 disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+                 perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+                 disc_ndf=64, disc_loss="hinge", min_adapt_weight=0.0, max_adapt_weight=1e4):
+        super().__init__()
+        assert disc_loss in ["hinge", "vanilla"]
+        self.codebook_weight = codebook_weight
+        self.pixel_weight = pixelloss_weight
+        self.perceptual_loss = LPAPS().eval()
+        self.perceptual_weight = perceptual_weight
+
+        self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
+                                                 n_layers=disc_num_layers,
+                                                 use_actnorm=use_actnorm,
+                                                 ndf=disc_ndf
+                                                 ).apply(weights_init)
+        self.discriminator_iter_start = disc_start
+        if disc_loss == "hinge":
+            self.disc_loss = hinge_d_loss
+        elif disc_loss == "vanilla":
+            self.disc_loss = vanilla_d_loss
+        else:
+            raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
+        print(f"VQLPAPSWithDiscriminator running with {disc_loss} loss.")
+        self.disc_factor = disc_factor
+        self.discriminator_weight = disc_weight
+        self.disc_conditional = disc_conditional
+        self.min_adapt_weight = min_adapt_weight
+        self.max_adapt_weight = max_adapt_weight
+
+    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+        if last_layer is not None:
+            nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+            g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+        else:
+            nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+            g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+        d_weight = torch.clamp(d_weight, self.min_adapt_weight, self.max_adapt_weight).detach()
+        d_weight = d_weight * self.discriminator_weight
+        return d_weight
+
+    def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
+                global_step, last_layer=None, cond=None, split="train"):
+        rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+        if self.perceptual_weight > 0:
+            p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
+            rec_loss = rec_loss + self.perceptual_weight * p_loss
+        else:
+            p_loss = torch.tensor([0.0])
+
+        nll_loss = rec_loss
+        # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+        nll_loss = torch.mean(nll_loss)
+
+        # now the GAN part
+        if optimizer_idx == 0:
+            # generator update
+            if cond is None:
+                assert not self.disc_conditional
+                logits_fake = self.discriminator(reconstructions.contiguous())
+            else:
+                assert self.disc_conditional
+                logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+            g_loss = -torch.mean(logits_fake)
+
+            try:
+                d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+            except RuntimeError:
+                assert not self.training
+                d_weight = torch.tensor(0.0)
+
+            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+            loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
+
+            log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
+                   "{}/quant_loss".format(split): codebook_loss.detach().mean(),
+                   "{}/nll_loss".format(split): nll_loss.detach().mean(),
+                   "{}/rec_loss".format(split): rec_loss.detach().mean(),
+                   "{}/p_loss".format(split): p_loss.detach().mean(),
+                   "{}/d_weight".format(split): d_weight.detach(),
+                   "{}/disc_factor".format(split): torch.tensor(disc_factor),
+                   "{}/g_loss".format(split): g_loss.detach().mean(),
+                   }
+            return loss, log
+
+        if optimizer_idx == 1:
+            # second pass for discriminator update
+            if cond is None:
+                logits_real = self.discriminator(inputs.contiguous().detach())
+                logits_fake = self.discriminator(reconstructions.contiguous().detach())
+            else:
+                logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+                logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+            d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+            log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+                   "{}/logits_real".format(split): logits_real.detach().mean(),
+                   "{}/logits_fake".format(split): logits_fake.detach().mean()
+                   }
+            return d_loss, log
+
+
+class VQLPAPSWithDiscriminator1dFeats(VQLPAPSWithDiscriminator):
+    def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
+                 disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+                 perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+                 disc_ndf=64, disc_loss="hinge", min_adapt_weight=0.0, max_adapt_weight=1e4):
+        super().__init__(disc_start=disc_start, codebook_weight=codebook_weight,
+                         pixelloss_weight=pixelloss_weight, disc_num_layers=disc_num_layers,
+                         disc_in_channels=disc_in_channels, disc_factor=disc_factor, disc_weight=disc_weight,
+                         perceptual_weight=perceptual_weight, use_actnorm=use_actnorm,
+                         disc_conditional=disc_conditional, disc_ndf=disc_ndf, disc_loss=disc_loss,
+                         min_adapt_weight=min_adapt_weight, max_adapt_weight=max_adapt_weight)
+
+        self.discriminator = NLayerDiscriminator1dFeats(input_nc=disc_in_channels, n_layers=disc_num_layers,
+                                                   use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init)
+
+class VQLPAPSWithDiscriminator1dSpecs(VQLPAPSWithDiscriminator):
+    def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
+                 disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+                 perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+                 disc_ndf=64, disc_loss="hinge", min_adapt_weight=0.0, max_adapt_weight=1e4):
+        super().__init__(disc_start=disc_start, codebook_weight=codebook_weight,
+                         pixelloss_weight=pixelloss_weight, disc_num_layers=disc_num_layers,
+                         disc_in_channels=disc_in_channels, disc_factor=disc_factor, disc_weight=disc_weight,
+                         perceptual_weight=perceptual_weight, use_actnorm=use_actnorm,
+                         disc_conditional=disc_conditional, disc_ndf=disc_ndf, disc_loss=disc_loss,
+                         min_adapt_weight=min_adapt_weight, max_adapt_weight=max_adapt_weight)
+
+        self.discriminator = NLayerDiscriminator1dSpecs(input_nc=disc_in_channels, n_layers=disc_num_layers,
+                                                   use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init)
+
+
+if __name__ == '__main__':
+    from foleycrafter.models.specvqgan.modules.diffusionmodules.model import Decoder, Decoder1d
+
+    optimizer_idx = 0
+    loss_config = {
+        'disc_conditional': False,
+        'disc_start': 30001,
+        'disc_weight': 0.8,
+        'codebook_weight': 1.0,
+    }
+    ddconfig = {
+        'ch': 128,
+        'num_res_blocks': 2,
+        'dropout': 0.0,
+        'z_channels': 256,
+        'double_z': False,
+    }
+    qloss = torch.rand(1, requires_grad=True)
+
+    ## AUDIO
+    loss_config['disc_in_channels'] = 1
+    ddconfig['in_channels'] = 1
+    ddconfig['resolution'] = 848
+    ddconfig['attn_resolutions'] = [53]
+    ddconfig['out_ch'] = 1
+    ddconfig['ch_mult'] = [1, 1, 2, 2, 4]
+    decoder = Decoder(**ddconfig)
+    loss = VQLPAPSWithDiscriminator(**loss_config)
+    x = torch.rand(16, 1, 80, 848)
+    # subtracting something which uses dec_conv_out so that it will be in a graph
+    xrec = torch.rand(16, 1, 80, 848) - decoder.conv_out(torch.rand(16, 128, 80, 848)).mean()
+    aeloss, log_dict_ae = loss(qloss, x, xrec, optimizer_idx, global_step=0,last_layer=decoder.conv_out.weight)
+    print(aeloss)
+    print(log_dict_ae)
diff --git a/foleycrafter/models/specvqgan/modules/misc/class_cond.py b/foleycrafter/models/specvqgan/modules/misc/class_cond.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7044573e685f24e2db3568148bc20e6f1536a31
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/misc/class_cond.py
@@ -0,0 +1,21 @@
+import torch
+
+class ClassOnlyStage(object):
+    def __init__(self):
+        pass
+
+    def eval(self):
+        return self
+
+    def encode(self, c):
+        """fake vqmodel interface because self.cond_stage_model should have something
+        similar to coord.py but even more `dummy`"""
+        # assert 0.0 <= c.min() and c.max() <= 1.0
+        info = None, None, c
+        return c, None, info
+
+    def decode(self, c):
+        return c
+
+    def get_input(self, batch, k):
+        return batch[k].unsqueeze(1).to(memory_format=torch.contiguous_format)
diff --git a/foleycrafter/models/specvqgan/modules/misc/coord.py b/foleycrafter/models/specvqgan/modules/misc/coord.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee69b0c897b6b382ae673622e420f55e494f5b09
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/misc/coord.py
@@ -0,0 +1,31 @@
+import torch
+
+class CoordStage(object):
+    def __init__(self, n_embed, down_factor):
+        self.n_embed = n_embed
+        self.down_factor = down_factor
+
+    def eval(self):
+        return self
+
+    def encode(self, c):
+        """fake vqmodel interface"""
+        assert 0.0 <= c.min() and c.max() <= 1.0
+        b,ch,h,w = c.shape
+        assert ch == 1
+
+        c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
+                                            mode="area")
+        c = c.clamp(0.0, 1.0)
+        c = self.n_embed*c
+        c_quant = c.round()
+        c_ind = c_quant.to(dtype=torch.long)
+
+        info = None, None, c_ind
+        return c_quant, None, info
+
+    def decode(self, c):
+        c = c/self.n_embed
+        c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
+                                            mode="nearest")
+        return c
diff --git a/foleycrafter/models/specvqgan/modules/misc/feat_cluster.py b/foleycrafter/models/specvqgan/modules/misc/feat_cluster.py
new file mode 100644
index 0000000000000000000000000000000000000000..47b0527d25bdcdf56e7598c7522ac8f9a4c25854
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/misc/feat_cluster.py
@@ -0,0 +1,83 @@
+import os
+from glob import glob
+
+import joblib
+import numpy as np
+import torch
+from sklearn.cluster import MiniBatchKMeans
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+from train import instantiate_from_config
+
+
+class FeatClusterStage(object):
+
+    def __init__(self, num_clusters=None, cached_kmeans_path=None, feats_dataset_config=None, num_workers=None):
+        if cached_kmeans_path is not None and os.path.exists(cached_kmeans_path):
+            print(f'Precalculated Clusterer already exists, loading from {cached_kmeans_path}')
+            self.clusterer = joblib.load(cached_kmeans_path)
+        elif feats_dataset_config is not None:
+            self.clusterer = self.load_or_precalculate_kmeans(num_clusters, feats_dataset_config, num_workers)
+        else:
+            raise Exception('Neither `feats_dataset_config` nor `cached_kmeans_path` are defined')
+
+    def eval(self):
+        return self
+
+    def encode(self, c):
+        # c_quant: cluster centers, c_ind: cluster index
+
+        B, D, T = c.shape
+        # (B*T, D) <- (B, T, D) <- (B, D, T)
+        c_flat = c.permute(0, 2, 1).view(B*T, D).cpu().numpy()
+
+        c_ind = self.clusterer.predict(c_flat)
+        c_quant = self.clusterer.cluster_centers_[c_ind]
+
+        c_ind = torch.from_numpy(c_ind).to(c.device)
+        c_quant = torch.from_numpy(c_quant).to(c.device)
+
+        c_ind = c_ind.long().unsqueeze(-1)
+        c_quant = c_quant.view(B, T, D).permute(0, 2, 1)
+
+        info = None, None, c_ind
+        # (B, D, T), (), ((), (768, 1024), (768, 1))
+        return c_quant, None, info
+
+    def decode(self, c):
+        return c
+
+    def get_input(self, batch, k):
+        x = batch[k]
+        x = x.permute(0, 2, 1).to(memory_format=torch.contiguous_format)
+        return x.float()
+
+    def load_or_precalculate_kmeans(self, num_clusters, dataset_cfg, num_workers):
+        print(f'Calculating clustering K={num_clusters}')
+        batch_size = 64
+        dataset_name = dataset_cfg.target.split('.')[-1]
+        cached_path = os.path.join('./specvqgan/modules/misc/', f'kmeans_K{num_clusters}_{dataset_name}.sklearn')
+        feat_depth = dataset_cfg.params.condition_dataset_cfg.feat_depth
+        feat_crop_len = dataset_cfg.params.condition_dataset_cfg.feat_crop_len
+
+        feat_loading_dset = instantiate_from_config(dataset_cfg)
+        feat_loading_dset = DataLoader(feat_loading_dset, batch_size, num_workers=num_workers, shuffle=True)
+
+        clusterer = MiniBatchKMeans(num_clusters, batch_size=batch_size*feat_crop_len, random_state=0)
+
+        for item in tqdm(feat_loading_dset):
+            batch = item['feature'].reshape(-1, feat_depth).float().numpy()
+            clusterer.partial_fit(batch)
+
+        joblib.dump(clusterer, cached_path)
+        print(f'Saved the calculated Clusterer @ {cached_path}')
+        return clusterer
+
+
+if __name__ == '__main__':
+    from omegaconf import OmegaConf
+
+    config = OmegaConf.load('./configs/vggsound_featcluster_transformer.yaml')
+    config.model.params.first_stage_config.params.ckpt_path = './logs/2021-05-19T22-16-54_vggsound_specs_vqgan/checkpoints/epoch_39.ckpt'
+    model = instantiate_from_config(config.model.params.cond_stage_config)
+    print(model)
diff --git a/foleycrafter/models/specvqgan/modules/misc/feats_class.py b/foleycrafter/models/specvqgan/modules/misc/feats_class.py
new file mode 100644
index 0000000000000000000000000000000000000000..72980972f919ceb63b3aeadb118e86c97ceb7f2b
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/misc/feats_class.py
@@ -0,0 +1,28 @@
+import torch
+
+class FeatsClassStage(object):
+    def __init__(self):
+        pass
+
+    def eval(self):
+        return self
+
+    def encode(self, c):
+        """fake vqmodel interface because self.cond_stage_model should have something
+        similar to coord.py but even more `dummy`"""
+        # assert 0.0 <= c.min() and c.max() <= 1.0
+        info = None, None, c
+        return c, None, info
+
+    def decode(self, c):
+        return c
+
+    def get_input(self, batch: dict, keys: dict) -> dict:
+        out = {}
+        for k in keys:
+            if k == 'target':
+                out[k] = batch[k].unsqueeze(1)
+            elif k == 'feature':
+                out[k] = batch[k].float().permute(0, 2, 1)
+            out[k] = out[k].to(memory_format=torch.contiguous_format)
+        return out
diff --git a/foleycrafter/models/specvqgan/modules/misc/raw_feats.py b/foleycrafter/models/specvqgan/modules/misc/raw_feats.py
new file mode 100644
index 0000000000000000000000000000000000000000..96b13f250abb0ac878026b207d1857084411caa5
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/misc/raw_feats.py
@@ -0,0 +1,23 @@
+import torch
+
+class RawFeatsStage(object):
+    def __init__(self):
+        pass
+
+    def eval(self):
+        return self
+
+    def encode(self, c):
+        """fake vqmodel interface because self.cond_stage_model should have something
+        similar to coord.py but even more `dummy`"""
+        # assert 0.0 <= c.min() and c.max() <= 1.0
+        info = None, None, c
+        return c, None, info
+
+    def decode(self, c):
+        return c
+
+    def get_input(self, batch, k):
+        x = batch[k]
+        x = x.permute(0, 2, 1).to(memory_format=torch.contiguous_format)
+        return x.float()
diff --git a/foleycrafter/models/specvqgan/modules/transformer/mingpt.py b/foleycrafter/models/specvqgan/modules/transformer/mingpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..d59f0fea2111fa8039d20cb3c04cd677b85d4115
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/transformer/mingpt.py
@@ -0,0 +1,535 @@
+"""
+taken from: https://github.com/karpathy/minGPT/
+GPT model:
+- the initial stem consists of a combination of token encoding and a positional encoding
+- the meat of it is a uniform sequence of Transformer blocks
+    - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
+    - all blocks feed into a central residual pathway similar to resnets
+- the final decoder is a linear projection into a vanilla Softmax classifier
+"""
+
+import math
+import logging
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+import sys
+sys.path.insert(0, '.')  # nopep8
+from train import instantiate_from_config
+
+logger = logging.getLogger(__name__)
+
+
+class GPTConfig:
+    """ base GPT config, params common to all GPT versions """
+    embd_pdrop = 0.1
+    resid_pdrop = 0.1
+    attn_pdrop = 0.1
+
+    def __init__(self, vocab_size, block_size, **kwargs):
+        self.vocab_size = vocab_size
+        self.block_size = block_size
+        for k,v in kwargs.items():
+            setattr(self, k, v)
+
+
+class GPT1Config(GPTConfig):
+    """ GPT-1 like network roughly 125M params """
+    n_layer = 12
+    n_head = 12
+    n_embd = 768
+
+
+class GPT2Config(GPTConfig):
+    """ GPT-2 like network roughly 1.5B params """
+    # TODO
+
+
+class CausalSelfAttention(nn.Module):
+    """
+    A vanilla multi-head masked self-attention layer with a projection at the end.
+    It is possible to use torch.nn.MultiheadAttention here but I am including an
+    explicit implementation here to show that there is nothing too scary here.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        assert config.n_embd % config.n_head == 0
+        # key, query, value projections for all heads
+        self.key = nn.Linear(config.n_embd, config.n_embd)
+        self.query = nn.Linear(config.n_embd, config.n_embd)
+        self.value = nn.Linear(config.n_embd, config.n_embd)
+        # regularization
+        self.attn_drop = nn.Dropout(config.attn_pdrop)
+        self.resid_drop = nn.Dropout(config.resid_pdrop)
+        # output projection
+        self.proj = nn.Linear(config.n_embd, config.n_embd)
+        # causal mask to ensure that attention is only applied to the left in the input sequence
+        mask = torch.tril(torch.ones(config.block_size,
+                                     config.block_size))
+        if hasattr(config, "n_unmasked"):
+            mask[:config.n_unmasked, :config.n_unmasked] = 1
+        self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
+        self.n_head = config.n_head
+
+    def forward(self, x, layer_past=None):
+        B, T, C = x.size()
+
+        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+
+        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
+        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
+        att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
+        att = F.softmax(att, dim=-1)
+        y = self.attn_drop(att) @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
+        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
+
+        # output projection
+        y = self.resid_drop(self.proj(y))
+
+        return y, att
+
+
+class Block(nn.Module):
+    """ an unassuming Transformer block """
+    def __init__(self, config):
+        super().__init__()
+        self.ln1 = nn.LayerNorm(config.n_embd)
+        self.ln2 = nn.LayerNorm(config.n_embd)
+        self.attn = CausalSelfAttention(config)
+        self.mlp = nn.Sequential(
+            nn.Linear(config.n_embd, 4 * config.n_embd),
+            nn.GELU(),  # nice
+            nn.Linear(4 * config.n_embd, config.n_embd),
+            nn.Dropout(config.resid_pdrop),
+        )
+
+    def forward(self, x):
+        # x = x + self.attn(self.ln1(x))
+
+        # x is a tuple (x, attention)
+        x, _ = x
+        res = x
+        x = self.ln1(x)
+        x, att = self.attn(x)
+        x = res + x
+
+        x = x + self.mlp(self.ln2(x))
+
+        return x, att
+
+
+class GPT(nn.Module):
+    """  the full GPT language model, with a context size of block_size """
+    def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
+                 embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
+        super().__init__()
+        config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
+                           embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
+                           n_layer=n_layer, n_head=n_head, n_embd=n_embd,
+                           n_unmasked=n_unmasked)
+        # input embedding stem
+        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
+        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
+        self.drop = nn.Dropout(config.embd_pdrop)
+        # transformer
+        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
+        # decoder head
+        self.ln_f = nn.LayerNorm(config.n_embd)
+        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+        self.block_size = config.block_size
+        self.apply(self._init_weights)
+        self.config = config
+        logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
+
+    def get_block_size(self):
+        return self.block_size
+
+    def _init_weights(self, module):
+        if isinstance(module, (nn.Linear, nn.Embedding)):
+            module.weight.data.normal_(mean=0.0, std=0.02)
+            if isinstance(module, nn.Linear) and module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+
+    def forward(self, idx, embeddings=None, targets=None):
+        # forward the GPT model
+        token_embeddings = self.tok_emb(idx)  # each index maps to a (learnable) vector
+
+        if embeddings is not None:  # prepend explicit embeddings
+            token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
+
+        t = token_embeddings.shape[1]
+        assert t <= self.block_size, "Cannot forward, model block size is exhausted."
+        position_embeddings = self.pos_emb[:, :t, :]  # each position maps to a (learnable) vector
+        x = self.drop(token_embeddings + position_embeddings)
+
+        # returns only last layer attention
+        # giving tuple (x, None) just because Sequential takes a single input but outputs two (x, atttention).
+        # att is (B, H, T, T)
+        x, att = self.blocks((x, None))
+        x = self.ln_f(x)
+        logits = self.head(x)
+
+        # if we are given some desired targets also calculate the loss
+        loss = None
+        if targets is not None:
+            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+
+        return logits, loss, att
+
+
+class DummyGPT(nn.Module):
+    # for debugging
+    def __init__(self, add_value=1):
+        super().__init__()
+        self.add_value = add_value
+
+    def forward(self, idx):
+        raise NotImplementedError('Model should output attention')
+        return idx + self.add_value, None
+
+
+class CodeGPT(nn.Module):
+    """Takes in semi-embeddings"""
+    def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
+                 embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
+        super().__init__()
+        config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
+                           embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
+                           n_layer=n_layer, n_head=n_head, n_embd=n_embd,
+                           n_unmasked=n_unmasked)
+        # input embedding stem
+        self.tok_emb = nn.Linear(in_channels, config.n_embd)
+        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
+        self.drop = nn.Dropout(config.embd_pdrop)
+        # transformer
+        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
+        # decoder head
+        self.ln_f = nn.LayerNorm(config.n_embd)
+        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+        self.block_size = config.block_size
+        self.apply(self._init_weights)
+        self.config = config
+        logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
+
+    def get_block_size(self):
+        return self.block_size
+
+    def _init_weights(self, module):
+        if isinstance(module, (nn.Linear, nn.Embedding)):
+            module.weight.data.normal_(mean=0.0, std=0.02)
+            if isinstance(module, nn.Linear) and module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        elif isinstance(module, (nn.Conv1d, nn.Conv2d)):
+            torch.nn.init.xavier_uniform(module.weight)
+            if module.bias is not None:
+                module.bias.data.fill_(0.01)
+
+    def forward(self, idx, embeddings=None, targets=None):
+        raise NotImplementedError('Model should output attention')
+        # forward the GPT model
+        token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
+
+        if embeddings is not None: # prepend explicit embeddings
+            token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
+
+        t = token_embeddings.shape[1]
+        assert t <= self.block_size, "Cannot forward, model block size is exhausted."
+        position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
+        x = self.drop(token_embeddings + position_embeddings)
+        x = self.blocks(x)
+        x = self.ln_f(x)
+        logits = self.head(x)
+
+        # if we are given some desired targets also calculate the loss
+        loss = None
+        if targets is not None:
+            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+
+        return logits, loss
+
+class GPTFeats(GPT):
+
+    def __init__(self, feat_embedding_config, GPT_config):
+        super().__init__(**GPT_config)
+        # patching the config by removing the default parameters for Conv1d
+        if feat_embedding_config.target.split('.')[-1] in ['LSTM', 'GRU']:
+            for p in ['in_channels', 'out_channels', 'padding', 'kernel_size']:
+                if p in feat_embedding_config.params:
+                    feat_embedding_config.params.pop(p)
+        self.embedder = instantiate_from_config(config=feat_embedding_config)
+        if isinstance(self.embedder, nn.Linear):
+            print('Checkout cond_transformer.configure_optimizers. Make sure not to use decay with Linear')
+
+    def forward(self, idx, feats):
+        if isinstance(self.embedder, nn.Linear):
+            feats = feats.permute(0, 2, 1)
+            feats = self.embedder(feats)
+        elif isinstance(self.embedder, (nn.LSTM, nn.GRU)):
+            feats = feats.permute(0, 2, 1)
+            feats, _ = self.embedder(feats)
+        elif isinstance(self.embedder, (nn.Conv1d, nn.Identity)):
+            # (B, D', T) <- (B, D, T)
+            feats = self.embedder(feats)
+            # (B, T, D') <- (B, T, D)
+            feats = feats.permute(0, 2, 1)
+        else:
+            raise NotImplementedError
+        # calling forward from super
+        return super().forward(idx, embeddings=feats)
+
+class GPTFeatsPosEnc(GPT):
+    def __init__(self, feat_embedding_config, GPT_config, PosEnc_config):
+        super().__init__(**GPT_config)
+        # patching the config by removing the default parameters for Conv1d
+        if feat_embedding_config.target.split('.')[-1] in ['LSTM', 'GRU']:
+            for p in ['in_channels', 'out_channels', 'padding', 'kernel_size']:
+                if p in feat_embedding_config.params:
+                    feat_embedding_config.params.pop(p)
+        self.embedder = instantiate_from_config(config=feat_embedding_config)
+
+        self.pos_emb_vis = nn.Parameter(torch.zeros(1, PosEnc_config['block_size_v'], PosEnc_config['n_embd']))
+        self.pos_emb_aud = nn.Parameter(torch.zeros(1, PosEnc_config['block_size_a'], PosEnc_config['n_embd']))
+
+        if isinstance(self.embedder, nn.Linear):
+            print('Checkout cond_transformer.configure_optimizers. Make sure not to use decay with Linear')
+
+    def foward(self, idx, feats):
+        if isinstance(self.embedder, nn.Linear):
+            feats = feats.permute(0, 2, 1)
+            feats = self.embedder(feats)
+        elif isinstance(self.embedder, (nn.LSTM, nn.GRU)):
+            feats = feats.permute(0, 2, 1)
+            feats, _ = self.embedder(feats)
+        elif isinstance(self.embedder, (nn.Conv1d, nn.Identity)):
+            # (B, D', T) <- (B, D, T)
+            feats = self.embedder(feats)
+            # (B, T, D') <- (B, T, D)
+            feats = feats.permute(0, 2, 1)
+        else:
+            raise NotImplementedError
+        # calling forward from super
+        # forward the GPT model
+        token_embeddings = self.tok_emb(idx)  # each index maps to a (learnable) vector
+
+        if feats is not None:  # prepend explicit feats
+            token_embeddings = torch.cat((feats, token_embeddings), dim=1)
+
+        t = token_embeddings.shape[1]
+        assert t <= self.block_size, "Cannot forward, model block size is exhausted."
+        vis_t = self.pos_emb_vis.shape[1]
+        position_embeddings = torch.cat([self.pos_emb_vis, self.pos_emb_aud[:, :t-vis_t, :]])
+        x = self.drop(token_embeddings + position_embeddings)
+
+        # returns only last layer attention
+        # giving tuple (x, None) just because Sequential takes a single input but outputs two (x, atttention).
+        # att is (B, H, T, T)
+        x, att = self.blocks((x, None))
+        x = self.ln_f(x)
+        logits = self.head(x)
+
+        # if we are given some desired targets also calculate the loss
+        loss = None
+
+        return logits, loss, att
+
+
+
+class GPTClass(GPT):
+
+    def __init__(self, token_embedding_config, GPT_config):
+        super().__init__(**GPT_config)
+        self.embedder = instantiate_from_config(config=token_embedding_config)
+
+    def forward(self, idx, token):
+        token = self.embedder(token)
+        # calling forward from super
+        return super().forward(idx, embeddings=token)
+
+class GPTFeatsClass(GPT):
+
+    def __init__(self, feat_embedding_config, token_embedding_config, GPT_config):
+        super().__init__(**GPT_config)
+
+        # patching the config by removing the default parameters for Conv1d
+        if feat_embedding_config.target.split('.')[-1] in ['LSTM', 'GRU']:
+            for p in ['in_channels', 'out_channels', 'padding', 'kernel_size']:
+                if p in feat_embedding_config.params:
+                    feat_embedding_config.params.pop(p)
+
+        self.feat_embedder = instantiate_from_config(config=feat_embedding_config)
+        self.cls_embedder = instantiate_from_config(config=token_embedding_config)
+
+        if isinstance(self.feat_embedder, nn.Linear):
+            print('Checkout cond_transformer.configure_optimizers. Make sure not to use decay with Linear')
+
+    def forward(self, idx, feats_token_dict: dict):
+        feats = feats_token_dict['feature']
+        token = feats_token_dict['target']
+
+        # Features. Output size: (B, T, D')
+        if isinstance(self.feat_embedder, nn.Linear):
+            feats = feats.permute(0, 2, 1)
+            feats = self.feat_embedder(feats)
+        elif isinstance(self.feat_embedder, (nn.LSTM, nn.GRU)):
+            feats = feats.permute(0, 2, 1)
+            feats, _ = self.feat_embedder(feats)
+        elif isinstance(self.feat_embedder, (nn.Conv1d, nn.Identity)):
+            # (B, D', T) <- (B, D, T)
+            feats = self.feat_embedder(feats)
+            # (B, T, D') <- (B, T, D)
+            feats = feats.permute(0, 2, 1)
+        else:
+            raise NotImplementedError
+
+        # Class. Output size: (B, 1, D')
+        token = self.cls_embedder(token)
+
+        # Concat
+        condition_emb = torch.cat([feats, token], dim=1)
+
+        # calling forward from super
+        return super().forward(idx, embeddings=condition_emb)
+
+
+#### sampling utils
+
+def top_k_logits(logits, k):
+    v, ix = torch.topk(logits, k)
+    out = logits.clone()
+    out[out < v[:, [-1]]] = -float('Inf')
+    return out
+
+@torch.no_grad()
+def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
+    """
+    take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
+    the sequence, feeding the predictions back into the model each time. Clearly the sampling
+    has quadratic complexity unlike an RNN that is only linear, and has a finite context window
+    of block_size, unlike an RNN that has an infinite context window.
+    """
+    block_size = model.get_block_size()
+    model.eval()
+    for k in range(steps):
+        x_cond = x if x.size(1) <= block_size else x[:, -block_size:]  # crop context if needed
+        raise NotImplementedError('v-iashin: the model outputs (logits, loss, attention)')
+        logits, _ = model(x_cond)
+        # pluck the logits at the final step and scale by temperature
+        logits = logits[:, -1, :] / temperature
+        # optionally crop probabilities to only the top k options
+        if top_k is not None:
+            logits = top_k_logits(logits, top_k)
+        # apply softmax to convert to probabilities
+        probs = F.softmax(logits, dim=-1)
+        # sample from the distribution or take the most likely
+        if sample:
+            ix = torch.multinomial(probs, num_samples=1)
+        else:
+            _, ix = torch.topk(probs, k=1, dim=-1)
+        # append to the sequence and continue
+        x = torch.cat((x, ix), dim=1)
+
+    return x
+
+
+
+#### clustering utils
+
+class KMeans(nn.Module):
+    def __init__(self, ncluster=512, nc=3, niter=10):
+        super().__init__()
+        self.ncluster = ncluster
+        self.nc = nc
+        self.niter = niter
+        self.shape = (3,32,32)
+        self.register_buffer("C", torch.zeros(self.ncluster,nc))
+        self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
+
+    def is_initialized(self):
+        return self.initialized.item() == 1
+
+    @torch.no_grad()
+    def initialize(self, x):
+        N, D = x.shape
+        assert D == self.nc, D
+        c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
+        for i in range(self.niter):
+            # assign all pixels to the closest codebook element
+            a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
+            # move each codebook element to be the mean of the pixels that assigned to it
+            c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
+            # re-assign any poorly positioned codebook elements
+            nanix = torch.any(torch.isnan(c), dim=1)
+            ndead = nanix.sum().item()
+            print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
+            c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
+
+        self.C.copy_(c)
+        self.initialized.fill_(1)
+
+
+    def forward(self, x, reverse=False, shape=None):
+        if not reverse:
+            # flatten
+            bs,c,h,w = x.shape
+            assert c == self.nc
+            x = x.reshape(bs,c,h*w,1)
+            C = self.C.permute(1,0)
+            C = C.reshape(1,c,1,self.ncluster)
+            a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
+            return a
+        else:
+            # flatten
+            bs, HW = x.shape
+            """
+            c = self.C.reshape( 1, self.nc,  1, self.ncluster)
+            c = c[bs*[0],:,:,:]
+            c = c[:,:,HW*[0],:]
+            x =      x.reshape(bs,       1, HW,             1)
+            x = x[:,3*[0],:,:]
+            x = torch.gather(c, dim=3, index=x)
+            """
+            x = self.C[x]
+            x = x.permute(0,2,1)
+            shape = shape if shape is not None else self.shape
+            x = x.reshape(bs, *shape)
+
+            return x
+
+
+if __name__ == '__main__':
+    import torch
+    from omegaconf import OmegaConf
+    import numpy as np
+    from tqdm import tqdm
+
+    device = torch.device('cuda:2')
+    torch.cuda.set_device(device)
+
+    cfg = OmegaConf.load('./configs/vggsound_transformer.yaml')
+
+    model = instantiate_from_config(cfg.model.params.transformer_config)
+    model = model.to(device)
+
+    mel_num = cfg.data.params.mel_num
+    spec_crop_len = cfg.data.params.spec_crop_len
+    feat_depth = cfg.data.params.feat_depth
+    feat_crop_len = cfg.data.params.feat_crop_len
+
+    gcd = np.gcd(mel_num, spec_crop_len)
+    z_idx_size = (2, int(mel_num / gcd) * int(spec_crop_len / gcd))
+
+    for i in tqdm(range(300)):
+        z_indices = torch.randint(0, cfg.model.params.transformer_config.params.GPT_config.vocab_size, z_idx_size).to(device)
+        c = torch.rand(2, feat_depth, feat_crop_len).to(device)
+        logits, loss, att = model(z_indices[:, :-1], feats=c)
diff --git a/foleycrafter/models/specvqgan/modules/transformer/permuter.py b/foleycrafter/models/specvqgan/modules/transformer/permuter.py
new file mode 100644
index 0000000000000000000000000000000000000000..94375a55efc302ec04da16676f19046e58aefa05
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/transformer/permuter.py
@@ -0,0 +1,295 @@
+import torch
+import torch.nn as nn
+import numpy as np
+
+TO_WARN_USER_ONCE = True
+
+class AbstractPermuter(nn.Module):
+    def __init__(self, *args, **kwargs):
+        super().__init__()
+    def forward(self, x, reverse=False):
+        raise NotImplementedError
+
+
+class Identity(AbstractPermuter):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x, reverse=False):
+        return x
+
+class ColumnMajor(AbstractPermuter):
+    '''Useful for spectrograms which are from left to right (features, time)'''
+    def __init__(self, H, W):
+        super().__init__()
+        self.H = H
+        self.W = W
+        idx = self.make_idx(H, W)
+        self.register_buffer('forward_shuffle_idx', idx)
+        self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+    def forward(self, x, reverse=False):
+        B, L = x.shape
+        L_idx = len(self.forward_shuffle_idx)
+        if L > L_idx:
+            # an ugly patch for "infinite" sampling because self.*_shuffle_idx are shorter
+            # otherwise even uglier patch in other places. 'if' is triggered only on sampling.
+            assert L % L_idx == 0 and L / L_idx == int(L / L_idx), f'L: {L}, L_idx: {L_idx}'
+            W_scale = L // L_idx
+            # print(f'Permuter is making a guess on the temp scale: {W_scale}. Ignore on "infinite" sampling')
+            idx = self.make_idx(self.H, self.W * W_scale)
+            if not reverse:
+                return x[:, idx]
+            else:
+                return x[:, torch.argsort(idx)]
+        else:
+            if not reverse:
+                return x[:, self.forward_shuffle_idx]
+            else:
+                return x[:, self.backward_shuffle_idx]
+
+    def make_idx(self, H, W):
+        idx = np.arange(H * W).reshape(H, W)
+        idx = idx.T
+        idx = torch.tensor(idx.ravel())
+        return idx
+
+class Subsample(AbstractPermuter):
+    def __init__(self, H, W):
+        super().__init__()
+        C = 1
+        indices = np.arange(H*W).reshape(C,H,W)
+        while min(H, W) > 1:
+            indices = indices.reshape(C,H//2,2,W//2,2)
+            indices = indices.transpose(0,2,4,1,3)
+            indices = indices.reshape(C*4,H//2, W//2)
+            H = H//2
+            W = W//2
+            C = C*4
+        assert H == W == 1
+        idx = torch.tensor(indices.ravel())
+        self.register_buffer('forward_shuffle_idx',
+                             nn.Parameter(idx, requires_grad=False))
+        self.register_buffer('backward_shuffle_idx',
+                             nn.Parameter(torch.argsort(idx), requires_grad=False))
+
+    def forward(self, x, reverse=False):
+        if not reverse:
+            return x[:, self.forward_shuffle_idx]
+        else:
+            return x[:, self.backward_shuffle_idx]
+
+
+def mortonify(i, j):
+    """(i,j) index to linear morton code"""
+    i = np.uint64(i)
+    j = np.uint64(j)
+
+    z = np.uint(0)
+
+    for pos in range(32):
+        z = (z |
+             ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
+             ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
+             )
+    return z
+
+
+class ZCurve(AbstractPermuter):
+    def __init__(self, H, W):
+        super().__init__()
+        reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
+        idx = np.argsort(reverseidx)
+        idx = torch.tensor(idx)
+        reverseidx = torch.tensor(reverseidx)
+        self.register_buffer('forward_shuffle_idx',
+                             idx)
+        self.register_buffer('backward_shuffle_idx',
+                             reverseidx)
+
+    def forward(self, x, reverse=False):
+        if not reverse:
+            return x[:, self.forward_shuffle_idx]
+        else:
+            return x[:, self.backward_shuffle_idx]
+
+
+class SpiralOut(AbstractPermuter):
+    def __init__(self, H, W):
+        super().__init__()
+        assert H == W
+        size = W
+        indices = np.arange(size*size).reshape(size,size)
+
+        i0 = size//2
+        j0 = size//2-1
+
+        i = i0
+        j = j0
+
+        idx = [indices[i0, j0]]
+        step_mult = 0
+        for c in range(1, size//2+1):
+            step_mult += 1
+            # steps left
+            for k in range(step_mult):
+                i = i - 1
+                j = j
+                idx.append(indices[i, j])
+
+            # step down
+            for k in range(step_mult):
+                i = i
+                j = j + 1
+                idx.append(indices[i, j])
+
+            step_mult += 1
+            if c < size//2:
+                # step right
+                for k in range(step_mult):
+                    i = i + 1
+                    j = j
+                    idx.append(indices[i, j])
+
+                # step up
+                for k in range(step_mult):
+                    i = i
+                    j = j - 1
+                    idx.append(indices[i, j])
+            else:
+                # end reached
+                for k in range(step_mult-1):
+                    i = i + 1
+                    idx.append(indices[i, j])
+
+        assert len(idx) == size*size
+        idx = torch.tensor(idx)
+        self.register_buffer('forward_shuffle_idx', idx)
+        self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+    def forward(self, x, reverse=False):
+        if not reverse:
+            return x[:, self.forward_shuffle_idx]
+        else:
+            return x[:, self.backward_shuffle_idx]
+
+
+class SpiralIn(AbstractPermuter):
+    def __init__(self, H, W):
+        super().__init__()
+        assert H == W
+        size = W
+        indices = np.arange(size*size).reshape(size,size)
+
+        i0 = size//2
+        j0 = size//2-1
+
+        i = i0
+        j = j0
+
+        idx = [indices[i0, j0]]
+        step_mult = 0
+        for c in range(1, size//2+1):
+            step_mult += 1
+            # steps left
+            for k in range(step_mult):
+                i = i - 1
+                j = j
+                idx.append(indices[i, j])
+
+            # step down
+            for k in range(step_mult):
+                i = i
+                j = j + 1
+                idx.append(indices[i, j])
+
+            step_mult += 1
+            if c < size//2:
+                # step right
+                for k in range(step_mult):
+                    i = i + 1
+                    j = j
+                    idx.append(indices[i, j])
+
+                # step up
+                for k in range(step_mult):
+                    i = i
+                    j = j - 1
+                    idx.append(indices[i, j])
+            else:
+                # end reached
+                for k in range(step_mult-1):
+                    i = i + 1
+                    idx.append(indices[i, j])
+
+        assert len(idx) == size*size
+        idx = idx[::-1]
+        idx = torch.tensor(idx)
+        self.register_buffer('forward_shuffle_idx', idx)
+        self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+    def forward(self, x, reverse=False):
+        if not reverse:
+            return x[:, self.forward_shuffle_idx]
+        else:
+            return x[:, self.backward_shuffle_idx]
+
+
+class Random(nn.Module):
+    def __init__(self, H, W):
+        super().__init__()
+        indices = np.random.RandomState(1).permutation(H*W)
+        idx = torch.tensor(indices.ravel())
+        self.register_buffer('forward_shuffle_idx', idx)
+        self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+    def forward(self, x, reverse=False):
+        if not reverse:
+            return x[:, self.forward_shuffle_idx]
+        else:
+            return x[:, self.backward_shuffle_idx]
+
+
+class AlternateParsing(AbstractPermuter):
+    def __init__(self, H, W):
+        super().__init__()
+        indices = np.arange(W*H).reshape(H,W)
+        for i in range(1, H, 2):
+            indices[i, :] = indices[i, ::-1]
+        idx = indices.flatten()
+        assert len(idx) == H*W
+        idx = torch.tensor(idx)
+        self.register_buffer('forward_shuffle_idx', idx)
+        self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+    def forward(self, x, reverse=False):
+        if not reverse:
+            return x[:, self.forward_shuffle_idx]
+        else:
+            return x[:, self.backward_shuffle_idx]
+
+
+if __name__ == "__main__":
+    p0 = AlternateParsing(16, 16)
+    print(p0.forward_shuffle_idx)
+    print(p0.backward_shuffle_idx)
+
+    x = torch.randint(0, 768, size=(11, 256))
+    y = p0(x)
+    xre = p0(y, reverse=True)
+    assert torch.equal(x, xre)
+
+    p1 = SpiralOut(2, 2)
+    print(p1.forward_shuffle_idx)
+    print(p1.backward_shuffle_idx)
+    x = torch.randint(0, 768, size=(11, 2*2))
+    y = p1(x)
+    xre = p1(y, reverse=True)
+    assert torch.equal(x, xre)
+
+    p2 = ColumnMajor(5, 53)
+    print(p2.forward_shuffle_idx)
+    print(p2.backward_shuffle_idx)
+    x = torch.randint(0, 768, size=(11, 5*53))
+    xre = p2(p2(x), reverse=True)
+    assert torch.equal(x, xre)
diff --git a/foleycrafter/models/specvqgan/modules/video_model/r2plus1d_18.py b/foleycrafter/models/specvqgan/modules/video_model/r2plus1d_18.py
new file mode 100644
index 0000000000000000000000000000000000000000..e526d7cb47bfcc50ba1c57ffb9e790c55a4f41fb
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/video_model/r2plus1d_18.py
@@ -0,0 +1,124 @@
+import sys
+
+import torch
+import torch.nn as nn
+import torchvision
+
+sys.path.insert(0, '.')  # nopep8
+from foleycrafter.models.specvqgan.modules.video_model.resnet import r2plus1d_18
+
+FPS = 15
+
+class Identity(nn.Module):
+    def __init__(self):
+        super(Identity, self).__init__()
+
+    def forward(self, x):
+        return x
+
+class r2plus1d18KeepTemp(nn.Module):
+
+    def __init__(self, pretrained=True):
+        super().__init__()
+
+        self.model = r2plus1d_18(pretrained=pretrained)
+
+        self.model.layer2[0].conv1[0][3] = nn.Conv3d(230, 128, kernel_size=(3, 1, 1), 
+            stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
+        self.model.layer2[0].downsample = nn.Sequential(
+            nn.Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False),
+            nn.BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+        )
+        self.model.layer3[0].conv1[0][3] = nn.Conv3d(460, 256, kernel_size=(3, 1, 1), 
+            stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
+        self.model.layer3[0].downsample = nn.Sequential(
+            nn.Conv3d(128, 256, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False),
+            nn.BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+        )
+        self.model.layer4[0].conv1[0][3] = nn.Conv3d(921, 512, kernel_size=(3, 1, 1), 
+            stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
+        self.model.layer4[0].downsample = nn.Sequential(
+            nn.Conv3d(256, 512, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False),
+            nn.BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+        )
+        self.model.avgpool = nn.AdaptiveAvgPool3d((None, 1, 1))
+        self.model.fc = Identity()
+
+        with torch.no_grad():
+            rand_input = torch.randn((1, 3, 30, 112, 112))
+            output = self.model(rand_input).detach().cpu()
+            print('Validate Video feature shape: ', output.shape) # (1, 512, 30)
+
+    def forward(self, x):
+        N = x.shape[0]
+        return self.model(x).reshape(N, 512, -1)
+
+    def eval(self):
+        return self
+    
+    def encode(self, c):
+        info = None, None, c
+        return c, None, info
+
+    def decode(self, c):
+        return c
+
+    def get_input(self, batch, k, drop_cond=False):
+        x = batch[k].cuda()
+        x = x.permute(0, 2, 1, 3, 4).to(memory_format=torch.contiguous_format) # (N, 3, T, 112, 112)
+        T = x.shape[2]
+        if drop_cond:
+            output = self.model(x) # (N, 512, T)
+        else:
+            cond_x = x[:, :, :T//2] # (N, 3, T//2, 112, 112)
+            x = x[:, :, T//2:] # (N, 3, T//2, 112, 112)
+            cond_feat = self.model(cond_x) # (N, 512, T//2)
+            feat = self.model(x) # (N, 512, T//2)
+            output = torch.cat([cond_feat, feat], dim=-1) # (N, 512, T)
+        assert output.shape[2] == T
+        return output
+
+
+class resnet50(nn.Module):
+
+    def __init__(self, pretrained=True):
+        super().__init__()
+        self.model = torchvision.models.resnet50(pretrained=pretrained)
+        self.model.fc = nn.Identity()
+        # freeze resnet 50 model
+        for params in self.model.parameters():
+            params.requires_grad = False
+
+    def forward(self, x):
+        N = x.shape[0]
+        return self.model(x).reshape(N, 2048)
+
+    def eval(self):
+        return self
+    
+    def encode(self, c):
+        info = None, None, c
+        return c, None, info
+
+    def decode(self, c):
+        return c
+
+    def get_input(self, batch, k, drop_cond=False):
+        x = batch[k].cuda()
+        x = x.permute(0, 2, 1, 3, 4).to(memory_format=torch.contiguous_format) # (N, 3, T, 112, 112)
+        T = x.shape[2]
+        feats = []
+        for t in range(T):
+            xt = x[:, :, t]
+            feats.append(self.model(xt))
+        output = torch.stack(feats, dim=-1)
+        assert output.shape[2] == T
+        return output
+
+
+
+if __name__ == '__main__':
+    model = r2plus1d18KeepTemp(False).cuda()
+    x = {'input': torch.randn((1, 60, 3, 112, 112))}
+    out = model.get_input(x, 'input')
+    print(out.shape)
diff --git a/foleycrafter/models/specvqgan/modules/video_model/resnet.py b/foleycrafter/models/specvqgan/modules/video_model/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5023327f7e53a59fa940983cccb84483a91d581
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/video_model/resnet.py
@@ -0,0 +1,344 @@
+import torch.nn as nn
+
+from torchvision.models.utils import load_state_dict_from_url
+
+
+__all__ = ['r3d_18', 'mc3_18', 'r2plus1d_18']
+
+model_urls = {
+    'r3d_18': 'https://download.pytorch.org/models/r3d_18-b3b3357e.pth',
+    'mc3_18': 'https://download.pytorch.org/models/mc3_18-a90a0ba3.pth',
+    'r2plus1d_18': 'https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth',
+}
+
+
+class Conv3DSimple(nn.Conv3d):
+    def __init__(self,
+                 in_planes,
+                 out_planes,
+                 midplanes=None,
+                 stride=1,
+                 padding=1):
+
+        super(Conv3DSimple, self).__init__(
+            in_channels=in_planes,
+            out_channels=out_planes,
+            kernel_size=(3, 3, 3),
+            stride=stride,
+            padding=padding,
+            bias=False)
+
+    @staticmethod
+    def get_downsample_stride(stride):
+        return stride, stride, stride
+
+
+class Conv2Plus1D(nn.Sequential):
+
+    def __init__(self,
+                 in_planes,
+                 out_planes,
+                 midplanes,
+                 stride=1,
+                 padding=1):
+        super(Conv2Plus1D, self).__init__(
+            nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3),
+                      stride=(1, stride, stride), padding=(0, padding, padding),
+                      bias=False),
+            nn.BatchNorm3d(midplanes),
+            nn.ReLU(inplace=True),
+            nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1),
+                      stride=(stride, 1, 1), padding=(padding, 0, 0),
+                      bias=False))
+
+    @staticmethod
+    def get_downsample_stride(stride):
+        return stride, stride, stride
+
+
+class Conv3DNoTemporal(nn.Conv3d):
+
+    def __init__(self,
+                 in_planes,
+                 out_planes,
+                 midplanes=None,
+                 stride=1,
+                 padding=1):
+
+        super(Conv3DNoTemporal, self).__init__(
+            in_channels=in_planes,
+            out_channels=out_planes,
+            kernel_size=(1, 3, 3),
+            stride=(1, stride, stride),
+            padding=(0, padding, padding),
+            bias=False)
+
+    @staticmethod
+    def get_downsample_stride(stride):
+        return 1, stride, stride
+
+
+class BasicBlock(nn.Module):
+
+    expansion = 1
+
+    def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
+        midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
+
+        super(BasicBlock, self).__init__()
+        self.conv1 = nn.Sequential(
+            conv_builder(inplanes, planes, midplanes, stride),
+            nn.BatchNorm3d(planes),
+            nn.ReLU(inplace=True)
+        )
+        self.conv2 = nn.Sequential(
+            conv_builder(planes, planes, midplanes),
+            nn.BatchNorm3d(planes)
+        )
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.conv2(out)
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
+
+        super(Bottleneck, self).__init__()
+        midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
+
+        # 1x1x1
+        self.conv1 = nn.Sequential(
+            nn.Conv3d(inplanes, planes, kernel_size=1, bias=False),
+            nn.BatchNorm3d(planes),
+            nn.ReLU(inplace=True)
+        )
+        # Second kernel
+        self.conv2 = nn.Sequential(
+            conv_builder(planes, planes, midplanes, stride),
+            nn.BatchNorm3d(planes),
+            nn.ReLU(inplace=True)
+        )
+
+        # 1x1x1
+        self.conv3 = nn.Sequential(
+            nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),
+            nn.BatchNorm3d(planes * self.expansion)
+        )
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.conv2(out)
+        out = self.conv3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class BasicStem(nn.Sequential):
+    """The default conv-batchnorm-relu stem
+    """
+    def __init__(self):
+        super(BasicStem, self).__init__(
+            nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2),
+                      padding=(1, 3, 3), bias=False),
+            nn.BatchNorm3d(64),
+            nn.ReLU(inplace=True))
+
+
+class R2Plus1dStem(nn.Sequential):
+    """R(2+1)D stem is different than the default one as it uses separated 3D convolution
+    """
+    def __init__(self):
+        super(R2Plus1dStem, self).__init__(
+            nn.Conv3d(3, 45, kernel_size=(1, 7, 7),
+                      stride=(1, 2, 2), padding=(0, 3, 3),
+                      bias=False),
+            nn.BatchNorm3d(45),
+            nn.ReLU(inplace=True),
+            nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
+                      stride=(1, 1, 1), padding=(1, 0, 0),
+                      bias=False),
+            nn.BatchNorm3d(64),
+            nn.ReLU(inplace=True))
+
+
+class VideoResNet(nn.Module):
+
+    def __init__(self, block, conv_makers, layers,
+                 stem, num_classes=400,
+                 zero_init_residual=False):
+        """Generic resnet video generator.
+
+        Args:
+            block (nn.Module): resnet building block
+            conv_makers (list(functions)): generator function for each layer
+            layers (List[int]): number of blocks per layer
+            stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
+            num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
+            zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
+        """
+        super(VideoResNet, self).__init__()
+        self.inplanes = 64
+
+        self.stem = stem()
+
+        self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)
+        self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2)
+        self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2)
+
+        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
+        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+        # init weights
+        self._initialize_weights()
+
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, Bottleneck):
+                    nn.init.constant_(m.bn3.weight, 0)
+
+    def forward(self, x):
+        x = self.stem(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        # Flatten the layer to fc
+        # x = x.flatten(1)
+        # x = self.fc(x)
+        N = x.shape[0]
+        x = x.squeeze()
+        if N == 1:
+            x = x[None]
+
+        return x
+
+    def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
+        downsample = None
+
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            ds_stride = conv_builder.get_downsample_stride(stride)
+            downsample = nn.Sequential(
+                nn.Conv3d(self.inplanes, planes * block.expansion,
+                          kernel_size=1, stride=ds_stride, bias=False),
+                nn.BatchNorm3d(planes * block.expansion)
+            )
+        layers = []
+        layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))
+
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes, conv_builder))
+
+        return nn.Sequential(*layers)
+
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv3d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out',
+                                        nonlinearity='relu')
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.BatchNorm3d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.Linear):
+                nn.init.normal_(m.weight, 0, 0.01)
+                nn.init.constant_(m.bias, 0)
+
+
+def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
+    model = VideoResNet(**kwargs)
+
+    if pretrained:
+        state_dict = load_state_dict_from_url(model_urls[arch],
+                                              progress=progress)
+        model.load_state_dict(state_dict)
+    return model
+
+
+def r3d_18(pretrained=False, progress=True, **kwargs):
+    """Construct 18 layer Resnet3D model as in
+    https://arxiv.org/abs/1711.11248
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on Kinetics-400
+        progress (bool): If True, displays a progress bar of the download to stderr
+
+    Returns:
+        nn.Module: R3D-18 network
+    """
+
+    return _video_resnet('r3d_18',
+                         pretrained, progress,
+                         block=BasicBlock,
+                         conv_makers=[Conv3DSimple] * 4,
+                         layers=[2, 2, 2, 2],
+                         stem=BasicStem, **kwargs)
+
+
+def mc3_18(pretrained=False, progress=True, **kwargs):
+    """Constructor for 18 layer Mixed Convolution network as in
+    https://arxiv.org/abs/1711.11248
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on Kinetics-400
+        progress (bool): If True, displays a progress bar of the download to stderr
+
+    Returns:
+        nn.Module: MC3 Network definition
+    """
+    return _video_resnet('mc3_18',
+                         pretrained, progress,
+                         block=BasicBlock,
+                         conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3,
+                         layers=[2, 2, 2, 2],
+                         stem=BasicStem, **kwargs)
+
+
+def r2plus1d_18(pretrained=False, progress=True, **kwargs):
+    """Constructor for the 18 layer deep R(2+1)D network as in
+    https://arxiv.org/abs/1711.11248
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on Kinetics-400
+        progress (bool): If True, displays a progress bar of the download to stderr
+
+    Returns:
+        nn.Module: R(2+1)D-18 network
+    """
+    return _video_resnet('r2plus1d_18',
+                         pretrained, progress,
+                         block=BasicBlock,
+                         conv_makers=[Conv2Plus1D] * 4,
+                         layers=[2, 2, 2, 2],
+                         stem=R2Plus1dStem, **kwargs)
diff --git a/foleycrafter/models/specvqgan/modules/vqvae/quantize.py b/foleycrafter/models/specvqgan/modules/vqvae/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..296df15e68c5810368d24cec1ce3abf9db1dd237
--- /dev/null
+++ b/foleycrafter/models/specvqgan/modules/vqvae/quantize.py
@@ -0,0 +1,131 @@
+import torch
+import torch.nn as nn
+
+
+class VectorQuantizer(nn.Module):
+    """
+    see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
+    ____________________________________________
+    Discretization bottleneck part of the VQ-VAE.
+    Inputs:
+    - n_e : number of embeddings
+    - e_dim : dimension of embedding
+    - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+    _____________________________________________
+    """
+
+    def __init__(self, n_e, e_dim, beta):
+        super(VectorQuantizer, self).__init__()
+        self.n_e = n_e
+        self.e_dim = e_dim
+        self.beta = beta
+
+        self.embedding = nn.Embedding(self.n_e, self.e_dim)
+        self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+        # better inheritence properties (so that when VectorQuantizer1d() inherits it, only these will be
+        # changed)
+        self.permute_order_in = [0, 2, 3, 1]
+        self.permute_order_out = [0, 3, 1, 2]
+
+    def forward(self, z):
+        """
+        Inputs the output of the encoder network z and maps it to a discrete
+        one-hot vector that is the index of the closest embedding vector e_j
+        z (continuous) -> z_q (discrete)
+        2d: z.shape = (batch, channel, height, width)
+        1d: z.shape = (batch, channel, time)
+        quantization pipeline:
+            1. get encoder input 2d: (B,C,H,W) or 1d: (B, C, T)
+            2. flatten input to 2d: (B*H*W,C) or 1d: (B*T, C)
+        """
+        # reshape z -> (batch, height, width, channel) or (batch, time, channel) and flatten
+        z = z.permute(self.permute_order_in).contiguous()
+        z_flattened = z.view(-1, self.e_dim)
+        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+            torch.sum(self.embedding.weight**2, dim=1) - 2 * \
+            torch.matmul(z_flattened, self.embedding.weight.t())
+
+        ## could possible replace this here
+        # #\start...
+        # find closest encodings
+        min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+
+        min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z)
+        min_encodings.scatter_(1, min_encoding_indices, 1)
+
+        # dtype min encodings: torch.float32
+        # min_encodings shape: torch.Size([2048, 512])
+        # min_encoding_indices.shape: torch.Size([2048, 1])
+
+        # get quantized latent vectors
+        z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+        #.........\end
+
+        # with:
+        # .........\start
+        #min_encoding_indices = torch.argmin(d, dim=1)
+        #z_q = self.embedding(min_encoding_indices)
+        # ......\end......... (TODO)
+
+        # compute loss for embedding
+        loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
+
+        # preserve gradients
+        z_q = z + (z_q - z).detach()
+
+        # perplexity
+        e_mean = torch.mean(min_encodings, dim=0)
+        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+
+        # reshape back to match original input shape
+        z_q = z_q.permute(self.permute_order_out).contiguous()
+
+        return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+    def get_codebook_entry(self, indices, shape):
+        # shape specifying (batch, height, width, channel)
+        # TODO: check for more easy handling with nn.Embedding
+        min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
+        min_encodings.scatter_(1, indices[:, None], 1)
+
+        # get quantized latent vectors
+        z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+        if shape is not None:
+            z_q = z_q.view(shape)
+
+            # reshape back to match original input shape
+            z_q = z_q.permute(self.permute_order_out).contiguous()
+
+        return z_q
+
+class VectorQuantizer1d(VectorQuantizer):
+
+    def __init__(self, n_embed, embed_dim, beta=0.25):
+        super().__init__(n_embed, embed_dim, beta)
+        self.permute_order_in = [0, 2, 1]
+        self.permute_order_out = [0, 2, 1]
+
+
+if __name__ == '__main__':
+    quantize = VectorQuantizer1d(n_embed=1024, embed_dim=256, beta=0.25)
+
+    # 1d Input (features)
+    enc_outputs = torch.rand(6, 256, 53)
+    quant, emb_loss, info = quantize(enc_outputs)
+    print(quant.shape)
+
+    quantize = VectorQuantizer(n_e=1024, e_dim=256, beta=0.25)
+
+    # Audio
+    enc_outputs = torch.rand(4, 256, 5, 53)
+    quant, emb_loss, info = quantize(enc_outputs)
+    print(quant.shape)
+
+    # Image
+    enc_outputs = torch.rand(4, 256, 16, 16)
+    quant, emb_loss, info = quantize(enc_outputs)
+    print(quant.shape)
diff --git a/foleycrafter/models/specvqgan/onset_baseline/config/__init__.py b/foleycrafter/models/specvqgan/onset_baseline/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eaee0a833230c377934c809dc4a1c65c562002fe
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/config/__init__.py
@@ -0,0 +1 @@
+from .config import init_args
\ No newline at end of file
diff --git a/foleycrafter/models/specvqgan/onset_baseline/config/config.py b/foleycrafter/models/specvqgan/onset_baseline/config/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..631ef2653af7737b6a0bbcfbe1f4a40dad7b8d00
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/config/config.py
@@ -0,0 +1,51 @@
+import argparse
+import numpy as np
+
+def init_args(return_parser=False): 
+    parser = argparse.ArgumentParser(description="""Configure""")
+
+    # basic configuration 
+    parser.add_argument('--exp', type=str, default='test101',
+                        help='checkpoint folder')
+
+    parser.add_argument('--epochs', type=int, default=100,
+                        help='number of total epochs to run (default: 90)')
+
+    parser.add_argument('--start_epoch', default=0, type=int,
+                        help='manual epoch number (useful on restarts) (default: 0)')
+    parser.add_argument('--resume', default='', type=str,
+                        metavar='PATH', help='path to checkpoint (default: None)')
+    parser.add_argument('--resume_optim', default=False, action='store_true')
+    parser.add_argument('--save_step', default=1, type=int)
+    parser.add_argument('--valid_step', default=1, type=int)
+    
+
+    # Dataloader parameter
+    parser.add_argument('--max_sample', default=-1, type=int)
+    parser.add_argument('--repeat', default=1, type=int)
+    parser.add_argument('--num_workers', type=int, default=8)
+    parser.add_argument('--batch_size', default=24, type=int)
+
+    # network parameters
+    parser.add_argument('--pretrained', default=False, action='store_true')
+
+    # optimizer parameters
+    parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
+    parser.add_argument('--momentum', type=float, default=0.9)
+    parser.add_argument('--weight_decay', default=5e-4,
+                        type=float, help='weight decay (default: 5e-4)')
+    parser.add_argument('--optim', type=str, default='Adam',
+                        choices=['SGD', 'Adam'])
+    parser.add_argument('--schedule', type=str, default='cos', choices=['none', 'cos', 'step'], required=False)
+
+    parser.add_argument('--aug_img', default=False, action='store_true')
+    parser.add_argument('--test_mode', default=False, action='store_true')
+
+
+    if return_parser:
+        return parser
+
+    # global args
+    args = parser.parse_args()
+
+    return args
diff --git a/foleycrafter/models/specvqgan/onset_baseline/data/__init__.py b/foleycrafter/models/specvqgan/onset_baseline/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4eb49348adeb79491b7c8df13f89234951836d97
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/data/__init__.py
@@ -0,0 +1,2 @@
+from .greatesthit import *
+from .impactset import *
\ No newline at end of file
diff --git a/foleycrafter/models/specvqgan/onset_baseline/data/greatesthit.py b/foleycrafter/models/specvqgan/onset_baseline/data/greatesthit.py
new file mode 100644
index 0000000000000000000000000000000000000000..cef9381dbf179941fd82ae9c8069f872c958a8ed
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/data/greatesthit.py
@@ -0,0 +1,158 @@
+from data import *
+import pdb
+from utils import sound, sourcesep
+import csv
+import glob
+import h5py
+import io
+import json
+import librosa
+import numpy as np
+import os
+import pickle
+from PIL import Image
+from PIL import ImageFilter
+import random
+import scipy
+import soundfile as sf
+import time
+from tqdm import tqdm
+import glob
+import cv2
+
+import torch
+import torch.nn as nn
+import torchaudio
+import torchvision.transforms as transforms
+# import kornia as K
+import sys
+sys.path.append('..')
+
+
+class GreatestHitDataset(object):
+    def __init__(self, args, split='train'):
+        self.split = split
+        if split == 'train':
+            list_sample = './data/greatesthit_train_2.00.json'
+        elif split == 'val':
+            list_sample = './data/greatesthit_valid_2.00.json'
+        elif split == 'test':
+            list_sample = './data/greatesthit_test_2.00.json'
+
+        # save args parameter
+        self.repeat = args.repeat if split == 'train' else 1
+        self.max_sample = args.max_sample
+
+        self.video_transform = transforms.Compose(
+            self.generate_video_transform(args))
+        
+        if isinstance(list_sample, str):
+            with open(list_sample, "r") as f:
+                self.list_sample = json.load(f)
+
+        if self.max_sample > 0:
+            self.list_sample = self.list_sample[0:self.max_sample]
+        self.list_sample = self.list_sample * self.repeat
+
+        random.seed(1234)
+        np.random.seed(1234)
+        num_sample = len(self.list_sample)
+        if self.split == 'train':
+            random.shuffle(self.list_sample)
+
+        # self.class_dist = self.unbalanced_dist()
+        print('Greatesthit Dataloader: # sample of {}: {}'.format(self.split, num_sample))
+
+
+    def __getitem__(self, index):
+        # import pdb; pdb.set_trace()
+        info = self.list_sample[index].split('_')[0]
+        video_path = os.path.join('data', 'greatesthit', 'greatesthit_processed', info)
+        frame_path = os.path.join(video_path, 'frames')
+        audio_path = os.path.join(video_path, 'audio')
+        audio_path = glob.glob(f"{audio_path}/*.wav")[0]
+        # Unused, consider remove
+        meta_path = os.path.join(video_path, 'hit_record.json')
+        if os.path.exists(meta_path):
+            with open(meta_path, "r") as f:
+                meta_dict = json.load(f)
+
+        audio, audio_sample_rate = sf.read(audio_path, start=0, stop=1000, dtype='float64', always_2d=True)
+        frame_rate = 15
+        duration = 2.0
+        frame_list = glob.glob(f'{frame_path}/*.jpg')
+        frame_list.sort()
+
+        hit_time = float(self.list_sample[index].split('_')[-1]) / 22050
+        if self.split == 'train':
+            frame_start = hit_time * frame_rate + np.random.randint(10) - 5
+            frame_start = max(frame_start, 0)
+            frame_start = min(frame_start, len(frame_list) - duration * frame_rate)
+            
+        else:
+            frame_start = hit_time * frame_rate
+            frame_start = max(frame_start, 0)
+            frame_start = min(frame_start, len(frame_list) - duration * frame_rate)
+        frame_start = int(frame_start)
+        
+        frame_list = frame_list[frame_start: int(
+            frame_start + np.ceil(duration * frame_rate))]
+        audio_start = int(frame_start / frame_rate * audio_sample_rate)
+        audio_end = int(audio_start + duration * audio_sample_rate)
+
+        imgs = self.read_image(frame_list)
+        audio, audio_rate = sf.read(audio_path, start=audio_start, stop=audio_end, dtype='float64', always_2d=True)
+        audio = audio.mean(-1)
+
+        onsets = librosa.onset.onset_detect(y=audio, sr=audio_rate, units='time', delta=0.3)
+        onsets = np.rint(onsets * frame_rate).astype(int)
+        onsets[onsets>29] = 29
+        label = torch.zeros(len(frame_list))
+        label[onsets] = 1
+
+        batch = {
+            'frames': imgs,
+            'label': label
+        }
+        return batch
+
+    def getitem_test(self, index):
+        self.__getitem__(index)
+
+    def __len__(self):
+        return len(self.list_sample)
+
+
+    def read_image(self, frame_list):
+        imgs = []
+        convert_tensor = transforms.ToTensor()
+        for img_path in frame_list:
+            image = Image.open(img_path).convert('RGB')
+            image = convert_tensor(image)
+            imgs.append(image.unsqueeze(0))
+        # (T, C, H ,W)
+        imgs = torch.cat(imgs, dim=0).squeeze()
+        imgs = self.video_transform(imgs)
+        imgs = imgs.permute(1, 0, 2, 3)
+        # (C, T, H ,W)
+        return imgs
+
+    def generate_video_transform(self, args):
+        resize_funct = transforms.Resize((128, 128))
+        if self.split == 'train':
+            crop_funct = transforms.RandomCrop(
+                (112, 112))
+            color_funct = transforms.ColorJitter(
+                brightness=0.1, contrast=0.1, saturation=0, hue=0)
+        else:
+            crop_funct = transforms.CenterCrop(
+                (112, 112))
+            color_funct = transforms.Lambda(lambda img: img)
+
+        vision_transform_list = [
+            resize_funct,
+            crop_funct,
+            color_funct,
+            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        ]
+        return vision_transform_list
diff --git a/foleycrafter/models/specvqgan/onset_baseline/data/impactset.py b/foleycrafter/models/specvqgan/onset_baseline/data/impactset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6d3d737176c2b8a3753785edd3951e6baac174b
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/data/impactset.py
@@ -0,0 +1,145 @@
+from data import *
+import pdb
+from utils import sound, sourcesep
+import csv
+import glob
+import h5py
+import io
+import json
+import librosa
+import numpy as np
+import os
+import pickle
+from PIL import Image
+from PIL import ImageFilter
+import random
+import scipy
+import soundfile as sf
+import time
+from tqdm import tqdm
+import glob
+import cv2
+
+import torch
+import torch.nn as nn
+import torchaudio
+import torchvision.transforms as transforms
+# import kornia as K
+import sys
+sys.path.append('..')
+
+
+class CountixAVDataset(object):
+    def __init__(self, args, split='train'):
+        self.split = split
+        if split == 'train':
+            list_sample = './data/countixAV_train.json'
+        elif split == 'val':
+            list_sample = './data/countixAV_val.json'
+        elif split == 'test':
+            list_sample = './data/countixAV_test.json'
+
+        # save args parameter
+        self.repeat = args.repeat if split == 'train' else 1
+        self.max_sample = args.max_sample
+
+        self.video_transform = transforms.Compose(
+            self.generate_video_transform(args))
+        
+        if isinstance(list_sample, str):
+            with open(list_sample, "r") as f:
+                self.list_sample = json.load(f)
+
+        if self.max_sample > 0:
+            self.list_sample = self.list_sample[0:self.max_sample]
+        self.list_sample = self.list_sample * self.repeat
+
+        random.seed(1234)
+        np.random.seed(1234)
+        num_sample = len(self.list_sample)
+        if self.split == 'train':
+            random.shuffle(self.list_sample)
+
+        # self.class_dist = self.unbalanced_dist()
+        print('Greatesthit Dataloader: # sample of {}: {}'.format(self.split, num_sample))
+
+
+    def __getitem__(self, index):
+        # import pdb; pdb.set_trace()
+        info = self.list_sample[index]
+        video_path = os.path.join('data', 'ImpactSet', 'impactset-proccess-resize', info)
+        frame_path = os.path.join(video_path, 'frames')
+        audio_path = os.path.join(video_path, 'audio')
+        audio_path = glob.glob(f"{audio_path}/*_denoised.wav")[0]
+
+        audio, audio_sample_rate = sf.read(audio_path, start=0, stop=1000, dtype='float64', always_2d=True)
+        frame_rate = 15
+        duration = 2.0
+        frame_list = glob.glob(f'{frame_path}/*.jpg')
+        frame_list.sort()
+
+        frame_start = random.randint(0, len(frame_list))
+        frame_start = min(frame_start, len(frame_list) - duration * frame_rate)
+        frame_start = int(frame_start)
+        
+        frame_list = frame_list[frame_start: int(
+            frame_start + np.ceil(duration * frame_rate))]
+        audio_start = int(frame_start / frame_rate * audio_sample_rate)
+        audio_end = int(audio_start + duration * audio_sample_rate)
+
+        imgs = self.read_image(frame_list)
+        audio, audio_rate = sf.read(audio_path, start=audio_start, stop=audio_end, dtype='float64', always_2d=True)
+        audio = audio.mean(-1)
+
+        onsets = librosa.onset.onset_detect(y=audio, sr=audio_rate, units='time', delta=0.3)
+        onsets = np.rint(onsets * frame_rate).astype(int)
+        onsets[onsets>29] = 29
+        label = torch.zeros(len(frame_list))
+        label[onsets] = 1
+
+        batch = {
+            'frames': imgs,
+            'label': label
+        }
+        return batch
+
+    def getitem_test(self, index):
+        self.__getitem__(index)
+
+    def __len__(self):
+        return len(self.list_sample)
+
+
+    def read_image(self, frame_list):
+        imgs = []
+        convert_tensor = transforms.ToTensor()
+        for img_path in frame_list:
+            image = Image.open(img_path).convert('RGB')
+            image = convert_tensor(image)
+            imgs.append(image.unsqueeze(0))
+        # (T, C, H ,W)
+        imgs = torch.cat(imgs, dim=0).squeeze()
+        imgs = self.video_transform(imgs)
+        imgs = imgs.permute(1, 0, 2, 3)
+        # (C, T, H ,W)
+        return imgs
+
+    def generate_video_transform(self, args):
+        resize_funct = transforms.Resize((128, 128))
+        if self.split == 'train':
+            crop_funct = transforms.RandomCrop(
+                (112, 112))
+            color_funct = transforms.ColorJitter(
+                brightness=0.1, contrast=0.1, saturation=0, hue=0)
+        else:
+            crop_funct = transforms.CenterCrop(
+                (112, 112))
+            color_funct = transforms.Lambda(lambda img: img)
+
+        vision_transform_list = [
+            resize_funct,
+            crop_funct,
+            color_funct,
+            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        ]
+        return vision_transform_list
\ No newline at end of file
diff --git a/foleycrafter/models/specvqgan/onset_baseline/data/transforms.py b/foleycrafter/models/specvqgan/onset_baseline/data/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..21834486ee6324245b49a961fc963a5af927e91a
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/data/transforms.py
@@ -0,0 +1,298 @@
+import torch
+import torchaudio
+import torchaudio.functional
+from torchvision import transforms
+import torchvision.transforms.functional as F
+import torch.nn as nn
+from PIL import Image
+import numpy as np
+import math
+import random
+
+
+class ResizeShortSide(object):
+    def __init__(self, size):
+        super().__init__()
+        self.size = size
+
+    def __call__(self, x):
+        '''
+        x must be PIL.Image
+        '''
+        w, h = x.size
+        short_side = min(w, h)
+        w_target = int((w / short_side) * self.size)
+        h_target = int((h / short_side) * self.size)
+        return x.resize((w_target, h_target))
+
+
+class RandomResizedCrop3D(nn.Module):
+    """Crop the given series of images to random size and aspect ratio.
+    The image can be a PIL Images or a Tensor, in which case it is expected
+    to have [N, ..., H, W] shape, where ... means an arbitrary number of leading dimensions
+
+    A crop of random size (default: of 0.08 to 1.0) of the original size and a random
+    aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
+    is finally resized to given size.
+    This is popularly used to train the Inception networks.
+
+    Args:
+      size (int or sequence): expected output size of each edge. If size is an
+        int instead of sequence like (h, w), a square output size ``(size, size)`` is
+        made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
+      scale (tuple of float): range of size of the origin size cropped
+      ratio (tuple of float): range of aspect ratio of the origin aspect ratio cropped.
+      interpolation (int): Desired interpolation enum defined by `filters`_.
+        Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR``
+        and ``PIL.Image.BICUBIC`` are supported.
+    """
+
+    def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=transforms.InterpolationMode.BILINEAR):
+        super().__init__()
+        if isinstance(size, tuple) and len(size) == 2:
+            self.size = size
+        else:
+            self.size = (size, size)
+
+        self.interpolation = interpolation
+        self.scale = scale
+        self.ratio = ratio
+
+    @staticmethod
+    def get_params(img, scale, ratio):
+        """Get parameters for ``crop`` for a random sized crop.
+
+        Args:
+          img (PIL Image or Tensor): Input image.
+          scale (list): range of scale of the origin size cropped
+          ratio (list): range of aspect ratio of the origin aspect ratio cropped
+
+        Returns:
+          tuple: params (i, j, h, w) to be passed to ``crop`` for a random
+            sized crop.
+        """
+        width, height = img.size
+        area = height * width
+
+        for _ in range(10):
+            target_area = area * \
+                torch.empty(1).uniform_(scale[0], scale[1]).item()
+            log_ratio = torch.log(torch.tensor(ratio))
+            aspect_ratio = torch.exp(
+                torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
+            ).item()
+
+            w = int(round(math.sqrt(target_area * aspect_ratio)))
+            h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+            if 0 < w <= width and 0 < h <= height:
+                i = torch.randint(0, height - h + 1, size=(1,)).item()
+                j = torch.randint(0, width - w + 1, size=(1,)).item()
+                return i, j, h, w
+
+        # Fallback to central crop
+        in_ratio = float(width) / float(height)
+        if in_ratio < min(ratio):
+            w = width
+            h = int(round(w / min(ratio)))
+        elif in_ratio > max(ratio):
+            h = height
+            w = int(round(h * max(ratio)))
+        else:  # whole image
+            w = width
+            h = height
+        i = (height - h) // 2
+        j = (width - w) // 2
+        return i, j, h, w
+
+    def forward(self, imgs):
+        """
+        Args:
+          img (PIL Image or Tensor): Image to be cropped and resized.
+
+        Returns:
+          PIL Image or Tensor: Randomly cropped and resized image.
+        """
+        i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio)
+        return [F.resized_crop(img, i, j, h, w, self.size, self.interpolation) for img in imgs]
+
+
+class Resize3D(object):
+    def __init__(self, size):
+        super().__init__()
+        self.size = size
+
+    def __call__(self, imgs):
+        '''
+        x must be PIL.Image
+        '''
+        return [x.resize((self.size, self.size)) for x in imgs]
+
+
+class RandomHorizontalFlip3D(object):
+    def __init__(self, p=0.5):
+        super().__init__()
+        self.p = p
+
+    def __call__(self, imgs):
+        '''
+        x must be PIL.Image
+        '''
+        if np.random.rand() < self.p:
+            return [x.transpose(Image.FLIP_LEFT_RIGHT) for x in imgs]
+        else:
+            return imgs
+
+
+class ColorJitter3D(torch.nn.Module):
+    """Randomly change the brightness, contrast and saturation of an image.
+
+    Args:
+    brightness (float or tuple of float (min, max)): How much to jitter brightness.
+        brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
+        or the given [min, max]. Should be non negative numbers.
+    contrast (float or tuple of float (min, max)): How much to jitter contrast.
+        contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
+        or the given [min, max]. Should be non negative numbers.
+    saturation (float or tuple of float (min, max)): How much to jitter saturation.
+        saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
+        or the given [min, max]. Should be non negative numbers.
+    hue (float or tuple of float (min, max)): How much to jitter hue.
+        hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
+        Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
+    """
+
+    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
+        super().__init__()
+        self.brightness = (1-brightness, 1+brightness)
+        self.contrast = (1-contrast, 1+contrast)
+        self.saturation = (1-saturation, 1+saturation)
+        self.hue = (0-hue, 0+hue)
+
+    @staticmethod
+    def get_params(brightness, contrast, saturation, hue):
+        """Get a randomized transform to be applied on image.
+
+        Arguments are same as that of __init__.
+
+        Returns:
+            Transform which randomly adjusts brightness, contrast and
+            saturation in a random order.
+        """
+        tfs = []
+
+        if brightness is not None:
+            brightness_factor = random.uniform(brightness[0], brightness[1])
+            tfs.append(transforms.Lambda(
+                lambda img: F.adjust_brightness(img, brightness_factor)))
+
+        if contrast is not None:
+            contrast_factor = random.uniform(contrast[0], contrast[1])
+            tfs.append(transforms.Lambda(
+                lambda img: F.adjust_contrast(img, contrast_factor)))
+
+        if saturation is not None:
+            saturation_factor = random.uniform(saturation[0], saturation[1])
+            tfs.append(transforms.Lambda(
+                lambda img: F.adjust_saturation(img, saturation_factor)))
+
+        if hue is not None:
+            hue_factor = random.uniform(hue[0], hue[1])
+            tfs.append(transforms.Lambda(
+                lambda img: F.adjust_hue(img, hue_factor)))
+
+        random.shuffle(tfs)
+        transform = transforms.Compose(tfs)
+
+        return transform
+
+    def forward(self, imgs):
+        """
+        Args:
+          img (PIL Image or Tensor): Input image.
+
+        Returns:
+          PIL Image or Tensor: Color jittered image.
+        """
+        transform = self.get_params(
+            self.brightness, self.contrast, self.saturation, self.hue)
+        return [transform(img) for img in imgs]
+
+
+class ToTensor3D(object):
+    def __init__(self):
+        super().__init__()
+
+    def __call__(self, imgs):
+        '''
+        x must be PIL.Image
+        '''
+        return [F.to_tensor(img) for img in imgs]
+
+
+class Normalize3D(object):
+    def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False):
+        super().__init__()
+        self.mean = mean
+        self.std = std
+        self.inplace = inplace
+
+    def __call__(self, imgs):
+        '''
+        x must be PIL.Image
+        '''
+        return [F.normalize(img, self.mean, self.std, self.inplace) for img in imgs]
+
+
+class CenterCrop3D(object):
+    def __init__(self, size):
+        super().__init__()
+        self.size = size
+
+    def __call__(self, imgs):
+        '''
+        x must be PIL.Image
+        '''
+        return [F.center_crop(img, self.size) for img in imgs]
+
+
+class FrequencyMasking(object):
+    def __init__(self, freq_mask_param: int, iid_masks: bool = False):
+        super().__init__()
+        self.masking = torchaudio.transforms.FrequencyMasking(freq_mask_param, iid_masks)
+
+    def __call__(self, item):
+        if 'cond_image' in item.keys():
+            batched_spec = torch.stack(
+                [torch.tensor(item['image']), torch.tensor(item['cond_image'])], dim=0
+            )[:, None] # (2, 1, H, W)
+            masked = self.masking(batched_spec).numpy()
+            item['image'] = masked[0, 0]
+            item['cond_image'] = masked[1, 0]
+        elif 'image' in item.keys():
+            inp = torch.tensor(item['image'])
+            item['image'] = self.masking(inp).numpy()
+        else:
+            raise NotImplementedError()
+        return item
+
+
+class TimeMasking(object):
+    def __init__(self, time_mask_param: int, iid_masks: bool = False):
+        super().__init__()
+        self.masking = torchaudio.transforms.TimeMasking(time_mask_param, iid_masks)
+
+    def __call__(self, item):
+        if 'cond_image' in item.keys():
+            batched_spec = torch.stack(
+                [torch.tensor(item['image']), torch.tensor(item['cond_image'])], dim=0
+            )[:, None] # (2, 1, H, W)
+            masked = self.masking(batched_spec).numpy()
+            item['image'] = masked[0, 0]
+            item['cond_image'] = masked[1, 0]
+        elif 'image' in item.keys():
+            inp = torch.tensor(item['image'])
+            item['image'] = self.masking(inp).numpy()
+        else:
+            raise NotImplementedError()
+        return item
diff --git a/foleycrafter/models/specvqgan/onset_baseline/demo.ipynb b/foleycrafter/models/specvqgan/onset_baseline/demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..98889a002e251dcbc0dc5fd2d4e81f2a8b0bc7f2
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/demo.ipynb
@@ -0,0 +1,352 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Change audio by detecting onset \n",
+    "This notebook contains a method that could change the target video sound with a given audio."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Load packages"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 118,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import IPython\n",
+    "import os\n",
+    "import numpy as np\n",
+    "from moviepy.editor import *\n",
+    "import librosa\n",
+    "from IPython.display import Audio\n",
+    "from IPython.display import Video"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 119,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Read videos\n",
+    "origin_video_path = 'data/target.mp4'\n",
+    "conditional_video_path = 'data/conditional.mp4'\n",
+    "# conditional_video_path = 'data/dog_bark.mp4'\n",
+    "\n",
+    "ori_videoclip = VideoFileClip(origin_video_path)\n",
+    "con_videoclip = VideoFileClip(conditional_video_path)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 120,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<video src=\"data/target.mp4\" controls  width=\"640\" >\n",
+       "      Your browser does not support the <code>video</code> element.\n",
+       "    </video>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.Video object>"
+      ]
+     },
+     "execution_count": 120,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "Video(origin_video_path, width=640)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 121,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<video src=\"data/conditional.mp4\" controls  width=\"640\" >\n",
+       "      Your browser does not support the <code>video</code> element.\n",
+       "    </video>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.Video object>"
+      ]
+     },
+     "execution_count": 121,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "Video(conditional_video_path, width=640)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 122,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# get the audio track from video\n",
+    "ori_audioclip = ori_videoclip.audio\n",
+    "ori_audio, ori_sr = ori_audioclip.to_soundarray(), ori_audioclip.fps\n",
+    "con_audioclip = con_videoclip.audio\n",
+    "con_audio, con_sr = con_audioclip.to_soundarray(), con_audioclip.fps\n",
+    "\n",
+    "ori_audio = ori_audio.mean(-1)\n",
+    "con_audio = con_audio.mean(-1)\n",
+    "\n",
+    "target_sr = 22050\n",
+    "ori_audio = librosa.resample(ori_audio, orig_sr=ori_sr, target_sr=target_sr)\n",
+    "con_audio = librosa.resample(con_audio, orig_sr=con_sr, target_sr=target_sr)\n",
+    "\n",
+    "ori_sr, con_sr = target_sr, target_sr"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 123,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def detect_onset_of_audio(audio, sample_rate):\n",
+    "    onsets = librosa.onset.onset_detect(\n",
+    "        y=audio, sr=sample_rate, units='samples', delta=0.3)\n",
+    "    return onsets\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 124,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhEAAAFZCAYAAAAmfX2OAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAvt0lEQVR4nO3deZwU1bn/8e8zKwgzw74JCLggRhZFBXHXQBSjicYbY4wJcUvivbkxyS8aEi9qjEGzqImGmESNEePV5CbGuCGKYlQ2QUFQEVBANlkEBmQbZub8/uhuaHq6Z7qru7qmqz/v16teTFedqn44M9399KmzmHNOAAAAmSoJOgAAAFCYSCIAAIAnJBEAAMATkggAAOAJSQQAAPCEJAIAAHhCEgEAADwpCzqAXDMzk9RL0vagYwEAoABVSVrr0phIKnRJhCIJxOqggwAAoID1lrSmpUJhTCK2S9KqVatUXV0ddCwAABSMbdu2qU+fPlKarflhTCIkSdXV1SQRAAD4iI6VAADAE5IIAADgCUkEAADwhCQCAAB4QhIBAAA8IYkAAACekEQAAABPSCIAAIAnJBEAAMCTvCQRZnaNmS03s91mNs/MTmmhfKWZ3WpmK81sj5m9b2aX5yNWwIs01qkBgNDxfdprM7tY0l2SrpH0mqRvSHrWzI5yzn2Y4rS/Suou6QpJyyR1y0esgBd19Y367N2v6NCu7fW7rwwPOhwAyJt8fDB/T9L9zrn7oo+vNbPPSPqWpPGJhc3sbEmnSRrgnNsc3b0iD3ECnry+YrOWrP9ES9Z/EnQoAJBXvt7OMLMKScMlTU04NFXSqBSnnS9prqTrzGyNmS0xs1+aWdsUz1FpZtWxTZF10IG84U4GgGLld0tEF0mlktYn7F8vqUeKcwZIOlnSbkkXRK8xSVInScn6RYyXdGMuggUAAOnL1+iMxO9qlmRfTEn02KXOuTnOuWcUuSUyLkVrxERJNXFb79yEDAAAmuN3S8QmSQ1q2urQTU1bJ2LWSVrjnKuN2/euIolHb0lL4ws75/ZI2hN7bGZZhgwAANLha0uEc65O0jxJoxMOjZY0I8Vpr0nqZWbt4/YdIalR0uqcBwkAADzJx+2MOyRdaWaXm9kgM7tTUl9J90qSmU00s4fiyj8i6WNJfzKzo8zsVEm/kPSAc25XHuIFAABp8H2Ip3PuMTPrLGmCpJ6SFkka65xbGS3SU5GkIlb+EzMbLeluRUZpfKzIvBE3+B0r4IVL2b0HAMItLxM4OecmKTLCItmxcUn2LVbTWyAAAKAVYe0MAADgCUkEAADwhCQCAAB4QhIBAAA8IYkAAACekEQAWWIBLgDFiiQCAAB4QhIBAAA8IYkAAACekEQAAABPSCIAAIAnJBEAArG3oTHoEABkiSQCyKG1W1mtPh1vrd6qI254Vr9+YWnQoQDIAkkEkKX4aSIemrkyZTns962H35Bz0p0vLAk6FABZIIkAcmjJ+u1Bh1AQ1tBiA4QCSQSQQ29+uCXoEAAgb0giAACAJyQRYbNjh2QW2XbsCDqa4rNpE3WfKf5miw+/89AgiQAAAJ6QRAAAfPfW6q1aV0uH2rApCzoAoNA51gIHmrVsw3adf89rkqQVt50bcDTIJVoigBzaW1oedAhAqzN/VW3QIcAnJBFADn1SeVDQIQCtTmMjrXVhRRIBAPDVLU+/E3QI8AlJBADAV9t31wcdAnxCEgEAADwhiQAAAJ6QRABZossYgGJFEgEAADwhiQAAAJ6QRAAAAE9IIgAAgCckEQAAwBOSCAAA4AlJBAAA8IQkAsgWE0UAKFIkEUCWHFkEgCKVlyTCzK4xs+VmttvM5pnZKWmed5KZ1ZvZfJ9DDAXnnFZt3cVHGgAgL3xPIszsYkl3SbpV0jGSXpH0rJn1beG8GkkPSZrmd4xhcdcLS3XK3bP165MuCToUAEARyEdLxPck3e+cu885965z7lpJqyR9q4Xzfi/pEUkzfY4vNH49bakk6a6TLw04EgBAMfA1iTCzCknDJU1NODRV0qhmzvu6pEMl3ZzGc1SaWXVsk1SVRcgAACBNfrdEdJFUKml9wv71knokO8HMDpd0m6RLnXP1aTzHeEm1cdtqz9ECyLtNO+qCDgGAR/kanZHY18+S7JOZlSpyC+NG59ySNK89UVJN3NY7iziBjDl6smalkQoEClaZz9ffJKlBTVsduqlp64QUuRVxnKRjzOye6L4SSWZm9ZLGOOdejD/BObdH0p7YYzPLUegAAKA5vrZEOOfqJM2TNDrh0GhJM5Kcsk3SYEnD4rZ7Jb0X/Xm2L4ECWdhR1xB0CAAQCL9bIiTpDkmTzWyuIiMtrpbUV5HkQGY2UdLBzrmvOucaJS2KP9nMNkja7ZxbJKAVamykOT4rVB9QsHxPIpxzj5lZZ0kTJPVUJEkY65xbGS3SU5GkAgAAFJB8tETIOTdJ0qQUx8a1cO5Nkm7KeVAAACArrJ0BAAA8IYkAssQCXACKFUkEkCWmOcgO1QcULpIIAADgCUkEAADwhCQCAAB4QhIBZIk+Edmh/oDCRRIBAAA8IYkAAOTNhu27gw4BOUQSAWSJ1nggfbtYsC5USCJC6q0ehwUdAgAg5EgiQur8r90VdAhFw9EzMCvM+AkULpIIAADgCUkEkGONsqBDAFot4/URKiQRQI692m9Y0CEAQF6QRAA5VldaHnQIBYUuJUDhIokAssRnIIBiRRIBZIssAkib0SUiVEgiAATq/xZ8FHQIyKNG7l+FCkkEkGMb23cMOoSCcsfLK4IOAXn0zEKSxjAhiQByjCGeQGqsnREuJBFAlphxEUgf80SEC0kEkGNGUgGgSJBEAAAAT0gigCwldjZ3NNcCKTHEM1xIIgAAgCckEUCW6AEBpO+T3fVBh4AcIokAAOTNY3NXBR0CcogkAsgxx01fAEWCJAIAAHhCEgFkiXYHAMWKJALIEh0rARQrkogQ29vQGHQIRYFFCQEUK5KIENtTTxKRD4lrZ5BTACgWJBFAlpq0RDA6A0CRIIkAsvTc2x8FHQIABCIvSYSZXWNmy81st5nNM7NTmil7oZk9b2YbzWybmc00s8/kI07Aiw827gg6BAAIhO9JhJldLOkuSbdKOkbSK5KeNbO+KU45VdLzksZKGi7pJUlPmtkxfscKeOHoWQlkZGddQ9AhIEfy0RLxPUn3O+fuc86965y7VtIqSd9KVtg5d61z7ufOudedc0udcz+StFTSeXmIFchYkxSCpCJjS7qk+k6BMLpk8vygQ0CO+JpEmFmFIq0JUxMOTZU0Ks1rlEiqkrQ5xfFKM6uObdGykLR5596gQyhaDY1Ou/fybStdjw0ZE3QIyKMFa7cHHQJyxO+WiC6SSiWtT9i/XlKPNK/xfUntJP01xfHxkmrjttWZhxlO//4gad4Fv5lp7B/m6ugbn9OOPaxYCCC88jU6o8kguCT7mjCzSyTdJOli59yGFMUmSqqJ23p7DxPIXLKljd/buEP1jU4LVm3Nf0AFaEdF26BDAOBBmc/X3ySpQU1bHbqpaevEAaIdMu+X9B/OuRdSlXPO7ZG0J+48z8ECXmyntSFrjw79jI5ds1hfDDoQABnxtSXCOVcnaZ6k0QmHRkuakeq8aAvEg5K+7Jx72rcAAR84luRqVkNj8kbI68Z+J8+RoLVwzumXz72nKYvWBR0KMuR3S4Qk3SFpspnNlTRT0tWS+kq6V5LMbKKkg51zX40+vkTSQ5K+I2mWmcVaMXY552rzEC+QO+QTTWzcvqflQigqLy7eoHteWiZJWnHbuQFHg0z4nkQ45x4zs86SJkjqKWmRpLHOuZXRIj0VSSpivhGN67fRLebPksb5HS8AIL82kFgWrHy0RMg5N0nSpBTHxiU8Pj0PIRUF42twICaMSToFCgCEDmtnAD4ikQMQZiQRIZa4RDXQ2tU1sHw9UEhIIgAfMeI4M0f87N/M9AkUEJIIAHnVUgvZojUMwipmdfW0RhUSkogQ43588NZs2RV0CEVr994G/eTJdzRj2aagQ0ES0w49Pun+MXe+nOdIkA2SiBCjT0Twvv+3BZr5/sdBh1GU7n91uR54bbm+fN/soENBEldcdGPS/Ss+3pnnSJANkgjAZ3+buyroEArKRffOlMvBcuqrNvNhVCgWr9sWdAjwiCQCQKvz0nup1ttDGP155sqWC6FVIokIMfpEoFBt3rE36BAApIEkAkCr89fXuQUEFAKSiBD78TNLtHoL94WDRvfWA6XT3WHOis2a8T6jKoDWjiQi5L758LygQwA8WbEpuwSYib5av3oL7iOoMcWS9MgMSUTIvffR9qBDAA6Q7oc7SUD4PXjc+YE8781Pvq2RE6dp8466QJ4/TEgiAACBWN++UyDP+6fXVmjD9j16cMaKQJ4/TEgiisDPpyzW/875MOgwipZzTtt2M9ogU3dPW0q9hZyT5WROEK9o7MoeSUTI7W1wmjT9fY3/x8KgQyla/5y/VkNumqo7n18SdCgFZW3tbv3kyXc8n7+udncOo4Ef7jvhAp05aU5gz1/CPbOskUQAefLraUuDDqHgzFu5xfO509/bmMNI4Jflm4NbX4YcInskEQDyKpPW6/pGVnRE+v4+b7XOvuvfzU55Hn+LjBwieyQRRSzIe5FAOlYF+C0Vhef7f1ugxR9t1w3/XJSyzG9fWrbvZ1oiskcSUaRun7JYJ9/+kj7+ZE/QoRSV9du4Tw+0pK6+UVuyGH65q64h5bH1cX1ldu+lpStbJBFF6nfT39earbt0/6vLgw6lqPzg/94KOgQk4ZzTsg2fqIEJiFqFI254Vsfc8rzW1ea2JerxN1frn/PX7nt8T1yrBLwhiShC8a0PvGfm1/JNnwQdQlFqaHTN3r57aOZKffqOl/Xdx+bnLyi0yGvnWJdisvnvPrYgm3CQBElEEdmxp16SdF9c68ML764PKhwgLbloHTj0R89o3J9eT3k89o30XwvWpiyDwuWc05V/nht0GKFEElFETvvFdElSY9w3smUb+GacT3QUzNyKj3dozvLNzd7nTsfLSxjyWWj2NjSqdmfmE44t3fCJLn/wdT3+5mpJ0rvrtvOFySckEUVk0yd75JzLqsMSkK1M2xXO+tXL+uLvZ+qqh/z7Jrk7ywQF/pjwxNsa+pOp+vmUxapvSL8T5Nade/Xi4g37bl/Q18U/JBFF5t9LN+nDZsZQA63Vq8vSXxp80ZrapPt3702eLGyP3uqTlNGHFfJj0vT3mbq/lSKJKDJX/vl1zfpgc9BhFLV+P3xaD81cEXQYoXbNX95Iuv/I/5mix15v/sNobhazZMI/ryxNP4mM9/ibq/URQ6t9QxJRZPY20KzXGkx44m3trKtvuWAI5WN+nz31qW9PXP/35teRaWQStlZp1gcfezrvu48t8PVWWLEjiYA+2EjnyiBs8dBhDOlZv635SdT6/fBpnfLzF/Xmh1u0duuBnV3pbNw6bdtdr7/MXpn6lhQrvgaCJALauosXXxBOuu3FoEMoOL99aZmWrN/ebJnGNDvRrdq8SxdMmqGrJx/4LXXCE2/r6ug3149qd6uW10er8ePHF+naR+dLisx38/KSjft+3xOfXRxgZMWLJAKau4I+EkG58YnIHP8Pz1qpyx98PeW3LET84rn3NObOf2tzMyOMJj77bkbXXLRmW5N9U99Zr8076jRy4jQNvXlqxnHCP1Pe/kirt+zU8J++oK89MEf/eHONJGn+h1s9XW/Fph05jK74kERAP3uGDD4of565UhOeWKQb/rlILy7eoL/MDn8P9A82Zv+mfewtz+sHf1ugZxaua3JszvLcJMWL1+1PLt5eW6sHX1vOUMFW4uTbX9r3883/elszlm3SO+uaJoPpuGDSa7rj+SW5Cq3okERAknT5g69ryqJ1aTcFI3cemrly38/FcF936Ybmb0ek62/zVuuav7yhhatrNe5Pc3T3tKWRAzlamvG19/ePBjj3N6/qpiff0YMzVuTk2sid7Xvq9eX7Zns+f8vOvfrNtKWsauwRSQQkSS8u3qBvPvyG/jZvVdChFLW7Xliq6e9t0N4Qz1WQ6zz1vHte1fT3NupXzy/RI7M/1IJVW3Ny3d++9H6Tfbc89Y5mffCx3viQYaBh89e5vPd5QRKBA1z/94XasG23npi/RlMWfRR0OEVp3J9e10+efEdvrd4aymGgH+V4ZcZ4P3q8+eGbufClP8zShZNm6JWlG/Xi4vW03oXEH19hRWMvyoIOAK3PCT+btu/nZbeeo7JScs18mzxrpSbPWnnAvjk/OkvdqtsEFFHu/HnGypYLFYDL7p8jSbr9C4N18fF9A44G2WJorzd8OoSEX9+G/jp3tf7rkTe0iqmyA3fCz6Zp1MRpmv7eBp31q+n64u9nasP2wpuJry5kt2qu//tCPf7maoaCoijlpSXCzK6R9ANJPSW9Lela59wrzZQ/TdIdkj4laa2knzvn7s1HrIXqfZ8mjIo1Dz/11v5e8Bcec7B+OPZIPTpnlZ6Yv0YPjDteh3RuJ+eclm/aof5d2sly1LkNB1pbu3vfktbvb9yhE26dpp9dMFgDurbTyAGdJUXWfnjgteUqMdOryzZp6fpPNP0Hp6ucFiXfxBZ6evyaUdpT36hj+3ZURVl69f3Y6x/qrdW1mnDeUdq4fY9++9L7uvbTh6t7CFqdJGk5QyhDzfzukWpmF0uaLOkaSa9J+oakKyUd5ZxrMp7NzPpLWiTpj5J+L+kkSZMkXeKc+3saz1ctqba2tlbV1dU5+3+0dlMWrdM3H06+XkCQzhvaS53bVaiirETnDemlow+ulpmprr5RZSWmHXX1alteWrC3TNbV7tKJEwtn0qh2FaXqUlWplR/vVLuKUp01qLs+O6SnhvXpoOq25aqI/h7M5Fsi2O+HT/ty3dbsc8N66SefO1p19Y0ykzq3q5CZ6e/zVuv7f1uQ8rwHv368Rg7orDblpXmMNnfq6ht1xA3PBh1G2lbcdm7QIQRu27ZtqqmpkaQa51yL42bzkUTMlvSGc+5bcfvelfRP59z4JOVvl3S+c25Q3L57JQ11zp2YpHylpMq4XVWSVucyiXht2SZdet9snXJ4F0mSc5KTU0mSN9lYdTo5lZaUpBw25Nz+N+rGRnfAm7YpMvf/rA82q2dNG/WoaaP2lZFGo111Ddq2e6+27apnUZmoPp3aatXm/Z31ykpMQ/t0UKd2FWpsdNqys07lpSVavmmHjuxZrRKL1H9piakkWu8mqcRMJSWSybTi4x0qLy3R/Bz19C9GR/ao0uKPcjOcE8md0K+TykpNjc5F3pecJIski/HL5MTeqWJvWfsfW9LjsT1Ny8ceJxxP2C+Tnn6r6Rweheakwzpry469qmpTprLSWF1ZrkYRNyv2ObPvZyd9uHmnyktNvTq0VWnJ/iCuPGWATjuia06eN9MkwtfbGWZWIWm4pNsSDk2VNCrFaSdGj8d7TtIVZlbunEu88The0o3Zxtqc2OpxXleRy8a62t1aV0uy0Jz4BEKS6hud5qVYiXHD9o35CAkSCUQezGG2WV+9tszbol9+W/HxgX3UzhvaK6BI/O8T0UVSqaT1CfvXS+qR4pweKcqXRa+XmN5OVKT/REyVpNVegk3lipP768geVdq8o04d25Xvy0Qboi0IiWLZeEOj2/fNNiaWWZr2f3soKYlrwdjXkiEtWlOrqjZlOqRzO8Va+9du3a3VW3bqiO5VuvnJd3L53yxYPx4babRatWWn6hudule1Ua8ObVTf6LSrrkFbd9appMTUuX2lKkptX8uDc5EVG50i/za6yM5GF0lEatqW691123T/q8t1+xcGa+P2PfrlVGa2S2ZA13Yae3RP3fPSMo0b1U9H9qjSP+ev0eotu7R6yy6d0L9TzmaShFRRVqJBPar09ZP6yyzaimaRlrXY331ZqSm+ITT2Y6x11CUciP/We2D55o8r4Xqx44+/uaagW/LOHdJTowd1lxSZ0Kq6TeTjMtZCYM2sRxt/vKWyzV0jsbV7/bbIl8qhvTscsP+Yvgc+zidfb2eYWS9JaySNcs7NjNv/Y0mXOeeOTHLOEkl/cs5NjNt3kqRXJfV0zjU7eUGx9olYs3VXq1nQ6QefGagR/TtpYI8q7drboC7tKlVSkof2vwA0NDod+qNngg6jRYMPrtET/3mSSkpMzjmZ2b4Pk3x3gi2WPhGXnNBX3x9zhNpXlqnELGVHy1T1MfW7p+qI7lV+hpgXhdRviD4Rrex2hqRNkhrUtNWhm5q2NsR8lKJ8vaTW2bbUCnSvqmy5UJYuP6m/HnhtuX52wWBdNLz3vjfFvQ2NKXv+V7Up9z2uIJUGkByNHdxDzyzcn0vPnzBaHQ6qaFLOOaf/nbNKRx9crSFx31z23QdnBE3OfOHY3rr5c59Su4rSjOv1hnMH6adPv6uZ48/Ujj0Nuuz+2br1gqNDkUBIYr2RkPM1iXDO1ZnZPEmjJT0ed2i0pCdSnDZT0nkJ+8ZImpukPwSi/BrdMKhntd6NLmwz4byjNOG8o5qUYehg/iy+5Wy1KS/VGx9u0T0vLtMN5w5KmkBIkSThyyOYBCkffvXFoZ7PvfKUAbrylAH7Hs8cf1YuQmo1CnVkCdKTj3f/OyRdaWaXm9kgM7tTUl9J90qSmU00s4fiyt8r6RAzuyNa/nJJV0j6ZR5ihSL3Alfcdq5W3HauDu4QjrHqYTD7R2fte0M+tm9HPTDueA3o2j7gqDJ3WLfCizmVAV3aaemt5wQdRqtW1YaJkcPM99+uc+4xM+ssaYIik00tkjTWOReb+7anIklFrPxyMxsr6U5J/6nIZFP/nc4cEchen05t9dsvH7vv8S2fP1p1DQs1btQhAUZVfH5+0RDV7tyrJ99aqx7VbXTR8N6hmXyoT8e2BT3F8LhR/bRt915dflJ/fapXNbeFWpBsKHxr9Ox3Tgk6hIKUlxTROTdJkQmjkh0bl2Tfy5KObVoafunTqa1+86VjdEzfjgfs71nTVg9dfkJAURWfgd2r9Mx3TtnX1+KqUwe0cEbhOaii8L6ZPnzFCF09ea6uOLm/vj9mYNDhFJTWfruz40HluvDY3hrUs3g64udS6/7tIm9eue7MJgkE8u/i4/sE0lkzny4+vk9ur3dcHy2fODan15T2T6D0h8uG6+TDu2jhTZ8hgWilpn73VM/nTv/BGfqfzzbt64X0FN5XAiCkqirLdMExBwcdhu/aVuSuo91Vp/TXj8+NfAA8dvVI/XXuah3Xr6PG/yP7JcHn/88Ybd5Zp/5d2kkKZiQOkvvqiYeoZ01b3T5lsSRlNZKlpm24R5D5jSQCCNDVpw7QH/79gX71H0P1uWG9CnYNkXwbc1R3TX1nvb4ycn9fnREDOmvEgM5au3VXM2emr+agctUcxAdMa/S1Uf00oEs7Hd+v476OuhVlJaqrD9cKsYWAJAI65+hUk4fCb+PPOVLfOHWAOrf3f56PsPjp54/WpSP6ak99Y9Lhg7noyPfyD07P+hrwT0VpicxMx/XrtG8f7UTB4GsPDhiNgfwys6JLILJ5s5/w2aP0lZGHyMxSzj/Qvbrl+lwwYYxW3HauetU0HfHyynVn6JDO7bKIEkGIb5VC/pBEILRTUrd2NyaZuAvNS6dTppmpbQsTHFn0ne+Rq0bq0oQJufp0OshzfPBX/y7tdGSPKvXq0LbJsevPbrKKAvKA2xlAQOiol7l2lem9ZXVuX6HVW1L3jaiOTsfer0s73XrBYL2ydJM+3LwzZXkEb9b4s9QtOr1/si8+iWuTXDqir2ralmvS9Pf17TMP090vLstLnMWGJAIISKFMwtNaDD64Ju2yqar2vZ+erYoknVe7tK8giWjleiS59dScnjVt9F9nHq7vjT5CMz9g2SW/cDujiHRqV6EHv368KlOsJoj8+uyQnkGHEIhMc6envn2yfnPJMZp8RfqTniVbenlg9ypVliVfIOvOi4fpxAGdmVitlTqyR+ZDOGOTmpWVlnhaihvp4dOkiDx8xQidPrCbXrnujKBDgYp5YaLM3tDblJfo/KG9Ui40lkyyzpWPXj0yZflDOrfT/149Uqce0TWj2NB6xS8+F583tinf/7HHLcXskUQUkdgLqVvcGgzH92OWynzqEh2JUVlWkrRZvRhk2hIxoEvmC3bd8cVhBzxecOMYdWyXfhKC1mXMpzIbhn7GwK4HJOllccnCaXGJopcWDhyIPhFFrhDXMShkgw+u1qRLh8useEfFdM1wSKuXekocYcGshIXtP884NKvzj+vXSSMHdFL/Lu20eUfdvv2/uMj7Eu6IKM6vQkUq/hvgpwd1lyRdcXL/gKIpThcN76O2FaVFfCsj0jcHSMf8CaO14MYxqixL7/Uy/pwj1bldRZO1MEpLTI9efaImXjjkgP1H9WLRrWzxNbSIHNFtf9PdHy4brk2f7Dng1gb8M3P8mdq6cy/Np0AGMukHI0nfOO1QXX3qgGaXZ6eTZW7RElEklk8ce0CzcEmJkUDkUc+athrUs7rZNzegGJ3Qv1PS/ded7W3F1JZeY7wEc4skokjw4QWgNfr9V4Y32de2vFTXnH6YL8/HW2FucTsj5I7sUaWhvTsEHQaQsdMHdtV5Q3plfZ0Li2B59bBpbjguWheSiJCbcu2pQYcAHMClWe73lw1Pu0Nds/jmWXCG9ung27XpE5Fb3M4A0CrlJIGQdExf5kJpzfJ+e4EcIqdoiQAQStO+f5peX75Z/3Fcyyt/onh899OH65mF63TFSQxvzwWSiBD79KBuQYcABObQru11aNfMZ7tEuB3WrUpLfnqOyot0xthcoxZDjBcJADTFe2PuUJMA8sq5dLtWIuwYbln4SCJCjBdo61BVyV1DIJnXrj8z6BCQJZKIEBt8cIegQ4CkEw/tHHQIrQoTn0GSXvp/p6tXh7ZBh4EskUSE2KCerNPQGqSa1hcoZv27tAs6BOQA7awhtn13fdAhFL2JFw7WRcN7Bx1GwTn5sC5BhwAgDSQRIUarcfAuOaFv0CG0Oul0rOzcnuXCgULA7YwQKyGLQIFivRegMJBEhBgpBArVZSceEnQIyBPWsihsJBEhRi94FComAwq30wd2DToE5Aiv1BAjhwDQGg3qWR10CMgRkogQI4dAa5SsW+WRPRiODBQikogQ43YGCsXx/ZhLo5jEvzNVlvMxVMj47YVYCTkEgFauTXnpvp+H9K7RU98+OcBokCnmiQgxhniiULikNzkQVolvTQtvGqON2/doAEu3FxxfWyLMrKOZTTaz2ug22cw6NFO+3MxuN7OFZrbDzNaa2UNm1svPOMOKHu4AWqPEYZ1VbcpJIAqU358yj0gaJuns6DZM0uRmyh8k6VhJt0T/vVDSEZL+5WeQYVXdloYmFJ5xo/oFHQJ8RiNpePj2KWNmgxRJHEY652ZH910laaaZDXTOvZd4jnOuVtLohOt8W9IcM+vrnPvQr3jDKI3ZhYG8a1/R9G3nuEM66eFZkZf3WYO65TskAB752RJxoqTaWAIhSc65WZJqJY3K4Do1iowK25rsoJlVmll1bJPEWDHk1YgUq3SW0bM1qZIk9XL+0P13LJnBMPz4DYeHn0lED0kbkuzfED3WIjNrI+k2SY8457alKDZekcQktq3OPFTAu7LS5G+J/7gmk1y5uCVLLBBi3M8IjYyTCDO7ycxcC9tx0eLJGtQtxf7E5ymX9Gg0xmuaKTpRkdaK2Ma6ywjcr780TENYRApIihQiPLz0ibhHkQ/35qyQNERS9yTHukpa39zJ0QTir5L6SzqzmVYIOef2SNoTd24LoRUPqgKFiL/b8ON3HB4ZJxHOuU2SNrVUzsxmSqoxsxOcc3Oi+0Yo0lowo5nzYgnE4ZLOcM59nGmMAIDWi07f4eFbnwjn3LuSpkj6o5mNNLORkv4o6an4kRlmttjMLoj+XCbp/yQdJ+lSSaVm1iO6VfgVK5BrJx3WJegQgFaLTsfh4fc8EZdKWihpanR7S9JlCWUGKtI6IUX6M5wf/Xe+pHVxG73U0ColG03QpX1lAJGEw6FMOhR6o0iyQ8PX2Yicc5slfaWFMhb38wrR5wYoSq9cd4a27d6rHjVtgg4FPuteTZIdFkxpGGKMt88P1n3IjT6dDgo6BOQJHeDDg8UVgBz76omHBB0CAOQFSQSQpeMOOXDGyrYVpSlKAkC4kEQAWRrSu+aAxwxfA1AsSCIAAIAnJBEhdnh3hsoBAPxDEhFibcq5Nx8Ex/0MoFmMzQgPkoiQqijjV5sviTkDOQSAYsEnTUgtvGlM0CEUrcpyXlYAigPvdiFVWcatjKD0ZdIkAEWCJALIsZq2rBUHoDiQRAA5Nuao7kGHALRq1W3Lgw4BOUISAeRYCcscA81qX8myTWFBEgEAADwhiQCyxIhOAMWKJAIAAHhCEgEAADwhiQAAAJ6QRAAAAE9IIgAAgCckEQAAwBOSCAAA4AlJBAAA8IQkAgAAeEISAQAAPCGJAAAAnpBEAAAAT0giAACAJyQRQJYGH1wTdAhAwfjR2CODDgE5RBIRIpeO6Bt0CEWpW1Vl0CEABeP0gd2CDgE5VBZ0AMidn3zuaNU3OA3v1zHoUAAARYAkIkRKS0y3XzQk6DAAAEWC2xkAAMATkggAgWpbXhp0CAA8IokAEKjnv3dq0CEA8IgkAkCgenc8KOgQkEdtymh5ChNfkwgz62hmk82sNrpNNrMOGZz/ezNzZnatf1ECAPKlb2eSxjDxuyXiEUnDJJ0d3YZJmpzOiWb2eUkjJK31JzQAAJAN34Z4mtkgRRKHkc652dF9V0maaWYDnXPvNXPuwZLukfQZSU/7FSMAAPDOz5aIEyXVxhIISXLOzZJUK2lUqpPMrESR1opfOOfebulJzKzSzKpjm6Sq7EMHAAAt8TOJ6CFpQ5L9G6LHUrleUr2k36T5POMVSUxi2+oMYgQAAB5lnESY2U3Rzo7NbcdFi7tkl0ixX2Y2XNJ3JI1zziUtk8RESTVxW+/M/kcAAMALL30i7pH0aAtlVkgaIql7kmNdJa1Pcd4pkrpJ+tDMYvtKJf3KzK51zvVLPME5t0fSntjjuPMAAICPMk4inHObJG1qqZyZzZRUY2YnOOfmRPeNUKS1YEaK0yZLeiFh33PR/X/KNFYAAOAf30ZnOOfeNbMpkv5oZt+I7v6DpKfiR2aY2WJJ451zjzvnPpb0cfx1zGyvpI+aG80BAADyz+95Ii6VtFDS1Oj2lqTLEsoMVKR1AgAQQqyPEl6+LgXunNss6SstlGm2E0OyfhAAgMLRprxEu/Y2BB0GfMDaGUCW0h1GBBQrOryHF0kEAADwhCQCQN4N6lkddAgAcoAkAgDgK25mhBdJBIC840MFCAeSCAAA4AlJBADAV/deNlxVlWX6+UVDgg4FOebrPBEAABzfr5MW3DhGJSXcyAobWiKALJXyxgi0iAQinEgiAACAJyQRAADAE5IIAADgCUkEAADwhCQCyKE25byk0nHNGYdKks4d3DPgSABkgyGeAPLus0N66di+HdWjuk3QoQDIAkkEgED06tA26BAAZIm2VwAA4AlJBJBDJx3aJegQACBvSCKAHDqHjoIAighJBJBDTOwLoJiQRAAAAE9IIgAAgCckEQAAwBOSCAAA4AlJBAAA8IQkAgAAeEISAQAAPCGJAAAAnpBEADnUuyOLSgEoHqziCeTAX64coWUbPtGIAZ2DDgUA8oYkAsiBkw7ropMOY/EtAMWF2xkAAMATkggAAOAJSQQAAPCEJAIAAHhCEgEAADwhiQAAAJ74mkSYWUczm2xmtdFtspl1SOO8QWb2r+g5281slpn19TNWAACQGb9bIh6RNEzS2dFtmKTJzZ1gZodKelXSYkmnSxoq6RZJu/0LEwAAZMq3yabMbJAiicNI59zs6L6rJM00s4HOufdSnHqrpGecc9fF7fvArzgBAIA3frZEnCipNpZASJJzbpakWkmjkp1gZiWSzpW0xMyeM7MNZjbbzD6f6knMrNLMqmObpKqc/i8AAEBSfk573UPShiT7N0SPJdNNUntJP5R0g6TrFWnN+IeZneGceznJOeMl3Zi4c9u2bV5iBgCgaGX62ZlxEmFmNynJh3aC46P/umSXSLFf2t8y8oRz7s7oz/PNbJSkb0pKlkRMlHRH3OOekhb36dOnhRABAEAKVZJazCi8tETcI+nRFsqskDREUvckx7pKWp/ivE2S6iW9k7D/XUknJzvBObdH0p7YYzPbLqm3pO0txJipKkmrfbp2MaI+c486zS3qM7eoz9zzq06rJK1Np2DGSYRzbpMiH/bNMrOZkmrM7ATn3JzovhGSaiTNSHHtOjN7XdLAhENHSFqZZnxO0pp0ymbCzGI/bnfOca8kS9Rn7lGnuUV95hb1mXs+1mna1/KtY6Vz7l1JUyT90cxGmtlISX+U9FT8yAwzW2xmF8Sd+gtJF5vZVWZ2mJn9l6TzJE3yK1YAAJA5v+eJuFTSQklTo9tbki5LKDNQkdYJSZJz7nFF+j9cFz33SklfcM696nOsAAAgA36OzpBzbrOkr7RQxpLse0DSA37F5dEeSTcrrv8FskJ95h51mlvUZ25Rn7kXeJ1apAsBAABAZliACwAAeEISAQAAPCGJAAAAnpBEAAAAT0giAACAJyQRaTCza8xsuZntNrN5ZnZK0DHlm5mdamZPmtlaM3OJK6taxE3R47vMbLqZfSqhTKWZ3W1mm8xsh5n9y8x6J5TpaGaTzaw2uk02sw4JZfpGY9kRvdZvzKzCr/+7H8xsvJm9bmbbo6vV/tPMBiaUoU4zYGbfMrO3zGxbdJtpZufEHac+sxD9m3VmdlfcPuo0A9G6cgnbR3HHC68+nXNszWySLpZUp8ikV4Mk3SXpE0l9g44tz/VwjqSfSrpQkQXUPp9w/HpFpkq9UNLRiqyvslZSVVyZ3ykyz/unJR0j6UVJ8yWVxpV5VpFJxk6MbgslPRl3vDS678XoNT6tyDTndwddRxnW5xRJ4yR9StJQSU8pMrV7O+rUc52eJ2msItPkHyHp1uhr91PUZ9Z1e7yk5ZIWSLqLv1HP9XiTpEWKrGQd27oWcn0GXqmtfZM0W9LvEva9K2li0LEFWCcHJBGKrMy6TtL1cfsqJW2V9I3o4xpF3tAvjivTS1KDpM9EHw+KXntEXJmR0X0Do4/PiZ7TK67MlyTtllQddN1kUaddo//PU6nTnNbrZklXUJ9Z1WF7SUuiHzTTFU0iqFNPdXmTpPkpjhVkfXI7oxnRpp3hikzZHW+qpFH5j6jV6q9IRr2vnlxkddWXtb+ehksqTyizVpGsPFbmREm1zrnZcWVmSapNKLMoem7Mc4q82Ibn7r+Ud7Gp3zdH/6VOs2BmpWb2JUntJM0U9ZmN30p62jn3QsJ+6tSbw6O3K5ab2aNmNiC6vyDr09dpr0OgiyLNPolLl69X5JeNiFhdJKunQ+LK1DnntiQp0yOuzIYk19+QUOaA53HObTGzOhXo78TMTNIdkl51zi2K7qZOPTCzwYokDW0Uue14gXPuHTOLvXlSnxmIJmLHKnI7IxF/o5mbLemrirTsdJd0g6QZ0X4PBVmfJBHpSZwb3JLsg7d6SiyTrLyXMoXkHklDJJ2c5Bh1mpn3JA2T1EHSFyT92cxOiztOfabJzPpI+rWkMc653c0UpU7T5Jx7Nu7hQjObKel9SV+TNCtWLOG0Vl2f3M5o3iZF7hslZmbd1DRbLGax3sXN1dNHkirMrGMLZbonuX7XhDIHPE/0muUqwN+Jmd0t6XxJZzjnVscdok49cM7VOeeWOefmOufGK9IR8DuiPr0Yrsj/fZ6Z1ZtZvaTTJP139OfY/4U69cg5t0ORDo6Hq0D/RkkimuGcq5M0T9LohEOjJc3If0St1nJF/ij31VO0P8lp2l9P8yTtTSjTU5EeyLEyMyXVmNkJcWVGKNJfIL7M0dFzY8YosordvNz9l/wVHcp1jyK9sM90zi1PKEKd5oYpcp+X+szcNEmDFWnZiW1zJf0l+vMHok6zYmaVinSEXKdC/RsNurdqa9+0f4jn5dFf9p2K3Gs9JOjY8lwP7bX/jcRJ+m70577R49cr0ov4gugf9CNKPjRplaSzFBlWNE3JhyYtUKQ38UhJbyn50KQXotc4K3rNQhvqNSlaX6fpwOFebePKUKeZ1enPJJ0iqZ8iH363KtKSOJr6zFkdT1fTIZ7Uafr198voa76/pBGSnlRkSOchhVqfgVdqIWySrpG0QvuztFODjimAOjhdkeQhcXswetwUGb60TpFhQi9LOjrhGm0k3S3pY0k7oy+gPgllOkl6OPrC2hb9uUNCmb6KzKuwM3qtuyVVBl1HGdZnsrp0ksbFlaFOM6vT++Nepxuib5Cjqc+c1vF0HZhEUKeZ1V9s3oc6ReZl+Lukowq5Pi16MQAAgIzQJwIAAHhCEgEAADwhiQAAAJ6QRAAAAE9IIgAAgCckEQAAwBOSCAAA4AlJBAAA8IQkAgAAeEISAQAAPCGJAAAAnvx/ukPZLG1xBuEAAAAASUVORK5CYII=",
+      "text/plain": [
+       "<Figure size 600x400 with 1 Axes>"
+      ]
+     },
+     "metadata": {
+      "needs_background": "light"
+     },
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "from matplotlib import pyplot as plt\n",
+    "onsets = detect_onset_of_audio(ori_audio, ori_sr)\n",
+    "plt.figure(dpi=100)\n",
+    "\n",
+    "time = np.arange(ori_audio.shape[0])\n",
+    "plt.plot(time, ori_audio)\n",
+    "plt.vlines(onsets, 0, ymax=0.5, colors='r')\n",
+    "plt.show()\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Method\n",
+    "The baseline is quite simple, and it has several steps:\n",
+    "- Take the original waveform (encoded and decoded by our codebook) and detect the onsets to determine the timestamp of sound events\n",
+    "- (Optional) Assume we don't have original waveform, we can use Andrew's great hit model to predict sound from frames and detect onsets from it.\n",
+    "- Detect onsets of conditional waveform (encoded and decoded by our codebook) and clip single onset event from them as sound candicates\n",
+    "- For each onset of original waveform, replace with conditional onset event randomly and then generate sound"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 125,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_onset_audio_range(audio, onsets, i):\n",
+    "    if i == 0:\n",
+    "        prev_offset = int(onsets[i] // 3)\n",
+    "    else:\n",
+    "        prev_offset = int((onsets[i] - onsets[i - 1]) // 3)\n",
+    "\n",
+    "    if i == onsets.shape[0] - 1:\n",
+    "        post_offset = int((audio.shape[0] - onsets[i]) // 4 * 2)\n",
+    "    else:\n",
+    "        post_offset = int((onsets[i + 1] - onsets[i]) // 4 * 2)\n",
+    "    return prev_offset, post_offset\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 126,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "ori_onsets = detect_onset_of_audio(ori_audio, ori_sr)\n",
+    "con_onsets = detect_onset_of_audio(con_audio, con_sr)\n",
+    "\n",
+    "np.random.seed(2022)\n",
+    "gen_audio = np.zeros_like(ori_audio)\n",
+    "for i in range(ori_onsets.shape[0]):\n",
+    "    prev_offset, post_offset = get_onset_audio_range(ori_audio, ori_onsets, i)\n",
+    "    j = np.random.choice(con_onsets.shape[0])\n",
+    "    prev_offset_con, post_offset_con = get_onset_audio_range(con_audio, con_onsets, j)\n",
+    "    prev_offset = min(prev_offset, prev_offset_con)\n",
+    "    post_offset = min(post_offset, post_offset_con)\n",
+    "    gen_audio[ori_onsets[i] - prev_offset: ori_onsets[i] + post_offset] = con_audio[con_onsets[j] - prev_offset: con_onsets[j] + post_offset]\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 127,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhEAAAFZCAYAAAAmfX2OAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAuZklEQVR4nO3deZwU9Z3/8feHAQYEZkABuQTBKOIFCAEkgidRcTUem6gxum50NeomMWazkZjDn4nB1azxikkkiVGMCdmYS42IqKhRQEFQEBBEboThkBmu4Ri+vz+6Zujp6emjpquru+v1fDzqQU/Vt6s/86W659Pf+h7mnBMAAEC2WoUdAAAAKE4kEQAAwBeSCAAA4AtJBAAA8IUkAgAA+EISAQAAfCGJAAAAvrQOO4BcMzOT1EvS9rBjAQCgCHWStN5lMJFUySURiiUQa8MOAgCAItZH0rp0hUoxidguSWvWrFFFRUXYsQAAUDRqamp0xBFHSBm25pdiEiFJqqioIIkAACBAdKwEAAC+kEQAAABfSCIAAIAvJBEAAMAXkggAAOALSQQAAPCFJAIAAPhCEgEAAHwhiQAAAL6QRAAtlMEaNQDEe6UUkUQALTB14ccacueLenXpprBDAQra6i27NPxH0/Xwy8vCDgU5RBIBtMBXnnxH1bv36d9+81bYoQAF7e6pi7Vl5179ZNrSsENBDpFEAAACx52M0kQSAQAAfCGJAAAAvuQliTCzm8xshZnVmtlcMxuTouypZvaGmW0xs91mtsTMvpGPOAEAQOZaB/0CZnaZpPsl3STpDUk3SHrezI5zzq1O8pSdkh6W9J73+FRJvzSznc65R4OOFwAAZCYfLRG3Svq1c+5XzrnFzrlbJK2RdGOyws65ec653zvn3nfOrXTOPSnpBUlJWy/MrNzMKuo3SZ0C+j0AAECcQJMIM2sraZikaQmHpkkaneE5hnplX22myARJ1XHbWl/BAgCArATdEtFVUpmkjQn7N0rqkeqJZrbWzPZImiPpZ865XzVTdKKkyritT4siBgAAGQm8T4QncYSwJdmXaIykjpJGSbrbzD50zv2+yYmd2yNpT8OJzVoYKgAg15gnojQFnURsllSnpq0O3dW0daIR59wK7+ECMztc0h2SmiQRAAAgHIHeznDO7ZU0V9K4hEPjJL2ZxalMUnmu4gIAAC2Xj9sZ90mabGZzJM2UdL2kvpJ+IUlmNlFSb+fc1d7PN0taLWmJ9/xTJf2XpIfyECsAAMhQ4EmEc26KmR0m6fuSekpaKGm8c26VV6SnYklFvVaKdZbsL2m/pOWSbpP0y6BjBQAAmctLx0rn3COSHmnm2DUJPz8kWh2ysnrLLvXq3E6ty5jFHMVh1Zad6t25PdcsUOR4Bxe5Fxdt1Nh7X9E1j70ddihARv42f51Ou3eGbvzdO2GHAqCFSCKK3ONvrpQk/fPDzeEGAmRo0usfSYolwACKG0kEAADwhSSiyO3cuz/sEICs1Ozmmo0il3Z+QRQjkogiN2/1trBDALKyeuuusEMAkCMkEQAAwBeSCAAA4AtJBAAA8IUkAgAA+EISAQAAfCGJAAAAvpBEAAAAX0giAACALyQRJWzWR1t0/oOva97qT8IOJVKqd+1T7b66sMMoSq8t3aTzH3xdC9dVhx0KcswlmbByz/46bdu1N//BIGdIIkrY5Y/O0vvra3TFpFlhhxIZ23bt1eA7p2nEXdPDDqUoXf2bt/T++hp9+besShsFp987Q0PufFFVNbVhhwKfSCIioHbfgbBDiIz5a7ZJkmpqWR+iJWpq94UdAvLg4+pY8vDGclYhLlYkEUBAXLL2WwBNbNtF0lisSCKAgCz6uCbsEICicM/UD8IOAT6RRAAB4TYSkJnddEQuWiQRAADAF5IIAADgC0kEkEPPL9gQdgglwWRhh4CAbdmxJ+wQkAMkEUAOTZmzJuwQgIKUOFbpg43bQ4kDuUUSAQRkedWOsEMAgECRRAABmb1ia9ghFC166wPFgSQCCMjT76wNOwQACBRJBAAgcC8vqQo7BASAJAIAELi6A0wDX4pIIgAAeccw3tJAEgEAAHzJSxJhZjeZ2QozqzWzuWY2JkXZS8zsRTPbZGY1ZjbTzM7JR5xASxlfrgBESOBJhJldJul+SXdJGirpdUnPm1nfZp4yVtKLksZLGibpFUnPmNnQoGMFAACZy0dLxK2Sfu2c+5VzbrFz7hZJayTdmKywc+4W59w9zrm3nXPLnHPfkbRM0gXJyptZuZlV1G+SOgX0exS8VVt2hh0CkJWtO/eGHQKAFgg0iTCztoq1JkxLODRN0ugMz9FKscSguZl7JkiqjtsiOzh/Xx1LT6O40GMfKG5Bt0R0lVQmaWPC/o2SemR4jm9K6iDpj80cnyipMm7rk32YQG7QJQLIDP2HSkPrPL1O4tcNS7KvCTO7QtIdkj7nnEs6U4lzbo+kPXHP8R8lAADIWNBJxGZJdWra6tBdTVsnGvE6ZP5a0uedc9ODCQ8AAPgV6O0M59xeSXMljUs4NE7Sm809z2uB+K2kLzrnngssQCDHaAkDECX5uJ1xn6TJZjZH0kxJ10vqK+kXkmRmEyX1ds5d7f18haQnJH1d0iwzq2/F2O2cq85DvAAAIAOBJxHOuSlmdpik70vqKWmhpPHOuVVekZ6KJRX1bvDi+pm31Xtc0jVBxwu0BO0QQGZ4r5SGvHSsdM49IumRZo5dk/Dz6XkICQAAtBBrZ5QQx5B7FBmXfpAWgAJGEgEAAHwhiQAAAL6QRAA5xAhPAFFCEgHkkNHnHMgIc6qUBpKIEsJ7EsWGpAsobiQRQA7tZSVVICP7ea+UBJIIAEDeTX1/Q9ghIAdIIkoI80Sg2DBPRHTV7qsLOwTkAEkEAADwhSQCAJB3dKotDSQRQI4cOEDTPJCpeWs+CTsE5ABJRAlZsI6V0sM0Zc6asEMoOh9W7Qg7BIRgY02tlm7k/74UkESUkCUbtocdQqS9uGhj2CEUneWbdoYdAkKweuuusENAjpBEADniGB4DZIS3SukgiQAAAL6QRAAAAF9IIkoIzenhovZ94JqNJD6rSgdJBJAjqUZ41h1wuv6JOXpg+rL8BQQUqHQpxP3Tl+r6J+aojmHTBY8kooSkWlp36869eYwkml5buinlsWmLNuqn05fmMaIikOKaZVrk0vW72atTHr9/+jJNW7RRry9r/j2FwkASERF8IIeL+s/eAZq8S9Yz767PqNze/az0WehIIkoI9xkBAPlEEgEAAHwhiQAAAL6QRETEO6tZ7CZMe+u4t5stpnEHN2gLH0lECUk1OmPhupo8RoJEf523LuwQis5SkojIe389n1uFjiSihNCxsnDV7qMlIimuWaSwnxa8gkcSAQAAfCGJAAAUJNqpCh9JRESk6C6RUt0Bp2feXa9123bnNqCI8Vv/Uea3zvbXHdDf5q/Tx9Vcs0DQ8pJEmNlNZrbCzGrNbK6ZjUlRtqeZPWVmH5jZATO7Px8xIrkpb6/RV38/T5+5++WwQwEy8tgbK/X1P8zXmT95NexQgJIXeBJhZpdJul/SXZKGSnpd0vNm1reZp5RL2uSVfzfo+JDaG8s3hx0CSlkATTSvemuY7GaqcSBw+WiJuFXSr51zv3LOLXbO3SJpjaQbkxV2zq10zn3dOfeEpOp0JzezcjOrqN8kdcpl8ACAcDB4p/AFmkSYWVtJwyRNSzg0TdLoHL3MBMWSjfptbY7OW5Q279ijP769Rrv27m+0f/qijf5OyJs4J+gT0byqmlr9cc6aJouUzfjA3wqOjou2ZPB/WfhaB3z+rpLKJCX+BdsoqUeOXmOipPvifu6kiCYSj89cpUmvr5AkzVvTeIbKZVU7VFO7TxXt2oQRWuSZyCKS+dXrH+l7f10oqenkUs8v3KB9dQfUpoz+30ChCjqJqJeYTlqSff5O7NweSXsaThzhr3zxy+a+8H7Tlodde+qyTiL4JoAgrdqyq+Hxy0uqmhyvO+DUpiy7c9IEDuRP0Cn+Zkl1atrq0F1NWycAAEARCTSJcM7tlTRX0riEQ+MkvRnkawOFhBYdAKUoH7cz7pM02czmSJop6XpJfSX9QpLMbKKk3s65q+ufYGZDvIcdJXXzft7rnFuUh3hLwtade3NyHpqGkS8fbd6Zk/NwzZaOmt370xdCqALvseScmyLpFknflzRf0lhJ451zq7wiPRVLKuLN87Zhkr7oPf5H0LECQYnvWPlh1Y4QIwGKx+/fWh12CEgjLx0rnXOPSHqkmWPXJNkX3d6RAfLTpM63utyL7wCL1Pxcf9w6AvKHsVNAHkR40BCAEkYSEVGOJobQ8E3ZH65ZoPCQRETQ5h17NPrul3XvC0vSluUPHgrBx9W7NWriS3pg+rK0Zck1gPwhiYiQ+s59j772kT6urtXPXlkeckTRxOyVmau/DfTTF5dqY80e/XT60nADAtAISUQEHTjAV7V8i/JMqrnAJQsUJpKICGF0RmHgFlHmuP6AwkYSAQAAfCGJiKA3lm/JuCxfBHOPPhHZe3Vp5suCc80C+UMSEUGLP64JO4TIIW1omU3b96QvBCDvSCKQEn/8AOTD/jpmci1GJBFIiabhllm5eadq99XprRVbww4lOrhoi9Jjb6zUgQNOb6/kvVJM8rJ2BgoD9+Lz78OqHfrB39/X7n11YYdSlBgZGx1T39+g9m3L9N2/Lgw7FGSBlogIGTXxJX2S5RLhDLFrucROgQzxzNyx35uqmtp9WT2H+i1ef5yzJuwQkCWSiIi56jezsyr//vrqgCIBMnPrlHezKv/2yk8CigRBe28tnzfFhiQiYhauy25kxsfVtQFFEl2Z3FbauWe/ZnxQxbLhkqYv3hh2CMgDvwus1R1wem3pJlXvyq7FCrlBEhFxZ/3vjLBDiJwHX1qme6amXvzs8kdn6ZrH3tYXJ83KU1TF4+JH3gg7BARg3pptSfdf9/jbKUduPDlrla7+zVsafOe0gCJDKiQREbd8086wQ4icqe9v0CMzlmv2R81P+rVgXaxZd84qmuYTzVu9LewQEIDmGiKmL67ShQ83nzje8cz7AUWETJBEIGN+mxuRHM30QGYWpZggj4+lcJFEQDM+qNKSDen7SvBmzd5LS5pPFCa9vkJzVzEm3o+XFm/Uso3bww4DeTTl7dV8kSlAJBHQNY+9rXPvfz3sMErS799KPWTt0p/P1PYshzBCuvbxORr309fCDgN59O2nF+jFRbTeFRqSCDSYnuYNyneAYPxt/vqwQyhac+kzEim3TJmvugN8EhUSkgg0eOWDqpTHaUoMxnf/ulBrP9kVdhhFafaKzFekRfHbtbcu7cgm5BdJBBocSJMkkEIE50fPLg47hKJEXhs9v3zto7BDQBySCDT4/VtrtC9uPHZiy8OarXxbDsrU9zeEHUJRuveFDxrNIZB4zbKEeLTQWpp/JBFo5Ojbn9f+ugP645w16j/hH42Obd6R3bobyE5VTfLZQdN9MP5l3lqd+ZMZWr5pRxBhFbxP3f68nHOaPGtVk2s223U3UBwWNzPkM13/okXra3TGT2bo+QUfBxFWJJFEoImXllTpv//0XpP9GxL+yN37whL9+B/RbYZPNYueH5+9P/log7ufj90Dds5p/pptTUZzfGPKu/po805d8Wh0Z7d8dekmfS/J6o9bExacu+u5RfrJCx/kKywE5LwHko8mu2XK/IbHyzft0LptuxsdH//g61qxeadu/N07QYYXKSQRaOKGyXOT7r/zmUUNjxd/XKOfvbJcj772kbbsiGaT8Uebczvb57Zm5v6vvwf8mzdW6qKfvaET70g+vW9VhJvur3ns7aT775u2tOHxm8s3a9LrK/TwKx+qlqXZ8+pAHkdU1B1w2rJjj87631f1mbtfztvrRhVJBDK2OS5ZuOChfzY8XrIhmpP+PDV7dc7P2dyti6Ubt+uHzy5KegzNmxk3tfgXJx1cwXZ9wjdUBGt3AEnbrGamjZ+zcqseeGlZw8979pMwBokkAr7sj/tmEdV78c++l/v7qhc8/E99srNp35PPMrFSTq3fxuq0+bS/LvctEZc/OitpZ+/LHp2lJ2auaviZ1T2DRRKBrEx5e7WenLWq0b57pkbzHvO2XbnvaLpwXY2G/vDFtOV27d2f89dO5nezV+nBuG91xegv89Zq8syVjfbdP31p8sIIxL4DwSxpP+aeV9KWqV/MDsHISxJhZjeZ2QozqzWzuWY2Jk3507xytWb2kZl9JR9xIr1vP71A303owLZjz34dedtzKZsNN23fU3I95feHOHPecd9/QR9WBdsCVLuvTrf/ZaHue3Gpfjd7VfonFKhvTHlX3/tb45Ue56z6REfe9lyjIc2JqmpqtWNPfpK1Urd3fzBJRCaufXyOJka4A3jQWgf9AmZ2maT7Jd0k6Q1JN0h63syOc841ualsZv0l/UPSJElfkvQZSY+Y2Sbn3NNBx1tMCm3NhYHfnZq2zA2nDdBNp31KlYe0abS/7oBTK5PMLKjwSs7Z970a6Pmrdx+8vm7/y0Ld/peFeu+Oz6qiXZsUz0ot2a2aMB19+/Npy3zrnIH60qh+qmzf9Jota8X1monrHp8T6uv/8rWPmKQqIBb05BxmNlvSO865G+P2LZb0V+fchCTl/0fShc65QXH7fiFpsHPulCTlyyWVx+3qJGltdXW1KioqcvI7vPJBlZ6eu7bxjI2u/p/Yg/hqdKmOJZSJnwfy4PNiXl6SehrqUnTaMd3UqV1rJamehvqUUtd3k+NqWjZZvTct65p9/qtLN6X4LQrbBYN7ZVTupcUbtWsvndLSOXtQd7VrUyapmVldk+x0CTuTfQwn3ZfJ85KF0GRn01KJZZKfJ8nz0pynmN8r9TJ9z4TliyP66pSjDsvJuWpqalRZWSlJlc65tMs7B9oSYWZtJQ2TdHfCoWmSRjfztFO84/FekHStmbVxziV+/Z4g6QctjTWVFZt2BtKJDk2VwgdOoXvmXRb8yqXpi6OX7EdNob9nxhzdVacoN0lEtoK+ndFVUpmkxOUhN0rq0cxzejRTvrV3vsS/5hMl3Rf3cydJa/0E25yRAw7VHRccJ6lxc3v9Q0vcEbfvYJnmnxffgt9QzqQfPbtINbX7dcWIvvr9W7kfTliI6utZOljXjesnTor6ju1rWueZlE320BJe65v/926zv0Mh69qxXDefcVRGZd9fX6M/zfX3Vrr05D56+p2cvg0LVvw1KyW/JZfsLl2TXcmel+T1EotZklKZvF7yMk1OnmFM1myZ15ZtKuqVai85ubdO7F0ZdhgpDT2ic2ivHXifCE9ii5cl2ZeufLL9cs7tkdQwgUEQ99SP71Wp43vl/yL6wvAjGh5PvOTEJsedc02m+Q3bib0r9eS1I/XbN1fqp9OX6vbxg9StU3nDTHKL7zxX7duWhRtkDlw6rI+OvO25sMNoYuXd5+fsXFt27GmURDz275/WGQO7Z/z8//3C4Cb7DhxwGvCdwrpmR/Q/VJOuHq5fvLpcP5+xXD/83PFqU9ZKt/15gSRp6Y/OU9vWDGTz69JhfTS8X5cmnVvDlsv3SpQFnURsllSnpq0O3dW0taHehmbK75fEur9xCqkT4jWjj9QdFx7f8PPXzz5aXz/76Iafzzuxh0xWUh/G3TqVh7bA0+eH9dG9n4/9kQ4qmamI60j49//8jE7q07nF52xVQB0Rbz7jKH3rnGMbfv72ucfq2+ce/Pmiob1V1srUpqx0rtmwnHN8j1CTiBUTx8vMCjLxL3aBvjucc3slzZU0LuHQOElvNvO0mUnKf1bSnCT9IZBnD1w+pEmP9K+e+alGCUQy5a3LSiqBkKTJ144I7bXrE4ggtSlrpVkTztLr/31GThKIsNx/2ZAm++77wuBGCUQy7dqUkUDkSPeKdqG99o8vPrGgvnCVmny8Q+6TdJ2ZfdnMBpnZTyX1lfQLSTKziWb2RFz5X0jqZ2b3eeW/LOlaST/JQ6xI43NDemv5j8frW+cMbNh3/kk9Q4woPOWtg7kt89GPxwdyXj96VLbTEYceEnYYLXLR0N5aeff5umpUv4Z9o4/qGmJEyJVMbkkM7NEpD5FEV+BJhHNuiqRbJH1f0nxJYyWNd87Vz17TU7Gkor78CknjJZ3ulf+epK8xR0T4Tuh9cMjszWd8KsRICkMQLfO3jx+UtMn/zzc1N5jp4P/Ls189NfcBFbnLP32wX9G3zh2YoiRKxWP//mmdPejwhp87lh+8a//d82MzB3ztTD6/ciUvHSudc49IeqSZY9ck2feqpJMDDgtZuu7UAUn39wixqTJMyXrFt9SXT+3fZN8VI/rq5L5dNP7EHvrHgg1Njj/71ZQTwEbavw7r0/C4fZuDLUeJE0eh+Lz1nbOS7j9jYHcN69dFJ3mr3XZsd/DP3HVjBui6Mck/x+BPvkZnoIh89/xBmrl8i15KmOwqviVCkl75r9NVu69OnQ9pm8/wCkaub7NOv3Vs0hkQf3zxCZKkh644WT0qFmvkgENz+8Il4IcXnaBn5q/XWyu3Ntrf77AODY/blLXS9FtP0wHnSmKEUJTddfEJSftZ3HZerJ9LRbs2+uFFJ2h77T717tw+3+FFCkkEGnn5m6epf9cOum7MAFXv2qfBdx6c9ytxJrr+XTsoynKdRHyqe/J7t/Wdwspamb6fMCcBYtfsgG4dddWofvpk595GC5glzvD4qe4d8x0eAnDlyH5J97eOS8Lj+8AgOCQRaHDzGUdpQLeDH7KJ61uEt9xUYcplj+/TB3bL2bmi5Pbxgxpds106JLSKcdFGCmuZ5B/jl9CgjGFQWeneqTx9oQw9eMXQnJ0rSrhko+fKkX2bPTY4xJkbo4okAg3SfbPuWRnNDpTNyeUcAvGd/pC5Vmmu2Qo6UJacwSnmLDm5b5f8BQJJJBGIc9HQ3imPd2rBEtBo3l0Xn8CkRj6dc0JzS/DEtCM5Kyk3jB2gS+NG3CB8fHKhwZGHFfekQoVoxJGpR1K0LWvVbCcxpEfP+2iZMH4Q/R4KDEkEEKAnrxsZdgglqy2tNyUl2fTkKHyMzoi4Yf266DfXfFqtWxnzywcg3XohP7uSOdWyddnwI/Sd8YNKbi2WqEvXKfLG0zNbwh75RRIRcZee3IfZ+0LCEtP+XDikV5PhxyhtD14xVBcO7hV2GEiCTzCkdHyvivSFkLUB3TqQQPiUrr2sa8fcDb1F/qTq6nDa0cyjUqj4FENKg3qSROTaHRccp2m3jA07jJJ1ct/OYYcAH5pbi2bRnefQ8lTASCIirncXerfnW4fy1mpNp0Dfkq2ZgOLXXJesQ9py172Q8UkWcdyuyD86sLZMuqHIzHRdnMykY3skXz8GhYskIuLS/TlLXHQLCBtJWOk6jtunRYckImKOOZxVDMM2iqW8s8J6CNFgZnQ2LkL8j0XIPf96kv5286mN9vGtLnhPXtt4wilGD2Tu0auGacr1oxrt44otTSbp1nHHaEC3Dg37zh7UPbyAkBGSiAg5tkcntW/LWgL5durRXXXVKKa29mNQz4qs17/gFlzx6l7RTi9/8/SsnsP8EeEiiYiQZEOo0vaJoJtaKIYyTBER47dRdPyJPXMbCLJCEhEhyd6knRl/nRfZfkDec+lJ6t25vf7n0hODCahIJKu3VizAhDhnHNtNx/eq0BUj+oYdSiQxADfCvvyZ/vSJyJP4Ws6kyo8+vJPeuO3MwOIpVrecfXTYISAgzU02lU556zI997UxOY4GmaIlIkKO6tZ4ZAa3KlDoDk+YWIr+DtHC/3fhI4mIkMROlZkspXzKgMOCCgdIq03CNdqmLP231bHHdJXEUuHFpksHbq0WI25nRNB3xh+rP7+zTl85Lf3Supee3EeHtG2twUdU5iEy4KCO5Qc/nr457hg9v3CDrh59ZNrnXTmyn7oc0lbDj+wSYHTItfLWjBwrRiQREXT92KN0/dj0CYQU68R2/kn0fm6p+L4nfu/9Rk18LX31rKP11bMy6w9R1sp0AcP+gLygvQ8AUJAO0Cmi4JFERMShHdqGHQKQlf5xMxcCKEwkERHxw8+dEHYIQFYmnDco7BAApEESERG0RBQOpubITEV7umwBhY4kooScdky3hsfXnto/xEiQiMQhuc8P69Pw+IaxA0KMBIWIHhGFL9Akwsy6mNlkM6v2tslm1jnNcy4xsxfMbLOZOTMbEmSMpWRE/0PVu3N7SdLNZ3wq5GiA9Ib27aKuHWOtZDefyTUbFWcPOjzsEJAjQbcXPiWpj6RzvZ8flTRZ0gUpntNB0huS/k/SpECjKzGtzDTjW6dr9746VbRj4pZ8u2hIL/11/vqkx+hknlwrk9687Szt2V+nTgnXLENhS9d/nztQ0xdvDDsM5EBgSYSZDVIseRjlnJvt7fsPSTPNbKBz7oNkz3POTfbKHpnh65RLKo/b1aklcRezL47oqzZlrZrM8of8yDRP4E/jQRcO6aW2rVupbWuuWaAYBfnOPUVSdX0CIUnOuVmSqiWNzuHrTPDOWb+tzeG5i0qndnREK1T0iUjukLbNX7Os7QJa8ApfkElED0lVSfZXecdyZaKkyritT+riQDBSfeDRNA8cRHJQOrJOIszsDq/DY6ptuFc82aVizez3xTm3xzlXU79J2p6rcxebVJXKN2EUGxIvoPD5af9+WNIf0pRZKekkScm64HaTRI8alJyM+0SQ0SHiuFVVOrJOIpxzmyVtTlfOzGZKqjSzEc65t7x9IxW75fBmtq8LAAAKS2B9IpxziyVNlTTJzEaZ2SjFhmw+Gz8yw8yWmNnFcT8f6s0NcZy3a6CZDTGzXPajiJxWfPsNXLeO5c0ei69+xw3hjLRiwEbJapfhst+8Uwpf0G/TKyUtkDTN296TdFVCmYGKtU7Uu1DSPEnPeT//wfv5K4FGWuKG9esSdggl7+tnN79UNSlc9gYeHtnR2iXvyK4srlYqAh0T6JzbKulLacpYws+/lfTb4KKKprJW/BkLWmX7zCb4ok9EZqgnoPDRYAgAKEjc+it8JBEAAMAXkgggDw7teHApdhrpgcyc0LsyfSGEiiQCyIMuhxxMIlrRPwXIyNiju4UdAtIgiSgh3D8EUEroW1v4SCIAAAWJ70WFjyQCyAO+UAEoRSQRAADAF5IIAADgC0kEAKAgsdpn4SOJKCG83QoXvcwBlCKSCAAA4AtJBAAA8IUkooQwphpASeEzreCRRJQQOiEVLmOmCKCR9m3Kwg4BOUASUUJoiQBQLIYf2SXsEJADJBFAHpS34a0GZKt1Ge+bQtc67ACAKBh/Yk/9ae5aDe93aNihAAWhW8fyZo/dMHaAlm/aqeH9aK0odCQRQB60KWulydeODDsMoGCcf1JP/XneuqTHJowflOdo4BdtRQCAvGvVis7GpYAkAgjIF4b3CTsEAAgUSUQJYXRGYRnal/u5AEobSQQAAPCFJAIICC1DAEodSUQJacdcBACAPOKvTgkx1psOXXlr3lJAJo7rWRF2CMgBPvGAHLrkZEZkAJk4vKJd2CEgB0gigByiMQhAlJBEADlEDgEgSkgigByiJQJAlASaRJhZFzObbGbV3jbZzDqnKN/GzP7HzBaY2U4zW29mT5hZryDjBAAA2Qu6JeIpSUMknettQyRNTlH+EEknS/qh9+8lko6R9PcggwRy5dAOB1cmdGKiCAClLbBVPM1skGKJwyjn3Gxv339ImmlmA51zHyQ+xzlXLWlcwnm+KuktM+vrnFud5HXKJcWvKdsph78GkJVxgw7Xgy8tCzsMAMiLIFsiTpFUXZ9ASJJzbpakakmjszhPpSQnaVszxyd456zf1voJFsgF+kQAiJIgk4gekqqS7K/yjqVlZu0k3S3pKedcTTPFJiqWaNRvDNQHACAPsk4izOwOM3NptuFe8WQ3ha2Z/Ymv00bSH7wYb2qunHNuj3Oupn6TtD3b3wkAAGTPT5+IhxX7457KSkknSTo8ybFukjamerKXQPxRUn9JZ6ZohQAKCrczAERJ1kmEc26zpM3pypnZTEmVZjbCOfeWt2+kYrcc3kzxvPoE4mhJZzjntmQbY5QM7lOpd9dWhx0GkLH+XTtoxeadYYcBIAcC6xPhnFssaaqkSWY2ysxGSZok6dn4kRlmtsTMLvYet5b0J0nDJV0pqczMenhb26BiLWp89S0oZa34/0iHGgJKR9DzRFwpaYGkad72nqSrEsoMVKx1Qop1irzQ+3e+pI/jtmxGdAChGHj4wRHGjmkigIwc2oHviMUqsHkiJMk5t1XSl9KUsbjHK8UXFRSh+gYhlmMHkjvz2O56eUmyAXtS+zZleY4GucLaGQAAwBeSCCAg3M0AUOpIIopc/8MOCTsEICsDunUIOwQAORJonwgE7wcXHK+2rVvp88OPCDsUICMTLzlJXaYu0RdH9g07FAAtRBJR5Lp0aKt7/nVw2GEAGevWqVz3fp5rNmocw5VKErczgBxgTAaAKCKJAILCNy8gI7RSFC+SCAAA4AtJBAAA8IUkAgAA+EISAQAAfCGJAAAAvpBEAAAAX0giAACALyQRAADAF5IIICBMnwNkhvdK8SKJAAAEjkShNJFEADlgxuoZAKKHJAIAAPhCEgEAAHwhiQAAAL6QRAAAAF9IIgAUnDZldFQFigFJRATwgRwOx5g23yrbtw07BOQR75XiRRIBAAB8IYkAAAC+kEQAAABfSCIiwESfiKD16twu7BCAgnZYh/KwQ0AAWocdAFDMnr5xtB54aZm+/y+Dwg4FKGgTxh+rrTv36PIRfcMOBTlEEgG0wLB+XfTEl0eEHUbJYSmS0tO1Y7ke+3feK6Um0NsZZtbFzCabWbW3TTazzmmec4eZLTGznWb2iZlNN7ORQcYJAACyF3SfiKckDZF0rrcNkTQ5zXOWSvpPSSdKOlXSSknTzKxbUEECQRjYo1PYIQAFrd9hh0iSzj6ue8iRwK/AbmeY2SDFEodRzrnZ3r7/kDTTzAY65z5I9jzn3FMJ57lV0rWSTpL0UpLXKZcU32OHT+5ENA3n1dRbxujDqh0aNeCwsEMBCtqfvjJaryyp0gWDe4UdCnwKsiXiFEnV9QmEJDnnZkmqljQ6kxOYWVtJ13vPebeZYhO84/Xb2hbEXJLIIfLr2B4V+peT+FBsCa7ZaOjWqVxf+PQRat+2LOxQ4FOQSUQPSVVJ9ld5x5plZv9iZjsk1Ur6hqRxzrnNzRSfKKkybuvjO2IABYGOlUBxyDqJ8Do+ujTbcK94shnRrZn98V5RrP/EaElTJf3RzJLeNHPO7XHO1dRvkrZn+zsBAIDs+ekT8bCkP6Qps1KxPgyHJznWTdLGVE92zu2U9KG3zTKzZYr1i5iYbbAAACAYWScR3m2F5m4tNDCzmZIqzWyEc+4tb99IxW45vJnly5oad54EAAAhC6xPhHNusWK3IiaZ2SgzGyVpkqRn40dmeHNCXOw97mBmP/bK9zOzk83sV4r1c/i/oGItddxfBgAEIeh5Iq6UtEDSNG97T9JVCWUGKtY6IUl1ko6V9LRi80U8q9jtjzHOufcDjhVAgWC9F6A4BDrttXNuq6QvpSljcY9rJV0SZEwAACA3WMUTAAD4QhIBAAB8IYkAAAC+kEREAJ3UUGwYUQQUB5IIAADgC0kEAADwhSQCAAD4QhIBAAB8IYmIADqpAQCCQBIBoOCQ9wLFgSQCAAD4QhIBAAB8IYkAAAC+kESUsO6dyiVJpww4LORIgMwc0rZMkvTp/oeGHAmATJhzLuwYcsrMKiRVV1dXq6KiIuxwQrVm6y79Zd46XTWqn7p0aBt2OEBaq7bs1N/mr9e/jT5Sle3bhB0OEDk1NTWqrKyUpErnXE268iQRAABAUvZJBLczAACALyQRAADAF5IIAADgC0kEAADwhSQCAAD4QhIBAAB8IYkAAAC+kEQAAABfSCIAAIAvJBEAAMCX1mEHEJSamrSzdQIAgDjZ/u0sxbUzektaG3YcAAAUsT7OuXXpCpViEmGSeknanuNTd1IsOekTwLmjiPrMPeo0t6jP3KI+cy+oOu0kab3LIEEoudsZ3i+dNnvKViw3kSRtz2RlM6RGfeYedZpb1GduUZ+5F2CdZnwuOlYCAABfSCIAAIAvJBGZ2yPp/3n/ouWoz9yjTnOL+swt6jP3Qq/TkutYCQAA8oOWCAAA4AtJBAAA8IUkAgAA+EISAQAAfCGJAAAAvpBEZMDMbjKzFWZWa2ZzzWxM2DHlm5mNNbNnzGy9mTkzuyjhuJnZHd7x3WY2w8yOTyhTbmYPmdlmM9tpZn83sz4JZbqY2WQzq/a2yWbWOaFMXy+Wnd65HjSztkH97kEwswlm9raZbTezKjP7q5kNTChDnWbBzG40s/fMrMbbZprZeXHHqc8W8K5ZZ2b3x+2jTrPg1ZVL2DbEHS+++nTOsaXYJF0maa+k6yQNknS/pB2S+oYdW57r4TxJP5J0iSQn6aKE499WbKrUSySdIOkPktZL6hRX5ueKzfN+tqShkl6WNF9SWVyZ5yUtkHSKty2Q9Ezc8TJv38veOc5WbJrzh8Kuoyzrc6qkayQdL2mwpGclrZLUgTr1XacXSBov6Rhvu8t77x5Pfba4bj8taYWkdyXdzzXqux7vkLRQUo+4rVsx12folVrom6TZkn6esG+xpIlhxxZinTRKIiSZpI8lfTtuX7mkbZJu8H6uVOwD/bK4Mr0k1Uk6x/t5kHfukXFlRnn7Bno/n+c9p1dcmcsl1UqqCLtuWlCn3bzfcyx1mtN63SrpWuqzRXXYUdJS7w/NDHlJBHXqqy7vkDS/mWNFWZ/czkjBa9oZJmlawqFpkkbnP6KC1V+xjLqhnpxzeyS9qoP1NExSm4Qy6xXLyuvLnCKp2jk3O67MLEnVCWUWes+t94Jib7ZhufuV8q7S+3er9y912gJmVmZml0vqIGmmqM+W+Jmk55xz0xP2U6f+HO3drlhhZn8wswHe/qKsz5JbxTPHuirW7LMxYf9Gxf6zEVNfF8nqqV9cmb3OuU+SlOkRV6YqyfmrEso0eh3n3CdmtldF+n9iZibpPkn/dM4t9HZTpz6Y2YmKJQ3tFLvteLFzbpGZ1X94Up9Z8BKxkxW7nZGIazR7syVdrVjLzuGSvivpTa/fQ1HWJ0lEZhLnBrck++CvnhLLJCvvp0wxeVjSSZJOTXKMOs3OB5KGSOos6VJJj5vZaXHHqc8MmdkRkh6Q9FnnXG2KotRphpxzz8f9uMDMZkpaLunfJM2qL5bwtIKuT25npLZZsftGiZlZdzXNFqOsvndxqnraIKmtmXVJU+bwJOfvllCm0et452yjIvw/MbOHJF0o6Qzn3Nq4Q9SpD865vc65D51zc5xzExTrCPh1UZ9+DFPsd59rZvvNbL+k0yR9zXtc/7tQpz4553Yq1sHxaBXpNUoSkYJzbq+kuZLGJRwaJ+nN/EdUsFYodlE21JPXn+Q0HaynuZL2JZTpqVgP5PoyMyVVmtmIuDIjFesvEF/mBO+59T6r2Cp2c3P3KwXLG8r1sGK9sM90zq1IKEKd5oYpdp+X+szeS5JOVKxlp36bI+l33uOPRJ22iJmVK9YR8mMV6zUadm/VQt90cIjnl73/7J8qdq+1X9ix5bkeOurgB4mT9A3vcV/v+LcV60V8sXdBP6XkQ5PWSDpLsWFFLyn50KR3FetNPErSe0o+NGm6d46zvHMW21CvR7z6Ok2Nh3u1jytDnWZXpz+WNEbSkYr98btLsZbEcdRnzup4hpoO8aROM6+/n3jv+f6SRkp6RrEhnf2KtT5Dr9Ri2CTdJGmlDmZpY8OOKYQ6OF2x5CFx+6133BQbvvSxYsOEXpV0QsI52kl6SNIWSbu8N9ARCWUOlfSk98aq8R53TijTV7F5FXZ553pIUnnYdZRlfSarSyfpmrgy1Gl2dfrruPdplfcBOY76zGkdz1DjJII6za7+6ud92KvYvAxPSzqumOvTvJMBAABkhT4RAADAF5IIAADgC0kEAADwhSQCAAD4QhIBAAB8IYkAAAC+kEQAAABfSCIAAIAvJBEAAMAXkggAAOALSQQAAPDl/wMsTYwZlKLTRwAAAABJRU5ErkJggg==",
+      "text/plain": [
+       "<Figure size 600x400 with 1 Axes>"
+      ]
+     },
+     "metadata": {
+      "needs_background": "light"
+     },
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "from matplotlib import pyplot as plt\n",
+    "plt.figure(dpi=100)\n",
+    "time = np.arange(gen_audio.shape[0])\n",
+    "plt.plot(time, gen_audio)\n",
+    "plt.show()\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 128,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# save audio\n",
+    "import soundfile as sf\n",
+    "sf.write('data/gen_audio.wav', gen_audio, ori_sr)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 129,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "t:   0%|          | 0/49 [00:00<?, ?it/s, now=None]    "
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Moviepy - Building video data/generate.mp4.\n",
+      "MoviePy - Writing audio in generateTEMP_MPY_wvf_snd.mp3\n",
+      "MoviePy - Done.\n",
+      "Moviepy - Writing video data/generate.mp4\n",
+      "\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "                                                             \r"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Moviepy - Done !\n",
+      "Moviepy - video ready data/generate.mp4\n"
+     ]
+    }
+   ],
+   "source": [
+    "gen_audioclip = AudioFileClip(\"data/gen_audio.wav\")\n",
+    "gen_videoclip = ori_videoclip.set_audio(gen_audioclip)\n",
+    "gen_videoclip.write_videofile('data/generate.mp4')\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 130,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<video src=\"data/generate.mp4\" controls  width=\"640\" >\n",
+       "      Your browser does not support the <code>video</code> element.\n",
+       "    </video>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.Video object>"
+      ]
+     },
+     "execution_count": 130,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "Video('data/generate.mp4', width=640)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "interpreter": {
+   "hash": "ce61937b7f7dfb4402f1892711bcd3e4a6b6f6d238d7280e2db39bcb9fe9525c"
+  },
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.13"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/foleycrafter/models/specvqgan/onset_baseline/demo_with_video_onset.ipynb b/foleycrafter/models/specvqgan/onset_baseline/demo_with_video_onset.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..36bdaab9a187a10e617c6c614d1dc03650c1caf2
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/demo_with_video_onset.ipynb
@@ -0,0 +1,548 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Change audio by detecting onset \n",
+    "This notebook contains a method that could change the target video sound with a given audio."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Load packages"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 28,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import IPython\n",
+    "import os\n",
+    "import numpy as np\n",
+    "from moviepy.editor import *\n",
+    "import librosa\n",
+    "from IPython.display import Audio\n",
+    "from IPython.display import Video"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Read videos\n",
+    "origin_video_path = 'demo-data/original.mp4'\n",
+    "# conditional_video_path = 'demo-data/conditional.mp4'\n",
+    "conditional_video_path = 'demo-data/dog_bark.mp4'\n",
+    "\n",
+    "ori_videoclip = VideoFileClip(origin_video_path)\n",
+    "con_videoclip = VideoFileClip(conditional_video_path)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 30,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<video src=\"demo-data/original.mp4\" controls  width=\"640\" >\n",
+       "      Your browser does not support the <code>video</code> element.\n",
+       "    </video>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.Video object>"
+      ]
+     },
+     "execution_count": 30,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "Video(origin_video_path, width=640)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<video src=\"demo-data/dog_bark.mp4\" controls  width=\"640\" >\n",
+       "      Your browser does not support the <code>video</code> element.\n",
+       "    </video>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.Video object>"
+      ]
+     },
+     "execution_count": 31,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "Video(conditional_video_path, width=640)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# get the audio track from video\n",
+    "ori_audioclip = ori_videoclip.audio\n",
+    "ori_audio, ori_sr = ori_audioclip.to_soundarray(), ori_audioclip.fps\n",
+    "con_audioclip = con_videoclip.audio\n",
+    "con_audio, con_sr = con_audioclip.to_soundarray(), con_audioclip.fps\n",
+    "\n",
+    "ori_audio = ori_audio.mean(-1)\n",
+    "con_audio = con_audio.mean(-1)\n",
+    "\n",
+    "target_sr = 22050\n",
+    "ori_audio = librosa.resample(ori_audio, orig_sr=ori_sr, target_sr=target_sr)\n",
+    "con_audio = librosa.resample(con_audio, orig_sr=con_sr, target_sr=target_sr)\n",
+    "\n",
+    "ori_sr, con_sr = target_sr, target_sr"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 33,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def detect_onset_of_audio(audio, sample_rate):\n",
+    "    onsets = librosa.onset.onset_detect(\n",
+    "        y=audio, sr=sample_rate, units='samples', delta=0.3)\n",
+    "    return onsets\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 34,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhAAAAFZCAYAAADJvxawAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAqa0lEQVR4nO3deZxddX3/8ddnJpPJwkxWQsISCGFfZAk7ZdWwWSlaLVWKRaQqWIXShca2glLNrz9bRKEIIj8VWgRqqxZBxF2WCBLZZVEISwgkhJCFbJOZ+f7+uHfg5ma2k9x7z70zr+fjcR6Ze8733Pnke2fufc/3fM85kVJCkiQpi6a8C5AkSY3HACFJkjIzQEiSpMwMEJIkKTMDhCRJyswAIUmSMjNASJKkzEbkXUClRUQA2wKr8q5FkqQG1AYsSgNcKGrIBQgK4WFh3kVIktTAtgde6q/BUAwQqwBefPFF2tvb865FkqSGsXLlSnbYYQcYxCj+UAwQALS3txsgJEmqEidRSpKkzAwQkiQpMwOEJEnKzAAhSZIyM0BIkqTMDBCSJCkzA4QkScqsJgEiIs6LiAURsS4i5kfEUQO0PyMiHo6INRHxckR8PSIm1aJWSZI0sKoHiIg4Hbgc+BxwAHAX8IOImN5H+z8ArgeuA/YG3gccDHyt2rVKkqTBqcUIxIXAdSmlr6WUnkgpXQC8CJzbR/vDgOdSSl9OKS1IKd0NXAMcVINaJUnSIFQ1QETESGAWcGfZpjuBI/rY7V5g+4g4JQq2Ad4L3NbH92iNiPaehcJdxCRJUhVVewRiMtAMLC5bvxiY2tsOKaV7gTOAm4EO4BVgOfCJPr7HHGBFydL4d+JcvRoiCsvq1XlXo4H4eg0Pvs71y9cmF7U6C6P8nuLRy7rChoi9gC8Dn6UwenESMAO4uo/nnguMK1m2r0C9kiSpH9W+G+dSoItNRxumsOmoRI85wD0ppS8UHz8SEauBuyLiH1NKL5c2TimtB9b3PI6IihQuSZL6VtURiJRSBzAfmF22aTaFuQ69GQN0l63rKv5rOpAkqQ5UewQC4DLghoh4AJgHfASYTvGQRETMBbZLKX2w2P5W4NqIOBf4ITCNwmmg96eUFtWgXkmSNICqB4iU0s3Fi0B9mkIYeAw4JaX0fLHJNAqBoqf9NyKiDfhL4N8oTKD8KXBRtWuVJEmDU4sRCFJKVwFX9bHtrF7WXQFcUeWyJEnSZvJeGJIkKTMDhCRJyswAIUmSMjNASJKkzAwQkiQpMwOEJEnKzAAhSZIyM0BIkqTMDBCSJCkzA4QkScrMACFJkjIzQEiSpMwMEJIkKTMDhCRJyswAIUmSMjNASJKkzAwQkiQpMwOEJEnKzAAhSZIyM0BIkqTMDBCSJCkzA4QkScrMACFJkjIzQEiSpMwMEJIkKTMDhCRJyswAIUmSMjNASJKkzAwQkiQpMwOEJEnKzAAhSZIyM0BIkqTMDBCSJCkzA4QkScrMACFJkjIzQEiSpMwMEJIkKbOaBIiIOC8iFkTEuoiYHxFHDdC+NSI+FxHPR8T6iHgmIs6uRa2SJGlgI6r9DSLidOBy4DzgHuCjwA8iYq+U0gt97HYLsA3wYeD3wJRa1CpJkganFh/KFwLXpZS+Vnx8QUScCJwLzClvHBEnAccAO6eUlhVXP1eDOiVJ0iBV9RBGRIwEZgF3lm26Eziij91OBR4A/i4iXoqIpyPiXyNidB/fozUi2nsWoK1S9UuSpN5VewRiMtAMLC5bvxiY2sc+OwN/AKwD3l18jquAiUBv8yDmABdXolhJkjQ4tToLI5U9jl7W9WgqbjsjpXR/Sul2CodBzupjFGIuMK5k2b4yJUuSpL5UewRiKdDFpqMNU9h0VKLHy8BLKaUVJeueoBA6tgd+V9o4pbQeWN/zOCK2sGRJkjSQqo5ApJQ6gPnA7LJNs4F7+9jtHmDbiNiqZN1uQDewsOJFSpKkzGpxCOMy4JyIODsi9oyILwLTgasBImJuRFxf0v5G4DXg6xGxV0QcDXwB+H8ppbU1qFeSJA2g6qdxppRujohJwKeBacBjwCkppeeLTaZRCBQ97d+IiNnAFRTOxniNwnUh/rHatUqSpMGpycWZUkpXUTiTordtZ/Wy7kk2PewhSZLqhPfCkCRJmRkgJElSZgYISZKUmQFCkiRlZoCQJEmZGSAkSVJmBghJkpSZAUKSJGVmgJAkSZkZICRJUmYGCEmSlJkBQpIkZWaAkCRJmRkgJElSZgYISZKUmQFCkiRlZoCQJEmZGSAkSVJmBghJkpSZAUKSJGVmgJAkSZkZICRJUmYGCEmSlJkBQpIkZWaAkCRJmRkgJElSZgYISZKUmQFCkiRlZoCQJEmZGSAkSVJmBghJkpSZAUKSJGVmgJAkSZkZICRJUmYGCEmSlJkBQpIkZWaAkCRJmdUkQETEeRGxICLWRcT8iDhqkPsdGRGdEfFQlUuUJEkZVD1ARMTpwOXA54ADgLuAH0TE9AH2GwdcD/yk2jVKkqRsajECcSFwXUrpaymlJ1JKFwAvAucOsN81wI3AvCrXJ0mSMqpqgIiIkcAs4M6yTXcCR/Sz34eAmcBnBvE9WiOivWcB2ragZEmSNAjVHoGYDDQDi8vWLwam9rZDROwK/B/gjJRS5yC+xxxgRcmycLOrlSRJg1KrszBS2ePoZR0R0UzhsMXFKaWnB/ncc4FxJcv2W1CnJEkahBFVfv6lQBebjjZMYdNRCSgcfjgIOCAiriyuawIiIjqBE1JKPy3dIaW0Hljf8zgiKlS6JEnqS1VHIFJKHcB8YHbZptnAvb3sshLYF9i/ZLkaeKr49X1VKVSSJGVS7REIgMuAGyLiAQpnVHwEmE4hGBARc4HtUkofTCl1A4+V7hwRS4B1KaXHkCRJdaHqASKldHNETAI+DUyjEBBOSSk9X2wyjUKgkCRJDaIWIxCklK4Crupj21kD7HsJcEnFi5IkSZvNe2FIkqTMDBCSJCkzA4QkScrMACFJkjIzQEiSpMwMEJIkKTMDhCRJyswAIUmSMjNASJKkzAwQkiQpMwOEJEnKzAAhSWVWrNnAN+99jtfeWJ93KVLdMkBIUpnzb36Qi//3cc7+5gN5lyLVLQOEJJX5+VOvAvDwi8vzLUSqYwYISZKUmQFCkiRlZoCQJEmZGSAkSVJmBghJkpSZAUKSJGVmgJAkSZkZICRJUmYGCEnSkNHR1Z13CcOGAUKSNCTMOfHj7Pb5X/Lsq2/kXcqwYICoAyklUkp5lyFJDe1b+58MwLV3PZtzJcODASJnKSVOv+ZXnP7VXxkiJEkNY0TeBQx3S9/o4P7nlgHw2uoOJm/VmnNFkiQNzBGIOhJ5FyBJQ4LvprVggKgjHsCQJDUKA0TOwqAsSWpABghJ0pDiH2a1YYCQJEmZGSAkSVJmBog64mUgJGnL3XjfC3mXMCwYIHLmoTpJUiMyQEiSpMwMEDmLkunCyStBSJIaRE0CREScFxELImJdRMyPiKP6afueiPhRRLwaESsjYl5EnFiLOvPg/S8kacvMf3EF+33yW3mXMexUPUBExOnA5cDngAOAu4AfRMT0PnY5GvgRcAowC/gZcGtEHFDtWiVJjedPr3+IFaPb8i5j2KnFzbQuBK5LKX2t+PiC4ojCucCc8sYppQvKVn0qIv4IeBfwYDULzUN4xRNJ2iIbuh3JzUNVRyAiYiSFUYQ7yzbdCRwxyOdoAtqAZX1sb42I9p6l2FaSKuInTyzOuwSpLlX7EMZkoBko/w1cDEwd5HP8NTAWuKWP7XOAFSXLwuxl5mej8QdDtFR3vn7Pc3mXINWlWp2FUf7RGL2s20REvB+4BDg9pbSkj2ZzgXEly/abX6YkSRqMas+BWAp0selowxQ2HZXYSHHy5XXA+1JKP+6rXUppPbC+ZL/NLlaSyvmWIvWuqiMQKaUOYD4wu2zTbODevvYrjjx8A/hASum2qhUoSQOY//zreZcg1aVanIVxGXBDRDwAzAM+AkwHrgaIiLnAdimlDxYfvx+4Hjgf+FVE9IxerE0prahBvdKgdHUnmpv883SoW9PRlXcJUl2q+hyIlNLNwAXAp4GHKFzn4ZSU0vPFJtMoBIoeH6UQbP4deLlk+VK1a81D6fDo/c/1eqKJ6tDfffthDvjsnbz2xvqBG0vSEFSTSZQppatSSjullFpTSrNSSr8s2XZWSunYksfHppSil+WsWtSap7+8cchd5mLIuuWBhaxc18lNv34x71Ik9cKr/Faf98LIWXg/zobmm5RUn9Z3duddwpBngMiZN9BqbOYHqT49stApc9VmgKhz6zu7/Su3jvnKSPl6afnaXtf/yTXzalzJ8GOAyFlvhzD+9n+f5Jz3/BOvbDWJWf92D399y8M5VCZJ9e9vfH/MTS1O41QGXd2J/3r4Fdj1UMZsWMsbHV38z4Mv8Y69tuGUfaflXZ7KODg0vHzzwD/k4tkf45evr2X62LF5lzPsvbhsDfOefS3vMoYtRyDyVjYAMfuLv3jz65fap7z59Xn/+ZtaVaQMBjOH5XsPvcShn/8xD724vPoFqaounv0xAI6+8r6cKxHAgwP8Tv3i6VdrU8gwZYCoM8++uvrNrx/bZmaOlWgwBjMCcf5ND7F45XrO+4/51S9IGkYGOoftz//f/TWpY7gyQOTs7t8t7XNb8iL8Q0pnt8c7hpKnXlmVdwnDXpPvkbkyQOTskYXL+9zWMWJk7QrRZskSCXyzG1pOvPyXAzdS1aSU+PiNHtrNkwFC2hIZZlGaHxqbp1PXl+vuXpB3CcOeZ2HkzQ+Vhvbt+Qt5feUaPkPQNMB4hC91Y7v18SV5l6AS9/y+78O/qg1HIHLmpawb26IV67jhgUX8eJdDBmwbDkE0tF89vzzvElTC8aD8GSCkCvjIH/8TP5l5cN5lqErWdnRx429ezrsMlfCIUv4MEFKFfPi9F/e7vcnftoa156fvyLsElTE/5M+3tJy9sGz1wI00JLy4rPdr9ktSIzJA5Oz2R1/JuwRJajgvvOYfX3kzQEiSGs5zr60ZVLv/vO95ANZt6OLxRSs8HbeCDBCSpCHrH77zGKvWbeAD1/6Kd375br770Et5lzRkGCAkSUPavpfcyW9eWA7Azb9+Md9ihhADRAN51xV3c+vDi/IuQ5IkA0QjefSlFXziWw/mXYYkNSwv3lc5BghJkpSZAUKSNGx4RfnKMUBIkqTMDBBSBT29xIvbDCdeU0DDmQGiAb0wyAuoqPZOuObXeZegGuo5NVAajgwQDejoL/zMv3ykOvDHX7mX1es78y5DGdz7zGts6OrOu4whwQDRoJ559Y0+t63b0MUtv36RxSvX1bAiaXja++If0tXdd6D/3kMv8bOnltSwIg3kvx5YmHcJQ4IBokGt29B3gv7yT37H3/33I7z73++pYUXqcf5NDzpCNMzc9OsXel3/4rI1nH/TQ3zo6x7aqiePvrQi7xKGBANEg+rv8+mnTxb+2lm0whGIPHzvoUXM/NTtrOlwaHu4+IfvPMZ3H9z0HguvvrE+h2o0kG/d33vgUzYGiAb1xCsr+eKPnmanv7+NFWs3bLQtPNE5d90JLv3+E5us//2SVTlUo1q44OaHNhl5ciBKQ5kBokH93bcf4Us/+R0A+33mTgDuX7CMl1espavbCULVtGTV4EZ2vnX/C3zjngUbrTvzuvs3afezJ5fwxR897WGPIWDJqvIRh7de0xVrNiANJQaIIeLWhxfxJ9fM4/C5P+XpxW9NsPRDqfKeemXwowiX3PrbjR6/3MthpQ9949d86Se/487fLt7i2pSvnlDfo/TXb7/P3lnjaqTqMkAMEX3dZGvGnNt5ZOHy2hYzxK3vZwLrYHz8xt+8OWu/9PCTZ800vpeXr93ocXl87+9sDanRGCCGgVOvvIflazryLkNFtz3yMj8vntb37qveOlPGuSuN72dPvbrR486ujQPDu6+6h2WrO+g2SGwRR1brgwFimNj/sz/it4tW5l3GkHD375du8XO8WjxW/uyrJZe+9k1xyCkfcXhk4QoOvPRH7Pyp2x2N2AJ2XX0wQOTojsdeqen3O+XLd3kWQAVU4pDQ3//Po6zb0LXRun/63uOe+jkEdHUnFixdzfznX6e5qe9Rpc/e+ngNqxpauisQtnf6+9v4ys+f4d/ufIplqx2h3Rw1CRARcV5ELIiIdRExPyKOGqD9McV26yLi2Yj4WC3qrLXzb+p93kI13frwyzX/nkNNpe5/sMc/3bHJur0+/UO++stneGN9Jykl7nv2Nf7vHU+ybHXHJoFD9Wnmp27nuH/9OX/8lXt5ecXaPtt9c97zznvZTFkmMvfnX+54kit++nsOvPRHFXm+4SaqfSwpIk4HbgDOA+4BPgqcA+yVUtrkah4RMQN4DLgWuAY4ErgKeH9K6b8H8f3agRUrVqygvb29Yv+PSnp11XoefnE551z/QN6lMOfkPdimfRQPPL+MNeu7OPfYmey89Vb9/uU01N10/wvsPrWNA6ZP6HX7Tn9/W40ressB08fzP+ce0bDzJVJK3Hj/Cxyz29ZsP2FMpn2/dtezzP3Bk2w3fjQvLNv4hnLf+NDBHLv7lIrUuHLdBt52Se3OmHj0khNoG9VStedfvHIdtz/6Mn9++E40beHv9aXf/y0jmoI5p+xZoeqy6+5O7Pyp2yv+vG/fYwp/fcLuTBs3igljRwKFn9e/+a9H+N5DL9FZPG5yyIyJzN5zGz53+xMctetkvvGhQzZ5v+zuTlvc13lZuXIl48aNAxiXUur3uHctAsR9wG9SSueWrHsC+G5KaU4v7f8FODWltGfJuquB/VJKh/fSvhVoLVnVBiysZIB44LllvPfqeRV5rkZz5C6TeGN9F+s3dPFkMfWPamniqF23pq11BABvrO9k7YYuulOiKYJtx42mszuxat0GxoxsZvTIZlat6+Tu3y+lOYL20S2sWtfJQTtOoGVEE4tXrmNKWyutI5rp7O4mJWiKwgz2pgg2dHUzoiloinhzVnvPz+1bjyl7vPH20gep+MWb+yS44/HaHk6qtFP2ncq6Dd0sX9PBtPGjaYpgZHMTKSW6U6I7Fa5f0TqimUljR7JqfScTxrSwcm0nb6zvZOLYkYxoCrpSoqs70dLcRHNT0NnVTUTQ2Z02Hjbe6Mu3HvQ0WbR8LQ8v7P9ywUfMnMSUtlZScb/Cv4Vn6+pKNXlNDtt5IktWrufZpfV7G/atWkdwxMxJtLY009nVTUtzEw8vXM7zFbor7+StRrL0jU2H8E/Zdyq3P1qZ12Cf7drZcdJYWpqi8MFa8vv6uyWreGbJatZu6OLQGRPZfsIYWpqDl5av5a7fLWXnrccyfnTLkL3z6TvfNo3mCJqbgpQSEVH8vS30TwARhfdC6P297QOHTOfQnSdVpJ66CRARMRJYA7wvpfSdkvVfAvZPKR3Tyz6/BB5MKZ1fsu7dwC3AmJTShrL2lwAXlz9PJQPEqVfezSMDvBlKkpSHL7z3bbzvoB0q8lxZAsSIinzHvk0GmoHyK+QsBqb2sc/UPtqPKD5f+UH8ucBlJY/bgIreau2Lp+/PyZffRccwuwXsoTMmssPEMUxpa2VUSzMPPP86K9duYNaOE5g2bhSd3YnX13SwYs0Gpo0rDCvvtW07azs6aWoK2lpHsHZDF2s6umiK4DsPvkTriCZmTB7Lr597nT89eAdeX9PBzK23ojslNnQlWprjzQQOhclSI5qa6OpOb/6lGxSSePkofs+wfrz5uPTrjfcJShoB37hnAc+8Wr9/hQ7kH9+5J2s6CnMkWpqbGDmiiQ1d3QTQ3BR0dSc6OrvpSokVazewdVsrGzoTazo62Xb8aDZ0ddPZnWiKwv6dXYmu4mvQFNDc1ERzSX+XHkLZpE+B9Z3d/PNtm17Ku9Qnj9+FtlEthdep5LXred3KL8JVaWNGNrPXtHZ23aatru+NcNr+27Lt+NFMGDOSkSOaWLehi588sYT7n1vG5K1aWbqF99s4bf9t+e5Di4DC6OLxe0xh0fJ1nLrftty34DV++/JKXlzW91yOwfjk8bswYexIOjoL76Fv/cwEqzs6eWThCp5buppDZkxkp8lj6ejspingh48vZo+pbYwb3cL1854fcu/BJ+8zlYN2mkjw1hk7iUQQG72/9fyd39eRy/12GF/VOvtS7RGIbYGXgCNSSvNK1v8DcGZKaY9e9nka+HpKaW7JuiOBu4FpKaV+x9QaYQ4E1P44a2+O2nUynzh+V/bZrp1Fy9fywHOvc9I+Uxk/ZmSudeXt+ddWM370SMaN6f24dJ5zIC55116cdeSM3L7/lkop8YunX2XXbdrYbvzoTPs+8fJKXlm5ju3Hj+bS257gl0+/yn47jOeik3bniJmTK1ZjtY6x9+awnSdy00c2OTJbUSkllq/Z8OZx/S2xbHUHLc1R1Tkbg1GN38G/esduvP+QHdi6rZWI4MVla3hjfSevrlrPRf/9CGNbR/DCsjUcu9vWvOfA7ZnS3kpHZzeHVejQQb2opxGIpUAXm442TGHTUYYer/TRvhN4raLV5ah9VAvX/flBfPibtZ1I+cznT+l1guQuU9rYZUpbTWupVztOGpt3CTz86RNoGzWCBHR2d7No+Tomjh3JuNH5vnFvqYjY7MmOe05rZ89phT8Krj/7kEqWtZFKTX772xN3Z1RLM5d+v++RlGv+7KCKfK/+RERFwgPAxAo9z5ZqGzWCVesqc8rzBw6dzkeP3nmT3/sdJhYm+e45DebNeXtFvtdQU9XTOFNKHcB8YHbZptnAvX3sNq+X9icAD5TPf2h0x1Vo1vhg/dU7dhvWZ1dUyrnHztzi5zhw+ngWzD2Fg3fa+EyPc4+dyTOfP4VxY1poaipMrGod0cyMyWMbPjwMF0/980k8/pkT+fhxu/Cut03rs933P/EHfY5yqX9jRjZX7Lk+/+596+KPhkZUi+tAXAacExFnR8SeEfFFYDpwNUBEzI2I60vaXw3sGBGXFdufDXwY+Nca1FpTtT7N57zjtvyDTxsf699c3/5Y4VTMi9+190brLzppD0Neg2sd0czY4hlKU9pH9drm7ouOY5/txtWyrCGldURlAsSCuadU5HmGq2ofwiCldHNETAI+DUyjcI2HU1JKzxebTKMQKHraL4iIU4AvAh8HFgGfHMw1INS7I2ZO4vg9ptDS7IVHK6ESl2DoCY+jWt56I2wbVfVfR1XZP5+2T7/bLz1tH848bMcaVTN0jW7Z8gDx5KUnNez1VOpFTd6xUkpXUbgYVG/bzupl3S+AA6tc1pDzieN34Yqf/n6T9d/40CGMHGF4qJSoyBhEwap1bx2V+8H5/V6gVQ2gpbn/nw3DQ2WMqsAhjFEVCCHDnX/yNKi21hGsWv/WJKLLT9+f0w7Yjo8eM5N9Lv7hm+uvOuNAw0OFbckRhicvPYnWktdjbcdbl6dud45Dwzt1v+3yLmFYaHe0ri74ydKgxo9t4YS9tgHg2g8exGkHFN64WsvCwpEVPL1NBVlOc917241PJR7V0rzRsOl2E946lbE951PjtOVGV3Byn/r22T/q/1DRwPvvPXAjDcgY16Cu/rNZ7DG1nVdWrtvofPqW5iau/MAB/OWNhRt1NQ8wpKrsZmzd/4ztt+8xhevOOphFy9cyeatWdvvHH/TZdsdJY/nqmbPYuq21zzaSNjZj8padNfHBw3eqTCHDnAGiQe29bWEGd28X49m3ZHb3CGf0V9yOE/u/CdR1Zx0MwLaDvFDSCXv3dVFWDQVHzJzEvc+8xntnbZ93KVJFGSCGoB0mjOGQnSbSPnrEJoc0tOX6OjVPw9MF79iVl5evY9/tez8t8+ozZ/HLp1/l7XtsU+PKpOoyQAxBTU3BLR+r7uVxJcGVHziAE/ee2u8p0u2jWvjDt21bw6qk2jBANKAbPly9y/hKGjyDQeP5y+N2ybuEIcPx7Qa0pROItGWqeQM6SdX1NyfunncJQ4YBogF59bR89RcfnLQqabgwQEgV9OMLj8m7BEl9+O9zj8i7hCHFORANyCH0fPXX/Tt5eEmqO2cfOYMPHDqdXaZslXcpQ4ojEA1oq1ZznyQN1m7bbGV4qAIDRIP56pmzMl1KWfXj+D2m5F2CNCx5Ea/qMEDkbGrGixJ51cL8jc14v4Oeq4WevI+vnZSHEf1cp0Obz7HwnE2fOIZXVq7LuwxlkPXN6PZPHsXjL6/gsBmTqlSRJNWesSxne5XdrVFDz7gxLRwxczJNnuIpaQgxQEiSpMwMEDnzmlBSY/rk8V4SuRG8zwmUVeMciAYxe69tOOPQ6XmXIanowhO8JHIj+Id37pl3CUOWASJng70m1LUfPKi6hWiLHDJjIkfOnJx3GaqS9x8ynW/d/0LeZWgzeNp79RggpAq45aPePn0o81CjtCnnQEjSAMwP0qYMEHXo5H2m8gknaEl1wxEIaVMGiDp16n7b5l2C+vHNsw/JuwTVUDgGIW3CAJGz3v6ySQl23aat9sVo0I7ZbWtamv1QGS4cgWhMU9pa8y5hSDNA5Ky/szDGjW6pXSHKzL9Khw9f6ca0/w7j8y5hSDNA1KENXd15l6BBGDfGgCfVM0eOqssAUYfOO84JlI3gTw7yCnfDRfhJJG3C60DUmR9feAy7TNkq7zI0CJ84flcmjBnJcXtMybsUVZn5oTF5mLG6DBB1pn2UL0mjGNXSzDlH7Zx3GaqBtlZ/LxvRVr6fVpWHMHLmXzZS/fuLo3fm8J0n5V2GMrropD3yLmFIM0Dk7NxjZ270uL3kzAvDhVQf2ka18K2PHJZ3Gcpoa0/jrCoDRM6mtI3a6PGoluacKpEkafAMEHVs+ZoNeZcgSQ3pP885NO8ShjwDhCRpyDlyl8l5lzDkGSAkSVJmVQ0QETEhIm6IiBXF5YaIGN9P+5aI+JeIeDQiVkfEooi4PiKGxZ2lxoxs7vexJEn1otojEDcC+wMnFZf9gRv6aT8GOBC4tPjve4DdgP+tZpH14oiZGw+5jXZCpVRXtp8wOu8SpLpRtQAREXtSCA3npJTmpZTmAX8B/GFE7N7bPimlFSml2SmlW1JKT6WUfgV8ApgVEdOrVWu98LRNSRqcMw4tfCScftAOm2ybc7LXf6iFao5AHA6sSCnd17OiGAhWAEdkeJ5xQAKW97YxIlojor1nARr2Ptipv1tzSqqZt3t58rp3yal7818fO5xLT9tnk20fOdorxNZCNQPEVGBJL+uXFLcNKCJGAf8HuDGltLKPZnMohJKeZWH2UuvDj5/orbsk1dqMyWPzLkEDaGlu4uCdJjJyhOcC5CVzz0fEJRGRBlgOKjbv7U/q6GN9+fdpAW4q1nheP03nUhil6Fm8RaIkSVW2OXcauZLCB3t/ngPeBmzTy7atgcX97VwMD7cAM4Dj+xl9IKW0Hlhfsu8ApUmSpC2VOUCklJYCSwdqFxHzgHERcUhK6f7iukMpjBLc289+PeFhV+C4lNJrWWscKsxCkqR6VbWDRymlJ4A7gGsj4rCIOAy4Fvh+SumpnnYR8WREvLv49Qjg28BBwBlAc0RMLS4jq1WrJEnKptqzT84AHgXuLC6PAGeWtdmdwqgEFOYvnFr89yHg5ZIly5kbkrTZHP2TBrY5cyAGLaW0DPizAdpEydfPUZhkKUmS6pjnv0iSpMwMEHXM60pJUnaejVcbBog6ctFJXn5VktQYDBB14McXHs3c9+zr5VclaTPc/JHD8i5hWKrqJEoNzi5T2thlSsPewkMa0hwZrH+H7jwp7xKGJUcgJKkf5x47M+8SpLpkgJCkMn1Nwpu14wQAWpqdpCd5CKOOeRKGlI/UxylQn/2jfdhx0lhO23/bGlck1R8DhCQN0rjRLVw4e7e8y5DqgocwJElSZgYISZKUmQGijn3smMJ1Id6577ScK5EkaWPOgahjf3HUzhy169bsOmWrvEuRJGkjjkDUsYhgz2ntjGj2ZZKk/uyzXXveJQw7fjJJkhpeW2tL3iUMOwYISSrj3Rwbz3YTRuddwrBjgJAkNbwxI5vzLmHYMUBIkqTMDBCSJCkzA4QkqeH1cfsSVZEBQpIkZWaAkCRJmRkgJElSZgYISSozbrQXJZIG4r0wJKnM2UfO4JGFyzlhr6l5l6JBSjiLstYMEJJUZvTIZq4586C8y5DqmocwJElSZgYISZKUmQFCktTwdt+mLe8Shh3nQEiSGt77D5nOG+u7OGLmpLxLGTYMEJKkhjeiuYlzj52ZdxnDiocwJElSZgYISZKUmQFCkiRlZoCQJEmZGSAkSVJmBghJkpRZVQNEREyIiBsiYkVxuSEixmfY/5qISBFxQfWqlCRJWVV7BOJGYH/gpOKyP3DDYHaMiNOAQ4FF1SlNkiRtrqpdSCoi9qQQGg5LKd1XXPcXwLyI2D2l9FQ/+24HXAmcCNxWrRolSdLmqeYIxOHAip7wAJBS+hWwAjiir50ioonCKMUXUkqPD/RNIqI1Itp7FsALokuSVGXVvJT1VGBJL+uXFLf15SKgE/jyIL/PHODi8pUrV64c5O6SJAmyfXZmDhARcQm9fGCXObj4b+rtKfpYT0TMAs4HDkwp9dqmF3OBy0oeTwOe3GGHHQa5uyRJKtMG9JsmNmcE4krgpgHaPAe8Ddiml21bA4v72O8oYArwQkT0rGsG/i0iLkgp7VS+Q0ppPbC+53FErAK2B1YNUGNWbcDCKj23NmZf15b9XVv2d23Z39m1MYgTGDIHiJTSUmDpQO0iYh4wLiIOSSndX1x3KDAOuLeP3W4Afly27ofF9V8fZH0JeGkwbbMoCTSrUkoeH6ki+7q27O/asr9ry/7eLIPqp6rNgUgpPRERdwDXRsRHi6u/Cny/9AyMiHgSmJNS+k5K6TXgtdLniYgNwCv9nbUhSZJqq9rXgTgDeBS4s7g8ApxZ1mZ3CqMSkiSpQVTzLAxSSsuAPxugTQywfadK1rQF1gOfoWS+harGvq4t+7u27O/asr+rJAZ/soMkSVKBN9OSJEmZGSAkSVJmBghJkpSZAUKSJGVmgJAkSZkZIAYhIs6LiAURsS4i5kfEUXnXVG8i4uiIuDUiFkVEiojTyrZHRFxS3L42In4eEXuXtWmNiCsiYmlErI6I/42I7cvaTIiIGyJiRXG5ISLGl7WZXqxldfG5vhwRI6v1f6+1iJgTEb+OiFURsSQivhsRu5e1sb8rJCLOjYhHImJlcZkXESeXbLevq6T4s54i4vKSdfZ3vUgpufSzAKcDHcA5wJ7A5cAbwPS8a6unBTgZ+GfgPRRulnZa2faLKFwe9T3APhTup7IIaCtp8xUK16x/B3AA8FPgIaC5pM0PKFyc7PDi8ihwa8n25uK6nxaf4x0ULm1+Rd59VMG+vgM4C9gb2A/4PvA8MNb+rkp/vws4BdituHyu+J6wt31d1X4/GFgAPAxc7s92/S25F1DvC3Af8JWydU8Ac/OurV4XygIEhTuwvgxcVLKuFVgOfLT4eFzxTfn0kjbbAl3AicXHexaf+9CSNocV1+1efHxycZ9tS9r8KbAOaM+7b6rU31sX++Bo+7tmfb4M+LB9XbX+3Qp4uvih/XOKAcL+rq/FQxj9KA5VzaJwGe5SdwJH1L6ihjUDmEpJP6bCXVR/wVv9OAtoKWuzCHispM3hwIqU0n0lbX4FrChr81hx3x4/pPAmM6ty/6W60nMp+GXFf+3vKomI5oj4U2AsMA/7ulr+HbgtpVR+c0X7u45U9VLWQ8BkCsNY5bcfX0zhh1iD09NXvfXjjiVtOlJKr/fSZmpJmyW9PP+SsjYbfZ+U0usR0cEQfM0iIoDLgLtTSo8VV9vfFRYR+1IIDKMoHMJ8d0rptxHR82FjX1dIMaAdSOEQRjl/tuuIAWJwyq/3Hb2s08A2px/L2/TWfnPaDBVXAm8D/qCXbfZ35TwF7A+MB/4Y+GZEHFOy3b6ugIjYAfgScEJKaV0/Te3vOuAhjP4tpXAMrDxtTmHTBKy+vVL8t79+fAUYGRETBmizTS/Pv3VZm42+T/E5Wxhir1lEXAGcChyXUlpYssn+rrCUUkdK6fcppQdSSnMoTOw7H/u60mZR6Jf5EdEZEZ3AMcAni1/3/D/t7zpggOhHSqkDmA/MLts0G7i39hU1rAUUfhnf7Mfi/JJjeKsf5wMbytpMozDLuqfNPGBcRBxS0uZQCnMAStvsU9y3xwkU7sQ3v3L/pfwUT2O7ksIs9ONTSgvKmtjf1RcUjoXb15X1E2BfCqM9PcsDwH8Wv34W+7t+5D2Ls94X3jqN82wKM3e/SOEY6I5511ZPC4VZ0/sXlwT8VfHr6cXtF1GYKf1uCr/IN9L7qVcvAm+ncNrUT+j91KuHKcyYPgx4hN5Pvfpx8TneXnzOIXPqFXBVsS+PofAXUs8yuqSN/V25/v48cBSwE4UPt89RGJmcbV/XpP9/zqancdrfdbDkXkAjLMB5wHO8lTyPzrumeluAYykEh/LlG8XtAVxC4RSsdRRmTe9T9hyjgCuA14A1wK3ADmVtJgL/QeE88JXFr8eXtZlO4doIa4rPdQXQmncfVbCve+vnBJxV0sb+rlx/X1fy+7+k+IEy276uWf//nI0DhP1dJ0sUO0mSJGnQnAMhSZIyM0BIkqTMDBCSJCkzA4QkScrMACFJkjIzQEiSpMwMEJIkKTMDhCRJyswAIUmSMjNASJKkzAwQkiQps/8PxbMsm6p3f6cAAAAASUVORK5CYII=",
+      "text/plain": [
+       "<Figure size 600x400 with 1 Axes>"
+      ]
+     },
+     "metadata": {
+      "needs_background": "light"
+     },
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "from matplotlib import pyplot as plt\n",
+    "onsets = detect_onset_of_audio(ori_audio, ori_sr)\n",
+    "plt.figure(dpi=100)\n",
+    "\n",
+    "time = np.arange(ori_audio.shape[0])\n",
+    "plt.plot(time, ori_audio)\n",
+    "plt.vlines(onsets, 0, ymax=0.8, colors='r')\n",
+    "plt.show()\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Method\n",
+    "The baseline is quite simple, and it has several steps:\n",
+    "- Take the original video, and apply self-trained video onset detection model to detect the onset\n",
+    "- Detect onsets of conditional waveform (encoded and decoded by our codebook) and clip single onset event from them as sound candicates\n",
+    "- For each onset of original waveform, replace with conditional onset event randomly and then generate sound"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 35,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "env: CUDA_VISIBLE_DEVICES=9\n",
+      "=> loading checkpoint 'checkpoints/EXP1/checkpoint_ep70.pth.tar'\n",
+      "=> loaded checkpoint 'checkpoints/EXP1/checkpoint_ep70.pth.tar' (epoch 70)\n"
+     ]
+    }
+   ],
+   "source": [
+    "%env CUDA_VISIBLE_DEVICES=9\n",
+    "import argparse\n",
+    "import numpy as np\n",
+    "import os\n",
+    "import sys\n",
+    "import time\n",
+    "from tqdm import tqdm\n",
+    "from collections import OrderedDict\n",
+    "\n",
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "from torch.utils.data import DataLoader\n",
+    "from torch.utils.tensorboard import SummaryWriter\n",
+    "\n",
+    "\n",
+    "from config import init_args\n",
+    "import data\n",
+    "import models\n",
+    "from models import *\n",
+    "from utils import utils, torch_utils\n",
+    "\n",
+    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+    "\n",
+    "\n",
+    "net = models.VideoOnsetNet(pretrained=False).to(device)\n",
+    "resume = 'checkpoints/EXP1/checkpoint_ep70.pth.tar'\n",
+    "net, _ = torch_utils.load_model(resume, net, device=device, strict=True)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 36,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torchvision.transforms as transforms\n",
+    "from PIL import Image\n",
+    "\n",
+    "\n",
+    "vision_transform_list = [\n",
+    "    transforms.Resize((128, 128)),\n",
+    "    transforms.CenterCrop((112, 112)),\n",
+    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
+    "]\n",
+    "video_transform = transforms.Compose(vision_transform_list)\n",
+    "\n",
+    "def read_image(frame_list):\n",
+    "    imgs = []\n",
+    "    convert_tensor = transforms.ToTensor()\n",
+    "    for img_path in frame_list:\n",
+    "        image = Image.open(img_path).convert('RGB')\n",
+    "        image = convert_tensor(image)\n",
+    "        imgs.append(image.unsqueeze(0))\n",
+    "    # (T, C, H ,W)\n",
+    "    imgs = torch.cat(imgs, dim=0).squeeze()\n",
+    "    imgs = video_transform(imgs)\n",
+    "    imgs = imgs.permute(1, 0, 2, 3)\n",
+    "    # (C, T, H ,W)\n",
+    "    return imgs\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 37,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# process videos into frames and read them\n",
+    "import glob\n",
+    "\n",
+    "save_path = 'demo-data/original_frames'\n",
+    "if os.path.exists(save_path):\n",
+    "    os.system(f'rm -rf {save_path}')\n",
+    "os.makedirs(save_path)\n",
+    "command = f'ffmpeg -v quiet -y -i \\\"{origin_video_path}\\\" -f image2 -vf \\\"scale=-1:360,fps=15\\\" -qscale:v 3 \\\"{save_path}\\\"/frame%06d.jpg'\n",
+    "os.system(command)\n",
+    "\n",
+    "frame_list = glob.glob(f'{save_path}/*.jpg')\n",
+    "frame_list.sort()\n",
+    "frame_list = frame_list[:2 * 15]\n",
+    "\n",
+    "frames = read_image(frame_list)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 38,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "inputs = {\n",
+    "    'frames': frames.unsqueeze(0).to(device)\n",
+    "}\n",
+    "pred = net(inputs).squeeze()\n",
+    "pred = torch.sigmoid(pred).data.cpu().numpy()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 39,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def postprocess_video_onsets(probs, thres=0.5, nearest=5):\n",
+    "    # import pdb; pdb.set_trace()\n",
+    "    video_onsets = []\n",
+    "    pred = np.array(probs, copy=True)\n",
+    "    while True:\n",
+    "        max_ind = np.argmax(pred)\n",
+    "        video_onsets.append(max_ind)\n",
+    "        low = max(max_ind - nearest, 0)\n",
+    "        high = min(max_ind + nearest, pred.shape[0])\n",
+    "        pred[low: high] = 0\n",
+    "        if (pred > thres).sum() == 0:\n",
+    "            break\n",
+    "    video_onsets.sort()\n",
+    "    video_onsets = np.array(video_onsets)\n",
+    "    return video_onsets\n",
+    "\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 40,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhAAAAFZCAYAAADJvxawAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAqaElEQVR4nO3de5xcdX3/8ddnN5vNhd1cCUmAQAj3i4BBruWqQcCfFG0tVYpFtCq0CrUXjG2VSpWfP1tFoRRFfnL5FQW1ahFUxDsQQFIggAEUwyUEEpKYC0l2k939/v6YWZhM9naSmTkzu6/n43Ee2Tnne2Y++e7uzHu/55zviZQSkiRJWTTlXYAkSWo8BghJkpSZAUKSJGVmgJAkSZkZICRJUmYGCEmSlJkBQpIkZTYq7wIqLSICmAmsz7sWSZIaUBuwLA0yUdSwCxAUwsPSvIuQJKmB7Qa8MFCD4Rgg1gM8//zztLe3512LJEkNY926dey+++4whFH84RggAGhvbzdASJJUJZ5EKUmSMjNASJKkzAwQkiQpMwOEJEnKzAAhSZIyM0BIkqTMDBCSJCmzmgSIiLgwIpZEREdELIyI4wdpf05EPBIRGyPixYj4akRMqUWtkiRpcFUPEBFxNnAF8CngcOCXwPcjYlY/7f8AuBG4DjgIeAfwBuAr1a5VkiQNTS1GID4CXJdS+kpKaXFK6WLgeeCCftofDTyTUvpiSmlJSulu4EvAETWoVZIkDUFVA0REjAbmAneWbboTOLaf3e4FdouIM6JgF+CPgdv7eY3WiGjvXSjcRUySJFVRtUcgpgLNwPKy9cuB6X3tkFK6FzgHuAXYDLwErAE+1M9rzAfWliwj706cGzZARGHZsCHvajQYv1/Dn9/j/Nj3NVOrqzDK7ykefawrbIg4EPgi8EkKoxenAbOBa/p57suBCSXLbhWoV5IkDaDad+NcCXSz7WjDNLYdleg1H7gnpfTZ4uNFEbEB+GVE/GNK6cXSximlTqCz93FEVKRwSZLUv6qOQKSUNgMLgXllm+ZRONehL+OAnrJ13cV/TQeSJNWBao9AAHwOuCkiHgQWAO8HZlE8JBERlwO7ppTeXWx/G3BtRFwA/BCYQeEy0AdSSstqUK8kSRpE1QNESumW4iRQH6cQBh4DzkgpPVtsMoNCoOhtf31EtAF/BfwbhRMofwJcUu1aJUnS0NRiBIKU0tXA1f1sO6+PdVcCV1a5LEmStJ28F4YkScrMACFJkjIzQEiSpMwMEJIkKTMDhCRJyswAIUmSMjNASJKkzAwQkiQpMwOEJEnKzAAhSZIyM0BIkqTMDBCSJCkzA4QkScrMACFJkjIzQEiSpMwMEJIkKTMDhCRJyswAIUmSMjNASJKkzAwQkiQpMwOEJEnKzAAhSZIyM0BIkqTMDBCSJCkzA4QkScrMACFJkjIzQEiSpMwMEJIkKTMDhCRJyswAIUmSMjNASJKkzAwQkiQpMwOEJEnKzAAhSZIyM0BIkqTMDBCSJCkzA4QkScqsJgEiIi6MiCUR0RERCyPi+EHat0bEpyLi2YjojIinI+L8WtQqSZIGN6raLxARZwNXABcC9wAfAL4fEQemlJ7rZ7dbgV2A9wK/BabVolZJkjQ0tfhQ/ghwXUrpK8XHF0fEm4ELgPnljSPiNOBEYK+U0uri6mdqUKckSRqiqh7CiIjRwFzgzrJNdwLH9rPbmcCDwN9HxAsR8VRE/GtEjO3nNVojor13AdoqVb8kSepbtUcgpgLNwPKy9cuB6f3ssxfwB0AH8Lbic1wNTAb6Og9iPvCJShQrSZKGplZXYaSyx9HHul5NxW3npJQeSCndQeEwyHn9jEJcDkwoWXarTMmSJKk/1R6BWAl0s+1owzS2HZXo9SLwQkppbcm6xRRCx27Ab0obp5Q6gc7exxGxgyVLkqTBVHUEIqW0GVgIzCvbNA+4t5/d7gFmRsROJev2BXqApRUvUpIkZVaLQxifA94XEedHxAER8XlgFnANQERcHhE3lrS/GVgFfDUiDoyIE4DPAv83pbSpBvVKkqRBVP0yzpTSLRExBfg4MAN4DDgjpfRssckMCoGit/0rETEPuJLC1RirKMwL8Y/VrlWSJA1NTSZnSildTeFKir62ndfHuifY9rCHJEmqE94LQ5IkZWaAkCRJmRkgJElSZgYISZKUmQFCkiRlZoCQJEmZGSAkSVJmBghJkpSZAUKSJGVmgJAkSZkZICRJUmYGCEmSlJkBQpIkZWaAkCRJmRkgJElSZgYISZKUmQFCkiRlZoCQJEmZGSAkSVJmBghJkpSZAUKSJGVmgJAkSZkZICRJUmYGCEmSlJkBQpIkZWaAkCRJmRkgJElSZgYISZKUmQFCkiRlZoCQJEmZGSAkSVJmBghJkpSZAUKSJGVmgJAkSZkZICRJUmYGCEmSlJkBQpIkZVaTABERF0bEkojoiIiFEXH8EPc7LiK6IuLhKpcoSZIyqHqAiIizgSuATwGHA78Evh8RswbZbwJwI/DjatcoSZKyqcUIxEeA61JKX0kpLU4pXQw8D1wwyH5fAm4GFlS5PkmSlFFVA0REjAbmAneWbboTOHaA/d4DzAH+eQiv0RoR7b0L0LYDJUuSpCGo9gjEVKAZWF62fjkwva8dImIf4H8D56SUuobwGvOBtSXL0u2uVpIkDUmtrsJIZY+jj3VERDOFwxafSCk9NcTnvhyYULLstgN1SpKkIRhV5edfCXSz7WjDNLYdlYDC4YcjgMMj4qriuiYgIqILODWl9JPSHVJKnUBn7+OIqFDpkiSpP1UdgUgpbQYWAvPKNs0D7u1jl3XAIcBhJcs1wJPFr++vSqGSJCmTao9AAHwOuCkiHqRwRcX7gVkUggERcTmwa0rp3SmlHuCx0p0jYgXQkVJ6DEmSVBeqHiBSSrdExBTg48AMCgHhjJTSs8UmMygECkmS1CBqMQJBSulq4Op+tp03yL6XApdWvChJkrTdvBeGJEnKzAAhSZIyM0BIkqTMDBCSJCkzA4QkScrMACFJkjIzQEiSpMwMEJIkKTMDhCRJyswAIUmSMjNASJKkzAwQklRm7cYt3HDvM6x6pTPvUqS6ZYCQpDIX3fIQn/jvxzn/hgfzLkWqWwYISSrzsydfBuCR59fkW4hUxwwQkiQpMwOEJEnKzAAhSZIyM0BIkqTMDBCSJCkzA4QkScrMACFJkjIzQEiSpMwMEJKkYWNzd0/eJYwYBghJ0rAw/78eZd9P/4LfTZqZdykjggGiDqSUSCnlXYYkNbSvPfAcANce+facKxkZDBA5Sylx9pfu4+wv32eIkCQ1jFF5FzDSrXxlMw88sxqAVRs2M3Wn1pwrkiRpcI5A1JHIuwBJGhYcza0FA0Qd8UdektQoDBA5C4cdJEkNyAAhSRpWwuHcmjBASJKkzAwQkiQpMwNEHXEaCEnacTcffnreJYwIBoiceQ6lJKkRGSAkSVJmBoicRcl1nMmZICRJDaImASIiLoyIJRHRERELI+L4Adq+PSJ+FBEvR8S6iFgQEW+uRZ158P4XkrRjFj67mkP/+c68yxhxqh4gIuJs4ArgU8DhwC+B70fErH52OQH4EXAGMBf4KXBbRBxe7VolSY3nT798H2s3bcm7jBGnFjfT+ghwXUrpK8XHFxdHFC4A5pc3TildXLbqYxHxh8BbgYeqWWgewqkoJWmHbOl2JDcPVR2BiIjRFEYRyseW7gSOHeJzNAFtwOp+trdGRHvvUmwrSRXx48XL8y5BqkvVPoQxFWgGyn8DlwPTh/gcfwOMB27tZ/t8YG3JsjR7mfnZavzBEC3Vna/e80zeJUh1qVZXYZR/NEYf67YREe8ELgXOTimt6KfZ5cCEkmW37S9TkiQNRbXPgVgJdLPtaMM0th2V2Erx5MvrgHeklO7qr11KqRPoLNlvu4uVpHK+pUh9q+oIREppM7AQmFe2aR5wb3/7FUcergfelVK6vWoFStIgFj77+7xLkOpSLa7C+BxwU0Q8CCwA3g/MAq4BiIjLgV1TSu8uPn4ncCNwEXBfRPSOXmxKKa2tQb3SkHT3JJqb/PN0uNu4uTvvEqS6VPVzIFJKtwAXAx8HHqYwz8MZKaVni01mUAgUvT5AIdj8O/BiyfKFateah9Lh0Qee6fNCE9Whv//mIxz+yTtZ9Urn4I0laRiqyUmUKaWrU0p7ppRaU0pzU0q/KNl2XkrppJLHJ6WUoo/lvFrUmqe/unnYTXMxbN364FLWdXTx9V89n3cpkvrgRW3V570wchbej7OhORW5VJ86m1vyLmHYM0DkzBtoNTbzg1SfFs3YN+8Shj0DRJ3r7Or2r9w65ndGytcLazb1uf5PzvlMjSsZeQwQOevrEMbffeMR3nfDr3hpbQdzL7uLv7n1kRwqk6T697e+P+bGAFFnunsS31i4lLsWr+DTdyzmlc4u/uuhF7jj0RfzLk19cHBoZLnh3mfY86O389yqjXmXIuD51RtZ8LtVeZcxYhkg8lY2ADHv8z9/9evSobkL//N/alWRMhjKOSzfffgFjvr0XTz8/JrqF6Sq+sR/Pw7ACZ/9ac6VCOChQX6nfv60l8ZXkwGizvzu5Q2vfv3YC86bVe+GMgJx0dcfZvm6Ti78fwurX5A0ggx2Dduf37yoJnWMVAaInN39m5X9bnN0fHjp6vE7Opw8+dL6vEsY8Zq8UUmuDBA5W7R0Tb/bNnf11K4QbZcskcA3u+HlzVf8YvBGqpqUEn95s4d282SAkHZEhrMozQ+Nzcup68t1dy/Ju4QRzwCRNz9UGto3Fy7ln77zGD1DODzht7qx3bbIK6HqyT2/7f/wr2rDAJEzp7JubMvWdnDTfc9y1+Llg7YNhyAa2n1eLlhXHA/KnwFCqoD337SQHw8hRKgxbdrczc33P5d3GSrhEaX8GSCkCnnvDQ8OuL3J37aGdcDHf5B3CSpjfsifb2k5e271hsEbaVh4fnXfc/ZLUiMyQOTsjkdfyrsESWo4z63yj6+8GSAkSQ3nmSHej+Q/738WgI4t3Ty+bK2X41aQAUKSNGz9w7cfY33HFt517X285Yt3852HX8i7pGHDACFJGtYOufRO/ue5NQDc8qvn8y1mGDFANJC3Xnk3tz2yLO8yJEkyQDSSR19Yy4e+9lDeZUhSw3LyvsoxQEiSpMwMEJKkEcMZ5SvHACFJkjIzQEgV9NTy9XmXoBpyTgGNZAaIBvTcECdQUe2d+vlf5F2Caqj30kBpJDJANKATPvtT//KR6sAf/ce9bOjsyrsMZXDv06vY0t2TdxnDggGiQT398iv9buvY0s2tv3qe5es6aliRNDId9Ikf0t3Tf6D/7qPL+emTK2pYkQbzjQeX5l3CsGCAaFAdW/pP0F/88W/4+28t4m3/fk8NK1Kvi77+kCNEI8zXf/Vcn+ufb5/GRd9ZzHu++qsaV6SBPPrC2rxLGBYMEA1qoM+nnzxR+Gtn2VpHIPLw3YeXMedjd7Bxs0PbI8U/fPsxvvPQtvdYeHmnSTlUo8F87YG+A5+yMUA0qMUvrePzP3qKPT96O2s7tmy1LbzQOXc9CS773uJt1v92hVdpDFcX3/LwNiNPyVkPNYwZIBrU339zEV/48W8AOPSzhUMVD+x2EC+u66C7xxOEqmnF+qGN7Hztgee4/p4lW60797oHtmn30ydW8PkfPeVhj2FgxfrOrVeU5Ie1G7cO+lKjM0AME7ftfzx/cs5nOOYL9/HU8tdOsPRDqfKefGnoowiX3vbrrR6/2Mdhpfdc/yu+8OPfcOevl+9wbcpXb6jvVToCcegn76x1OVJVGSCGiQ/94SV9rp89/w4WLV1T22KGuc4BTmAdir/81uN0R+FXr/Twk1fNNL4X12za6nH5IYyBrtaQGo0BYgQ486p7WLNxc95lqOj2X7/Mz/aaC8Db/u//vLrec1ca30+ffHmrx11NzVs9ftvV97B6w2Z6DBI7xJHV+mCAGCEO++SP+PWydXmXMSzc/duVO/wcL48vnJ3/u1Ulf7H6pjjsdDdt/Ra7aOlaXn/Zj9jrY3c4GrED7Lr6YIDI0Q8ee6mmr3fGF3/pVQAVUIlDQh89/cN0jBq91bp/+u7jXvo5DHT3JJas2sjCmfvTPMAJzZ+87fEaVjW89FQgbO/50dv5j589zb/d+SSrNzhCuz1qEiAi4sKIWBIRHRGxMCKOH6T9icV2HRHxu4j4YC3qrLWLvv5QzV/ztkderPlrDjeVuv/B/n/zX9usO/DjP+TLv3iaVzq7SClx/+9W8X9+8ASrN2ymY0t3RV5X1TXnY3dw8tUP8Efn/isvtk/tt90NC571vJftlOVE5oF85gdPcOVPfsvrL/tRRZ5vpIlqH0uKiLOBm4ALgXuADwDvAw5MKW0zm0dEzAYeA64FvgQcB1wNvDOl9K0hvF47sHbt2rW0t7dX7P9RSS+v7+SR59fwvhsfzLsU5p++P7u0j+HBZ1ezsbObC06aw14770Rz08g9Hv/1B55jv+ltHD6r70mA9vzo7TWu6DWHz5rIf11wbMOeL5FS4uYHnuPEfXdmt0njMu37lV/+jsu//wS7ThzLc6u3vqHc9e95AyftN60iNa7r2MLrLq3dFROPXnoqbWNaqvb8y9d1cMejL/Lnx+xJ0w7+Xl/2vV8zqimYf8YBFaouu56exF4fu6Piz/vG/afxN6fux4wJY5g0vjA6mFLib7+xiO8+/AJdxeMmR86ezLwDduFTdyzm+H2mcv17jtzm/bKnJ+1wX+dl3bp1TJgwAWBCSmnA4961CBD3A/+TUrqgZN1i4Dsppfl9tP8McGZK6YCSddcAh6aUjumjfSvQWrKqDVhayQDx4DOr+eNrFlTkuRrNcXtP4ZXObjq3dPNEMfWPaWni+H12pq11FACvdHaxaUs3PSnRFMHMCWPp6kms79jCuNHNjB3dzPqOLu7+7UqaI2gf28L6ji6O2GMSLaOaWL6ug2ltrbSOaqarp4eUoCkgAU0RbOnuYVRT0BRB709r78/ta48pe7z19tIHqfjFq/sk+MHjtT2cVGlnHDKdji09rNm4mRkTx9IUwejmJlJK9KRETyrMX9E6qpkp40ezvrOLSeNaWLepi1c6u5g8fjSjmoLulOjuSbQ0N9HcFHR19xARdPWkrYeNt/rytQe9TZat2cQjSweeLvjYOVOY1tZKKu5X+LfwbN3dqSbfk6P3msyKdZ38buWGqr/W9tqpdRTHzplCa0szXd09tDQ38cjSNTxbobvyTt1pNCtf2XYI/4xDpnPHo5X5Hhy8azt7TBlPS1MUPlhLfl9/s2I9T6/YwKYt3Rw1ezK7TRpHS3PwwppN/PI3K9lr5/FMHNsybO98+pbXzaA5guamIKVERBR/bwv9E0BE4b0Q+n5ve9eRszhqrykVqaduAkREjAY2Au9IKX27ZP0XgMNSSif2sc8vgIdSSheVrHsbcCswLqW0paz9pcAnyp+nkgHizKvuZtEgb4aSJOXhs3/8Ot5xxO4Vea4sAWJURV6xf1OBZqB8hpzlwPR+9pneT/tRxecrP4h/OfC5ksdtQEVvtfb5sw/j9Ct+yeYRdgvYo2ZPZvfJ45jW1sqYlmYefPb3rNu0hbl7TGLGhDF09SR+v3EzazduYcaEwrDygTPb2bS5i6amoK11FJu2dLNxczdNEXz7oRdoHdXE7Knj+dUzv+dP37A7v9+4mTk770RPSmzpTrQ0x6sJHAonS41qaqK7J736l24Ur60vH8XvHdaPVx+Xfr31PkFJI+D6e5bw9Mv1+1foYP7xLQewcXPhHImW5iZGj2piS3cPATQ3Bd09ic1dPXSnxNpNW9i5rZUtXYmNm7uYOXEsW7p76OpJNEVh/67uRHfxe9AU0NzURHNJf5ceQtmmT4HOrh7+5fZtp/Iu9eFT9qZtTEvh+1Tyvev9vpVPwlVp40Y3c+CMdvbZpa2u741w1mEzmTlxLJPGjWb0qCY6tnTz48UreOCZ1UzdqZWVr3QO/iSDPP93Hl4GFEYXT9l/GsvWdHDmoTO5f8kqfv3iOp5fvWmQZxnYh0/Zm0njR7O5q/Ae+trPTLBhcxeLlq7lmZUbOHL2ZPacOp7NXT00Bfzw8eXsP72NCWNbuHHBs8PuPfj0g6dzxJ6TCV6bIySRCGKr97fev/P7O3J56O4Tq1pnf6o9AjETeAE4NqW0oGT9PwDnppT272Ofp4CvppQuL1l3HHA3MCOlNOCYWiOcAwG1P87al+P3mcqHTtmHg3dtZ9maTTz4zO857eDpTBw3evCdh7FnV21g4tjRTBjX93HpPM+BuPStB3LecbNze/0dlVLi50+9zD67tLHrxLGZ9l384jpeWtfBbhPHctnti/nFUy9z6O4TueS0/Th2Tv8nK2ZVrWPsfTl6r8l8/f3bHJmtqJQSazZuefW4/o5YvWEzLc1R1XM2hqIav4N//aZ9eeeRu7NzWysRwfOrN/JKZxcvr+/kkm8tYnzrKJ5bvZGT9t2Zt79+N6a1t7K5q4ejK3TooF7U0wjESqCbbUcbprHtKEOvl/pp3wWsqmh1OWof08J1f34E772htidSPv3pM/o8QXLvaW3sPa2tprXUqz2mjM+7BB75+Km0jRlFArp6eli2poPJ40czYWy+b9w7KiK2+2THA2a0c8CMwh8FN55/ZCXL2kqlTn77uzfvx5iWZi77Xv8jKV/6syMq8loDiYiKhAeAyRV6nh3VNmYU6zsqc8nzu46axQdO2Gub3/vdJxdO8j1gBiyY/8aKvNZwU9XLOFNKm4GFwLyyTfOAe/vZbUEf7U8FHiw//6HRnVyhs8aH6q/ftO+IvrqiUi44ac4OP8frZ01kyeVn8IY9t77S44KT5vD0p89gwrgWmpoKJ1a1jmpm9tTxDR8eRoon/+U0Hv/nN/OXJ+/NW183o9923/vQH/Q7yqWBjRvdPHijIfr02w6piz8aGlEt5oH4HPC+iDg/Ig6IiM8Ds4BrACLi8oi4saT9NcAeEfG5YvvzgfcC/1qDWmuq1pf5XHjyjn/wiYrcoPmbHyxcivmJtx601fpLTtvfkNfgWkc1M754hdK09jF9trn7kpM5eNcJtSxrWGkdVZkAseTyMyryPCNVtQ9hkFK6JSKmAB8HZlCY4+GMlNKzxSYzKASK3vZLIuIM4PPAXwLLgA8PZQ4I9e3YOVM4Zf9ptDQ78WglVGIKht7wOKbltTfCtjFV/3VUlf3LWQcPuP2ysw7m3KP3qFE1w9fYlh0PEE9cdlrDzqdSL2ryjpVSuprCZFB9bTuvj3U/B15f5bKGnQ+dsjdX/uS326y//j1HMnqU4aFSoiJjEAXrS+7G+f2LBpygVQ2gpXngnw3DQ2WMqcAhjDEVCCEjnX/yNKi21lGs73ztJKIrzj6Msw7flQ+cOIeDP/HDV9dffc7rDQ8VtiNHGJ647DRaS74fmza/Nj11u+c4NLwzD9017xJGhHZH6+qCnywNauL4Fk49cBcArn33EZx1eOGNq7UsLBxXwcvbVJDlMteDZm59KfGYluathk13nfTapYztOV8apx03toIn96l/n/zDgQ8VDb7/QYM30qCMcQ3qmj+by/7T23lpXcdW19O3NDdx1bsO569uLtyoq3mQIVVlN3vngc/YfuP+07juvDewbM0mpu7Uyr7/+P1+2+4xZTxfPncuO7e19ttG0tZmT92xqybefcyelSlkhDNANKiDZhbO4O5rMp5DSs7uHuUZ/RW3x+SBbwJ13XlvAGDmECdKOvWg/iZl1XBw7Jwp3Pv0Kv547m55lyJVlAFiGNp90jiO3HMy7WNHbXNIQzuuv0vzNDJd/KZ9eHFNB4fs1vdlmdecO5dfPPUyb9x/lxpXJlWXAWIYamoKbv1gdafHlQRXvetw3nzQ9AEvkW4f08L/et3MGlYl1YYBogHd9N7qTeMraegMBo3nr07eO+8Shg3HtxvQjp5ApB1TzRvQSaquv33zfnmXMGwYIBqQs6fla6D44EmrkkYKA4RUQXd95MS8S5DUj29dcGzeJQwrngPRgBxCz9dA3b+nh5ekunP+cbN511Gz2HvaTnmXMqw4AtGAdmo190nSUO27y06GhyowQDSYL587N9NUyqofp+w/Le8SpBHJSbyqwwCRs+kZJyVy1sL8jc94v4Pe2UJPP9jvnZSHUQPM06Ht51h4zmZNHsdL6zryLkMZZH0zuuPDx/P4i2s5evaUKlUkSbVnLMvZgWV3a9TwM2FcC8fOmUqTl3hKGkYMEJIkKTMDRM6cE0pqTB8+xSmRG8E7PIGyajwHokHMO3AXzjlqVt5lSCr6yKlOidwI/uEtB+RdwrBlgMjZUOeEuvbdR1S3EO2QI2dP5rg5U/MuQ1XyziNn8bUHnsu7DG0HL3uvHgOEVAG3fsDbpw9nHmqUtuU5EJI0CPODtC0DRB06/eDpfMgTtKS64QiEtC0DRJ0689CZeZegAdxw/pF5l6AaCscgpG0YIHLW1182KcE+u7TVvhgN2Yn77kxLsx8qI4UjEI1pWltr3iUMawaInA10FcaEsS21K0SZ+VfpyOF3ujEdtvvEvEsY1gwQdWhLd0/eJWgIJowz4En1zJGj6jJA1KELT/YEykbwJ0c4w91IEX4SSdtwHog6c9dHTmTvaTvlXYaG4EOn7MOkcaM5ef9peZeiKjM/NCYPM1aXAaLOtI/xW9IoxrQ0877j98q7DNVAW6u/l41oJ99Pq8pDGDnzLxup/v3FCXtxzF5T8i5DGV1y2v55lzCsGSBydsFJc7Z63F5y5YXhQqoPbWNa+Nr7j867DGW0s5dxVpUBImfT2sZs9XhMS3NOlUiSNHQGiDq2ZuOWvEuQpIb0n+87Ku8Shj0DhCRp2Dlu76l5lzDsGSAkSVJmVQ0QETEpIm6KiLXF5aaImDhA+5aI+ExEPBoRGyJiWUTcGBEj4s5S40Y3D/hYkqR6Ue0RiJuBw4DTisthwE0DtB8HvB64rPjv24F9gf+uZpH14tg5Ww+5jfWESqmu7DZpbN4lSHWjagEiIg6gEBrel1JakFJaAPwF8L8iYr++9kkprU0pzUsp3ZpSejKldB/wIWBuRMyqVq31wss2JWlozjmq8JFw9hG7b7Nt/unO/1AL1RyBOAZYm1K6v3dFMRCsBY7N8DwTgASs6WtjRLRGRHvvAjTsfbDTQLfmlFQzb3R68rp36ZkH8Y0PHsNlZx28zbb3n+AMsbVQzQAxHVjRx/oVxW2DiogxwP8Gbk4preun2XwKoaR3WZq91Ppw1+K+uktSrc2eOj7vEjSIluYm3rDnZEaP8lqAvGTu+Yi4NCLSIMsRxeZ9/Ukd/awvf50W4OvFGi8coOnlFEYpehdvkShJUpVtz51GrqLwwT6QZ4DXAbv0sW1nYPlAOxfDw63AbOCUAUYfSCl1Ap0l+w5SmiRJ2lGZA0RKaSWwcrB2EbEAmBARR6aUHiiuO4rCKMG9A+zXGx72AU5OKa3KWuNwYRaSJNWrqh08SiktBn4AXBsRR0fE0cC1wPdSSk/2touIJyLibcWvRwHfBI4AzgGaI2J6cRldrVolSVI21T775BzgUeDO4rIIOLeszX4URiWgcP7CmcV/HwZeLFmyXLkhSdvN0T9pcNtzDsSQpZRWA382SJso+foZCidZSpKkOub1L5IkKTMDRB1zXilJys6r8WrDAFFHLjnN6VclSY3BAFEH7vrICVz+9kOcflWStsMt7z867xJGpKqeRKmh2XtaG3tPa9hbeEjDmiOD9e+ovabkXcKI5AiEJA3ggpPm5F2CVJcMEJJUpr+T8ObuMQmAlmZP0pM8hFHHvAhDykfq5xKoT/7hwewxZTxnHTazxhVJ9ccAIUlDNGFsCx+Zt2/eZUh1wUMYkiQpMwOEJEnKzABRxz54YmFeiLccMiPnSiRJ2prnQNSxvzh+L47fZ2f2mbZT3qVIkrQVRyDqWERwwIx2RjX7bZKkgRy8a3veJYw4fjJJkhpeW2tL3iWMOAYISSrj3Rwbz66TxuZdwohjgJAkNbxxo5vzLmHEMUBIkqTMDBCSJCkzA4QkqeH1c/sSVZEBQpIkZWaAkCRJmRkgJElSZgYISSozYayTEkmD8V4YklTm/ONms2jpGk49cHrepWiIEp5FWWsGCEkqM3Z0M18694i8y5DqmocwJElSZgYISZKUmQFCktTw9tulLe8SRhzPgZAkNbx3HjmLVzq7OXbOlLxLGTEMEJKkhjequYkLTpqTdxkjiocwJElSZgYISZKUmQFCkiRlZoCQJEmZGSAkSVJmBghJkpRZVQNEREyKiJsiYm1xuSkiJmbY/0sRkSLi4upVKUmSsqr2CMTNwGHAacXlMOCmoewYEWcBRwHLqlOaJEnaXlWbSCoiDqAQGo5OKd1fXPcXwIKI2C+l9OQA++4KXAW8Gbi9WjVKkqTtU80RiGOAtb3hASCldB+wFji2v50ioonCKMVnU0qPD/YiEdEaEe29C+CE6JIkVVk1p7KeDqzoY/2K4rb+XAJ0AV8c4uvMBz5RvnLdunVD3F2SJEG2z87MASIiLqWPD+wybyj+m/p6in7WExFzgYuA16eU+mzTh8uBz5U8ngE8sfvuuw9xd0mSVKYNGDBNbM8IxFXA1wdp8wzwOmCXPrbtDCzvZ7/jgWnAcxHRu64Z+LeIuDiltGf5DimlTqCz93FErAd2A9YPUmNWbcDSKj23tmZf15b9XVv2d23Z39m1MYQLGDIHiJTSSmDlYO0iYgEwISKOTCk9UFx3FDABuLef3W4C7ipb98Pi+q8Osb4EvDCUtlmUBJr1KSWPj1SRfV1b9ndt2d+1ZX9vlyH1U9XOgUgpLY6IHwDXRsQHiqu/DHyv9AqMiHgCmJ9S+nZKaRWwqvR5ImIL8NJAV21IkqTaqvY8EOcAjwJ3FpdFwLllbfajMCohSZIaRDWvwiCltBr4s0HaxCDb96xkTTugE/hnSs63UNXY17Vlf9eW/V1b9neVxNAvdpAkSSrwZlqSJCkzA4QkScrMACFJkjIzQEiSpMwMEJIkKTMDxBBExIURsSQiOiJiYUQcn3dN9SYiToiI2yJiWUSkiDirbHtExKXF7Zsi4mcRcVBZm9aIuDIiVkbEhoj474jYrazNpIi4KSLWFpebImJiWZtZxVo2FJ/rixExulr/91qLiPkR8auIWB8RKyLiOxGxX1kb+7tCIuKCiFgUEeuKy4KIOL1ku31dJcWf9RQRV5Sss7/rRUrJZYAFOBvYDLwPOAC4AngFmJV3bfW0AKcD/wK8ncLN0s4q234JhelR3w4cTOF+KsuAtpI2/0Fhzvo3AYcDPwEeBppL2nyfwuRkxxSXR4HbSrY3F9f9pPgcb6IwtfmVefdRBfv6B8B5wEHAocD3gGeB8fZ3Vfr7rcAZwL7F5VPF94SD7Ouq9vsbgCXAI8AV/mzX35J7AfW+APcD/1G2bjFwed611etCWYCgcAfWF4FLSta1AmuADxQfTyi+KZ9d0mYm0A28ufj4gOJzH1XS5ujiuv2Kj08v7jOzpM2fAh1Ae959U6X+3rnYByfY3zXr89XAe+3rqvXvTsBTxQ/tn1EMEPZ3fS0ewhhAcahqLoVpuEvdCRxb+4oa1mxgOiX9mAp3Uf05r/XjXKClrM0y4LGSNscAa1NK95e0uQ9YW9bmseK+vX5I4U1mbuX+S3Wldyr41cV/7e8qiYjmiPhTYDywAPu6Wv4duD2lVH5zRfu7jlR1KuthYCqFYazy248vp/BDrKHp7au++nGPkjabU0q/76PN9JI2K/p4/hVlbbZ6nZTS7yNiM8PwexYRAXwOuDul9Fhxtf1dYRFxCIXAMIbCIcy3pZR+HRG9Hzb2dYUUA9rrKRzCKOfPdh0xQAxN+Xzf0cc6DW57+rG8TV/tt6fNcHEV8DrgD/rYZn9XzpPAYcBE4I+AGyLixJLt9nUFRMTuwBeAU1NKHQM0tb/rgIcwBraSwjGw8rQ5jW0TsPr3UvHfgfrxJWB0REwapM0ufTz/zmVttnqd4nO2MMy+ZxFxJXAmcHJKaWnJJvu7wlJKm1NKv00pPZhSmk/hxL6LsK8rbS6FflkYEV0R0QWcCHy4+HXv/9P+rgMGiAGklDYDC4F5ZZvmAffWvqKGtYTCL+Or/Vg8v+REXuvHhcCWsjYzKJxl3dtmATAhIo4saXMUhXMAStscXNy316kU7sS3sHL/pfwUL2O7isJZ6KeklJaUNbG/qy8oHAu3ryvrx8AhFEZ7epcHgf8sfv077O/6kfdZnPW+8NplnOdTOHP38xSOge6Rd231tFA4a/qw4pKAvy5+Pau4/RIKZ0q/jcIv8s30fenV88AbKVw29WP6vvTqEQpnTB8NLKLvS6/uKj7HG4vPOWwuvQKuLvbliRT+Qupdxpa0sb8r19+fBo4H9qTw4fYpCiOT8+zrmvT/z9j2Mk77uw6W3AtohAW4EHiG15LnCXnXVG8LcBKF4FC+XF/cHsClFC7B6qBw1vTBZc8xBrgSWAVsBG4Ddi9rMxn4fxSuA19X/HpiWZtZFOZG2Fh8riuB1rz7qIJ93Vc/J+C8kjb2d+X6+7qS3/8VxQ+UefZ1zfr/Z2wdIOzvOlmi2EmSJElD5jkQkiQpMwOEJEnKzAAhSZIyM0BIkqTMDBCSJCkzA4QkScrMACFJkjIzQEiSpMwMEJIkKTMDhCRJyswAIUmSMvv/eh8iVEVmnFkAAAAASUVORK5CYII=",
+      "text/plain": [
+       "<Figure size 600x400 with 1 Axes>"
+      ]
+     },
+     "metadata": {
+      "needs_background": "light"
+     },
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "from matplotlib import pyplot as plt\n",
+    "# video_onsets = (np.nonzero(pred > 0.5)[0] / 15 * ori_sr).astype(int)\n",
+    "video_onsets = postprocess_video_onsets(pred, thres=0.5, nearest=4)\n",
+    "video_onsets = (video_onsets / 15 * ori_sr).astype(int)\n",
+    "plt.figure(dpi=100)\n",
+    "\n",
+    "time = np.arange(ori_audio.shape[0])\n",
+    "plt.plot(time, ori_audio)\n",
+    "plt.vlines(video_onsets, 0, ymax=0.8, colors='r')\n",
+    "plt.show()\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 41,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "array([-0.06068027, -0.0599093 , -0.05623583, -0.01206349])"
+      ]
+     },
+     "execution_count": 41,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "(onsets - video_onsets) / ori_sr\n",
+    "# video_onsets"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 42,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_onset_audio_range(audio_len, onsets, i):\n",
+    "    if i == 0:\n",
+    "        prev_offset = int(onsets[i] // 3)\n",
+    "    else:\n",
+    "        prev_offset = int((onsets[i] - onsets[i - 1]) // 3)\n",
+    "\n",
+    "    if i == onsets.shape[0] - 1:\n",
+    "        post_offset = int((audio_len - onsets[i]) // 4 * 2)\n",
+    "    else:\n",
+    "        post_offset = int((onsets[i + 1] - onsets[i]) // 4 * 2)\n",
+    "    return prev_offset, post_offset\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 43,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# ori_onsets = detect_onset_of_audio(ori_audio, ori_sr)\n",
+    "con_onsets = detect_onset_of_audio(con_audio, con_sr)\n",
+    "\n",
+    "np.random.seed(2022)\n",
+    "gen_audio = np.zeros_like(ori_audio)\n",
+    "for i in range(video_onsets.shape[0]):\n",
+    "    prev_offset, post_offset = get_onset_audio_range(int(con_sr * 2), video_onsets, i)\n",
+    "    j = np.random.choice(con_onsets.shape[0])\n",
+    "    prev_offset_con, post_offset_con = get_onset_audio_range(con_audio.shape[0], con_onsets, j)\n",
+    "    prev_offset = min(prev_offset, prev_offset_con)\n",
+    "    post_offset = min(post_offset, post_offset_con)\n",
+    "    gen_audio[video_onsets[i] - prev_offset: video_onsets[i] + post_offset] = con_audio[con_onsets[j] - prev_offset: con_onsets[j] + post_offset]\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 44,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhAAAAFZCAYAAADJvxawAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAtHklEQVR4nO3dd5wU9f3H8feXXu/ovRcBQUGqKAoq9phgi5piSDQx0SSaYpT8YtdoTKImxvZTfzGxJJpixYKioigSQQQpgiIQ6tHvDo47rnx/f+zesbfsleFm5ju7+3o+HvPgdnZm9sP37nbf953vfMdYawUAAOBFI9cFAACA9EOAAAAAnhEgAACAZwQIAADgGQECAAB4RoAAAACeESAAAIBnTVwX4DdjjJHUQ1Kh61oAAEhDbSVtsnVMFJVxAUKx8LDBdREAAKSxXpI21rZBJgaIQklav369cnJyXNcCAEDaKCgoUO/evaV69OJnYoCQJOXk5BAgAAAICIMoAQCAZwQIAADgGQECAAB4RoAAAACeESAAAIBnBAgAAOAZAQIAAHhGgAAAAJ4RIAAAgGcECABwpI57FSEgtLs/CBAA4MAby/M06ubX9eanea5LySrFpeU66fdzdOXfF7kuJe0RIAAgZNZaXfrXBcrfV6rvPLbAdTlZ5Zp/LdEX2/fq+Y83uS4l7REgACBka3cUuS4hayUGh91F+x1Wkv4IEAAQkLdWbtXa7XsPWr9zb4mDarLHf3cU1evU0K0zV4RQTeYyQQ4mMcYcL+lqSWMkdZd0trX2uTr2mSzpLknDJW2SdKe19kEPr5kjKT8/P5/beQNwZt7qHbro4Q8kSY9fMl4927XUuh1FOv6wzhr4y5dT7nPsoI56+OKxatWsSZilZpx+186UJLVs2lgLr5uqWcvydPLhXfV/c9fo96+vSrnPGz+drEFd2oRZZiQVFBQoNzdXknKttQW1bRt0gDhd0rGSPpL0L9URIIwx/SUtlfSwpIfi+94v6SJr7b/q+ZoECADO3Tv7sxo/rOry/BXHamTvdv4WlEUqA4RXQ7u11atXHe9zNenFS4AINOZaa1+R9IokGWPqs8v3Jf3XWntV/PEKY8xYST9XLIAcxBjTXFLzhFVtD7VeICwVFVbH3fmWOrVppmcvP1aNGtXr9wNZ4iv3vaeJAzpqdN92uvrUoa7LyRqfbinU8x9v1D1vfKZnLz9G7Vo1c11SpEVtDMRESbOS1r0maawxpmkN+8yQlJ+wbAiuPMAfT8xfp42792nxhnwt21RryEeWmvfFDt331mqtyit0XUpWufLvH2vN9r0adfPrrkuJvKgFiG6Skke+5CnWU9Kphn1ul5SbsPQKrDrAJ4kjwSuY1Aa1OP/Bea5LyFrrd3K1TG2iFiAkKfnd1NSwPrbS2hJrbUHlIom4DiBj5O8rdV1C1rqNqzRqFbUAsUWxXohEXSSVSdoRfjkAgGxVVkHvYG2iFiDmSTo5ad0pkhZYa4nhANJG/caNA+kr0ABhjGljjBlljBkVX9U//rhP/PnbjTF/TdjlQUl9jTF3GWOGGWO+I+kSSb8Lsk4AAOBN0LOVjJX0VsLju+L//kXSdMUml+pT+aS1do0x5gxJd0u6QrGJpH5c3zkgAABAOIKeB+JtHRgEmer56SnWzZE0OriqAPd2JczBT1c3gHQUtTEQQFb4YtvB90cAgHRCgACAANRz9l0gbREgAMeYRwqIKn45a0OAABx7ZO4a1yUAgGcECMCxFxdvqnsjAIgYAgQARFx+EfPoufDGiq2uS4g0AgQARNwT89e5LgE4CAECACLOMtIWEUSAAAAAnhEgAACAZwQIAADgGQECAAB4RoAAAACeESAAIAB+3grjxcWb/TsYPCmv4AqYmhAgACDiVuYVui4ha72+fIvrEiKLAAEAQA0KistclxBZBAgAAOAZAQIAAHhGgAAAAJ4RIAAgAM98uN51CVmpoJg7l4aFAAEAAVi7o8h1CVnpTW7BHRoCBAAA8IwAAYTs5/9Y7LoEAGgwAgQQsn8u3OC6BARswy5OX7hQtL9MeQXFrsvIGk1cFwAAmWThup0694F5rsvISmNvfUNF+8tdl5E16IEAAB/9c+FG1yVkLcJDuAgQAADAMwIEEKLte0pclwAAviBAACG67rmlrktAwPy8jTcQZQQIIERfbNvrugQA8AUBAgB89NG6Xa5LyEoVFdZ1CVmHAAEAPvp0S6HrErLSFuZ/CB0BAgAAeEaAAAAAnhEgAACAZwQIAADgGQECAAB4RoAAAACeESCACLCWa9gBpBcCBBABs1dsdV0CkNbeX73DdQlZhwABhKim+yS8uGRTuIUAGeZfCze4LiHrECDSXElZuT5ev5tpXNMEZyoAZAoCRJr70VOLNO2+9/TAnNWuS0E9WJEgcOhKyytcl5CVSssrGKeUAgEizc1anidJenTuGseVAAjSc4s2avD/vKIXFnO6K0xFJWUac8vr+uaj/3FdSuQQIAAgDVz19MeSpB//bZHbQrLMnFXbVFBcprmfb3ddSuQQIIAIqGFsJZDSY+/R4xiWt1Zuc11CZBEg0thWbl+bdmo6jcrZVXhx44vLXZcQOTVd4YTgECDS2Ml3v1P1dQUDfNLCZ1v3uC4ByEi8BYaPAJHG8veVVn1N+E5vO/fud10CkNbogQgfAQKIgHc/Y4AW0BCfbil0XULWIUAAgE/W7djruoSsRS9e+AgQGYLTf4B7fIghmxAgAACAZwSIDMEIZABAmEIJEMaYy40xa4wxxcaYhcaY42rZdooxxqZYhoZRKwAcCmutPtmYH9rrFe0vC+21oo45cdwIPEAYYy6QdI+k2yQdJeldSa8YY/rUsesQSd0Tls8CLBMAGuTNT7fq+ueXhfZ6h1//Gjd4ihv/69muS8hKYfRA/FTSo9baR6y1K6y1V0laL+kHdey31Vq7JWEpD7xSwKG9JWXaWshfUunqlaVbQn/NsgoCBNwJNEAYY5pJGiNpVtJTsyQdU8fui4wxm40xs40xJ9TyGs2NMTmVi6S2DasacGPkTbM0/rbZ2lZY4roUAKhT0D0QnSQ1lpSXtD5PUrca9tks6XuSzpV0jqSVkmYbY46vYfsZkvITlg0NrDkt0ZWZ/ir/mly8frfbQnBImAgR2aZJSK+T/OlmUqyLbWjtSsVCQ6V5xpjekn4u6Z0Uu9wu6a6Ex22V4SHi/dXb9exHG6utIz4AAMIUdIDYLqlcB/c2dNHBvRK1+UDSN1I9Ya0tkVTV52uyYEL0rz0833UJAIAsF+gpDGvtfkkLJZ2c9NTJkt73cKijFDu1AQAAIiCMUxh3SXrcGLNA0jzFxjf0kfSgJBljbpfU01p7cfzxVZLWSlomqZliPQ/nxhfUIPP7XQAk4/ceLgUeIKy1TxtjOkq6XrH5HJZKOsNauy6+SXfFAkWlZpJ+J6mnpH2KBYkzrbUvB11rOisoZlIZwKUsOHsKVBPKIEpr7f2S7q/huelJj++UdGcIZWWc1dv2aGDnNq7LABASBk/DJe6FkUE27NrnugQAISssLuUybjhBgEgzvFEAqDT38+064sZZuvGF8KbQBioRINLMB1/srPE5wkVm4LuYnoyDIY2/fTU2Zc5f5q2rY0vAfwSINLM5n9MUme6Wl5a7LgFpYvnmAtclIIsRIICI+e/OItclAECdCBBphrMUAIAoIECkEWutPt+2x3UZAAAQINLJPxZs0ANvr3ZdBoAUmEgK2YYAkUYenbvGdQkAAEgiQGQUhkcAbpSWV+i1ZVtcl5GVlm7Md11C1iJAAEADPfD2au0qKnVdRtYpLa/Ql+6d67qMrEWAAIAGevmTza5LyEql5RWuS8hqBAgAAOAZAQIAkJZcTB+OAwgQmYRRlACAkBAgAKCBPt1S6LqErLSraL/rErIaAQIAGmB9RO5dsq2wxHUJoZt61xzXJWQ1AgQANMCKiNwRc9xtb7guIXRF+8tdl5DVCBAZZD+XNGWMDbuKVF7BoBYA0UWAyCCXPb7QdQnwyaTfvMX3E0CkESCAiHpjRZ7rEgCgRgSINLFr736tzGOkNwAgGggQaeLyJz+q13YVnDcHAISAAJEm5n2xo17brd2xN+BKAERVPjf0QogIEBnGGKZ2BbJVhaUHEuFp4roAIBvs2FOiwuIy12UAgG8IEEAIxtyafZP8ZAt6/ZCtOIWRYSxdmAAQCN5fqyNAAAAAzwgQGYbu1OjJ38fIeCAT0AFRHQEiw9DFFj3XP7/UdQkA4DsCBBCweavrN4cH0hN9fshWBAggYA2ZHHTmks0qLOYUCBAF9O9WR4AAArZ9T8kh73vFUx/pkscW+FgNMhkfcAgTASLDvL1ym+sS4LP/rN3pugQAYoxZMgJEhrn5peX8kANZivEYCBMBAgAyBH86IEwEiAxEB0R0cHt1hGnmJ5tdl5DR+G2ujgABBIi7IyJM/1yw3nUJyCIEiAw04Jcv65S75+i655jAyKWi/WWa+/l2X471ncc+1NaCYl+OBeDQ8PdAdQSIDLUqb48e/2Cd8ouYQ8CVw69/TdP//KEvx3rz06264YVlvhwLmYvPN4SJAJHhdhXtd10CfLK18NDnk0BwonT7mSUb8l2XgCxCgMhwZ90713UJ8MnekjLXJQBZzdLHUw0BIsMV8qGTMT7dUui6BACoQoAAAKAeGERZHQEiDby/2p+R/Mgs1loVl5a7LgNAliJApIGnP2zYtd3vfsb9MTJNcWm5+s94WUOve1XvrOL760pFhdX3n1jouoys9OyiDa5LyHoEiDRQ3sDZDL/56H98qgT1tSegsScXPDRP+ftK9dT8/1atu/j/+P66snjDbpWW06/twk+eXuy6hKzXxHUBqN3uov16aQnT06abyXe+Fchx56/Zqfvf+lx/fn9tIMeHN/vLKlyXADhDD0TEvbp0iy/HGXvrG3px8SZfjoW67dgb3Pwbz3+8iQ+uiKDvIbswiLI6AkTE7S/354Ni+54S/ehviyRJ63cWqd+1M/Xcoo2+HBvh2sKU1pHBvU6QzQgQEdfQ8Q/Jjr3jTR0X716/6umPfT02YnY7mP1zG7NU1mpVXqH6XTtTFzw0z98Dkx/q1O/amep37UzuTJuBCBAR53eA2Lh7X7XHu/bu10NzVnOjJh/d9OLy0F9z3G1vMCq9Fqfc/Y6k2BgSP0XxMzFKl/au31lU9fX1L6T/zf2YibI6AkQELV6/WwvX7dSuvft168wVgb7WUbe8rttf+VTjfz1bf5z9WaCvlQ2stXrW0amhnzy9WP/+aIMW/XeXKiqsnpr/X328freTWsKydGO+bn95hQqL63/TuJIy/z5go3gK48/vrXVdQpXEWp744L81b4i0xFUYEbNx9z595b73nLz2Xa+v0t6SMvVs31IXT+znpAYXrLUyPt0R6fezVvlynEP102cOvrRt7R1nej7O1sJijb9ttmacPlSXTR5Ytd7PtvLDl+L3elm4bpe+cXRffWVUjzrr++5fF+qv3xmv91dv1/wvduqc0T3Vt2NrLVy3Sz3atVD33Jb1fv3yCAaIFZsLXJdQZWi3ttUel1dYFewr1SNzv9DYfh00aVAn7dizXxt2FWlM3/aR+tlKJYLfbqdCCRDGmMslXS2pu6Rlkq6y1r5by/aTJd0labikTZLutNY+GEatrr3k+EqJh975QpJ0/fPLdNbIHrrz3CPVslnjque37ynRlvxijeiZW7Vu8frd6t+5tfaWlKlL2xZq3OjAm8CW/GK9tGSTLp7YT82axDq89pdVVH1dm1V5hXrig3X6wZSBnt7Uvbhr1kr98c3PJR34oC0pK9dPn1msmUs265GLx2rq4V2r7ZNXUKxnF23UiUO7aHCXNvrNqytlrdWlxw3Qn976PJA6G6LftTM1rl979WrfSs8u2qgfnzRYnds000nDuqpbTgvd+OIyHTe4s2Yu2aTnPt6kkb3baXG85+L2Vz6tChDnP/i+Ply7S9KBtiotr9ANLyzT948fqNxWTdWmeZNq3/9K1lpVWKV87lCt2b636usF63Zpwbpduurpj/V/08fqO48tkDHSKYd31WvL8jSwc2ut3hbb/p1V27S1oFhfe3i+JOkPST1v91wwStOO6ln1uPLntaSsXE0aNar6P5SWV0TyapgXFm/SHy86ymkNj85do1teWq6cFtU/Yp6cv07XP195W/rVB+2XGHbLyivUyBg1amS0b395tfchP3uRcOiMDThSGWMukPS4pMslvSfpMkmXSjrcWntQn5Yxpr+kpZIelvSQpGMl3S/pImvtv+rxejmS8vPz85WTk+Pb/yMsD85ZrTte+dR1GfVy05eH64YXlh20vlOb5nrnF1P0nzU7Nf3PH1atv3RSfz0yd03V41+ffYR++ewnunzKQJWUVWhC/w763uOpZ/U7lL+i66PftTOrPf7q2F56ZkH1sQQrbj5Nxki3vLRcT87Pvm7Yk4Z20aPTx1VrqwvH9dbf6zFD6itXHqfT/3Dgb4V3f3GCendo5Utdyd87L8b0ba+F63bV+PzaO87Ur19eof+NB+pkD35jTKRnoAzq96W+DvV7s/aOM/X51j2aeteclM8/9u1x1d5TwrbsplPVunlmd9wXFBQoNzdXknKttbV2Z4URIOZL+sha+4OEdSskPWetnZFi+99I+rK1dljCugcljbTWTkyxfXNJzRNWtZW0wc8AMfez7VVzKBgTW+KvXvW1kRK+Nge2PVBnQs01bSM9/O6BD1gccMHY3gqid7M+H4Lw14XjevtyHL53NfOrjQ9Vpn5vzh3dS00bR+80y3ljemlsvw6+HMtLgAg0ShljmkkaI+mOpKdmSTqmht0mxp9P9JqkS4wxTa21yaOlZki6oaG11mZlXqGeXpCZvxDpgvbPHJn64RIltHEw/vVRNK90Gt23vW8Bwoug+2I6SWosKS9pfZ6kbjXs062G7ZvEj5c8r/Ptio2XqNRWkq/f5TF92+vqU4fIWls1iKay3+bA44Ofk7X12s7aA5cHPTQndZdptrv61CGBHLektLxqDASCd9H4PurVvuHjWbYVlugxpvOuUVC/L/X129dWOn39oLhu15qM6JFb90YBCOtkTvJ5EpNiXV3bp1ova22JpKpZdIIYxTuqdzuN6t3O9+Om8syH67WrqP6XpIWpQ+tm2lnHFM2vXHmcrvr7x1qZVyhJevGHkzS4axsNve5VSdKRvXJ1yaT+GtSljdbtKNLRAzrqrtdXaki3HN30wjIN75GjxRvyq4735KUTNKx7jjq0bhbY/+unpwzRzCWbdcVTH9W57YieOVq6MTqj3L1q0bSRikvrN/DvlmkjdN1zB67dP3pABz327fFV38tkL/1oklo1a6wTfz9Hf/vu0bro4Q+qnvvi12eokY8DKBes3ZkyQFw+ZaDuf/vgwXn1cc7onvr3R9UvwV3wq6l6dO4a9chtoWlH9dQRNyZ3jkbTFScMcvr6I3u105V/X9SgKd2fvHSC/v3RRv381MPUpnmTSLS963aNmkDHQMRPYRRJOt9a+2zC+j9IGmWtnZxin3ckLbLWXpmw7mxJz0hqleIURvL+aT2IsrzCauAvX3ZdRpX6DsZ6cv46DercRhMGdKxxm9LyCjVtXPfVFzv2lGhX0X4N6tK2zm2DMP+LHdpVVKrTRqTuJPvTm5/pd7NW6dZpI/SNo/uqpKxchcVlOuaONyM3Kv8nUw9TcVm5Hoh/qM7+2WT1bNeyWghYdtOpmvHvT/Te59u14FdT9eHaXRrWva3aNG+i/jMO/Cwm/ixU9sY9/O4XGte/g0b3aX/Qaz/y7he6deYK3XjW4Zp+bH9f/1/vfrbtoLvMPnHJBE0a3EkVFVYbd++rmnG1Ngt/NVX3v71aQ7q11Xmje2nJxnxNS7iMOvnnv6y8QoP+5xVJNQ8idm3KkM567NvjXZchSTrihtdUmOLOtFefOkTDe+TotWVbdPWpQ2Uk/XXeOt39Ruwy6Ie+OUanDq/++zfh128oryD2t+J5Y3rpnwvDP53genBqGKI4iHKhtfbyhHXLJT1fyyDKs6y1hyese0CxwHHQIMoU+6d1gJBiIaLyWu7K69zDsvaOM/XWp1v1wZoduuz4gYH+5Z9pSsrKNeRXqf86d+HP08dpypDOMsbo7ZVb1TWnhYZ1j/1OrN62R7e8tFw/OnGQxvSt+dxp5Wj6kb1y9fwPJ4VSd31UVFgN+OXLatuiiS6dNEAtmzXS944fWG2bytpv+vJwfeuYflq3Y6+m3fdeVQ/fcYM76fFLJhx07E279+mYO97UzV8ZXuN8KCVl5WrepHGDrgQJym/OPUIXjOvjuoxqrnjqI81MuKtwTR/EX3v4A63dvlfvXXtiyt7kynafvSJPl/xlQWD11oQAUV2Yl3F+X9I8Sd+T9F1Jw62164wxt0vqaa29OL595WWcDyl2KedESQ8qSy7jTLQqr7BqCt4wJF//Du+WbswPPfQlOnZQR733+Q799rwjdf7Yho/E319WoYXrdml033Zq3qRx3TtEiLVWxaUV1eYPqJwIy68JsZZtyteZf3T3/U7ljnOO0IXjoxUgpFjbS/6cZrbWVusdCwsBorrAx0BYa582xnSUdL1iE0ktlXSGtXZdfJPukvokbL/GGHOGpLslXaHYRFI/rk94yDTtWjYN/DWe/t7RembBBv3k5MHq1d6f6/Oz2YieufrteUfq6n8uCfV1pw7roke+NU6Sv7NFNmvSSBMH1nxaKsqMMdXCQ+W6xH8bakhXN6fZahPF+3NI/o5Pi/qMldkilHthWGvvt9b2s9Y2t9aOsda+k/DcdGvtlKTt51hrR8e3758ts1Am65LTQr84zb9Rv69edVy1x3eed6QmDOio3391JOHBR+eP7X3QFL5BqwwPEm+uYWoUwbaO4v05kJm4mVbEfX1CX1+O8/61J2potwOndB65eKy+6kMXN1Lza7ZFRFsE84OCPi0NVMrsOTkzgF9vUD3axa69n3P1FK3YXKCThnXx58BI6ZrThur15cnTmSDTRLG3h/iAsBAgIs7vLtK+HVurb8fWvh4TBxvUpY3rEpClTjm8pjn6AH9xCiPiovf3DaLghwkT2lx7+lCHlSBquuW2cF0CsgQ9EBEXxUFacOvBb4zRaSO6qbS8Qqu37dF3jxvguiRExB8uHOW6BGQRAkTEkR+QaOlNp6pN/HbCM84YVsfWQOaaNqqHnvt4k+syshqnMCLOjwDRvxNjHlx482cHzdTeYJXhAdHx9s+nuC4hK91x7pGuS8h6vBtF3KGcwmjbookKi2Pzz08/pp8umeTvfQhQP/18Hqx66vCuvh4P/uCSXTdaNE2vmVEzET0QEXcoAeKeC0ZVfX3jl4fzBueI36ef2rfiviRRxFlGZCsCRMQdypsT4yaiIYpzBMB/fJuRrQgQEcebEwAgiggQEWeM0S/P8Hadf8fWzQOqBkAyepqQrRhEmQa+d/xAvb1ym95fvaNe24/s3U7XnDZUfRj7kFGS7ywJJJsyhCnqER4CRIb6wZSBrkuAj0b1bqcfnzjYdRmIuNyWTV2XgCzCKYw0wQ320tPEAR19Oc5zVxyr9q25CgNAdBAggAA1bsT5cQCZiQCRJiw36U1LPzvlMNclAEAgCBBAgI7q0951CQAQCAIEAADwjKsw0gSDKLPTkb1yuZcJgEgiQAAR9sIPJ7kuAQBS4hRGmqhvB8S4fpxzB5AdBndp47qErEaAyDBNG/MtBcI2bVQP1yVkpScuneC6hKzGp02aqO9sAkzLnznOG9PLdQmop7u+Osp1CVmpa04LNW3Mm54rBIg0ceu0Ea5LQMjuOOcI1yWgnhoxYZgzf/vu0a5LyFoEiDQxuGtbRuNnkSaNjJpwOgoeeL1rb6bowBTvzvAOlUb4GwdATb53fHbeQI/bqbtDgAAiiPdEoH74VXGHAAFEUKc2zV2XAKQFwrY7BIgMY8jjGeH5K451XQKQFhqRIJwhQGSYkb1zXZcAH3TJaeG6BACoFQEiw/zoxMGuS0CS2885QpMGddKQrm1dlwJkHDog3CFApJG6prO+dFJ/tWjaOJRaUH8Xje+jJy6doMtPyM5R8kCQuArDHQJEBuGGndHGNOOA/4gP7vCOBoRk6rCurksAMk5Lel2dIUAAIWnWhF+3THbZ5AGuS8hKrZoTIFzhHS2N2DrOUdCVB7jzk6mHuS4hKzVvQoBwhQCRQRgDAbjDAGZkGwIEAADwjAABAAA8I0BkEMZAAADCQoDIIIyBAACEpYnrAlB/logAIMmAzq114pAurstAFiJAAEAae/NnU1yXgCzFKYwMwhgIAEBYCBAAAMAzAkQaqWsmSkZIAADCQoAAAACeESDSyLGDOtX6PGMgAABhIUCkkanDuugv3xnvugwAAAgQ6cQYo8mHda7xecZAAADCQoAAAACeESAyCGMgAABhIUAAANLaSUOZytuFQAOEMaa9MeZxY0x+fHncGNOujn0eM8bYpOWDIOvMFIyBSG+VA2TPGd3TcSUAULeg74XxlKRekk6LP/5fSY9LOquO/V6V9O2Ex/v9Ly19/fP7E7VkQ75ufmm561Lgk46tm2nyYZ310XUnq32rpq7LAYA6BRYgjDHDFAsOR1tr58fXfVfSPGPMEGvtylp2L7HWbgmqtnQ3tl8Hje3X4aAAwRiI9NehdTPXJQBAvQR5CmOipPzK8CBJ1toPJOVLOqaOfacYY7YaY1YZYx42xtR4gssY09wYk1O5SGrrS/UAgLRw1sgerkvISkEGiG6StqZYvzX+XE1ekfR1SSdK+pmkcZLeNMY0r2H7GYqFksplw6EWDAANEfYH2YBOrUN9vaj6yigChAueA4Qx5sYUgxyTl7HxzVON6zM1rI/tYO3T1tqZ1tql1toXJZ0u6TBJZ9awy+2SchOWXl7/T5nilOG15TJEQZvmQQ87gkt9OrQM9fXu/8boUF8vqozhBK4Lh9ID8SdJw+pYlkraIqlriv07S8qr74tZazdLWidpcA3Pl1hrCyoXSYX1/6+kt34dW1V7PL5/B0eVoL6mH9Mv5Xre/zLD9ycPDPX1DCOf4JDnP4estdslba9rO2PMPEm5xpjx1tr/xNdNUKyX4P36vp4xpqOk3pI2e6010zXiUweIlLYtuIIG2SOwMRDW2hWKXY75sDHmaGPM0ZIelvRS4hUYxphPjTFnx79uY4z5nTFmojGmnzFmiqQXFQsszwZVK+AeYRBAegl6JsqvS/pE0qz4skTSN5O2GaJYr4QklUs6QtLzklZJ+kv834nW2qw5NYHMVXOnEdOAAUgvgY7ostbulPSNOrYxCV/vk3RqkDUBLllyAg5R08ZGpeX8AIXtrZ9P0X1vfa6zj2KG2GTcCyONtWzWuOrry6eEO3gLQHhaNm2sSYM6HbSeYVDB69+ptX53/kgdm6L9sx0BIo3dc8EoDerSRn+4cJR+cdpQ1+WgQfgkQO1SXapIjxZc4qL0NDa4a1u98dPJrsuAB/zFCCBT0AMBAAA8I0AAQMRZrtJBBBEggBCdP6a36xKQhhjrgCgiQAAh6pM0/TjQEIypOeBrE/q4LiHrECCACOCDALXp27EV1+nU4aShXVyXkHUIEEAEnMMkNajFsO45rksADkKAACLgRyelvNksUGV4D0IEooV5IIAIaNKIDmrU7vITBqlxo0Y6aVgXfeneuZIYXJmI04DhI0AAQBpo0bSxrpw6WJbUgIjgFAYARNz4/h1SruevbrhEgAAigA8C1ObCcVyi6MKXR/ZwXUKkESAAIOIaM0bGie7tWrguIdIIEACAtGeYKSN0BAgAQNrjfiHhI0AAAJAKmaRWBAgASFNd23KOPlCcFakV80AAEcD5W9SXMUazfzZZJaUVym3V1HU5kRHE7xC/l7UjQABAmhnYuY3rEqKHz/rQcQoDAIAUGJhZOwIEACD98VkfOgIEEAH8pQMg3RAgAACAZwQIIAIY7Q1EEB2DtSJAACHr27GV6xKAzEMGDx0BAgAAeEaAAAAAnhEggJDR0wqkCX5Za0WAAAAAnhEggAgw/KUDNAi/QuEjQAAhS3VlmOVyMQBphgABAEAqBPtaESAAAIBnBAgAQNrr04EJ2sJGgABCxmAvwH/9O7V2XULWIUAAEcBVGJnj5MO7ui4BCAUBAgB8dNH43q5LyEomgBTOGMraESAAwEfcWRXZggABAH4iP2QMvpW1I0AAIaNbFEgTJIhaESCACGjMKMrMQULMHHwva0WAAEKWHBXmzThRjRoRIACkFwIE4Fj33JauSwAAzwgQAADAMwIEAAApMASidgQIAADgGQECAPzEeFhkCQIEELLcVs1clwAADUaAAEJ2zwWjXJcAoB4mDuzouoRII0AAIevfqbXOH9PLdRlIE91yWrguIWtNOayz6xIijQABOMDobtTXMfwV7EwQd/jMJAQIAIgyPsMQUQQIAIgwbg+OqAo0QBhj/scY874xpsgYs7ue+xhjzI3GmE3GmH3GmLeNMcODrBMA/NI9lzELyA5B90A0k/QPSQ942OcXkn4q6YeSxknaIul1Y0xb/8sDAH8N7Zaj35x7hOsystIrVx6nGacPdV1G1gg0QFhrb7DW3i3pk/psb2IjVq6SdJu19t/W2qWSviWplaSv1bBPc2NMTuUiiaABwKkLxvXROUf1dF1G1hnWPUeXTR7ouoysEbUxEP0ldZM0q3KFtbZE0hxJx9SwzwxJ+QnLhoBrBBqsd/tWrktAwPy6RfvwHjm+HAfwWxPXBSTpFv83L2l9nqS+Nexzu6S7Eh63FSECEXfZ5AHavqdEpwzv6roUBKRTm+a+HOdrE/r4chx4M5k5IOrkuQciPsDR1rGMbWBdyZfJmxTrYhtaW2KtLahcJBU28LWBwLVo2li3TBuh4wbzJpWprjhhoKYOiwXEE4d2OaRj/Pa8I9WiaWM/y8oKN5x1eNXXl07q73n/od3a6v6vj/azpIxkrPU2pY0xppOkTnVsttZaW5ywz3RJ91hr29Vx7AGSVksaba1dlLD+eUm7rbXfqkd9OZLy8/PzlZND1x+AaMgvKtX6XUWatTxPf5z92UHPL7ruZM1ZtU0fr9+t6790uIxhIiM/VFRYLd6wW4O7ttWIG1476Pn//eYYHdmrnX713Ce64azh6tW+ZVa3e0FBgXJzcyUpN/5HeY08n8Kw1m6XtP0Qa6vLGsWuujhZ0iJJMsY0kzRZ0jUBvSYABC63VVPltsrVYV3bqm+HVvrZPxZXe75962aadlRPTWPwpa8aNTI6qk97SdLzVxyrr9z3XrXnj+zVTt1yW+iRb41zUV5aC3oeiD7GmFGS+khqbIwZFV/aJGzzqTHmbEmyse6QeyT90hhztjFmhKTHJBVJeirIWgEgDM2aNNK5Y3rp56cc5rqUrDOydzvNveYE12VkjKCvwrhZsZ6EmyS1iX+9SFLiGIkhknITHt+pWIi4X9ICST0lnWKtZWwDgIzxwxMHV33d2KcrNlC3Xu1b6c/TD/Q2tGzGGJND5XkMRNQxBgJAuigsLtVj763VmUd214DObereAb55ackmlVdYfWUUp4wSeRkDQYAAAACSvAWIqE0kBQAA0gABAgAAeEaAAAAAnhEgAACAZwQIAADgGQECAAB4RoAAAACeESAAAIBnBAgAAOAZAQIAAHjm+Xbe6aKgoNYZOAEAQBIvn52ZeC+MnpI2uK4DAIA01stau7G2DTIxQBhJPST5ffvvtooFk14BHBvV0dbhor3DRXuHi/b2rq2kTbaOgJBxpzDi/+FaU9OhiOUSSVJhXXcoQ8PQ1uGivcNFe4eL9j4k9WonBlECAADPCBAAAMAzAkT9lUi6Kf4vgkVbh4v2DhftHS7aOyAZN4gSAAAEjx4IAADgGQECAAB4RoAAAACeESAAAIBnBAgAAOAZAaIejDGXG2PWGGOKjTELjTHHua4paowxxxtjXjTGbDLGWGPMtKTnjTHmxvjz+4wxbxtjhidt09wYc68xZrsxZq8x5gVjTK+kbdobYx43xuTHl8eNMe2StukTr2Vv/Fh/NMY0C+r/HjZjzAxjzIfGmEJjzFZjzHPGmCFJ29DePjHG/MAYs8QYUxBf5hljTk94nrYOSPxn3Rpj7klYR3tHhbWWpZZF0gWS9ku6VNIwSfdI2iOpj+vaorRIOl3SrZLOkWQlTUt6/hrFpkc9R9IISX+XtElS24RtHlBszvqpko6S9KakjyU1TtjmFUmfSJoYXz6R9GLC843j696MH2OqYlOb3+u6jXxs61clTZc0XNJISS9JWiepNe0dSHufJekMSYfFl9vi7wnDaetA232cpDWSFku6h5/t6C3OC4j6Imm+pAeS1q2QdLvr2qK6KClASDKSNku6JmFdc0m7JV0Wf5wbf1O+IGGbHpLKJZ0afzwsfuwJCdscHV83JP749Pg+PRK2uVBSsaQc120TUHt3jrfB8bR3aG2+U9IltHVg7dtG0qr4h/bbigcI2jtaC6cwahHvqhojaVbSU7MkHRN+RWmrv6RuSmhHa22JpDk60I5jJDVN2maTpKUJ20yUlG+tnZ+wzQeS8pO2WRrft9Jrir3JjPHvvxQpufF/d8b/pb0DYoxpbIy5UFJrSfNEWwflPkkzrbVvJK2nvSMk4+7G6bNOinVj5SWtz1Pshxj1U9lWqdqxb8I2+621u1Js0y1hm60pjr81aZtqr2Ot3WWM2a8M/J4ZY4ykuyTNtdYuja+mvX1mjDlCscDQQrFTmGdba5cbYyo/bGhrn8QD2mjFTmEk42c7QggQ9ZM837dJsQ51O5R2TN4m1faHsk2m+JOkIyVNSvEc7e2flZJGSWon6VxJfzHGTE54nrb2gTGmt6Q/SDrFWltcy6a0dwRwCqN22xU7B5acNrvo4ASMmm2J/1tbO26R1MwY076ObbqmOH7npG2qvU78mE2VYd8zY8y9kr4s6QRr7YaEp2hvn1lr91trP7fWLrDWzlBsYN+Voq39NkaxdllojCkzxpRJmizpx/GvK/+ftHcEECBqYa3dL2mhpJOTnjpZ0vvhV5S21ij2y1jVjvHxJZN1oB0XSipN2qa7YqOsK7eZJynXGDM+YZsJio0BSNxmRHzfSqcodie+hf79l9yJX8b2J8VGoZ9orV2TtAntHTyj2Llw2tpfsyUdoVhvT+WyQNKT8a+/EO0dHa5HcUZ90YHLOL+j2MjduxU7B9rXdW1RWhQbNT0qvlhJP4l/3Sf+/DWKjZQ+W7Ff5KeU+tKr9ZJOUuyyqdlKfenVYsVGTB8taYlSX3r1RvwYJ8WPmTGXXkm6P96WkxX7C6lyaZmwDe3tX3v/WtJxkvop9uF2m2I9kyfT1qG0/9s6+DJO2jsCi/MC0mGRdLmktTqQPI93XVPUFklTFAsOyctj8eeNpBsVuwSrWLFR0yOSjtFC0r2SdkgqkvSipN5J23SQ9IRi14EXxL9ul7RNH8XmRiiKH+teSc1dt5GPbZ2qna2k6Qnb0N7+tfejCb//W+MfKCfT1qG1/9uqHiBo74gsJt5IAAAA9cYYCAAA4BkBAgAAeEaAAAAAnhEgAACAZwQIAADgGQECAAB4RoAAAACeESAAAIBnBAgAAOAZAQIAAHhGgAAAAJ79P3fepGTE0Ml4AAAAAElFTkSuQmCC",
+      "text/plain": [
+       "<Figure size 600x400 with 1 Axes>"
+      ]
+     },
+     "metadata": {
+      "needs_background": "light"
+     },
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "from matplotlib import pyplot as plt\n",
+    "plt.figure(dpi=100)\n",
+    "time = np.arange(gen_audio.shape[0])\n",
+    "plt.plot(time, gen_audio)\n",
+    "plt.show()\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 45,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# save audio\n",
+    "import soundfile as sf\n",
+    "sf.write('data/gen_audio.wav', gen_audio, ori_sr)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 46,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "t:  58%|█████▊    | 26/45 [00:41<00:05,  3.45it/s, now=None]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Moviepy - Building video data/generate.mp4.\n",
+      "MoviePy - Writing audio in generateTEMP_MPY_wvf_snd.mp3\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "t:  58%|█████▊    | 26/45 [00:42<00:05,  3.45it/s, now=None]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "MoviePy - Done.\n",
+      "Moviepy - Writing video data/generate.mp4\n",
+      "\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "t:  58%|█████▊    | 26/45 [01:03<00:05,  3.45it/s, now=None]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Moviepy - Done !\n",
+      "Moviepy - video ready data/generate.mp4\n"
+     ]
+    }
+   ],
+   "source": [
+    "gen_audioclip = AudioFileClip(\"data/gen_audio.wav\")\n",
+    "gen_videoclip = ori_videoclip.set_audio(gen_audioclip)\n",
+    "gen_videoclip.write_videofile('data/generate.mp4')\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 47,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<video src=\"data/generate.mp4\" controls  width=\"640\" >\n",
+       "      Your browser does not support the <code>video</code> element.\n",
+       "    </video>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.Video object>"
+      ]
+     },
+     "execution_count": 47,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "Video('data/generate.mp4', width=640)\n"
+   ]
+  }
+ ],
+ "metadata": {
+  "interpreter": {
+   "hash": "419ed25a44e8f5205333d07bc5a26d3abb4bd07afa4dac02924f75b129c3e2d9"
+  },
+  "kernelspec": {
+   "display_name": "Python 3.8.8 ('AVanalogy')",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.8"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/foleycrafter/models/specvqgan/onset_baseline/main.py b/foleycrafter/models/specvqgan/onset_baseline/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..be1b7968118f37a6663fa01a471be74ab905ff86
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/main.py
@@ -0,0 +1,202 @@
+import argparse
+import numpy as np
+import os
+import sys
+import time
+from tqdm import tqdm
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+
+from config import init_args
+import data
+import models
+from models import *
+from utils import utils, torch_utils
+
+
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+def validation(args, net, criterion, data_loader, device='cuda'):
+    # import pdb; pdb.set_trace()
+    net.eval()
+    pred_all = torch.tensor([]).to(device)
+    target_all = torch.tensor([]).to(device)
+    with torch.no_grad():
+        for step, batch in tqdm(enumerate(data_loader), total=len(data_loader), desc="Validation"):
+            pred, target = predict(args, net, batch, device)
+            pred_all = torch.cat([pred_all, pred], dim=0)
+            target_all = torch.cat([target_all, target], dim=0)
+
+    res = criterion.evaluate(pred_all, target_all)
+    torch.cuda.empty_cache()
+    net.train()
+    return res
+
+
+def predict(args, net, batch, device):
+    inputs = {
+        'frames': batch['frames'].to(device)
+    }
+    pred = net(inputs)   
+    target = batch['label'].to(device)  
+    return pred, target
+
+
+def train(args, device):
+    # save dir
+    gpus = torch.cuda.device_count()
+    gpu_ids = list(range(gpus))
+
+    # ----- make dirs for checkpoints ----- #
+    sys.stdout = utils.LoggerOutput(os.path.join('checkpoints', args.exp, 'log.txt'))
+    os.makedirs('./checkpoints/' + args.exp, exist_ok=True)
+
+    writer = SummaryWriter(os.path.join('./checkpoints', args.exp, 'visualization'))
+    # ------------------------------------- #
+    tqdm.write('{}'.format(args)) 
+
+    # ------------------------------------ #
+
+    
+    # ----- Dataset and Dataloader ----- #
+    train_dataset = data.GreatestHitDataset(args, split='train')
+    # train_dataset.getitem_test(1)
+    train_loader = DataLoader(
+        train_dataset,
+        batch_size=args.batch_size,
+        shuffle=True,
+        num_workers=args.num_workers,
+        pin_memory=True,
+        drop_last=False)
+    
+    val_dataset = data.GreatestHitDataset(args, split='val')
+    val_loader = DataLoader(
+        val_dataset,
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.num_workers,
+        pin_memory=True,
+        drop_last=False)
+    # --------------------------------- #
+
+    # ----- Network ----- #
+    net = models.VideoOnsetNet(pretrained=False).to(device)
+    criterion = models.BCLoss(args)
+    optimizer = torch_utils.make_optimizer(net, args)
+    # --------------------- #
+
+    # -------- Loading checkpoints weights ------------- #
+    if args.resume:
+        resume = './checkpoints/' + args.resume
+        net, args.start_epoch = torch_utils.load_model(resume, net, device=device, strict=True)
+        if args.resume_optim:
+            tqdm.write('loading optimizer...')
+            optim_state = torch.load(resume)['optimizer']
+            optimizer.load_state_dict(optim_state)
+            tqdm.write('loaded optimizer!')
+        else:
+            args.start_epoch = 0
+
+    # ------------------- 
+    net = nn.DataParallel(net, device_ids=gpu_ids)
+    #  --------- Random or resume validation ------------ #
+    res = validation(args, net, criterion, val_loader, device)
+    writer.add_scalars('VideoOnset' + '/validation', res, args.start_epoch)
+    tqdm.write("Beginning, Validation results: {}".format(res))
+    tqdm.write('\n')
+
+    # ----------------- Training ---------------- #
+    # import pdb; pdb.set_trace()
+    VALID_STEP = args.valid_step
+    for epoch in range(args.start_epoch, args.epochs):
+        running_loss = 0.0
+        torch_utils.adjust_learning_rate(optimizer, epoch, args)
+        for step, batch in tqdm(enumerate(train_loader), total=len(train_loader), desc="Training"):
+            pred, target = predict(args, net, batch, device)
+            loss = criterion(pred, target)        
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+
+            if step % 1 == 0:
+                tqdm.write("Epoch: {}/{}, step: {}/{}, loss: {}".format(epoch+1, args.epochs, step+1, len(train_loader), loss))
+                running_loss += loss.item()
+
+            current_step = epoch * len(train_loader) + step + 1
+            BOARD_STEP = 3
+            if (step+1) % BOARD_STEP == 0:
+                writer.add_scalar('VideoOnset' + '/training loss', running_loss / BOARD_STEP, current_step)
+                running_loss = 0.0
+        
+        
+        # ----------- Validtion -------------- #
+        if (epoch + 1) % VALID_STEP == 0:
+            res = validation(args, net, criterion, val_loader, device)
+            writer.add_scalars('VideoOnset' + '/validation', res, epoch + 1)
+            tqdm.write("Epoch: {}/{}, Validation results: {}".format(epoch + 1, args.epochs, res))
+
+        # ---------- Save model ----------- #
+        SAVE_STEP = args.save_step
+        if (epoch + 1) % SAVE_STEP == 0:
+            path = os.path.join('./checkpoints', args.exp, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar')
+            torch.save({'epoch': epoch + 1,
+                        'step': current_step,
+                        'state_dict': net.state_dict(),
+                        'optimizer': optimizer.state_dict(),
+                        },
+                        path)
+        # --------------------------------- #
+    torch.cuda.empty_cache()
+    tqdm.write('Training Complete!')
+    writer.close()
+
+
+def test(args, device):
+    # save dir
+    gpus = torch.cuda.device_count()
+    gpu_ids = list(range(gpus))
+
+    # ----- make dirs for results ----- #
+    sys.stdout = utils.LoggerOutput(os.path.join('results', args.exp, 'log.txt'))
+    os.makedirs('./results/' + args.exp, exist_ok=True)
+    # ------------------------------------- #
+    tqdm.write('{}'.format(args)) 
+    # ------------------------------------ #
+    # ----- Dataset and Dataloader ----- #
+    test_dataset = data.GreatestHitDataset(args, split='test')
+    test_loader = DataLoader(
+        test_dataset,
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.num_workers,
+        pin_memory=True,
+        drop_last=False)
+
+    # --------------------------------- #
+    # ----- Network ----- #
+    net = models.VideoOnsetNet(pretrained=False).to(device)
+    criterion = models.BCLoss(args)
+    # -------- Loading checkpoints weights ------------- #
+    if args.resume:
+        resume = './checkpoints/' + args.resume
+        net, _ = torch_utils.load_model(resume, net, device=device, strict=True)
+
+    # ------------------- #
+    net = nn.DataParallel(net, device_ids=gpu_ids)
+    #  --------- Testing ------------ #
+    res = validation(args, net, criterion, test_loader, device)
+    tqdm.write("Testing results: {}".format(res))
+
+
+# CUDA_VISIBLE_DEVICES=1 python main.py --exp='EXP1' --epochs=100 --batch_size=12 --num_workers=8 --save_step=10 --valid_step=1 --lr=0.0001 --optim='Adam' --repeat=1 --schedule='cos'
+if __name__ == '__main__':
+    args = init_args()
+    if args.test_mode:
+        test(args, DEVICE)
+    else:
+        train(args, DEVICE)
\ No newline at end of file
diff --git a/foleycrafter/models/specvqgan/onset_baseline/main_cxav.py b/foleycrafter/models/specvqgan/onset_baseline/main_cxav.py
new file mode 100644
index 0000000000000000000000000000000000000000..498ce1fd3cddb79d0e175501ed43c009fe9aa098
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/main_cxav.py
@@ -0,0 +1,202 @@
+import argparse
+import numpy as np
+import os
+import sys
+import time
+from tqdm import tqdm
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+
+from config import init_args
+import data
+import models
+from models import *
+from utils import utils, torch_utils
+
+
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+def validation(args, net, criterion, data_loader, device='cuda'):
+    # import pdb; pdb.set_trace()
+    net.eval()
+    pred_all = torch.tensor([]).to(device)
+    target_all = torch.tensor([]).to(device)
+    with torch.no_grad():
+        for step, batch in tqdm(enumerate(data_loader), total=len(data_loader), desc="Validation"):
+            pred, target = predict(args, net, batch, device)
+            pred_all = torch.cat([pred_all, pred], dim=0)
+            target_all = torch.cat([target_all, target], dim=0)
+
+    res = criterion.evaluate(pred_all, target_all)
+    torch.cuda.empty_cache()
+    net.train()
+    return res
+
+
+def predict(args, net, batch, device):
+    inputs = {
+        'frames': batch['frames'].to(device)
+    }
+    pred = net(inputs)   
+    target = batch['label'].to(device)  
+    return pred, target
+
+
+def train(args, device):
+    # save dir
+    gpus = torch.cuda.device_count()
+    gpu_ids = list(range(gpus))
+
+    # ----- make dirs for checkpoints ----- #
+    sys.stdout = utils.LoggerOutput(os.path.join('checkpoints', args.exp, 'log.txt'))
+    os.makedirs('./checkpoints/' + args.exp, exist_ok=True)
+
+    writer = SummaryWriter(os.path.join('./checkpoints', args.exp, 'visualization'))
+    # ------------------------------------- #
+    tqdm.write('{}'.format(args)) 
+
+    # ------------------------------------ #
+
+    
+    # ----- Dataset and Dataloader ----- #
+    train_dataset = data.CountixAVDataset(args, split='train')
+    # train_dataset.getitem_test(1)
+    train_loader = DataLoader(
+        train_dataset,
+        batch_size=args.batch_size,
+        shuffle=True,
+        num_workers=args.num_workers,
+        pin_memory=True,
+        drop_last=False)
+    
+    val_dataset = data.CountixAVDataset(args, split='val')
+    val_loader = DataLoader(
+        val_dataset,
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.num_workers,
+        pin_memory=True,
+        drop_last=False)
+    # --------------------------------- #
+
+    # ----- Network ----- #
+    net = models.VideoOnsetNet(pretrained=False).to(device)
+    criterion = models.BCLoss(args)
+    optimizer = torch_utils.make_optimizer(net, args)
+    # --------------------- #
+
+    # -------- Loading checkpoints weights ------------- #
+    if args.resume:
+        resume = './checkpoints/' + args.resume
+        net, args.start_epoch = torch_utils.load_model(resume, net, device=device, strict=True)
+        if args.resume_optim:
+            tqdm.write('loading optimizer...')
+            optim_state = torch.load(resume)['optimizer']
+            optimizer.load_state_dict(optim_state)
+            tqdm.write('loaded optimizer!')
+        else:
+            args.start_epoch = 0
+
+    # ------------------- 
+    net = nn.DataParallel(net, device_ids=gpu_ids)
+    #  --------- Random or resume validation ------------ #
+    res = validation(args, net, criterion, val_loader, device)
+    writer.add_scalars('VideoOnset' + '/validation', res, args.start_epoch)
+    tqdm.write("Beginning, Validation results: {}".format(res))
+    tqdm.write('\n')
+
+    # ----------------- Training ---------------- #
+    # import pdb; pdb.set_trace()
+    VALID_STEP = args.valid_step
+    for epoch in range(args.start_epoch, args.epochs):
+        running_loss = 0.0
+        torch_utils.adjust_learning_rate(optimizer, epoch, args)
+        for step, batch in tqdm(enumerate(train_loader), total=len(train_loader), desc="Training"):
+            pred, target = predict(args, net, batch, device)
+            loss = criterion(pred, target)        
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+
+            if step % 1 == 0:
+                tqdm.write("Epoch: {}/{}, step: {}/{}, loss: {}".format(epoch+1, args.epochs, step+1, len(train_loader), loss))
+                running_loss += loss.item()
+
+            current_step = epoch * len(train_loader) + step + 1
+            BOARD_STEP = 3
+            if (step+1) % BOARD_STEP == 0:
+                writer.add_scalar('VideoOnset' + '/training loss', running_loss / BOARD_STEP, current_step)
+                running_loss = 0.0
+        
+        
+        # ----------- Validtion -------------- #
+        if (epoch + 1) % VALID_STEP == 0:
+            res = validation(args, net, criterion, val_loader, device)
+            writer.add_scalars('VideoOnset' + '/validation', res, epoch + 1)
+            tqdm.write("Epoch: {}/{}, Validation results: {}".format(epoch + 1, args.epochs, res))
+
+        # ---------- Save model ----------- #
+        SAVE_STEP = args.save_step
+        if (epoch + 1) % SAVE_STEP == 0:
+            path = os.path.join('./checkpoints', args.exp, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar')
+            torch.save({'epoch': epoch + 1,
+                        'step': current_step,
+                        'state_dict': net.state_dict(),
+                        'optimizer': optimizer.state_dict(),
+                        },
+                        path)
+        # --------------------------------- #
+    torch.cuda.empty_cache()
+    tqdm.write('Training Complete!')
+    writer.close()
+
+
+def test(args, device):
+    # save dir
+    gpus = torch.cuda.device_count()
+    gpu_ids = list(range(gpus))
+
+    # ----- make dirs for results ----- #
+    sys.stdout = utils.LoggerOutput(os.path.join('results', args.exp, 'log.txt'))
+    os.makedirs('./results/' + args.exp, exist_ok=True)
+    # ------------------------------------- #
+    tqdm.write('{}'.format(args)) 
+    # ------------------------------------ #
+    # ----- Dataset and Dataloader ----- #
+    test_dataset = data.CountixAVDataset(args, split='test')
+    test_loader = DataLoader(
+        test_dataset,
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.num_workers,
+        pin_memory=True,
+        drop_last=False)
+
+    # --------------------------------- #
+    # ----- Network ----- #
+    net = models.VideoOnsetNet(pretrained=False).to(device)
+    criterion = models.BCLoss(args)
+    # -------- Loading checkpoints weights ------------- #
+    if args.resume:
+        resume = './checkpoints/' + args.resume
+        net, _ = torch_utils.load_model(resume, net, device=device, strict=True)
+
+    # ------------------- #
+    net = nn.DataParallel(net, device_ids=gpu_ids)
+    #  --------- Testing ------------ #
+    res = validation(args, net, criterion, test_loader, device)
+    tqdm.write("Testing results: {}".format(res))
+
+
+# CUDA_VISIBLE_DEVICES=1 python main.py --exp='EXP1' --epochs=100 --batch_size=12 --num_workers=8 --save_step=10 --valid_step=1 --lr=0.0001 --optim='Adam' --repeat=1 --schedule='cos'
+if __name__ == '__main__':
+    args = init_args()
+    if args.test_mode:
+        test(args, DEVICE)
+    else:
+        train(args, DEVICE)
\ No newline at end of file
diff --git a/foleycrafter/models/specvqgan/onset_baseline/models/__init__.py b/foleycrafter/models/specvqgan/onset_baseline/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b314242ca0d707d9e6f4a39937fbe119eaf88c62
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/models/__init__.py
@@ -0,0 +1,3 @@
+from .resnet import *
+from .r2plus1d_18 import *
+from .video_onset_net import *
\ No newline at end of file
diff --git a/foleycrafter/models/specvqgan/onset_baseline/models/r2plus1d_18.py b/foleycrafter/models/specvqgan/onset_baseline/models/r2plus1d_18.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c2d3a4de4ff8d1166100ddc47f14d09ab1119b3
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/models/r2plus1d_18.py
@@ -0,0 +1,51 @@
+import torch
+import torch.nn as nn
+
+
+import sys
+sys.path.append('..')
+from foleycrafter.models.specvqgan.onset_baseline.models.resnet import r2plus1d_18
+
+
+class r2plus1d18KeepTemp(nn.Module):
+
+    def __init__(self, pretrained=True):
+        super().__init__()
+
+        self.model = r2plus1d_18(pretrained=pretrained)
+
+        self.model.layer2[0].conv1[0][3] = nn.Conv3d(230, 128, kernel_size=(3, 1, 1), 
+            stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
+        self.model.layer2[0].downsample = nn.Sequential(
+            nn.Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False),
+            nn.BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+        )
+        self.model.layer3[0].conv1[0][3] = nn.Conv3d(460, 256, kernel_size=(3, 1, 1), 
+            stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
+        self.model.layer3[0].downsample = nn.Sequential(
+            nn.Conv3d(128, 256, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False),
+            nn.BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+        )
+        self.model.layer4[0].conv1[0][3] = nn.Conv3d(921, 512, kernel_size=(3, 1, 1), 
+            stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
+        self.model.layer4[0].downsample = nn.Sequential(
+            nn.Conv3d(256, 512, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False),
+            nn.BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+        )
+        self.model.avgpool = nn.AdaptiveAvgPool3d((None, 1, 1))
+        self.model.fc = nn.Identity()
+
+
+    def forward(self, x):
+        # import pdb; pdb.set_trace()
+        x = self.model(x)
+        return x
+
+
+
+
+if __name__ == '__main__':
+    model = r2plus1d18KeepTemp(False).cuda()
+    rand_input = torch.randn((1, 3, 30, 112, 112)).cuda()
+    out = model(rand_input)
+
diff --git a/foleycrafter/models/specvqgan/onset_baseline/models/resnet.py b/foleycrafter/models/specvqgan/onset_baseline/models/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bc15653409a60c61a4d053ee9a69dc4be119e65
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/models/resnet.py
@@ -0,0 +1,348 @@
+import torch.nn as nn
+
+# from torchvision.models.utils import load_state_dict_from_url
+from torch.hub import load_state_dict_from_url
+
+
+__all__ = ['r3d_18', 'mc3_18', 'r2plus1d_18']
+
+model_urls = {
+    'r3d_18': 'https://download.pytorch.org/models/r3d_18-b3b3357e.pth',
+    'mc3_18': 'https://download.pytorch.org/models/mc3_18-a90a0ba3.pth',
+    'r2plus1d_18': 'https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth',
+}
+
+
+class Conv3DSimple(nn.Conv3d):
+    def __init__(self,
+                 in_planes,
+                 out_planes,
+                 midplanes=None,
+                 stride=1,
+                 padding=1):
+
+        super(Conv3DSimple, self).__init__(
+            in_channels=in_planes,
+            out_channels=out_planes,
+            kernel_size=(3, 3, 3),
+            stride=stride,
+            padding=padding,
+            bias=False)
+
+    @staticmethod
+    def get_downsample_stride(stride):
+        return stride, stride, stride
+
+
+class Conv2Plus1D(nn.Sequential):
+
+    def __init__(self,
+                 in_planes,
+                 out_planes,
+                 midplanes,
+                 stride=1,
+                 padding=1):
+        super(Conv2Plus1D, self).__init__(
+            nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3),
+                      stride=(1, stride, stride), padding=(0, padding, padding),
+                      bias=False),
+            nn.BatchNorm3d(midplanes),
+            nn.ReLU(inplace=True),
+            nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1),
+                      stride=(stride, 1, 1), padding=(padding, 0, 0),
+                      bias=False))
+
+    @staticmethod
+    def get_downsample_stride(stride):
+        return stride, stride, stride
+
+
+class Conv3DNoTemporal(nn.Conv3d):
+
+    def __init__(self,
+                 in_planes,
+                 out_planes,
+                 midplanes=None,
+                 stride=1,
+                 padding=1):
+
+        super(Conv3DNoTemporal, self).__init__(
+            in_channels=in_planes,
+            out_channels=out_planes,
+            kernel_size=(1, 3, 3),
+            stride=(1, stride, stride),
+            padding=(0, padding, padding),
+            bias=False)
+
+    @staticmethod
+    def get_downsample_stride(stride):
+        return 1, stride, stride
+
+
+class BasicBlock(nn.Module):
+
+    expansion = 1
+
+    def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
+        midplanes = (inplanes * planes * 3 * 3 *
+                     3) // (inplanes * 3 * 3 + 3 * planes)
+
+        super(BasicBlock, self).__init__()
+        self.conv1 = nn.Sequential(
+            conv_builder(inplanes, planes, midplanes, stride),
+            nn.BatchNorm3d(planes),
+            nn.ReLU(inplace=True)
+        )
+        self.conv2 = nn.Sequential(
+            conv_builder(planes, planes, midplanes),
+            nn.BatchNorm3d(planes)
+        )
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.conv2(out)
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
+
+        super(Bottleneck, self).__init__()
+        midplanes = (inplanes * planes * 3 * 3 *
+                     3) // (inplanes * 3 * 3 + 3 * planes)
+
+        # 1x1x1
+        self.conv1 = nn.Sequential(
+            nn.Conv3d(inplanes, planes, kernel_size=1, bias=False),
+            nn.BatchNorm3d(planes),
+            nn.ReLU(inplace=True)
+        )
+        # Second kernel
+        self.conv2 = nn.Sequential(
+            conv_builder(planes, planes, midplanes, stride),
+            nn.BatchNorm3d(planes),
+            nn.ReLU(inplace=True)
+        )
+
+        # 1x1x1
+        self.conv3 = nn.Sequential(
+            nn.Conv3d(planes, planes * self.expansion,
+                      kernel_size=1, bias=False),
+            nn.BatchNorm3d(planes * self.expansion)
+        )
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.conv2(out)
+        out = self.conv3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class BasicStem(nn.Sequential):
+    """The default conv-batchnorm-relu stem
+    """
+
+    def __init__(self):
+        super(BasicStem, self).__init__(
+            nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2),
+                      padding=(1, 3, 3), bias=False),
+            nn.BatchNorm3d(64),
+            nn.ReLU(inplace=True))
+
+
+class R2Plus1dStem(nn.Sequential):
+    """R(2+1)D stem is different than the default one as it uses separated 3D convolution
+    """
+
+    def __init__(self):
+        super(R2Plus1dStem, self).__init__(
+            nn.Conv3d(3, 45, kernel_size=(1, 7, 7),
+                      stride=(1, 2, 2), padding=(0, 3, 3),
+                      bias=False),
+            nn.BatchNorm3d(45),
+            nn.ReLU(inplace=True),
+            nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
+                      stride=(1, 1, 1), padding=(1, 0, 0),
+                      bias=False),
+            nn.BatchNorm3d(64),
+            nn.ReLU(inplace=True))
+
+
+class VideoResNet(nn.Module):
+
+    def __init__(self, block, conv_makers, layers,
+                 stem, num_classes=400,
+                 zero_init_residual=False):
+        """Generic resnet video generator.
+        Args:
+            block (nn.Module): resnet building block
+            conv_makers (list(functions)): generator function for each layer
+            layers (List[int]): number of blocks per layer
+            stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
+            num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
+            zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
+        """
+        super(VideoResNet, self).__init__()
+        self.inplanes = 64
+
+        self.stem = stem()
+
+        self.layer1 = self._make_layer(
+            block, conv_makers[0], 64, layers[0], stride=1)
+        self.layer2 = self._make_layer(
+            block, conv_makers[1], 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(
+            block, conv_makers[2], 256, layers[2], stride=2)
+        self.layer4 = self._make_layer(
+            block, conv_makers[3], 512, layers[3], stride=2)
+
+        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
+        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+        # init weights
+        self._initialize_weights()
+
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, Bottleneck):
+                    nn.init.constant_(m.bn3.weight, 0)
+
+    def forward(self, x):
+        x = self.stem(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        # Flatten the layer to fc
+        # x = x.flatten(1)
+        # x = self.fc(x)
+        N = x.shape[0]
+        x = x.squeeze()
+        if N == 1:
+            x = x[None]
+
+        return x
+
+    def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
+        downsample = None
+
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            ds_stride = conv_builder.get_downsample_stride(stride)
+            downsample = nn.Sequential(
+                nn.Conv3d(self.inplanes, planes * block.expansion,
+                          kernel_size=1, stride=ds_stride, bias=False),
+                nn.BatchNorm3d(planes * block.expansion)
+            )
+        layers = []
+        layers.append(block(self.inplanes, planes,
+                      conv_builder, stride, downsample))
+
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes, conv_builder))
+
+        return nn.Sequential(*layers)
+
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv3d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out',
+                                        nonlinearity='relu')
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.BatchNorm3d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.Linear):
+                nn.init.normal_(m.weight, 0, 0.01)
+                nn.init.constant_(m.bias, 0)
+
+
+def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
+    model = VideoResNet(**kwargs)
+
+    if pretrained:
+        state_dict = load_state_dict_from_url(model_urls[arch],
+                                              progress=progress)
+        model.load_state_dict(state_dict)
+    return model
+
+
+def r3d_18(pretrained=False, progress=True, **kwargs):
+    """Construct 18 layer Resnet3D model as in
+    https://arxiv.org/abs/1711.11248
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on Kinetics-400
+        progress (bool): If True, displays a progress bar of the download to stderr
+    Returns:
+        nn.Module: R3D-18 network
+    """
+
+    return _video_resnet('r3d_18',
+                         pretrained, progress,
+                         block=BasicBlock,
+                         conv_makers=[Conv3DSimple] * 4,
+                         layers=[2, 2, 2, 2],
+                         stem=BasicStem, **kwargs)
+
+
+def mc3_18(pretrained=False, progress=True, **kwargs):
+    """Constructor for 18 layer Mixed Convolution network as in
+    https://arxiv.org/abs/1711.11248
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on Kinetics-400
+        progress (bool): If True, displays a progress bar of the download to stderr
+    Returns:
+        nn.Module: MC3 Network definition
+    """
+    return _video_resnet('mc3_18',
+                         pretrained, progress,
+                         block=BasicBlock,
+                         conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3,
+                         layers=[2, 2, 2, 2],
+                         stem=BasicStem, **kwargs)
+
+
+def r2plus1d_18(pretrained=False, progress=True, **kwargs):
+    """Constructor for the 18 layer deep R(2+1)D network as in
+    https://arxiv.org/abs/1711.11248
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on Kinetics-400
+        progress (bool): If True, displays a progress bar of the download to stderr
+    Returns:
+        nn.Module: R(2+1)D-18 network
+    """
+    return _video_resnet('r2plus1d_18',
+                         pretrained, progress,
+                         block=BasicBlock,
+                         conv_makers=[Conv2Plus1D] * 4,
+                         layers=[2, 2, 2, 2],
+                         stem=R2Plus1dStem, **kwargs)
diff --git a/foleycrafter/models/specvqgan/onset_baseline/models/video_onset_net.py b/foleycrafter/models/specvqgan/onset_baseline/models/video_onset_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..01fc395c1809c7234e47152328ca419c21575196
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/models/video_onset_net.py
@@ -0,0 +1,78 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from sklearn.metrics import average_precision_score
+import sys
+sys.path.append('..')
+from foleycrafter.models.specvqgan.onset_baseline.models import r2plus1d18KeepTemp
+from foleycrafter.models.specvqgan.onset_baseline.utils import torch_utils
+
+class VideoOnsetNet(nn.Module):
+    # Video Onset detection network
+    def __init__(self, pretrained):
+        super(VideoOnsetNet, self).__init__()
+        self.net = r2plus1d18KeepTemp(pretrained=pretrained)
+        self.fc = nn.Sequential(
+            nn.Linear(512, 128),
+            nn.ReLU(True),
+            nn.Linear(128, 1)
+        )
+
+    def forward(self, inputs, loss=False, evaluate=False):
+        # import pdb; pdb.set_trace()
+        x = inputs['frames']
+        x = self.net(x)
+        x = x.transpose(-1, -2)
+        x = self.fc(x)
+        x = x.squeeze(-1)
+
+        return x
+
+
+class BCLoss(nn.Module):
+    # binary classification loss
+    def __init__(self, args):
+        super(BCLoss, self).__init__()
+
+    def forward(self, pred, target):
+        # import pdb; pdb.set_trace()
+        pred = pred.contiguous().view(-1)
+        target = target.contiguous().view(-1)
+        pos_weight = (target.shape[0] - target.sum()) / target.sum()
+        criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(pred.device)
+        loss = criterion(pred, target.float())
+        return loss
+
+    def evaluate(self, pred, target):
+        # import pdb; pdb.set_trace()
+
+        pred = pred.contiguous().view(-1)
+        target = target.contiguous().view(-1)
+        pred = torch.sigmoid(pred)
+        pred = pred.data.cpu().numpy()
+        target = target.data.cpu().numpy()
+        
+        pos_index = np.nonzero(target == 1)[0]
+        neg_index = np.nonzero(target == 0)[0]
+        balance_num = min(pos_index.shape[0], neg_index.shape[0])
+        index = np.concatenate((pos_index[:balance_num], neg_index[:balance_num]), axis=0)
+        pred = pred[index]
+        target = target[index]
+        ap = average_precision_score(target, pred)
+        acc = torch_utils.binary_acc(pred, target, thred=0.5)
+        res = {
+            'AP': ap,
+            'Acc': acc
+        }
+        return res
+
+
+
+if __name__ == '__main__':
+    model = VideoOnsetNet(False).cuda()
+    rand_input = torch.randn((1, 3, 30, 112, 112)).cuda()
+    inputs = {
+        'frames': rand_input
+    }
+    out = model(inputs)
\ No newline at end of file
diff --git a/foleycrafter/models/specvqgan/onset_baseline/onset_gen.py b/foleycrafter/models/specvqgan/onset_baseline/onset_gen.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dbb12dad941a6e3c526bcea8575506e7bf071d5
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/onset_gen.py
@@ -0,0 +1,189 @@
+import glob
+import os
+import numpy as np
+from moviepy.editor import *
+import librosa
+import soundfile as sf
+
+import argparse
+import numpy as np
+import os
+import sys
+import time
+from tqdm import tqdm
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+import torchvision.transforms as transforms
+from PIL import Image
+import shutil
+
+from config import init_args
+import data
+import models
+from models import *
+from utils import utils, torch_utils
+
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+vision_transform_list = [
+    transforms.Resize((128, 128)),
+    transforms.CenterCrop((112, 112)),
+    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+]
+video_transform = transforms.Compose(vision_transform_list)
+
+def read_image(frame_list):
+    imgs = []
+    convert_tensor = transforms.ToTensor()
+    for img_path in frame_list:
+        image = Image.open(img_path).convert('RGB')
+        image = convert_tensor(image)
+        imgs.append(image.unsqueeze(0))
+    # (T, C, H ,W)
+    imgs = torch.cat(imgs, dim=0).squeeze()
+    imgs = video_transform(imgs)
+    imgs = imgs.permute(1, 0, 2, 3)
+    # (C, T, H ,W)
+    return imgs
+
+
+def get_video_frames(origin_video_path):
+    save_path = 'results/temp'
+    if os.path.exists(save_path):
+        os.system(f'rm -rf {save_path}')
+    os.makedirs(save_path)
+    command = f'ffmpeg -v quiet -y -i \"{origin_video_path}\" -f image2 -vf \"scale=-1:360,fps=15\" -qscale:v 3 \"{save_path}\"/frame%06d.jpg'
+    os.system(command)
+    frame_list = glob.glob(f'{save_path}/*.jpg')
+    frame_list.sort()
+    frame_list = frame_list[:2 * 15]
+    frames = read_image(frame_list)
+    return frames
+
+
+def postprocess_video_onsets(probs, thres=0.5, nearest=5):
+    # import pdb; pdb.set_trace()
+    video_onsets = []
+    pred = np.array(probs, copy=True)
+    while True:
+        max_ind = np.argmax(pred)
+        video_onsets.append(max_ind)
+        low = max(max_ind - nearest, 0)
+        high = min(max_ind + nearest, pred.shape[0])
+        pred[low: high] = 0
+        if (pred > thres).sum() == 0:
+            break
+    video_onsets.sort()
+    video_onsets = np.array(video_onsets)
+    return video_onsets
+
+
+def detect_onset_of_audio(audio, sample_rate):
+    onsets = librosa.onset.onset_detect(
+        y=audio, sr=sample_rate, units='samples', delta=0.3)
+    return onsets
+
+
+def get_onset_audio_range(audio_len, onsets, i):
+    if i == 0:
+        prev_offset = int(onsets[i] // 3)
+    else:
+        prev_offset = int((onsets[i] - onsets[i - 1]) // 3)
+
+    if i == onsets.shape[0] - 1:
+        post_offset = int((audio_len - onsets[i]) // 4 * 2)
+    else:
+        post_offset = int((onsets[i + 1] - onsets[i]) // 4 * 2)
+    return prev_offset, post_offset
+
+
+def generate_audio(con_videoclip, video_onsets):
+    np.random.seed(2022)
+    con_audioclip = con_videoclip.audio
+    con_audio, con_sr = con_audioclip.to_soundarray(), con_audioclip.fps
+    con_audio = con_audio.mean(-1)
+    target_sr = 22050
+    if target_sr != con_sr:
+        con_audio = librosa.resample(con_audio, orig_sr=con_sr, target_sr=target_sr)
+        con_sr = target_sr
+    
+    con_onsets = detect_onset_of_audio(con_audio, con_sr)
+    gen_audio = np.zeros(int(2 * con_sr))
+
+    for i in range(video_onsets.shape[0]):
+        prev_offset, post_offset = get_onset_audio_range(
+            int(con_sr * 2), video_onsets, i)
+        j = np.random.choice(con_onsets.shape[0])
+        prev_offset_con, post_offset_con = get_onset_audio_range(
+            con_audio.shape[0], con_onsets, j)
+        prev_offset = min(prev_offset, prev_offset_con)
+        post_offset = min(post_offset, post_offset_con)
+        gen_audio[video_onsets[i] - prev_offset: video_onsets[i] +
+                post_offset] = con_audio[con_onsets[j] - prev_offset: con_onsets[j] + post_offset]
+    return gen_audio
+
+
+def generate_video(net, original_video_list, cond_video_list_0, cond_video_list_1, cond_video_list_2):
+    save_folder = 'results/onset_baseline/vis'
+    os.makedirs(save_folder, exist_ok=True)
+    origin_video_folder = os.path.join(save_folder, '0_original')
+    os.makedirs(origin_video_folder, exist_ok=True)
+
+    for i in range(len(original_video_list)):
+        # import pdb; pdb.set_trace()
+        shutil.copy(original_video_list[i], os.path.join(
+            origin_video_folder, original_video_list[i].split('/')[-1]))
+        
+        ori_videoclip = VideoFileClip(original_video_list[i])
+
+        frames = get_video_frames(original_video_list[i])
+        inputs = {
+            'frames': frames.unsqueeze(0).to(device)
+        }
+        pred = net(inputs).squeeze()
+        pred = torch.sigmoid(pred).data.cpu().numpy()
+        video_onsets = postprocess_video_onsets(pred, thres=0.5, nearest=4)
+        video_onsets = (video_onsets / 15 * 22050).astype(int)
+
+        for ind, cond_video in enumerate([cond_video_list_0[i], cond_video_list_1[i], cond_video_list_2[i]]):
+            cond_video_folder = os.path.join(save_folder, f'{ind * 2 + 1}_conditional_{ind}')
+            os.makedirs(cond_video_folder, exist_ok=True)
+            shutil.copy(cond_video, os.path.join(
+                cond_video_folder, cond_video.split('/')[-1]))
+            con_videoclip = VideoFileClip(cond_video)
+            gen_audio = generate_audio(con_videoclip, video_onsets)
+            save_audio_path = 'results/gen_audio.wav'
+            sf.write(save_audio_path, gen_audio, 22050)
+            gen_audioclip = AudioFileClip(save_audio_path)
+            gen_videoclip = ori_videoclip.set_audio(gen_audioclip)
+            save_gen_folder = os.path.join(save_folder, f'{ind * 2 + 2}_generate_{ind}')
+            os.makedirs(save_gen_folder, exist_ok=True)
+            gen_videoclip.write_videofile(os.path.join(save_gen_folder, original_video_list[i].split('/')[-1]))
+
+
+
+if __name__ == '__main__': 
+    net = models.VideoOnsetNet(pretrained=False).to(device)
+    resume = 'checkpoints/EXP1/checkpoint_ep100.pth.tar'
+    net, _ = torch_utils.load_model(resume, net, device=device, strict=True)
+    read_folder = '' # name to a directory that generated with `audio_generation.py` 
+    original_video_list = glob.glob(f'{read_folder}/2sec_full_orig_video/*.mp4')
+    original_video_list.sort()
+
+    cond_video_list_0 = glob.glob(f'{read_folder}/2sec_full_cond_video_0/*.mp4')
+    cond_video_list_0.sort()
+
+    cond_video_list_1 = glob.glob(f'{read_folder}/2sec_full_cond_video_1/*.mp4')
+    cond_video_list_1.sort()
+
+    cond_video_list_2 = glob.glob(f'{read_folder}/2sec_full_cond_video_2/*.mp4')
+    cond_video_list_2.sort()
+
+    generate_video(net, original_video_list, cond_video_list_0, cond_video_list_1, cond_video_list_2)
\ No newline at end of file
diff --git a/foleycrafter/models/specvqgan/onset_baseline/onset_gen_cxav.py b/foleycrafter/models/specvqgan/onset_baseline/onset_gen_cxav.py
new file mode 100644
index 0000000000000000000000000000000000000000..e82e1393d3c2ac4f6633f88f79f7ae2c59dccfd6
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/onset_gen_cxav.py
@@ -0,0 +1,184 @@
+import glob
+import os
+import numpy as np
+from moviepy.editor import *
+import librosa
+import soundfile as sf
+
+import argparse
+import numpy as np
+import os
+import sys
+import time
+from tqdm import tqdm
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+import torchvision.transforms as transforms
+from PIL import Image
+import shutil
+
+from config import init_args
+import data
+import models
+from models import *
+from utils import utils, torch_utils
+
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+vision_transform_list = [
+    transforms.Resize((128, 128)),
+    transforms.CenterCrop((112, 112)),
+    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+]
+video_transform = transforms.Compose(vision_transform_list)
+
+def read_image(frame_list):
+    imgs = []
+    convert_tensor = transforms.ToTensor()
+    for img_path in frame_list:
+        image = Image.open(img_path).convert('RGB')
+        image = convert_tensor(image)
+        imgs.append(image.unsqueeze(0))
+    # (T, C, H ,W)
+    imgs = torch.cat(imgs, dim=0).squeeze()
+    imgs = video_transform(imgs)
+    imgs = imgs.permute(1, 0, 2, 3)
+    # (C, T, H ,W)
+    return imgs
+
+
+def get_video_frames(origin_video_path):
+    save_path = 'results/temp'
+    if os.path.exists(save_path):
+        os.system(f'rm -rf {save_path}')
+    os.makedirs(save_path)
+    command = f'ffmpeg -v quiet -y -i \"{origin_video_path}\" -f image2 -vf \"scale=-1:360,fps=15\" -qscale:v 3 \"{save_path}\"/frame%06d.jpg'
+    os.system(command)
+    frame_list = glob.glob(f'{save_path}/*.jpg')
+    frame_list.sort()
+    frame_list = frame_list[:2 * 15]
+    frames = read_image(frame_list)
+    return frames
+
+
+def postprocess_video_onsets(probs, thres=0.5, nearest=5):
+    # import pdb; pdb.set_trace()
+    video_onsets = []
+    pred = np.array(probs, copy=True)
+    while True:
+        max_ind = np.argmax(pred)
+        video_onsets.append(max_ind)
+        low = max(max_ind - nearest, 0)
+        high = min(max_ind + nearest, pred.shape[0])
+        pred[low: high] = 0
+        if (pred > thres).sum() == 0:
+            break
+    video_onsets.sort()
+    video_onsets = np.array(video_onsets)
+    return video_onsets
+
+
+def detect_onset_of_audio(audio, sample_rate):
+    onsets = librosa.onset.onset_detect(
+        y=audio, sr=sample_rate, units='samples', delta=0.3)
+    return onsets
+
+
+def get_onset_audio_range(audio_len, onsets, i):
+    if i == 0:
+        prev_offset = int(onsets[i] // 3)
+    else:
+        prev_offset = int((onsets[i] - onsets[i - 1]) // 3)
+
+    if i == onsets.shape[0] - 1:
+        post_offset = int((audio_len - onsets[i]) // 4 * 2)
+    else:
+        post_offset = int((onsets[i + 1] - onsets[i]) // 4 * 2)
+    return prev_offset, post_offset
+
+
+def generate_audio(con_videoclip, video_onsets):
+    np.random.seed(2022)
+    con_audioclip = con_videoclip.audio
+    con_audio, con_sr = con_audioclip.to_soundarray(), con_audioclip.fps
+    con_audio = con_audio.mean(-1)
+    target_sr = 22050
+    if target_sr != con_sr:
+        con_audio = librosa.resample(con_audio, orig_sr=con_sr, target_sr=target_sr)
+        con_sr = target_sr
+    
+    con_onsets = detect_onset_of_audio(con_audio, con_sr)
+    gen_audio = np.zeros(int(2 * con_sr))
+
+    for i in range(video_onsets.shape[0]):
+        prev_offset, post_offset = get_onset_audio_range(
+            int(con_sr * 2), video_onsets, i)
+        j = np.random.choice(con_onsets.shape[0])
+        prev_offset_con, post_offset_con = get_onset_audio_range(
+            con_audio.shape[0], con_onsets, j)
+        prev_offset = min(prev_offset, prev_offset_con)
+        post_offset = min(post_offset, post_offset_con)
+        gen_audio[video_onsets[i] - prev_offset: video_onsets[i] +
+                post_offset] = con_audio[con_onsets[j] - prev_offset: con_onsets[j] + post_offset]
+    return gen_audio
+
+
+def generate_video(net, original_video_list, cond_video_lists):
+    save_folder = 'results/onset_baseline_cxav/vis4'
+    os.makedirs(save_folder, exist_ok=True)
+    origin_video_folder = os.path.join(save_folder, '0_original')
+    os.makedirs(origin_video_folder, exist_ok=True)
+
+    for i in range(len(original_video_list)):
+        # import pdb; pdb.set_trace()
+        shutil.copy(original_video_list[i], os.path.join(
+            origin_video_folder, cond_video_lists[0][i].split('/')[-1]))
+        
+        ori_videoclip = VideoFileClip(original_video_list[i])
+
+        frames = get_video_frames(original_video_list[i])
+        inputs = {
+            'frames': frames.unsqueeze(0).to(device)
+        }
+        pred = net(inputs).squeeze()
+        pred = torch.sigmoid(pred).data.cpu().numpy()
+        video_onsets = postprocess_video_onsets(pred, thres=0.5, nearest=4)
+        video_onsets = (video_onsets / 15 * 22050).astype(int)
+
+        for ind, cond_idx in enumerate(range(len(cond_video_lists))):
+            cond_video = cond_video_lists[cond_idx][i]
+            cond_video_folder = os.path.join(save_folder, f'{ind * 2 + 1}_conditional_{ind}')
+            os.makedirs(cond_video_folder, exist_ok=True)
+            shutil.copy(cond_video, os.path.join(
+                cond_video_folder, cond_video.split('/')[-1]))
+            con_videoclip = VideoFileClip(cond_video)
+            gen_audio = generate_audio(con_videoclip, video_onsets)
+            save_audio_path = 'results/gen_audio.wav'
+            sf.write(save_audio_path, gen_audio, 22050)
+            gen_audioclip = AudioFileClip(save_audio_path)
+            gen_videoclip = ori_videoclip.set_audio(gen_audioclip)
+            save_gen_folder = os.path.join(save_folder, f'{ind * 2 + 2}_generate_{ind}')
+            os.makedirs(save_gen_folder, exist_ok=True)
+            gen_videoclip.write_videofile(os.path.join(save_gen_folder, cond_video.split('/')[-1]))
+
+
+
+if __name__ == '__main__': 
+    net = models.VideoOnsetNet(pretrained=False).to(device)
+    resume = 'checkpoints/cxav_train/checkpoint_ep100.pth.tar'
+    net, _ = torch_utils.load_model(resume, net, device=device, strict=True)
+    read_folder = '' # name to a directory that generated with `audio_generation.py` 
+
+    cond_video_list_0 = glob.glob(f'{read_folder}/2sec_full_cond_video_0/*.mp4')
+    cond_video_list_0.sort()
+    original_video_list = ['_to_'.join(v.replace('2sec_full_cond_video_0', '2sec_full_orig_video').split('_to_')[:2])+'.mp4' for v in cond_video_list_0]
+    assert len(original_video_list) == len(cond_video_list_0)
+
+    generate_video(net, original_video_list, [cond_video_list_0,])
\ No newline at end of file
diff --git a/foleycrafter/models/specvqgan/onset_baseline/utils/__init__.py b/foleycrafter/models/specvqgan/onset_baseline/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..097a8993463a066fdbf215c91e723c7ee44727d8
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/utils/__init__.py
@@ -0,0 +1,6 @@
+from . import sourcesep
+from . import utils
+from . import sound
+from . import vis_utils
+from . import torch_utils
+from .data_sampler import ASMRSampler
\ No newline at end of file
diff --git a/foleycrafter/models/specvqgan/onset_baseline/utils/data_sampler.py b/foleycrafter/models/specvqgan/onset_baseline/utils/data_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c425a9c4570b93fafda1c6179554db26068be44
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/utils/data_sampler.py
@@ -0,0 +1,85 @@
+import copy
+import csv
+import json
+import numpy as np
+import os
+import pickle
+import random
+
+import torch
+from torch.utils.data.sampler import Sampler
+
+import pdb
+
+class ASMRSampler(Sampler): 
+    """
+    Total videos: 2794. The sampler ends when last $BATCH_SIZE videos are left. 
+    """
+    def __init__(self, list_sample, batch_size, rand_per_epoch=True): 
+        self.list_sample = list_sample
+        self.batch_size = batch_size
+        if not rand_per_epoch: 
+            random.seed(1234)
+
+        self.N = len(self.list_sample)
+        self.sample_class_dict =  self.generate_vid_dict()
+        # self.indexes = self.gen_index_batchwise()
+        # pdb.set_trace()
+
+    def generate_vid_dict(self): 
+        _ = [self.list_sample[i].append(i) for i in range(len(self.list_sample))]
+        sample_class_dict = {}
+        for i in range(len(self.list_sample)): 
+            video_name = self.list_sample[i][0]
+            if video_name not in sample_class_dict: 
+                sample_class_dict[video_name] = []
+            sample_class_dict[video_name].append(self.list_sample[i])
+        
+        return sample_class_dict
+
+    def gen_index_batchwise(self): 
+        indexes = []
+        scd_copy = copy.deepcopy(self.sample_class_dict)
+        for i in range(self.N // self.batch_size): 
+            if len(list(scd_copy.keys())) <= self.batch_size: 
+                break 
+            batch_vid = random.sample(scd_copy.keys(), self.batch_size)
+            for vid in batch_vid: 
+                rand_clip = random.choice(scd_copy[vid])
+                indexes.append(rand_clip[-1])
+                scd_copy[vid].remove(rand_clip)   # removed added element
+                # remove dict if empty
+                if len(scd_copy[vid]) == 0: 
+                    del scd_copy[vid]
+        
+        # add remain items to indexes
+        # for k, v in scd_copy.items(): 
+        #     for item in v: 
+        #         indexes.append(item[-1])
+        return indexes
+            
+    def __iter__(self): 
+        return iter(self.gen_index_batchwise())
+
+    def __len__(self): 
+        return self.N
+
+
+class VoxcelebSampler(Sampler): 
+    def __init__(self, list_sample, batch_size, rand_per_epoch=True): 
+        self.list_sample = list_sample
+        self.batch_size = batch_size
+        if not rand_per_epoch: 
+            random.seed(1234)
+        
+        self.N = len(self.list_sample)
+        self.sample_class_dict = self.generate_vid_dict()
+
+    def generate_vid_dict(self): 
+        _ = [self.sample[i].append(i) for i in range(len(self.list_sample))]
+        sample_class_dict = {}
+        pdb.set_trace()
+        for i in range(len(self.list_sample)): 
+            video_name = self.list_sample[i][0]
+            if video_name in batch_vid: 
+                pdb.set_trace()
\ No newline at end of file
diff --git a/foleycrafter/models/specvqgan/onset_baseline/utils/sound.py b/foleycrafter/models/specvqgan/onset_baseline/utils/sound.py
new file mode 100644
index 0000000000000000000000000000000000000000..a389c09aa21a8185ba0b4d1a63e327a8e40e4906
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/utils/sound.py
@@ -0,0 +1,151 @@
+import copy
+import numpy as np
+import scipy.io.wavfile
+import scipy.signal
+
+from . import utils as ut
+
+import pdb
+
+def load_sound(wav_fname):
+    rate, samples = scipy.io.wavfile.read(wav_fname)
+    times = (1./rate) * np.arange(len(samples))
+    return Sound(times, rate, samples)
+
+
+class Sound:
+    def __init__(self, times, rate, samples=None):
+        # Allow Sound(samples, sr)
+        if samples is None:
+            samples = times
+            times = None
+        if samples.dtype == np.float32:
+            samples = samples.astype('float64')
+
+        self.rate = rate
+        # self.samples = ut.atleast_2d_col(samples)
+        self.samples = samples
+
+        self.length = samples.shape[0]
+        if times is None:
+            self.times = np.arange(len(self.samples)) / float(self.rate)
+        else:
+            self.times = times
+
+    def copy(self):
+        return copy.deepcopy(self)
+
+    def parts(self):
+        return (self.times, self.rate, self.samples)
+
+    def __getslice__(self, *args):
+        return Sound(self.times.__getslice__(*args), self.rate,
+                    self.samples.__getslice__(*args))
+
+    def duration(self):
+        return self.samples.shape[0] / float(self.rate)
+
+    def normalized(self, check=True):
+        if self.samples.dtype == np.double:
+            assert (not check) or np.max(np.abs(self.samples)) <= 4.
+            x = copy.deepcopy(self)
+            x.samples = np.clip(x.samples, -1., 1.)
+            return x
+        else:
+            s = copy.deepcopy(self)
+            s.samples = np.array(s.samples, 'double') / np.iinfo(s.samples.dtype).max
+            s.samples[s.samples < -1] = -1
+            s.samples[s.samples > 1] = 1
+            return s
+
+    def unnormalized(self, dtype_name='int32'):
+        s = self.normalized()
+        inf = np.iinfo(np.dtype(dtype_name))
+        samples = np.clip(s.samples, -1., 1.)
+        samples = inf.max * samples
+        samples = np.array(np.clip(samples, inf.min, inf.max), dtype_name)
+        s.samples = samples
+        return s
+
+    def sample_from_time(self, t, bound=False):
+        if bound:
+            return min(max(0, int(np.round(t * self.rate))), self.samples.shape[0]-1)
+        else:
+            return int(np.round(t * self.rate))
+
+    # st = sample_from_time
+
+    def shift_zero(self):
+        s = copy.deepcopy(self)
+        s.times -= s.times[0]
+        return s
+
+    def select_channel(self, c):
+        s = copy.deepcopy(self)
+        s.samples = s.samples[:, c]
+        return s
+
+    def left_pad_silence(self, n):
+        if n == 0:
+            return self.shift_zero()
+        else:
+            if np.ndim(self.samples) == 1:
+                samples = np.concatenate([[0] * n, self.samples])
+            else:
+                samples = np.vstack(
+                [np.zeros((n, self.samples.shape[1]), self.samples.dtype), self.samples])
+        return Sound(None, self.rate, samples)
+
+    def right_pad_silence(self, n):
+        if n == 0:
+            return self.shift_zero()
+        else:
+            if np.ndim(self.samples) == 1:
+                samples = np.concatenate([self.samples, [0] * n])
+            else:
+                samples = np.vstack([self.samples, np.zeros(
+                (n, self.samples.shape[1]), self.samples.dtype)])
+        return Sound(None, self.rate, samples)
+
+    def pad_slice(self, s1, s2):
+        assert s1 < self.samples.shape[0] and s2 >= 0
+        s = self[max(0, s1): min(s2, self.samples.shape[0])]
+        s = s.left_pad_silence(max(0, -s1))
+        s = s.right_pad_silence(max(0, s2 - self.samples.shape[0]))
+        return s
+
+    def to_mono(self, force_copy= True):
+        s = copy.deepcopy(self)
+        s.samples = make_mono(s.samples)
+        return s
+
+    def slice_time(self, t1, t2):
+        return self[self.st(t1): self.st(t2)]
+
+    @property
+    def nchannels(self):
+        return 1 if np.ndim(self.samples) == 1 else self.samples.shape[1]
+
+    def save(self, fname):
+        s = self.unnormalized('int16')
+        scipy.io.wavfile.write(fname, s.rate, s.samples.transpose())
+
+    def resampled(self, new_rate, clip= True):
+        if new_rate == self.rate:
+            return copy.deepcopy(self)
+        else:
+        #assert self.samples.shape[1] == 1
+            return Sound(None, new_rate, self.resample(self.samples, float(new_rate)/self.rate, clip= clip))
+
+    def trim_to_size(self, n):
+        return Sound(None, self.rate, self.samples[:n])
+
+    def resample(self, signal, sc, clip = True, num_samples = None):
+        n = int(round(signal.shape[0] * sc)) if num_samples is None else num_samples
+        r = scipy.signal.resample(signal, n)
+    
+        if clip:
+            r = np.clip(r, -1, 1)
+        else: 
+            r = r.astype(np.int16)
+        return r
diff --git a/foleycrafter/models/specvqgan/onset_baseline/utils/sourcesep.py b/foleycrafter/models/specvqgan/onset_baseline/utils/sourcesep.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7498738c83db288ec64edbb432f763f172067bd
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/utils/sourcesep.py
@@ -0,0 +1,266 @@
+import numpy as np
+
+import torch
+import torchaudio.functional
+import torchaudio
+from . import utils
+
+import pdb
+
+
+def stft_frame_length(pr): return int(pr.frame_length_ms * pr.samp_sr * 0.001)
+
+def stft_frame_step(pr): return int(pr.frame_step_ms * pr.samp_sr * 0.001)
+
+
+def stft_num_fft(pr): return int(2**np.ceil(np.log2(stft_frame_length(pr))))
+
+def log10(x): return torch.log(x)/torch.log(torch.tensor(10.))
+
+
+def db_from_amp(x, cuda=False):
+    if cuda: 
+        return 20. * log10(torch.max(torch.tensor(1e-5).to('cuda'), x.float()))
+    else: 
+        return 20. * log10(torch.max(torch.tensor(1e-5), x.float()))
+
+
+def amp_from_db(x):
+    return torch.pow(10., x / 20.)
+
+
+def norm_range(x, min_val, max_val):
+    return 2.*(x - min_val)/float(max_val - min_val) - 1.
+
+def unnorm_range(y, min_val, max_val):
+  return 0.5*float(max_val - min_val) * (y + 1) + min_val
+
+def normalize_spec(spec, pr):
+    return norm_range(spec, pr.spec_min, pr.spec_max)
+
+
+def unnormalize_spec(spec, pr):
+    return unnorm_range(spec, pr.spec_min, pr.spec_max)
+
+
+def normalize_phase(phase, pr):
+    return norm_range(phase, -np.pi, np.pi)
+
+
+def unnormalize_phase(phase, pr):
+    return unnorm_range(phase, -np.pi, np.pi)
+
+
+def normalize_ims(im): 
+    if type(im) == type(np.array([])): 
+        im = im.astype('float32')
+    else: 
+        im = im.float()
+    return -1. + 2. * im
+
+
+def stft(samples, pr, cuda=False):
+    spec_complex = torch.stft(
+        samples, 
+        stft_num_fft(pr),
+        hop_length=stft_frame_step(pr), 
+        win_length=stft_frame_length(pr)).transpose(1,2)
+
+    real = spec_complex[..., 0]
+    imag = spec_complex[..., 1]
+    mag = torch.sqrt((real**2) + (imag**2))
+    phase = utils.angle(real, imag)
+    if pr.log_spec:
+        mag = db_from_amp(mag, cuda=cuda)
+    return mag, phase
+
+
+def make_complex(mag, phase):
+    return torch.cat(((mag * torch.cos(phase)).unsqueeze(-1), (mag * torch.sin(phase)).unsqueeze(-1)), -1)
+
+
+def istft(mag, phase, pr):
+    if pr.log_spec:
+        mag = amp_from_db(mag)
+    # print(make_complex(mag, phase).shape)
+    samples = torchaudio.functional.istft(
+        make_complex(mag, phase).transpose(1,2),
+        stft_num_fft(pr),
+        hop_length=stft_frame_step(pr),
+        win_length=stft_frame_length(pr))
+    return samples
+
+
+
+def aud2spec(sample, pr, stereo=False, norm=False, cuda=True): 
+    sample = sample[:, :pr.sample_len]
+    spec, phase = stft(sample.transpose(1,2).reshape((sample.shape[0]*2, -1)), pr, cuda=cuda)
+    spec = spec.reshape(sample.shape[0], 2, pr.spec_len, -1)
+    phase = phase.reshape(sample.shape[0], 2, pr.spec_len, -1)
+    return spec, phase
+
+
+def mix_sounds(samples0, pr, samples1=None, cuda=False, dominant=False, noise_ratio=0):
+    # pdb.set_trace()
+    samples0 = utils.normalize_rms(samples0, pr.input_rms)
+    if samples1 is not None:
+        samples1 = utils.normalize_rms(samples1, pr.input_rms)
+
+    if dominant: 
+        samples0 = samples0[:, :pr.sample_len]
+        samples1 = samples1[:, :pr.sample_len] * noise_ratio
+    else: 
+        samples0 = samples0[:, :pr.sample_len]
+        samples1 = samples1[:, :pr.sample_len]
+    
+    samples_mix = (samples0 + samples1)
+    if cuda: 
+        samples0 = samples0.to('cuda')
+        samples1 = samples1.to('cuda')
+        samples_mix = samples_mix.to('cuda')
+
+    spec_mix, phase_mix = stft(samples_mix, pr, cuda=cuda)
+
+    spec0, phase0 = stft(samples0, pr, cuda=cuda)
+    spec1, phase1 = stft(samples1, pr, cuda=cuda)
+
+    spec_mix = spec_mix[:, :pr.spec_len]
+    phase_mix = phase_mix[:, :pr.spec_len]
+    spec0 = spec0[:, :pr.spec_len]
+    spec1 = spec1[:, :pr.spec_len]
+    phase0 = phase0[:, :pr.spec_len]
+    phase1 = phase1[:, :pr.spec_len]
+
+    return utils.Struct(
+        samples=samples_mix.float(),
+        phase=phase_mix.float(),
+        spec=spec_mix.float(),
+        sample_parts=[samples0, samples1],
+        spec_parts=[spec0.float(), spec1.float()],
+        phase_parts=[phase0.float(), phase1.float()])
+
+
+def pit_loss(pred_spec_fg, pred_spec_bg, snd, pr, cuda=True, vis=False):
+    # if pr.norm_spec: 
+    def ns(x): return normalize_spec(x, pr)
+    # else: 
+    #     def ns(x): return x
+    if pr.norm:
+        gts_ = [[ns(snd.spec_parts[0]), None],
+                [ns(snd.spec_parts[1]), None]]
+        preds = [[ns(pred_spec_fg), None],
+                [ns(pred_spec_bg), None]]
+    else:
+        gts_ = [[snd.spec_parts[0], None],
+                [snd.spec_parts[1], None]]
+        preds = [[pred_spec_fg, None],
+                [pred_spec_bg, None]]
+
+    def l1(x, y): return torch.mean(torch.abs(x - y), (1, 2))
+    losses = []
+    for i in range(2):
+        gt = [gts_[i % 2], gts_[(i+1) % 2]]
+        fg_spec = pr.l1_weight * l1(preds[0][0], gt[0][0])
+        bg_spec = pr.l1_weight * l1(preds[1][0], gt[1][0])
+        losses.append(fg_spec + bg_spec)
+
+    losses = torch.cat([x.unsqueeze(0) for x in losses], dim=0)
+    if vis:
+        print(losses)
+    loss_val = torch.min(losses, dim=0)
+    if vis:
+        print(loss_val[1])
+    loss = torch.mean(loss_val[0])
+
+    return loss
+
+
+def diff_loss(spec_diff, phase_diff, snd, pr, device, norm=False, vis=False):
+    def ns(x): return normalize_spec(x, pr)
+    def np(x): return normalize_phase(x, pr)
+    criterion = torch.nn.L1Loss()
+    
+    gt_spec_diff = snd.spec_diff
+    gt_phase_diff = snd.phase_diff
+    criterion = criterion.to(device)
+
+    if norm:
+        gt_spec_diff = ns(gt_spec_diff)
+        gt_phase_diff = np(gt_phase_diff)
+        pred_spec_diff = ns(spec_diff)
+        pred_phase_diff = np(phase_diff)
+    else:
+        pred_spec_diff = spec_diff
+        pred_phase_diff = phase_diff
+
+    spec_loss = criterion(pred_spec_diff, gt_spec_diff)
+    phase_loss = criterion(pred_phase_diff, gt_phase_diff)
+    loss = pr.l1_weight * spec_loss + pr.phase_weight * phase_loss
+    if vis:
+        print(loss)
+    return loss
+
+# def pit_loss(out, snd, pr, cuda=False, vis=False):
+#     def ns(x): return normalize_spec(x, pr)
+#     def np(x): return normalize_phase(x, pr)
+#     if cuda: 
+#         snd['spec_part0'] = snd['spec_part0'].to('cuda')
+#         snd['phase_part0'] = snd['phase_part0'].to('cuda')
+#         snd['spec_part1'] = snd['spec_part1'].to('cuda')
+#         snd['phase_part1'] = snd['phase_part1'].to('cuda')
+#     # gts_ = [[ns(snd['spec_part0'][:, 0, :, :]), np(snd['phase_part0'][:, 0, :, :])],
+#     #         [ns(snd['spec_part1'][:, 0, :, :]), np(snd['phase_part1'][:, 0, :, :])]]
+#     gts_ = [[ns(snd.spec_parts[0]), np(snd.phase_parts[0])],
+#             [ns(snd.spec_parts[1]), np(snd.phase_parts[1])]]
+#     preds = [[ns(out.pred_spec_fg), np(out.pred_phase_fg)],
+#              [ns(out.pred_spec_bg), np(out.pred_phase_bg)]]
+    
+#     def l1(x, y): return torch.mean(torch.abs(x - y), (1, 2))
+#     losses = []
+#     for i in range(2):
+#         gt = [gts_[i % 2], gts_[(i+1) % 2]]
+#         #   print 'preds[0][0] shape =', shape(preds[0][0])
+#         # fg_spec = pr.l1_weight * l1(preds[0][0], gt[0][0])
+#         # fg_phase = pr.phase_weight * l1(preds[0][1], gt[0][1])
+
+#         # bg_spec = pr.l1_weight * l1(preds[1][0], gt[1][0])
+#         # bg_phase = pr.phase_weight * l1(preds[1][1], gt[1][1])
+
+#         # losses.append(fg_spec + fg_phase + bg_spec + bg_phase)
+#         fg_spec = pr.l1_weight * l1(preds[0][0], gt[0][0])
+
+#         bg_spec = pr.l1_weight * l1(preds[1][0], gt[1][0])
+
+#         losses.append(fg_spec + bg_spec)
+#     # pdb.set_trace()
+#     # pdb.set_trace()
+#     losses = torch.cat([x.unsqueeze(0) for x in losses], dim=0)
+#     if vis: 
+#         print(losses)
+#     loss_val = torch.min(losses, dim=0)
+#     if vis: 
+#         print(loss_val[1])
+#     loss = torch.mean(loss_val[0])
+
+#     return loss
+
+# def stereo_mel()
+
+
+def audio_stft(stft, audio, pr):
+    N, C, A = audio.size()
+    audio = audio.view(N * C, A)
+    spec = stft(audio)
+    spec = spec.transpose(-1, -2)
+    spec = db_from_amp(spec, cuda=True)
+    spec = normalize_spec(spec, pr)
+    _, T, F = spec.size()
+    spec = spec.view(N, C, T, F)
+    return spec
+
+
+def normalize_audio(samples, desired_rms=0.1, eps=1e-4):
+    # print(np.mean(samples**2))
+    rms = np.maximum(eps, np.sqrt(np.mean(samples**2)))
+    samples = samples * (desired_rms / rms)
+    return samples
\ No newline at end of file
diff --git a/foleycrafter/models/specvqgan/onset_baseline/utils/torch_utils.py b/foleycrafter/models/specvqgan/onset_baseline/utils/torch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4137e54d86b0ef520868c79f264c04852c590723
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/utils/torch_utils.py
@@ -0,0 +1,113 @@
+from collections import OrderedDict
+import os
+import numpy as np
+import random
+import sys
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+
+sys.path.append('..')
+import data
+
+
+# ---------------------------------------------------- #
+def load_model(cp_path, net, device=None, strict=True): 
+    if not device:
+        device = torch.device('cpu')
+    if os.path.isfile(cp_path): 
+        print("=> loading checkpoint '{}'".format(cp_path))
+        checkpoint = torch.load(cp_path, map_location=device)
+
+        # check if there is module
+        if list(checkpoint['state_dict'].keys())[0][:7] == 'module.': 
+            state_dict = OrderedDict()
+            for k, v in checkpoint['state_dict'].items(): 
+                name = k[7:]
+                state_dict[name] = v
+        else: 
+            state_dict = checkpoint['state_dict']
+        net.load_state_dict(state_dict, strict=strict) 
+
+        print("=> loaded checkpoint '{}' (epoch {})"
+                    .format(cp_path, checkpoint['epoch']))
+        start_epoch = checkpoint['epoch']
+    else: 
+        print("=> no checkpoint found at '{}'".format(cp_path))
+        start_epoch = 0
+        sys.exit()
+    
+    return net, start_epoch
+
+
+# ---------------------------------------------------- #
+def binary_acc(pred, target, thred):
+    pred = pred > thred
+    acc = np.sum(pred == target) / target.shape[0]
+    return acc
+
+def calc_acc(prob, labels, k):
+    pred = torch.argsort(prob, dim=-1, descending=True)[..., :k]
+    top_k_acc = torch.sum(pred == labels.view(-1, 1)).float() / labels.size(0)
+    return top_k_acc
+
+# ---------------------------------------------------- #
+
+def get_dataloader(args, pr, split='train', shuffle=False, drop_last=False, batch_size=None):
+    data_loader = getattr(data, pr.dataloader)
+    if split == 'train':
+        read_list = pr.list_train
+    elif split == 'val':
+        read_list = pr.list_val
+    elif split == 'test':
+        read_list = pr.list_test
+    dataset = data_loader(args, pr, read_list, split=split)
+    batch_size = batch_size if batch_size else args.batch_size
+    dataset.getitem_test(1)
+    loader = DataLoader(
+        dataset, 
+        batch_size=batch_size, 
+        shuffle=shuffle, 
+        num_workers=args.num_workers, 
+        pin_memory=True, 
+        drop_last=drop_last)
+    
+    return dataset, loader
+
+
+# ---------------------------------------------------- #
+def make_optimizer(model, args):
+    '''
+    Args:
+        model: NN to train
+    Returns:
+        optimizer: pytorch optmizer for updating the given model parameters.
+    '''
+    if args.optim == 'SGD':
+        optimizer = torch.optim.SGD(
+            filter(lambda p: p.requires_grad, model.parameters()),
+            lr=args.lr,
+            momentum=args.momentum,
+            weight_decay=args.weight_decay,
+            nesterov=False
+        )
+    elif args.optim == 'Adam':
+        optimizer = torch.optim.Adam(
+            filter(lambda p: p.requires_grad, model.parameters()),
+            lr=args.lr,
+            weight_decay=args.weight_decay,
+        )
+    return optimizer
+
+
+def adjust_learning_rate(optimizer, epoch, args):
+    """Decay the learning rate based on schedule"""
+    lr = args.lr
+    if args.schedule == 'cos':  # cosine lr schedule
+        lr *= 0.5 * (1. + np.cos(np.pi * epoch / args.epochs))
+    elif args.schedule == 'none':  # no lr schedule
+        lr = args.lr
+    for param_group in optimizer.param_groups:
+        param_group['lr'] = lr
\ No newline at end of file
diff --git a/foleycrafter/models/specvqgan/onset_baseline/utils/utils.py b/foleycrafter/models/specvqgan/onset_baseline/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9f7e72a27f3ff0954606d473a2a953fa4127590
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/utils/utils.py
@@ -0,0 +1,158 @@
+import copy
+import errno
+import inspect
+import numpy as np
+import os
+import sys
+
+import torch
+
+import pdb
+
+
+class LoggerOutput(object):
+    def __init__(self, fpath=None):
+        self.console = sys.stdout
+        self.file = None
+        if fpath is not None:
+            self.mkdir_if_missing(os.path.dirname(fpath))
+            self.file = open(fpath, 'w')
+
+    def __del__(self):
+        self.close()
+
+    def __enter__(self):
+        pass
+
+    def __exit__(self, *args):
+        self.close()
+
+    def write(self, msg):
+        self.console.write(msg)
+        if self.file is not None:
+            self.file.write(msg)
+
+    def flush(self):
+        self.console.flush()
+        if self.file is not None:
+            self.file.flush()
+            os.fsync(self.file.fileno())
+
+    def close(self):
+        self.console.close()
+        if self.file is not None:
+            self.file.close()
+
+    def mkdir_if_missing(self, dir_path):
+        try:
+            os.makedirs(dir_path)
+        except OSError as e:
+            if e.errno != errno.EEXIST:
+                raise
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+
+    def __init__(self):
+        self.initialized = False
+        self.val = None
+        self.avg = None
+        self.sum = None
+        self.count = None
+
+    def initialize(self, val, weight):
+        self.val = val
+        self.avg = val
+        self.sum = val*weight
+        self.count = weight
+        self.initialized = True
+
+    def update(self, val, weight=1):
+        val = np.asarray(val)
+        if not self.initialized:
+            self.initialize(val, weight)
+        else:
+            self.add(val, weight)
+
+    def add(self, val, weight):
+        self.val = val
+        self.sum += val * weight
+        self.count += weight
+        self.avg = self.sum / self.count
+
+    def value(self):
+        if self.val is None:
+            return 0.
+        else:
+            return self.val.tolist()
+
+    def average(self):
+        if self.avg is None:
+            return 0.
+        else:
+            return self.avg.tolist()
+
+
+class Struct:
+  def __init__(self, *dicts, **fields):
+    for d in dicts:
+      for k, v in d.iteritems():
+        setattr(self, k, v)
+    self.__dict__.update(fields)
+
+  def to_dict(self):
+    return {a: getattr(self, a) for a in self.attrs()}
+
+  def attrs(self):
+    #return sorted(set(dir(self)) - set(dir(Struct)))
+    xs = set(dir(self)) - set(dir(Struct))
+    xs = [x for x in xs if ((not (hasattr(self.__class__, x) and isinstance(getattr(self.__class__, x), property))) \
+        and (not inspect.ismethod(getattr(self, x))))]
+    return sorted(xs)
+
+  def updated(self, other_struct_=None, **kwargs):
+    s = copy.deepcopy(self)
+    if other_struct_ is not None:
+      s.__dict__.update(other_struct_.to_dict())
+    s.__dict__.update(kwargs)
+    return s
+
+  def copy(self):
+    return copy.deepcopy(self)
+
+  def __str__(self):
+    attrs = ', '.join('%s=%s' % (a, getattr(self, a)) for a in self.attrs())
+    return 'Struct(%s)' % attrs
+
+
+class Params(Struct):
+  def __init__(self, **kwargs):
+    self.__dict__.update(kwargs)
+
+
+def normalize_rms(samples, desired_rms=0.1, eps=1e-4):
+  rms = torch.max(torch.tensor(eps), torch.sqrt(
+      torch.mean(samples**2, dim=1)).float())
+  samples = samples * desired_rms / rms.unsqueeze(1)
+  return samples
+
+
+def normalize_rms_np(samples, desired_rms=0.1, eps=1e-4):
+  rms = np.maximum(eps, np.sqrt(np.mean(samples**2, 1)))
+  samples = samples * (desired_rms / rms)
+  return samples
+
+
+def angle(real, imag): 
+  return torch.atan2(imag, real)
+
+
+def atleast_2d_col(x):
+  x = np.asarray(x)
+  if np.ndim(x) == 0:
+    return x[np.newaxis, np.newaxis]
+  if np.ndim(x) == 1:
+    return x[:, np.newaxis]
+  else:
+    return x
diff --git a/foleycrafter/models/specvqgan/onset_baseline/utils/vis_utils.py b/foleycrafter/models/specvqgan/onset_baseline/utils/vis_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..67b95108fa71cb842eb405545bfca799b922fd4e
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/utils/vis_utils.py
@@ -0,0 +1,706 @@
+import copy
+import cv2
+import itertools as itl
+import json
+import kornia as K
+from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+from pathlib import Path
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+import pylab
+import random
+
+import torch
+
+import pdb
+
+def clip_rescale(x, lo = None, hi = None):
+    if lo is None:
+        lo = np.min(x)
+    if hi is None:
+        hi = np.max(x)
+    return np.clip((x - lo)/(hi - lo), 0., 1.)
+
+def apply_cmap(im, cmap = pylab.cm.jet, lo = None, hi = None):
+    return cmap(clip_rescale(im, lo, hi).flatten()).reshape(im.shape[:2] + (-1,))[:, :, :3]
+
+def cmap_im(cmap, im, lo = None, hi = None):
+    return np.uint8(255*apply_cmap(im, cmap, lo, hi))
+
+def calc_acc(prob, labels, k=1):
+    thred = 0.5
+    pred = torch.argsort(prob, dim=-1, descending=True)[..., :k]
+    corr = (pred.view(-1) == labels).cpu().numpy()
+    corr = corr.reshape((-1, resol*resol))
+    acc = corr.sum(1) / (resol*resol)  # compute rate of successful patch for each image
+    corr_index = np.where((acc > thred) == True)[0]
+    return corr_index    
+
+# def compute_acc_list(A_IS, k=0): 
+#     criterion = nn.NLLLoss()
+#     M, N = A_IS.size()
+#     target = torch.from_numpy(np.repeat(np.eye(N), M // N, axis=0)).to(DEVICE)
+#     _, labels = target.max(dim=1)
+#     loss = criterion(torch.log(A_IS), labels.long())
+#     acc = None
+#     if k > 0:
+#         corr_index = calc_acc(A_IS, labels, k)
+#     return corr_index
+
+def get_fcn_sim(full_img, feat_audio, net, B, resol, norm=True):
+    feat_img = net.forward_fcn(full_img)
+    feat_img = feat_img.permute(0, 2,3,1).reshape(-1, 128)
+    A_II, A_IS, A_SI = net.GetAMatrix(feat_img, feat_audio, norm=norm)
+    A_IS_ = A_IS.reshape((B, resol*resol, B))
+    A_IIS_ = (A_II @ A_IS).reshape((B, resol*resol, B))
+    A_II_ = A_II.reshape((B, resol*resol, B*resol*resol))
+
+    return A_IS_, A_IIS_, A_II_
+
+def upsample_lowest(sim, im_h, im_w, pr): 
+    sim_h, sim_w = sim.shape
+    prob_map_per_patch = np.zeros((im_h, im_w, pr.resol*pr.resol))
+    # pdb.set_trace()
+    for i in range(pr.resol): 
+        for j in range(pr.resol): 
+            y1 = pr.patch_stride * i 
+            y2 = pr.patch_stride * i + pr.psize
+            x1 = pr.patch_stride * j
+            x2 = pr.patch_stride * j + pr.psize
+            prob_map_per_patch[y1:y2, x1:x2, i * pr.resol + j] = sim[i, j]
+    # pdb.set_trace()
+    upsampled = np.sum(prob_map_per_patch, axis=-1) / np.sum(prob_map_per_patch > 0, axis=-1)
+
+    return upsampled
+
+
+def grid_interp(pr, input, output_size, mode='bilinear'):
+    # import pdb; pdb.set_trace()
+    n = 1
+    c = 1
+    ih, iw = input.shape
+    input = input.view(n, c, ih, iw)
+    oh, ow = output_size
+
+    pad = (pr.psize - pr.patch_stride) // 2 
+    ch = oh - pad * 2 
+    cw = ow - pad * 2
+    # normalize to [-1, 1]
+    h = (torch.arange(0, oh) - pad) / (ch-1) * 2 - 1
+    w = (torch.arange(0, ow) - pad) / (cw-1) * 2 - 1
+
+    grid = torch.zeros(oh, ow, 2)
+    grid[:, :, 0] = w.unsqueeze(0).repeat(oh, 1)
+    grid[:, :, 1] = h.unsqueeze(0).repeat(ow, 1).transpose(0, 1)
+    grid = grid.unsqueeze(0).repeat(n, 1, 1, 1) # grid.shape: [n, oh, ow, 2]
+    grid = grid.to(input.device)
+    res = torch.nn.functional.grid_sample(input, grid, mode=mode, padding_mode="border", align_corners=False).squeeze()
+    return res 
+
+
+def upsample_lowest_torch(sim, im_h, im_w, pr): 
+    sim = sim.reshape(pr.resol*pr.resol)
+    # precompute the temeplate
+    prob_map_per_patch = torch.from_numpy(pr.template).to('cuda')
+    prob_map_per_patch = prob_map_per_patch * sim.reshape(1,1,-1)
+    upsampled = torch.sum(prob_map_per_patch, dim=-1) / torch.sum(prob_map_per_patch > 0, dim=-1)
+
+    return upsampled
+
+
+def gen_vis_map(prob, im_h, im_w, pr, bound=False, lo=0, hi=0.3, mode='nearest'): 
+    """
+    prob: probability map for patches
+    im_h, im_w: original image size
+    resol: resolution of patches
+    bound: whether to give low and high bound for probability
+    lo: 
+    hi: 
+    mode: upsample method for probability
+    """
+    resol = pr.resol
+    if mode == 'nearest': 
+        resample = PIL.Image.NEAREST
+    elif mode == 'bilinear': 
+        resample = PIL.Image.BILINEAR
+    sim = prob.reshape((resol, resol))
+    # pdb.set_trace()
+    # updample similarity
+    if mode in ['nearest', 'bilinear']: 
+        if torch.is_tensor(sim): 
+            sim = sim.cpu().numpy()
+        sim_up = np.array(Image.fromarray(sim).resize((im_w, im_h), resample=resample))
+    elif mode == 'lowest': 
+        sim_up = upsample_lowest_torch(sim, im_w, im_h, pr)
+        sim_up = sim_up.detach().cpu().numpy()
+    elif mode == 'grid': 
+        sim_up = grid_interp(pr, sim, (im_h, im_w), 'bilinear')
+        sim_up = sim_up.detach().cpu().numpy()
+
+    if not bound: 
+        lo = None
+        hi = None
+    # generate heat map
+    # pdb.set_trace()
+    vis = cmap_im(pylab.cm.jet, sim_up, lo=lo, hi=hi)
+
+    # p weights the cmap on original image
+    p = sim_up / sim_up.max() * 0.3 + 0.3
+    p = p[..., None]
+    
+    return p, vis
+
+
+def gen_upsampled_prob(prob, im_h, im_w, pr, bound=False, lo=0, hi=0.3, mode='nearest'): 
+    """
+    prob: probability map for patches
+    im_h, im_w: original image size
+    resol: resolution of patches
+    bound: whether to give low and high bound for probability
+    lo: 
+    hi: 
+    mode: upsample method for probability
+    """
+    resol = pr.resol
+    if mode == 'nearest': 
+        resample = PIL.Image.NEAREST
+    elif mode == 'bilinear': 
+        resample = PIL.Image.BILINEAR
+    sim = prob.reshape((resol, resol))
+    # pdb.set_trace()
+    # updample similarity
+    if mode in ['nearest', 'bilinear']: 
+        if torch.is_tensor(sim): 
+            sim = sim.cpu().numpy()
+        sim_up = np.array(Image.fromarray(sim).resize((im_w, im_h), resample=resample))
+    elif mode == 'lowest': 
+        sim_up = upsample_lowest_torch(sim, im_w, im_h, pr)
+        sim_up = sim_up.cpu().numpy()
+    sim_up = sim_up / sim_up.max()
+    return sim_up
+
+
+def gen_vis_map_probmap_up(prob_up, bound=False, lo=0, hi=0.3, mode='nearest'): 
+    if mode == 'nearest': 
+        resample = PIL.Image.NEAREST
+    elif mode == 'bilinear': 
+        resample = PIL.Image.BILINEAR
+    if not bound: 
+        lo = None
+        hi = None
+    vis = cmap_im(pylab.cm.jet, prob_up, lo=None, hi=None)
+    if bound: 
+        # when hi gets larger, cmap becomes less visibal
+        p = prob_up / prob_up.max() * (0.3+0.4*(1-hi)) + 0.3
+    else: 
+        # if not bound, cmap always weights 0.3 on original image
+        p = prob_up / prob_up.max() * 0.3 + 0.3
+    p = p[..., None]
+    
+    return p, vis
+
+
+def rgb2bgr(im): 
+    return cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
+
+def gen_bbox_patches(im, patch_ind, resol, patch_size=64, lin_w=3, lin_color=np.array([255,0,0])): 
+    # TODO: make it work for different image size
+    stride = int((256-patch_size)/(resol-1))
+    
+    im_w, im_h = im.shape[1], im.shape[0]
+
+    r_ind = patch_ind // resol
+    c_ind = patch_ind % resol
+    y1 = r_ind * stride
+    y2 = r_ind * stride + patch_size
+    x1 = c_ind * stride
+    x2 = c_ind * stride + patch_size
+
+    im_bbox = copy.deepcopy(im)
+    im_bbox[y1:y1+lin_w, x1:x2, :] = lin_color
+    im_bbox[y2-lin_w:y2, x1:x2, :] = lin_color
+    im_bbox[y1:y2, x1:x1+lin_w, :] = lin_color
+    im_bbox[y1:y2, x2-lin_w:x2, :] = lin_color
+    
+    return (x1, y1, x2-x1, y2-y1), im_bbox 
+
+def get_fcn_sim(full_img, feat_audio, net, B, resol, norm=True):
+    feat_img = net.forward_fcn(full_img)
+    feat_img = feat_img.permute(0, 2,3,1).reshape(-1, 128)
+    A_II, A_IS, A_SI = net.GetAMatrix(feat_img, feat_audio, norm=norm)
+    A_IS_ = A_IS.reshape((B, resol*resol, B))
+    A_IIS_ = (A_II @ A_IS).reshape((B, resol*resol, B))
+    A_II_ = A_II.reshape((B, resol*resol, B, resol*resol))
+    return A_IS_, A_IIS_, A_II_
+
+def put_text(im, text, loc, font_scale=4): 
+    fontScale = font_scale
+    thickness = int(fontScale / 4)
+    fontColor = (0,255,255)
+    lineType = 4
+    im = cv2.putText(im, text, loc, cv2.FONT_HERSHEY_SIMPLEX, fontScale, fontColor, thickness, lineType)
+    return im
+
+def im2video(save_path, frame_list, fps=5): 
+    height, width, _ = frame_list[0].shape
+    fourcc = cv2.VideoWriter_fourcc(*'mp4v') 
+    video = cv2.VideoWriter(save_path, fourcc, fps, (width, height))
+    
+    for frame in frame_list: 
+        video.write(rgb2bgr(frame))
+
+    cv2.destroyAllWindows()
+    video.release()
+    new_name = "{}_new{}".format(save_path[:-4], save_path[-4:])
+    os.system("ffmpeg -v quiet -y -i \"{}\" -pix_fmt yuv420p -vcodec h264 -strict -2 -acodec aac \"{}\"".format(save_path, new_name))
+    os.system("rm -rf \"{}\"".format(save_path))
+
+def get_face_landmark(frame_path_): 
+    video_folder = Path(frame_path_).parent.parent
+    frame_name = frame_path_.split('/')[-1]
+    face_landmark_path = os.path.join(video_folder, "face_bbox_landmark.json")
+    if not os.path.exists(face_landmark_path): 
+        return None
+    with open(face_landmark_path, 'r') as f:
+        face_landmark = json.load(f)
+    if len(face_landmark[frame_name]) == 0: 
+        return None
+    b = face_landmark[frame_name][0]
+    return b
+
+def make_color_wheel():
+    # same source as color_flow
+
+    RY = 15
+    YG = 6
+    GC = 4
+    CB = 11
+    BM = 13
+    MR = 6
+
+    ncols = RY + YG + GC + CB + BM + MR
+
+    #colorwheel = zeros(ncols, 3) # r g b
+    # matlab correction
+    colorwheel = np.zeros((1+ncols, 4)) # r g b
+
+    col = 0
+    #RY
+    colorwheel[1:1+RY, 1] = 255
+    colorwheel[1:1+RY, 2] = np.floor(255*np.arange(0, 1+RY-1)/RY).T
+    col = col+RY
+
+    #YG
+    colorwheel[col+1:col+1+YG, 1] = 255 - np.floor(255*np.arange(0,1+YG-1)/YG).T
+    colorwheel[col+1:col+1+YG, 2] = 255
+    col = col+YG
+
+    #GC
+    colorwheel[col+1:col+1+GC, 2] = 255
+    colorwheel[col+1:col+1+GC, 3] = np.floor(255*np.arange(0,1+GC-1)/GC).T
+    col = col+GC
+
+    #CB
+    colorwheel[col+1:col+1+CB, 2] = 255 - np.floor(255*np.arange(0,1+CB-1)/CB).T
+    colorwheel[col+1:col+1+CB, 3] = 255
+    col = col+CB
+
+    #BM
+    colorwheel[col+1:col+1+BM, 3] = 255
+    colorwheel[col+1:col+1+BM, 1] = np.floor(255*np.arange(0,1+BM-1)/BM).T
+    col = col+BM
+
+    #MR
+    colorwheel[col+1:col+1+MR, 3] = 255 - np.floor(255*np.arange(0,1+MR-1)/MR).T
+    colorwheel[col+1:col+1+MR, 1] = 255  
+
+    # 1-based to 0-based indices
+    return colorwheel[1:, 1:]
+
+def warp(im, flow): 
+    # im : C x H x W
+    # flow : 2 x H x W, such that flow[dst_y, dst_x] = (src_x, src_y),
+    #     where (src_x, src_y) is the pixel location we want to sample from.
+
+    # grid_sample the grid is in the range in [-1, 1] 
+    grid =  -1. + 2. * flow/(-1 + np.array([im.shape[2], im.shape[1]], np.float32))[:, None, None]
+
+    # print('grid range =', grid.min(), grid.max())
+    ft = torch.FloatTensor
+    warped = torch.nn.functional.grid_sample(
+        ft(im[None].astype(np.float32)), 
+        ft(grid.transpose((1, 2, 0))[None]), 
+        mode = 'bilinear', padding_mode = 'zeros', align_corners=True)
+    return warped.cpu().numpy()[0].astype(im.dtype)
+
+def compute_color(u, v):
+    # from same source as color_flow; please see above comment
+    # nan_idx = ut.lor(np.isnan(u), np.isnan(v))
+    nan_idx = np.logical_or(np.isnan(u), np.isnan(v))
+    u[nan_idx] = 0
+    v[nan_idx] = 0
+    colorwheel = make_color_wheel()
+    ncols = colorwheel.shape[0]
+    
+    rad = np.sqrt(u**2 + v**2)
+
+    a = np.arctan2(-v, -u)/np.pi
+    
+    #fk = (a + 1)/2. * (ncols-1) + 1
+    fk = (a + 1)/2. * (ncols-1)
+
+    k0 = np.array(np.floor(fk), 'l')
+
+    k1 = k0 + 1
+    k1[k1 == ncols] = 1
+
+    f = fk - k0
+
+    im = np.zeros(u.shape + (3,))
+    
+    for i in range(colorwheel.shape[1]):
+        tmp = colorwheel[:, i]
+        col0 = tmp[k0]/255.
+        col1 = tmp[k1]/255.
+        col = (1-f)*col0 + f*col1
+
+        idx = rad <= 1
+        col[idx] = 1 - rad[idx]*(1-col[idx])
+        col[np.logical_not(idx)] *= 0.75
+        im[:, :, i] = np.uint8(np.floor(255*col*(1-nan_idx)))
+
+    return im
+
+def color_flow(flow, max_flow = None):
+    flow = flow.copy()
+    # based on flowToColor.m by Deqing Sun, orignally based on code by Daniel Scharstein
+    UNKNOWN_FLOW_THRESH = 1e9
+    UNKNOWN_FLOW = 1e10
+    height, width, nbands = flow.shape
+    assert nbands == 2
+    u, v = flow[:,:,0], flow[:,:,1]
+    maxu = -999.
+    maxv = -999.
+    minu = 999.
+    minv = 999.
+    maxrad = -1.
+
+    idx_unknown = np.logical_or(np.abs(u) > UNKNOWN_FLOW_THRESH,  np.abs(v) > UNKNOWN_FLOW_THRESH)
+    u[idx_unknown] = 0
+    v[idx_unknown] = 0
+    
+    maxu = max(maxu, np.max(u))
+    maxv = max(maxv, np.max(v))
+    
+    minu = min(minu, np.min(u))
+    minv = min(minv, np.min(v))
+
+    rad = np.sqrt(u**2 + v**2)
+    maxrad = max(maxrad, np.max(rad))
+
+    if max_flow > 0:
+        maxrad = max_flow
+
+    u = u/(maxrad + np.spacing(1))
+    v = v/(maxrad + np.spacing(1))
+    
+    im = compute_color(u, v)
+    im[idx_unknown] = 0
+    return im
+
+def plt_fig_to_np_img(fig): 
+    canvas = FigureCanvas(fig)  # draw the canvas, cache the renderer
+    canvas.draw() 
+    width, height = fig.get_size_inches() * fig.get_dpi()
+    image = np.fromstring(canvas.tostring_rgb(), dtype='uint8')
+    image = image.reshape(int(height), int(width), 3)
+
+    return image
+
+def save_np_img(image, path): 
+    cv2.imwrite(path, rgb2bgr(image))
+
+def find_patch_topk_aud(mat, top_k): 
+    top_k_ind = torch.argsort(mat, dim=-1, descending=True)[..., :top_k].squeeze()
+    top_k_ind = top_k_ind.reshape(-1).cpu().numpy()
+    return top_k_ind
+
+def find_patch_pred_topk(mat, top_k, target): 
+    M, N = mat.size()
+    labels = torch.from_numpy(target * np.ones(M)).to('cuda')
+    top_k_ind = torch.sum(torch.argsort(mat, dim=-1, descending=True)[..., :top_k] == labels.view(-1, 1), dim=-1).nonzero().reshape(-1)
+    top_k_ind  = top_k_ind.reshape(-1).cpu().numpy()
+    return top_k_ind
+
+def gen_masked_img(mask_ind, resol, img): 
+    mask = torch.zeros(resol*resol)
+    mask = mask.scatter_(0, torch.from_numpy(mask_ind), 1.)
+    mask = mask.reshape(resol, resol).numpy()
+    img_h = img.shape[1]
+    img_w = img.shape[0]
+    mask_up = np.array(Image.fromarray(mask*255).resize((img_h, img_w), resample=PIL.Image.NEAREST))
+    mask_up = mask_up[..., None]
+    image_seg = np.uint8(img * 0.7 + mask_up * 0.3)
+    
+    return image_seg
+
+def drop_2rand_ch(patch, remain_c=0): 
+    B, P, C, H, W = patch.shape
+    patch_c = patch[:, :, remain_c, :, :].unsqueeze(2)
+    # patch_droped = torch.zeros_like(patch)
+    # patch_droped[:, :, remain_c, :, :] = patch_c
+    c_std = torch.std(patch_c, dim=(3,4))
+    gauss_n = 0.5 + (0.01 * c_std.reshape(B, P, 1, 1, 1) * torch.randn(B, P, 2, H, W).to('cuda'))
+    
+    patch_dropped = torch.cat([gauss_n[:, :, :remain_c], patch_c, gauss_n[:, :, remain_c:]], dim=2)
+    
+    return patch_dropped
+    # pdb.set_trace()
+
+def vis_patch(patch, exp_path, resol, b_step): 
+    B, P, C, H, W = patch.shape
+    for i in range(B): 
+        patch_i = patch[i].reshape(resol, resol, C, H, W)
+        patch_i = patch_i.permute(2, 0, 3, 1, 4)
+        patch_folded_i = patch_i.reshape(C, resol*H, resol*W)
+        patch_folded_i = (patch_folded_i * 255).cpu().numpy().astype(np.uint8).transpose(1,2,0)
+        cv2.imwrite('{}/{}_{}_patch_folded.jpg'.format(exp_path, str(b_step).zfill(4), str(i).zfill(4)), rgb2bgr(patch_folded_i))
+
+def blur_patch(patch, k_size=3, sigma=0.5): 
+    B, P, C, H, W = patch.shape
+    gauss = K.filters.GaussianBlur2d((k_size, k_size), (sigma, sigma))
+    patch = patch.reshape(B*P, C, H, W)
+    blur_patch = gauss(patch).reshape(B, P, C, H, W)
+    return blur_patch
+
+def gray_project_patch(patch, device):
+    N, P, C, H, W = patch.size()
+    a = torch.tensor([[-1, 2, -1]]).float()
+    B = (torch.eye(3) - (a.T @ a) / (a @ a.T)).to(device)
+    patch = patch.permute(0, 1, 3, 4, 2)
+    patch = (patch @ B).permute(0, 1, 4, 2, 3)
+    return patch
+
+def parse_color(c):
+    if type(c) == type((0,)) or type(c) == type(np.array([1])):
+        return c
+    elif type(c) == type(''):
+        return color_from_string(c)
+
+def colors_from_input(color_input, default, n):
+    """ Parse color given as input argument; gives user several options """
+    # todo: generalize this to non-colors
+    expanded = None
+    if color_input is None:
+        expanded = [default] * n
+    elif (type(color_input) == type((1,))) and map(type, color_input) == [int, int, int]:
+        # expand (r, g, b) -> [(r, g, b), (r, g, b), ..]
+        expanded = [color_input] * n
+    else:
+        # general case: [(r1, g1, b1), (r2, g2, b2), ...]
+        expanded = color_input
+
+    expanded = map(parse_color, expanded)
+    return expanded
+
+def draw_pts(im, points, colors = None, width = 1, texts = None):
+    # ut.check(colors is None or len(colors) == len(points))
+    points = list(points)
+    colors = colors_from_input(colors, (255, 0, 0), len(points))
+    rects = [(p[0] - width/2, p[1] - width/2, width, width) for p in points]
+    return draw_rects(im, rects, fills = colors, outlines = [None]*len(points), texts = texts)
+
+def to_pil(im): 
+    #print im.dtype
+    return Image.fromarray(np.uint8(im))
+
+def from_pil(pil): 
+  #print pil
+  return np.array(pil)
+
+def draw_on(f, im):
+    pil = to_pil(im)
+    draw = ImageDraw.ImageDraw(pil)
+    f(draw)
+    return from_pil(pil)
+
+def fail(s = ''): raise RuntimeError(s)
+
+def check(cond, str = 'Check failed!'):
+    if not cond:
+        fail(str)
+
+def draw_rects(im, rects, outlines = None, fills = None, texts = None, text_colors = None, line_widths = None, as_oval = False):
+    rects = list(rects)
+    outlines = colors_from_input(outlines, (0, 0, 255), len(rects))
+    outlines = list(outlines)
+    text_colors = colors_from_input(text_colors, (255, 255, 255), len(rects))
+    text_colors = list(text_colors)
+    fills = colors_from_input(fills, None, len(rects))
+    fills = list(fills)
+    
+    if texts is None: texts = [None] * len(rects)
+    if line_widths is None: line_widths = [None] * len(rects)
+    
+    def check_size(x, s): 
+        check(x is None or len(list(x)) == len(rects), "%s different size from rects" % s)
+    check_size(outlines, 'outlines')
+    check_size(fills, 'fills')
+    check_size(texts, 'texts')
+    check_size(text_colors, 'texts')
+    
+    def f(draw):
+        for (x, y, w, h), outline, fill, text, text_color, lw in zip(rects, outlines, fills, texts, text_colors, line_widths):
+            if lw is None:
+                if as_oval:
+                    draw.ellipse((x, y, x + w, y + h), outline = outline, fill = fill)
+                else:
+                    draw.rectangle((x, y, x + w, y + h), outline = outline, fill = fill)
+            else:
+                d = int(np.ceil(lw/2))
+                draw.rectangle((x-d, y-d, x+w+d, y+d), fill = outline)
+                draw.rectangle((x-d, y-d, x+d, y+h+d), fill = outline)
+                
+                draw.rectangle((x+w+d, y+h+d, x-d, y+h-d), fill = outline)
+                draw.rectangle((x+w+d, y+h+d, x+w-d, y-d), fill = outline)
+                
+            if text is not None:
+                # draw text inside rectangle outline
+                border_width = 2
+                draw.text((border_width + x, y), text, fill = text_color)
+    return draw_on(f, im)
+
+def rand_color():
+    return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
+
+def int_tuple(x): 
+    return tuple([int(v) for v in x])
+
+itup = int_tuple
+
+red = (255, 0, 0)
+green = (0, 255, 0)
+blue = (0, 0, 255)
+yellow = (255, 255, 0)
+purple = (255, 0, 255)
+cyan = (0, 255, 255)
+
+
+def stash_seed(new_seed = 0):
+    """ Sets the random seed to new_seed. Returns the old seed. """
+    if type(new_seed) == type(''):
+        new_seed = hash(new_seed) % 2**32
+
+    py_state = random.getstate()
+    random.seed(new_seed)
+
+    np_state = np.random.get_state()
+    np.random.seed(new_seed)
+    return (py_state, np_state)
+
+
+def do_with_seed(f, seed = 0):
+    old_seed = stash_seed(seed)
+    res = f()
+    unstash_seed(old_seed[0], old_seed[1])
+    return res
+
+def sample_at_most(xs, bound):
+    return random.sample(xs, min(bound, len(xs)))
+
+class ColorChooser:
+    def __init__(self, dist_thresh = 500, attempts = 500, init_colors = [], init_pts = []):
+        self.pts = init_pts
+        self.colors = init_colors
+        self.attempts = attempts
+        self.dist_thresh = dist_thresh
+
+    def choose(self, new_pt = (0, 0)):
+        new_pt = np.array(new_pt)
+        nearby_colors = []
+        for pt, c in zip(self.pts, self.colors):
+            if np.sum((pt - new_pt)**2) <= self.dist_thresh**2:
+                nearby_colors.append(c)
+
+        if len(nearby_colors) == 0:
+            color_best = rand_color()
+        else:
+            nearby_colors = np.array(sample_at_most(nearby_colors, 100), 'l')
+            choices = np.array(np.random.rand(self.attempts, 3)*256, 'l')
+            dists = np.sqrt(np.sum((choices[:, np.newaxis, :] - nearby_colors[np.newaxis, :, :])**2, axis = 2))
+            costs = np.min(dists, axis = 1)
+        assert costs.shape == (len(choices),)
+        color_best = itup(choices[np.argmax(costs)])
+
+        self.pts.append(new_pt)
+        self.colors.append(color_best)
+        return color_best
+
+def unstash_seed(py_state, np_state):
+    random.setstate(py_state)
+    np.random.set_state(np_state)
+
+def distinct_colors(n):
+    #cc = ColorChooser(attempts = 10, init_colors = [red, green, blue, yellow, purple, cyan], init_pts = [(0, 0)]*6)
+    cc = ColorChooser(attempts = 100, init_colors = [red, green, blue, yellow, purple, cyan], init_pts = [(0, 0)]*6)
+    do_with_seed(lambda : [cc.choose((0,0)) for x in range(n)])
+    return cc.colors[:n]
+
+def make(w, h, fill = (0,0,0)):
+    return np.uint8(np.tile([[fill]], (h, w, 1)))
+
+def rgb_from_gray(img, copy = True, remove_alpha = True):
+    if img.ndim == 3 and img.shape[2] == 3:
+        return img.copy() if copy else img
+    elif img.ndim == 3 and img.shape[2] == 4:
+        return (img.copy() if copy else img)[..., :3]
+    elif img.ndim == 3 and img.shape[2] == 1:
+        return np.tile(img, (1,1,3))
+    elif img.ndim == 2:
+        return np.tile(img[:,:,np.newaxis], (1,1,3))
+    else:
+        raise RuntimeError('Cannot convert to rgb. Shape: ' + str(img.shape))
+
+def hstack_ims(ims, bg_color = (0, 0, 0)):
+    max_h = max([im.shape[0] for im in ims])
+    result = []
+    for im in ims:
+        #frame = np.zeros((max_h, im.shape[1], 3))
+        frame = make(im.shape[1], max_h, bg_color)
+        frame[:im.shape[0],:im.shape[1]] = rgb_from_gray(im)
+        result.append(frame)
+    return np.hstack(result)
+
+def gen_ranked_prob_map(prob_map): 
+    prob_ranked = torch.zeros_like(prob_map)
+    _, index = torch.topk(prob_map, len(prob_map), largest=False)
+    prob_ranked[index] = torch.arange(len(prob_map)).float().cuda()
+    prob_ranked = prob_ranked.float() / torch.max(prob_ranked)
+    return prob_ranked
+
+def get_topk_patch_mask(prob_map): 
+    # _, index = 
+    pass
+
+def load_img(frame_path): 
+    image = Image.open(frame_path).convert('RGB')
+    image = image.resize((256, 256), resample=PIL.Image.BILINEAR)
+    image = np.array(image)
+
+    img_h = image.shape[0]
+    img_w = image.shape[1]
+
+    return image, img_h, img_w
+
+def plt_subp_show_img(fig, img, cols, rows, subp_index, interpolation='bilinear', aspect='auto'): 
+    fig.add_subplot(rows, cols, subp_index)
+    plt.cla()
+    plt.axis('off')
+    plt.imshow(img, interpolation=interpolation, aspect=aspect)
+    return fig
+
+ 
+    
\ No newline at end of file
diff --git a/foleycrafter/models/specvqgan/onset_baseline/webify.py b/foleycrafter/models/specvqgan/onset_baseline/webify.py
new file mode 100644
index 0000000000000000000000000000000000000000..67bbf6399015e362e74a30003a74bf9e3f9f7c3a
--- /dev/null
+++ b/foleycrafter/models/specvqgan/onset_baseline/webify.py
@@ -0,0 +1,241 @@
+import os
+import datetime
+import sys
+import shutil
+import glob
+import argparse
+
+
+def parse_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--path', type=str)
+    parser.add_argument('--imgsize', type=int, default=100)
+    parser.add_argument('--num', type=int, default=10000)
+
+    args = parser.parse_args()
+    return args
+
+
+# --------------------------------------  joint ----------------------------------- #
+def create_audio_visual_sec(args, f, name):
+    dir_list = [name for name in os.listdir(
+        args.path) if os.path.isdir(os.path.join(args.path, name))]
+    dir_list.sort()
+
+    f.write('''<div align = "center">''')
+
+    joint_sec = """
+<h3>{}</h3>
+<table>
+<tbody>
+<tr>
+<th>Index #</th>
+    """.format(name)
+    for name in dir_list:
+        joint_sec += '''\n<th>{}</th>'''.format(name)
+    joint_sec += '''\n</tr>\n'''
+    f.write(joint_sec)
+
+    item_list = []
+    count = []
+    for i in range(len(dir_list)):
+        file_list = os.listdir(os.path.join(args.path, dir_list[i]))
+        file_list.sort()
+        count.append(len(file_list))
+        item_list.append(file_list)
+    file_count = min(count)
+    for j in range(min(file_count, args.num)):
+        f.write('''<tr>\n''')
+        for i in range(-1, len(dir_list)):
+            if i == -1:
+                f.write('''<td> sample #{} </td>'''.format(str(j)))
+                f.write('\n')
+            else:
+                sample = os.path.join(dir_list[i], item_list[i][j])
+                if sample.split('.')[-1] in ['wav', 'mp3']:
+                    f.write('''    <td> <div align = "center"><audio controls=""><source src='{}' type="audio/{}"></audio></div> </td>'''.format(
+                        sample, sample.split('.')[-1]))
+                elif sample.split('.')[-1] in ['jpg', 'png', 'gif']:
+                    f.write(
+                        '''    <td> <div align = "center"><img src='{}' style="zoom:{}%" /></div> </td>'''.format(sample, args.imgsize))
+                elif sample.split('.')[-1] in ['mp4', 'avi', 'webm']:
+                    f.write('''    <td> <div align = "center"><video id='{}' controls height='400'><source src="{}" type="video/{}" preload = "none"></video> <p>Speed: <input type="text" size="5" value="1" oninput="document.getElementById('{}').playbackRate = parseFloat(event.target.value);"></p></div> </td>'''.format(
+                        sample, sample, sample.split('.')[-1], sample))
+                f.write('\n')
+                # <video id='{}' controls><source src="{}" type="video/{}"></video>
+
+        f.write('''</tr>\n''')
+
+    f.write('''</tbody></table>\n''')
+    f.write('''</div>\n''')
+
+
+# --------------------------------------  Audio  ----------------------------------- #
+def create_audio_sec(args, f, name):
+    f.write('''<div align = "center">''')
+
+    audio_sec = """
+<h3>{}</h3>
+<table>
+<tbody>
+<tr>
+<th>Index #</th>
+<th>Mixture</th> 
+<th>Original audio #1</th>
+<th>Original audio #2</th>
+<th>Separated audio #1</th>
+<th>Separated audio #2</th>
+<th>regenerated audio mix</th>
+<th>regenerated audio #1</th>
+<th>regenerated audio #2</th>
+</tr>\n
+    """.format(name)
+    f.write(audio_sec)
+    folder_path = os.path.join(args.path, 'audio')
+    dir_list = os.listdir(folder_path)
+    dir_list.sort()
+    audio_list = []
+    for i in range(len(dir_list)):
+        l = os.listdir(os.path.join(folder_path, dir_list[i]))
+        l.sort()
+        audio_list.append(l)
+
+    for j in range(len(audio_list[0])):
+        f.write('''<tr>\n''')
+        for i in range(-1, len(dir_list)):
+            if i == -1:
+                f.write('''<td> audio #{} </td>'''.format(str(j)))
+                f.write('\n')
+            else:
+                audio_path = os.path.join(
+                    folder_path, dir_list[i], audio_list[i][j])
+                f.write('''    <td> <audio controls=""><source src='{}' type="audio/{}"></audio> </td>'''.format(
+                    audio_path, audio_path.split('.')[-1]))
+                f.write('\n')
+        f.write('''</tr>\n''')
+
+    f.write('''</tbody></table>\n''')
+    f.write('''</div>\n''')
+
+
+# --------------------------------------  Image ----------------------------------- #
+def create_image_sec(args, f, name):
+    f.write('''<div align = "center">''')
+
+    image_sec = """
+<h3>{}</h3>
+<table>
+<tbody>
+<tr>
+<th>Index #</th>
+<th>Mixture Spec </th> 
+<th>Original Spec #1</th>
+<th>Original Spec #2</th>
+<th>Separated Spec #1</th>
+<th>Separated Spec #2</th>
+</tr>\n
+    """.format(name)
+
+    f.write(image_sec)
+    folder_path = os.path.join(args.path, 'spec_img')
+    dir_list = os.listdir(folder_path)
+    dir_list.sort()
+    image_list = []
+    for i in range(len(dir_list)):
+        l = os.listdir(os.path.join(folder_path, dir_list[i]))
+        l.sort()
+        image_list.append(l)
+
+    for j in range(len(image_list[0])):
+        f.write('''<tr>\n''')
+        for i in range(-1, len(dir_list)):
+            if i == -1:
+                f.write('''<td> audio #{} </td>'''.format(str(j)))
+                f.write('\n')
+            else:
+                img_path = os.path.join(
+                    folder_path, dir_list[i], image_list[i][j])
+                f.write('''    <td> <div align = "center"><img src='{}' style="zoom:{}%" /></div> </td>'''.format(
+                    img_path, 175))
+                f.write('\n')
+        f.write('''</tr>\n''')
+
+    f.write('''</tbody></table>\n''')
+    f.write('''</div>\n''')
+
+# --------------------------------------  Video ----------------------------------- #
+
+
+def create_video_sec(args, f, name):
+    f.write('''<div align = "center">''')
+
+    video_sec = """
+<h3>{}</h3>
+<table>
+<tbody>
+<tr>
+<th></th>
+<th></th>
+<th></th>
+</tr>\n
+    """.format(name)
+
+    f.write(video_sec)
+    # folder_path = os.path.join(args.path, 'videos')
+    video_list = glob.glob('%s/*.mp4' % args.path)
+    video_list.sort()
+
+    columns = 3
+    rows = len(video_list) // columns + 1
+
+    for i in range(rows):
+        f.write('''<tr>\n''')
+        for j in range(columns):
+            index = i * columns + j
+            if index < len(video_list):
+                video_path = video_list[i * columns + j]
+                f.write('''    <td> <div align = "center"><h4>{}</h4><video width="480" onmouseover = "this.controls = true;" onmouseout = "this.controls = false;"><source src="{}" type="video/{}"></video></div> </td>'''.format(
+                    video_path.split('/')[-1], video_path, video_path.split('.')[-1]))
+                f.write('\n')
+
+        f.write('''</tr>\n''')
+
+    f.write('''</tbody></table>\n''')
+    f.write('''</div>\n''')
+
+
+def webify(args):
+    html_file = os.path.join(args.path, 'index.html')
+    f = open(html_file, 'wt')
+
+    # head
+    # <link rel="stylesheet" type="text/css" title="Cool stylesheet" href="style.css">
+    head = """<!DOCTYPE html>
+<html lang="en"><head><meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
+<meta name="viewport" content="width=device-width, initial-scale=1.0">
+<title>Listening and Looking - UM Owens Lab</title>
+</head>
+    """
+    f.write(head)
+
+    intro_sec = '''
+<body data-gr-c-s-loaded="true">
+<h1> Listening and Looking - UM Owens Lab </h1>
+<h5> Creator: Ziyang Chen <br>
+University of Michigan </h5>
+<p> This page contains the results of experiment.</p>
+'''
+    f.write(intro_sec)
+    # create_audio_sec(args, f, "Audio Separation")
+    # create_image_sec(args, f, 'Spectorgram Visualization')
+    # create_video_sec(args, f, 'CAM Visualization')
+    create_audio_visual_sec(args, f, 'Stereo CRW')
+    f.write('''</body>\n''')
+    f.write('''</html>\n''')
+    f.close()
+
+
+if __name__ == "__main__":
+    args = parse_args()
+    webify(args)
+    print('Webify Succeed!')
diff --git a/foleycrafter/models/specvqgan/util.py b/foleycrafter/models/specvqgan/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..deb92db0bd3157ffe72bab1ea909a14eceea8694
--- /dev/null
+++ b/foleycrafter/models/specvqgan/util.py
@@ -0,0 +1,150 @@
+import hashlib
+import os
+
+import requests
+from tqdm import tqdm
+
+URL_MAP = {
+    'vggishish_lpaps': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt',
+    'vggishish_mean_std_melspec_10s_22050hz': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt',
+    'melception': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt',
+}
+
+CKPT_MAP = {
+    'vggishish_lpaps': 'vggishish16.pt',
+    'vggishish_mean_std_melspec_10s_22050hz': 'train_means_stds_melspec_10s_22050hz.txt',
+    'melception': 'melception-21-05-10T09-28-40.pt',
+}
+
+MD5_MAP = {
+    'vggishish_lpaps': '197040c524a07ccacf7715d7080a80bd',
+    'vggishish_mean_std_melspec_10s_22050hz': 'f449c6fd0e248936c16f6d22492bb625',
+    'melception': 'a71a41041e945b457c7d3d814bbcf72d',
+}
+
+
+def download(url, local_path, chunk_size=1024):
+    os.makedirs(os.path.split(local_path)[0], exist_ok=True)
+    with requests.get(url, stream=True) as r:
+        total_size = int(r.headers.get("content-length", 0))
+        with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
+            with open(local_path, "wb") as f:
+                for data in r.iter_content(chunk_size=chunk_size):
+                    if data:
+                        f.write(data)
+                        pbar.update(chunk_size)
+
+
+def md5_hash(path):
+    with open(path, "rb") as f:
+        content = f.read()
+    return hashlib.md5(content).hexdigest()
+
+
+def get_ckpt_path(name, root, check=False):
+    assert name in URL_MAP
+    path = os.path.join(root, CKPT_MAP[name])
+    if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
+        print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
+        download(URL_MAP[name], path)
+        md5 = md5_hash(path)
+        assert md5 == MD5_MAP[name], md5
+    return path
+
+
+class KeyNotFoundError(Exception):
+    def __init__(self, cause, keys=None, visited=None):
+        self.cause = cause
+        self.keys = keys
+        self.visited = visited
+        messages = list()
+        if keys is not None:
+            messages.append("Key not found: {}".format(keys))
+        if visited is not None:
+            messages.append("Visited: {}".format(visited))
+        messages.append("Cause:\n{}".format(cause))
+        message = "\n".join(messages)
+        super().__init__(message)
+
+
+def retrieve(
+    list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
+):
+    """Given a nested list or dict return the desired value at key expanding
+    callable nodes if necessary and :attr:`expand` is ``True``. The expansion
+    is done in-place.
+
+    Parameters
+    ----------
+        list_or_dict : list or dict
+            Possibly nested list or dictionary.
+        key : str
+            key/to/value, path like string describing all keys necessary to
+            consider to get to the desired value. List indices can also be
+            passed here.
+        splitval : str
+            String that defines the delimiter between keys of the
+            different depth levels in `key`.
+        default : obj
+            Value returned if :attr:`key` is not found.
+        expand : bool
+            Whether to expand callable nodes on the path or not.
+
+    Returns
+    -------
+        The desired value or if :attr:`default` is not ``None`` and the
+        :attr:`key` is not found returns ``default``.
+
+    Raises
+    ------
+        Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
+        ``None``.
+    """
+
+    keys = key.split(splitval)
+
+    success = True
+    try:
+        visited = []
+        parent = None
+        last_key = None
+        for key in keys:
+            if callable(list_or_dict):
+                if not expand:
+                    raise KeyNotFoundError(
+                        ValueError(
+                            "Trying to get past callable node with expand=False."
+                        ),
+                        keys=keys,
+                        visited=visited,
+                    )
+                list_or_dict = list_or_dict()
+                parent[last_key] = list_or_dict
+
+            last_key = key
+            parent = list_or_dict
+
+            try:
+                if isinstance(list_or_dict, dict):
+                    list_or_dict = list_or_dict[key]
+                else:
+                    list_or_dict = list_or_dict[int(key)]
+            except (KeyError, IndexError, ValueError) as e:
+                raise KeyNotFoundError(e, keys=keys, visited=visited)
+
+            visited += [key]
+        # final expansion of retrieved value
+        if expand and callable(list_or_dict):
+            list_or_dict = list_or_dict()
+            parent[last_key] = list_or_dict
+    except KeyNotFoundError as e:
+        if default is None:
+            raise e
+        else:
+            list_or_dict = default
+            success = False
+
+    if not pass_success:
+        return list_or_dict
+    else:
+        return list_or_dict, success
diff --git a/foleycrafter/models/time_detector/model.py b/foleycrafter/models/time_detector/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..78c97ed083ebde61e6173739f7b2a567bc8a0f3f
--- /dev/null
+++ b/foleycrafter/models/time_detector/model.py
@@ -0,0 +1,16 @@
+import torch
+import torch.nn as nn
+from foleycrafter.models.specvqgan.onset_baseline.models import VideoOnsetNet
+
+class TimeDetector(nn.Module):
+    def __init__(self, video_length=150, audio_length=1024):
+        super(TimeDetector, self).__init__()
+        self.pred_net = VideoOnsetNet(pretrained=False)
+        self.soft_fn  = nn.Tanh()
+        self.up_sampler = nn.Linear(video_length, audio_length)
+         
+    def forward(self, inputs):
+        x = self.pred_net(inputs)
+        x = self.up_sampler(x)
+        x = self.soft_fn(x)
+        return x
\ No newline at end of file
diff --git a/foleycrafter/models/time_detector/resnet.py b/foleycrafter/models/time_detector/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..07a01f25a886c9e8e32bc6e4833c284d302bc050
--- /dev/null
+++ b/foleycrafter/models/time_detector/resnet.py
@@ -0,0 +1,347 @@
+import torch.nn as nn
+
+from torch.hub import load_state_dict_from_url
+
+
+__all__ = ['r3d_18', 'mc3_18', 'r2plus1d_18']
+
+model_urls = {
+    'r3d_18': 'https://download.pytorch.org/models/r3d_18-b3b3357e.pth',
+    'mc3_18': 'https://download.pytorch.org/models/mc3_18-a90a0ba3.pth',
+    'r2plus1d_18': 'https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth',
+}
+
+
+class Conv3DSimple(nn.Conv3d):
+    def __init__(self,
+                 in_planes,
+                 out_planes,
+                 midplanes=None,
+                 stride=1,
+                 padding=1):
+
+        super(Conv3DSimple, self).__init__(
+            in_channels=in_planes,
+            out_channels=out_planes,
+            kernel_size=(3, 3, 3),
+            stride=stride,
+            padding=padding,
+            bias=False)
+
+    @staticmethod
+    def get_downsample_stride(stride):
+        return stride, stride, stride
+
+
+class Conv2Plus1D(nn.Sequential):
+
+    def __init__(self,
+                 in_planes,
+                 out_planes,
+                 midplanes,
+                 stride=1,
+                 padding=1):
+        super(Conv2Plus1D, self).__init__(
+            nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3),
+                      stride=(1, stride, stride), padding=(0, padding, padding),
+                      bias=False),
+            nn.BatchNorm3d(midplanes),
+            nn.ReLU(inplace=True),
+            nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1),
+                      stride=(stride, 1, 1), padding=(padding, 0, 0),
+                      bias=False))
+
+    @staticmethod
+    def get_downsample_stride(stride):
+        return stride, stride, stride
+
+
+class Conv3DNoTemporal(nn.Conv3d):
+
+    def __init__(self,
+                 in_planes,
+                 out_planes,
+                 midplanes=None,
+                 stride=1,
+                 padding=1):
+
+        super(Conv3DNoTemporal, self).__init__(
+            in_channels=in_planes,
+            out_channels=out_planes,
+            kernel_size=(1, 3, 3),
+            stride=(1, stride, stride),
+            padding=(0, padding, padding),
+            bias=False)
+
+    @staticmethod
+    def get_downsample_stride(stride):
+        return 1, stride, stride
+
+
+class BasicBlock(nn.Module):
+
+    expansion = 1
+
+    def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
+        midplanes = (inplanes * planes * 3 * 3 *
+                     3) // (inplanes * 3 * 3 + 3 * planes)
+
+        super(BasicBlock, self).__init__()
+        self.conv1 = nn.Sequential(
+            conv_builder(inplanes, planes, midplanes, stride),
+            nn.BatchNorm3d(planes),
+            nn.ReLU(inplace=True)
+        )
+        self.conv2 = nn.Sequential(
+            conv_builder(planes, planes, midplanes),
+            nn.BatchNorm3d(planes)
+        )
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.conv2(out)
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
+
+        super(Bottleneck, self).__init__()
+        midplanes = (inplanes * planes * 3 * 3 *
+                     3) // (inplanes * 3 * 3 + 3 * planes)
+
+        # 1x1x1
+        self.conv1 = nn.Sequential(
+            nn.Conv3d(inplanes, planes, kernel_size=1, bias=False),
+            nn.BatchNorm3d(planes),
+            nn.ReLU(inplace=True)
+        )
+        # Second kernel
+        self.conv2 = nn.Sequential(
+            conv_builder(planes, planes, midplanes, stride),
+            nn.BatchNorm3d(planes),
+            nn.ReLU(inplace=True)
+        )
+
+        # 1x1x1
+        self.conv3 = nn.Sequential(
+            nn.Conv3d(planes, planes * self.expansion,
+                      kernel_size=1, bias=False),
+            nn.BatchNorm3d(planes * self.expansion)
+        )
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.conv2(out)
+        out = self.conv3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class BasicStem(nn.Sequential):
+    """The default conv-batchnorm-relu stem
+    """
+
+    def __init__(self):
+        super(BasicStem, self).__init__(
+            nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2),
+                      padding=(1, 3, 3), bias=False),
+            nn.BatchNorm3d(64),
+            nn.ReLU(inplace=True))
+
+
+class R2Plus1dStem(nn.Sequential):
+    """R(2+1)D stem is different than the default one as it uses separated 3D convolution
+    """
+
+    def __init__(self):
+        super(R2Plus1dStem, self).__init__(
+            nn.Conv3d(3, 45, kernel_size=(1, 7, 7),
+                      stride=(1, 2, 2), padding=(0, 3, 3),
+                      bias=False),
+            nn.BatchNorm3d(45),
+            nn.ReLU(inplace=True),
+            nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
+                      stride=(1, 1, 1), padding=(1, 0, 0),
+                      bias=False),
+            nn.BatchNorm3d(64),
+            nn.ReLU(inplace=True))
+
+
+class VideoResNet(nn.Module):
+
+    def __init__(self, block, conv_makers, layers,
+                 stem, num_classes=400,
+                 zero_init_residual=False):
+        """Generic resnet video generator.
+        Args:
+            block (nn.Module): resnet building block
+            conv_makers (list(functions)): generator function for each layer
+            layers (List[int]): number of blocks per layer
+            stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
+            num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
+            zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
+        """
+        super(VideoResNet, self).__init__()
+        self.inplanes = 64
+
+        self.stem = stem()
+
+        self.layer1 = self._make_layer(
+            block, conv_makers[0], 64, layers[0], stride=1)
+        self.layer2 = self._make_layer(
+            block, conv_makers[1], 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(
+            block, conv_makers[2], 256, layers[2], stride=2)
+        self.layer4 = self._make_layer(
+            block, conv_makers[3], 512, layers[3], stride=2)
+
+        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
+        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+        # init weights
+        self._initialize_weights()
+
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, Bottleneck):
+                    nn.init.constant_(m.bn3.weight, 0)
+
+    def forward(self, x):
+        x = self.stem(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        # Flatten the layer to fc
+        # x = x.flatten(1)
+        # x = self.fc(x)
+        N = x.shape[0]
+        x = x.squeeze()
+        if N == 1:
+            x = x[None]
+
+        return x
+
+    def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
+        downsample = None
+
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            ds_stride = conv_builder.get_downsample_stride(stride)
+            downsample = nn.Sequential(
+                nn.Conv3d(self.inplanes, planes * block.expansion,
+                          kernel_size=1, stride=ds_stride, bias=False),
+                nn.BatchNorm3d(planes * block.expansion)
+            )
+        layers = []
+        layers.append(block(self.inplanes, planes,
+                      conv_builder, stride, downsample))
+
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes, conv_builder))
+
+        return nn.Sequential(*layers)
+
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv3d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out',
+                                        nonlinearity='relu')
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.BatchNorm3d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.Linear):
+                nn.init.normal_(m.weight, 0, 0.01)
+                nn.init.constant_(m.bias, 0)
+
+
+def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
+    model = VideoResNet(**kwargs)
+
+    if pretrained:
+        state_dict = load_state_dict_from_url(model_urls[arch],
+                                              progress=progress)
+        model.load_state_dict(state_dict)
+    return model
+
+
+def r3d_18(pretrained=False, progress=True, **kwargs):
+    """Construct 18 layer Resnet3D model as in
+    https://arxiv.org/abs/1711.11248
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on Kinetics-400
+        progress (bool): If True, displays a progress bar of the download to stderr
+    Returns:
+        nn.Module: R3D-18 network
+    """
+
+    return _video_resnet('r3d_18',
+                         pretrained, progress,
+                         block=BasicBlock,
+                         conv_makers=[Conv3DSimple] * 4,
+                         layers=[2, 2, 2, 2],
+                         stem=BasicStem, **kwargs)
+
+
+def mc3_18(pretrained=False, progress=True, **kwargs):
+    """Constructor for 18 layer Mixed Convolution network as in
+    https://arxiv.org/abs/1711.11248
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on Kinetics-400
+        progress (bool): If True, displays a progress bar of the download to stderr
+    Returns:
+        nn.Module: MC3 Network definition
+    """
+    return _video_resnet('mc3_18',
+                         pretrained, progress,
+                         block=BasicBlock,
+                         conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3,
+                         layers=[2, 2, 2, 2],
+                         stem=BasicStem, **kwargs)
+
+
+def r2plus1d_18(pretrained=False, progress=True, **kwargs):
+    """Constructor for the 18 layer deep R(2+1)D network as in
+    https://arxiv.org/abs/1711.11248
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on Kinetics-400
+        progress (bool): If True, displays a progress bar of the download to stderr
+    Returns:
+        nn.Module: R(2+1)D-18 network
+    """
+    return _video_resnet('r2plus1d_18',
+                         pretrained, progress,
+                         block=BasicBlock,
+                         conv_makers=[Conv2Plus1D] * 4,
+                         layers=[2, 2, 2, 2],
+                         stem=R2Plus1dStem, **kwargs)
\ No newline at end of file
diff --git a/foleycrafter/pipelines/auffusion_pipeline.py b/foleycrafter/pipelines/auffusion_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cdaf10e60f53cee3cc86a0103b97560c5aa84bb
--- /dev/null
+++ b/foleycrafter/pipelines/auffusion_pipeline.py
@@ -0,0 +1,2103 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+from dataclasses import dataclass
+
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, ImageProjection
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.image_processor import PipelineImageInput
+from diffusers.models.attention_processor import FusedAttnProcessor2_0
+from diffusers.utils import (
+    deprecate,
+    is_accelerate_available,
+    is_accelerate_version,
+    logging,
+    replace_example_docstring,
+)
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from huggingface_hub import snapshot_download
+from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler
+from transformers import PretrainedConfig, AutoTokenizer
+import torch.nn as nn
+import os, json, PIL
+import numpy as np
+import torch.nn.functional as F
+from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+from diffusers.utils.outputs import BaseOutput
+import matplotlib.pyplot as plt
+
+from foleycrafter.models.auffusion_unet import UNet2DConditionModel
+from foleycrafter.models.adapters.ip_adapter import VideoProjModel
+from foleycrafter.models.auffusion.loaders.ip_adapter import IPAdapterMixin
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+
+def json_dump(data_json, json_save_path):
+    with open(json_save_path, 'w') as f:
+        json.dump(data_json, f, indent=4)
+        f.close()
+
+
+def json_load(json_path):
+    with open(json_path, 'r') as f:
+        data = json.load(f)
+        f.close()
+    return data                
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):
+    text_encoder_config = PretrainedConfig.from_pretrained(
+        pretrained_model_name_or_path
+    )
+    model_class = text_encoder_config.architectures[0]
+
+    if model_class == "CLIPTextModel":
+        from transformers import CLIPTextModel
+        return CLIPTextModel
+    if "t5" in model_class.lower():
+        from transformers import T5EncoderModel
+        return T5EncoderModel
+    if "clap" in model_class.lower():
+        from transformers import ClapTextModelWithProjection
+        return ClapTextModelWithProjection
+    else:
+        raise ValueError(f"{model_class} is not supported.")
+
+
+class ConditionAdapter(nn.Module):
+    def __init__(self, config):
+        super(ConditionAdapter, self).__init__()
+        self.config = config
+        self.proj = nn.Linear(self.config["condition_dim"], self.config["cross_attention_dim"])
+        self.norm = torch.nn.LayerNorm(self.config["cross_attention_dim"])
+        print(f"INITIATED: ConditionAdapter: {self.config}")
+
+    def forward(self, x):
+        x = self.proj(x)
+        x = self.norm(x)
+        return x
+    
+    @classmethod
+    def from_pretrained(cls, pretrained_model_name_or_path):
+        config_path = os.path.join(pretrained_model_name_or_path, "config.json")
+        ckpt_path = os.path.join(pretrained_model_name_or_path, "condition_adapter.pt")
+        config = json.loads(open(config_path).read())
+        instance = cls(config)
+        instance.load_state_dict(torch.load(ckpt_path))
+        print(f"LOADED: ConditionAdapter from {pretrained_model_name_or_path}")
+        return instance
+
+    def save_pretrained(self, pretrained_model_name_or_path):
+        os.makedirs(pretrained_model_name_or_path, exist_ok=True)
+        config_path = os.path.join(pretrained_model_name_or_path, "config.json")
+        ckpt_path = os.path.join(pretrained_model_name_or_path, "condition_adapter.pt")        
+        json_dump(self.config, config_path)
+        torch.save(self.state_dict(), ckpt_path)
+        print(f"SAVED: ConditionAdapter {self.config['model_name']} to {pretrained_model_name_or_path}")
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+    """
+    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+    Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+    """
+    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+    # rescale the results from guidance (fixes overexposure)
+    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+    return noise_cfg
+
+
+
+LRELU_SLOPE = 0.1
+MAX_WAV_VALUE = 32768.0
+
+
+class AttrDict(dict):
+    def __init__(self, *args, **kwargs):
+        super(AttrDict, self).__init__(*args, **kwargs)
+        self.__dict__ = self
+
+
+def get_config(config_path):
+    config = json.loads(open(config_path).read())
+    config = AttrDict(config)
+    return config
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def apply_weight_norm(m):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        weight_norm(m)
+
+
+def get_padding(kernel_size, dilation=1):
+    return int((kernel_size*dilation - dilation)/2)
+
+
+class ResBlock1(torch.nn.Module):
+    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super(ResBlock1, self).__init__()
+        self.h = h
+        self.convs1 = nn.ModuleList([
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+                               padding=get_padding(kernel_size, dilation[0]))),
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+                               padding=get_padding(kernel_size, dilation[1]))),
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
+                               padding=get_padding(kernel_size, dilation[2])))
+        ])
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList([
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+                               padding=get_padding(kernel_size, 1))),
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+                               padding=get_padding(kernel_size, 1))),
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+                               padding=get_padding(kernel_size, 1)))
+        ])
+        self.convs2.apply(init_weights)
+
+    def forward(self, x):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            xt = c1(xt)
+            xt = F.leaky_relu(xt, LRELU_SLOPE)
+            xt = c2(xt)
+            x = xt + x
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs1:
+            remove_weight_norm(l)
+        for l in self.convs2:
+            remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
+        super(ResBlock2, self).__init__()
+        self.h = h
+        self.convs = nn.ModuleList([
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+                               padding=get_padding(kernel_size, dilation[0]))),
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+                               padding=get_padding(kernel_size, dilation[1])))
+        ])
+        self.convs.apply(init_weights)
+
+    def forward(self, x):
+        for c in self.convs:
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            xt = c(xt)
+            x = xt + x
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs:
+            remove_weight_norm(l)
+
+
+
+class Generator(torch.nn.Module):
+    def __init__(self, h):
+        super(Generator, self).__init__()
+        self.h = h
+        self.num_kernels = len(h.resblock_kernel_sizes)
+        self.num_upsamples = len(h.upsample_rates)
+        # self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
+        self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)) # change: 80 --> 512
+        resblock = ResBlock1 if h.resblock == '1' else ResBlock2
+
+        self._device = "cuda" if torch.cuda.is_available() else "cpu"
+
+        self.ups = nn.ModuleList()
+        for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
+            if (k-u) % 2 == 0:
+                self.ups.append(weight_norm(
+                    ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
+                                    k, u, padding=(k-u)//2)))
+            else:
+                self.ups.append(weight_norm(
+                    ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
+                                    k, u, padding=(k-u)//2+1, output_padding=1)))
+            
+            # self.ups.append(weight_norm(
+            #     ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
+            #                     k, u, padding=(k-u)//2)))
+            
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = h.upsample_initial_channel//(2**(i+1))
+            for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
+                self.resblocks.append(resblock(h, ch, k, d))
+
+        self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+        self.ups.apply(init_weights)
+        self.conv_post.apply(init_weights)
+
+    @property
+    def device(self) -> torch.device:
+        return torch.device(self._device)
+
+    @property
+    def dtype(self):
+        return self.type
+
+    def forward(self, x):
+        x = self.conv_pre(x)
+        for i in range(self.num_upsamples):
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            x = self.ups[i](x)
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.resblocks[i*self.num_kernels+j](x)
+                else:
+                    xs += self.resblocks[i*self.num_kernels+j](x)
+            x = xs / self.num_kernels
+        x = F.leaky_relu(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_weight_norm(self):
+        print('Removing weight norm...')
+        for l in self.ups:
+            remove_weight_norm(l)
+        for l in self.resblocks:
+            l.remove_weight_norm()
+        remove_weight_norm(self.conv_pre)
+        remove_weight_norm(self.conv_post)
+
+    @classmethod
+    def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None):
+        if subfolder is not None:
+            pretrained_model_name_or_path = os.path.join(pretrained_model_name_or_path, subfolder)
+        config_path = os.path.join(pretrained_model_name_or_path, "config.json")
+        ckpt_path   = os.path.join(pretrained_model_name_or_path, "vocoder.pt")
+
+        config  = get_config(config_path)
+        vocoder = cls(config)
+
+        state_dict_g = torch.load(ckpt_path)
+        vocoder.load_state_dict(state_dict_g["generator"])
+        vocoder.eval()
+        vocoder.remove_weight_norm()
+        return vocoder
+    
+    @torch.no_grad()
+    def inference(self, mels, lengths=None):
+        self.eval()
+        with torch.no_grad():
+            wavs = self(mels).squeeze(1)
+
+        wavs = (wavs.cpu().numpy() * MAX_WAV_VALUE).astype("int16")
+
+        if lengths is not None:
+            wavs = wavs[:, :lengths]
+
+        return wavs
+
+
+
+def normalize_spectrogram(
+    spectrogram: torch.Tensor,
+    max_value: float = 200, 
+    min_value: float = 1e-5, 
+    power: float = 1., 
+) -> torch.Tensor:
+    
+    # Rescale to 0-1
+    max_value = np.log(max_value) # 5.298317366548036
+    min_value = np.log(min_value) # -11.512925464970229
+    spectrogram = torch.clamp(spectrogram, min=min_value, max=max_value)
+    data = (spectrogram - min_value) / (max_value - min_value)
+    # Apply the power curve
+    data = torch.pow(data, power)
+    # 1D -> 3D
+    data = data.repeat(3, 1, 1)
+    # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner
+    data = torch.flip(data, [1])
+
+    return data
+
+
+def denormalize_spectrogram(
+    data: torch.Tensor,
+    max_value: float = 200, 
+    min_value: float = 1e-5, 
+    power: float = 1, 
+) -> torch.Tensor:
+    
+    assert len(data.shape) == 3, "Expected 3 dimensions, got {}".format(len(data.shape))
+
+    max_value = np.log(max_value)
+    min_value = np.log(min_value)
+    # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner
+    data = torch.flip(data, [1])    
+    if data.shape[0] == 1:
+        data = data.repeat(3, 1, 1)        
+    assert data.shape[0] == 3, "Expected 3 channels, got {}".format(data.shape[0])
+    data = data[0]
+    # Reverse the power curve
+    data = torch.pow(data, 1 / power)
+    # Rescale to max value
+    spectrogram = data * (max_value - min_value) + min_value
+
+    return spectrogram
+
+@staticmethod
+def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
+    """
+    Convert a PyTorch tensor to a NumPy image.
+    """
+    images = images.cpu().permute(0, 2, 3, 1).float().numpy()
+    return images
+
+@staticmethod
+def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
+    """
+    Convert a numpy image or a batch of images to a PIL image.
+    """
+    if images.ndim == 3:
+        images = images[None, ...]
+    images = (images * 255).round().astype("uint8")
+    if images.shape[-1] == 1:
+        # special case for grayscale (single channel) images
+        pil_images = [PIL.Image.fromarray(image.squeeze(), mode="L") for image in images]
+    else:
+        pil_images = [PIL.Image.fromarray(image) for image in images]
+
+    return pil_images
+
+
+def image_add_color(spec_img):
+    cmap = plt.get_cmap('viridis')
+    cmap_r = cmap.reversed()
+    image = cmap(np.array(spec_img)[:,:,0])[:, :, :3]  # 省略透明度通道
+    image = (image - image.min()) / (image.max() - image.min())
+    image = PIL.Image.fromarray(np.uint8(image*255))
+    return image
+
+
+@dataclass
+class PipelineOutput(BaseOutput):
+    """
+    Output class for audio pipelines.
+
+    Args:
+        audios (`np.ndarray`)
+            List of denoised audio samples of a NumPy array of shape `(batch_size, num_channels, sample_rate)`.
+    """
+
+    images: Union[List[PIL.Image.Image], np.ndarray]
+    spectrograms: Union[List[np.ndarray], np.ndarray]
+    audios: Union[List[np.ndarray], np.ndarray]
+
+
+
+class AuffusionPipeline(DiffusionPipeline):
+
+    r"""
+    Pipeline for text-to-image generation using Stable Diffusion.
+
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+    In addition the pipeline inherits the following loading methods:
+        - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+        - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
+        - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+    as well as the following saving methods:
+        - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
+
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+        text_encoder ([`CLIPTextModel`]):
+            Frozen text-encoder. Stable Diffusion uses the text portion of
+            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+        tokenizer (`CLIPTokenizer`):
+            Tokenizer of class
+            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+        safety_checker ([`StableDiffusionSafetyChecker`]):
+            Classification module that estimates whether generated images could be considered offensive or harmful.
+            Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+        feature_extractor ([`CLIPImageProcessor`]):
+            Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+    """
+    _optional_components = ["safety_checker", "feature_extractor", "text_encoder_list", "tokenizer_list", "adapter_list", "vocoder"]
+
+    def __init__(
+        self,
+        vae: AutoencoderKL,
+        unet: UNet2DConditionModel,        
+        scheduler: KarrasDiffusionSchedulers,
+        safety_checker: StableDiffusionSafetyChecker,
+        feature_extractor: CLIPImageProcessor,
+        text_encoder_list: Optional[List[Callable]] = None,
+        tokenizer_list: Optional[List[Callable]] = None,
+        vocoder: Generator = None,
+        requires_safety_checker: bool = False,        
+        adapter_list: Optional[List[Callable]] = None,
+        tokenizer_model_max_length: Optional[int] = 77, # 77 is the default value for the CLIPTokenizer(and set for other models)
+    ):
+        super().__init__()
+
+        self.text_encoder_list = text_encoder_list
+        self.tokenizer_list = tokenizer_list
+        self.vocoder = vocoder
+        self.adapter_list = adapter_list
+        self.tokenizer_model_max_length = tokenizer_model_max_length
+
+        self.register_modules(
+            vae=vae,
+            unet=unet,
+            scheduler=scheduler,
+            safety_checker=safety_checker,
+            feature_extractor=feature_extractor,
+        )
+
+        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+        self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+
+    @classmethod
+    def from_pretrained(
+        cls,
+        pretrained_model_name_or_path: str = "auffusion/auffusion-full-no-adapter",
+        dtype: torch.dtype = torch.float16,
+        device: str = "cuda",
+    ):
+        if not os.path.isdir(pretrained_model_name_or_path):
+            pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path) 
+        
+        vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
+        unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
+        feature_extractor = CLIPImageProcessor.from_pretrained(pretrained_model_name_or_path, subfolder="feature_extractor")
+        scheduler = PNDMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
+
+        vocoder = Generator.from_pretrained(pretrained_model_name_or_path, subfolder="vocoder").to(device, dtype)
+
+        text_encoder_list, tokenizer_list, adapter_list = [], [], []
+        
+        condition_json_path = os.path.join(pretrained_model_name_or_path, "condition_config.json")
+        condition_json_list = json.loads(open(condition_json_path).read())
+        
+        for i, condition_item in enumerate(condition_json_list):
+            
+            # Load Condition Adapter
+            text_encoder_path = os.path.join(pretrained_model_name_or_path, condition_item["text_encoder_name"])
+            tokenizer = AutoTokenizer.from_pretrained(text_encoder_path)
+            tokenizer_list.append(tokenizer)
+            text_encoder_cls = import_model_class_from_model_name_or_path(text_encoder_path)
+            text_encoder = text_encoder_cls.from_pretrained(text_encoder_path).to(device, dtype)
+            text_encoder_list.append(text_encoder)
+            print(f"LOADING CONDITION ENCODER {i}")
+
+            # Load Condition Adapter
+            adapter_path = os.path.join(pretrained_model_name_or_path, condition_item["condition_adapter_name"])
+            adapter = ConditionAdapter.from_pretrained(adapter_path).to(device, dtype)
+            adapter_list.append(adapter)
+            print(f"LOADING CONDITION ADAPTER {i}")
+
+
+        pipeline = cls(
+            vae=vae,
+            unet=unet,
+            text_encoder_list=text_encoder_list,
+            tokenizer_list=tokenizer_list,
+            vocoder=vocoder,
+            adapter_list=adapter_list,
+            scheduler=scheduler,
+            safety_checker=None,
+            feature_extractor=feature_extractor,
+        )
+        pipeline = pipeline.to(device, dtype)
+
+        return pipeline
+    
+
+    def to(self, device, dtype=None):
+        super().to(device, dtype)
+
+        self.vocoder.to(device, dtype)
+
+        for text_encoder in self.text_encoder_list:
+            text_encoder.to(device, dtype)
+       
+        if self.adapter_list is not None:
+            for adapter in self.adapter_list:
+                adapter.to(device, dtype)
+
+        return self
+
+
+    def enable_vae_slicing(self):
+        r"""
+        Enable sliced VAE decoding.
+
+        When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+        steps. This is useful to save some memory and allow larger batch sizes.
+        """
+        self.vae.enable_slicing()
+
+    def disable_vae_slicing(self):
+        r"""
+        Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+        computing decoding in one step.
+        """
+        self.vae.disable_slicing()
+
+    def enable_vae_tiling(self):
+        r"""
+        Enable tiled VAE decoding.
+
+        When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+        several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+        """
+        self.vae.enable_tiling()
+
+    def disable_vae_tiling(self):
+        r"""
+        Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+        computing decoding in one step.
+        """
+        self.vae.disable_tiling()
+
+    def enable_sequential_cpu_offload(self, gpu_id=0):
+        r"""
+        Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+        text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+        `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+        Note that offloading happens on a submodule basis. Memory savings are higher than with
+        `enable_model_cpu_offload`, but performance is lower.
+        """
+        if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
+            from accelerate import cpu_offload
+        else:
+            raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
+
+        device = torch.device(f"cuda:{gpu_id}")
+
+        if self.device.type != "cpu":
+            self.to("cpu", silence_dtype_warnings=True)
+            torch.cuda.empty_cache()  # otherwise we don't see the memory savings (but they probably exist)
+
+        for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+            cpu_offload(cpu_offloaded_model, device)
+
+        if self.safety_checker is not None:
+            cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
+
+    def enable_model_cpu_offload(self, gpu_id=0):
+        r"""
+        Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+        to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+        method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+        `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+        """
+        if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+            from accelerate import cpu_offload_with_hook
+        else:
+            raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+        device = torch.device(f"cuda:{gpu_id}")
+
+        if self.device.type != "cpu":
+            self.to("cpu", silence_dtype_warnings=True)
+            torch.cuda.empty_cache()  # otherwise we don't see the memory savings (but they probably exist)
+
+        hook = None
+        for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+            _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+        if self.safety_checker is not None:
+            _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+        # We'll offload the last model manually.
+        self.final_offload_hook = hook
+
+    @property
+    def _execution_device(self):
+        r"""
+        Returns the device on which the pipeline's models will be executed. After calling
+        `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+        hooks.
+        """
+        if not hasattr(self.unet, "_hf_hook"):
+            return self.device
+        for module in self.unet.modules():
+            if (
+                hasattr(module, "_hf_hook")
+                and hasattr(module._hf_hook, "execution_device")
+                and module._hf_hook.execution_device is not None
+            ):
+                return torch.device(module._hf_hook.execution_device)
+        return self.device
+
+    
+    def _encode_prompt(
+        self,
+        prompt,
+        device,
+        num_images_per_prompt,
+        do_classifier_free_guidance,
+        negative_prompt=None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+    ):
+
+        assert len(self.text_encoder_list) == len(self.tokenizer_list), "Number of text_encoders must match number of tokenizers"
+        if self.adapter_list is not None:
+            assert len(self.text_encoder_list) == len(self.adapter_list), "Number of text_encoders must match number of adapters"
+
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        def get_prompt_embeds(prompt_list, device):
+            if isinstance(prompt_list, str):
+                prompt_list = [prompt_list]
+
+            prompt_embeds_list = []
+            for prompt in prompt_list:
+                encoder_hidden_states_list = []
+
+                # Generate condition embedding
+                for j in range(len(self.text_encoder_list)):
+                    # get condition embedding using condition encoder
+                    input_ids = self.tokenizer_list[j](prompt, return_tensors="pt").input_ids.to(device)
+                    cond_embs = self.text_encoder_list[j](input_ids).last_hidden_state # [bz, text_len, text_dim]
+                    # padding to max_length
+                    if cond_embs.shape[1] < self.tokenizer_model_max_length: 
+                        cond_embs = torch.functional.F.pad(cond_embs, (0, 0, 0, self.tokenizer_model_max_length - cond_embs.shape[1]), value=0)
+                    else:
+                        cond_embs = cond_embs[:, :self.tokenizer_model_max_length, :]
+
+                    # use condition adapter
+                    if self.adapter_list is not None:
+                        cond_embs = self.adapter_list[j](cond_embs)
+                        encoder_hidden_states_list.append(cond_embs)
+
+                prompt_embeds = torch.cat(encoder_hidden_states_list, dim=1)
+                prompt_embeds_list.append(prompt_embeds)
+
+            prompt_embeds = torch.cat(prompt_embeds_list, dim=0)
+            return prompt_embeds
+
+
+        if prompt_embeds is None:           
+            prompt_embeds = get_prompt_embeds(prompt, device)
+
+        prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+        bs_embed, seq_len, _ = prompt_embeds.shape
+        # duplicate text embeddings for each generation per prompt, using mps friendly method
+        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+        if do_classifier_free_guidance and negative_prompt_embeds is None:
+
+            if negative_prompt is None:
+                negative_prompt_embeds = torch.zeros_like(prompt_embeds).to(dtype=prompt_embeds.dtype, device=device)
+
+            elif prompt is not None and type(prompt) is not type(negative_prompt):
+                raise TypeError(
+                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+                    f" {type(prompt)}."
+                )
+            elif isinstance(negative_prompt, str):
+                negative_prompt = [negative_prompt]
+            elif batch_size != len(negative_prompt):
+                raise ValueError(
+                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                    " the batch size of `prompt`."
+                )
+            else:
+                negative_prompt_embeds = get_prompt_embeds(negative_prompt, device)
+
+        if do_classifier_free_guidance:
+            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+            seq_len = negative_prompt_embeds.shape[1]
+
+            negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+            # For classifier free guidance, we need to do two forward passes.
+            # Here we concatenate the unconditional and text embeddings into a single batch
+            # to avoid doing two forward passes
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+        return prompt_embeds
+            
+
+    def run_safety_checker(self, image, device, dtype):
+        if self.safety_checker is None:
+            has_nsfw_concept = None
+        else:
+            if torch.is_tensor(image):
+                feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+            else:
+                feature_extractor_input = self.image_processor.numpy_to_pil(image)
+            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+            image, has_nsfw_concept = self.safety_checker(
+                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+            )
+        return image, has_nsfw_concept
+
+    def decode_latents(self, latents):
+        warnings.warn(
+            "The decode_latents method is deprecated and will be removed in a future version. Please"
+            " use VaeImageProcessor instead",
+            FutureWarning,
+        )
+        latents = 1 / self.vae.config.scaling_factor * latents
+        image = self.vae.decode(latents, return_dict=False)[0]
+        image = (image / 2 + 0.5).clamp(0, 1)
+        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+        return image
+
+    def prepare_extra_step_kwargs(self, generator, eta):
+        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+        # and should be between [0, 1]
+
+        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        extra_step_kwargs = {}
+        if accepts_eta:
+            extra_step_kwargs["eta"] = eta
+
+        # check if the scheduler accepts generator
+        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        if accepts_generator:
+            extra_step_kwargs["generator"] = generator
+        return extra_step_kwargs
+    
+
+    def check_inputs(
+        self,
+        prompt,
+        height,
+        width,
+        callback_steps,
+        negative_prompt=None,
+        prompt_embeds=None,
+        negative_prompt_embeds=None,
+    ):
+        if height % 8 != 0 or width % 8 != 0:
+            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+        if (callback_steps is None) or (
+            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+        ):
+            raise ValueError(
+                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+                f" {type(callback_steps)}."
+            )
+
+        if prompt is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt is None and prompt_embeds is None:
+            raise ValueError(
+                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+            )
+        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+        if negative_prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if prompt_embeds is not None and negative_prompt_embeds is not None:
+            if prompt_embeds.shape != negative_prompt_embeds.shape:
+                raise ValueError(
+                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+                    f" {negative_prompt_embeds.shape}."
+                )
+
+
+
+    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+        if isinstance(generator, list) and len(generator) != batch_size:
+            raise ValueError(
+                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+            )
+
+        if latents is None:
+            latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+        else:
+            latents = latents.to(device)
+
+        # scale the initial noise by the standard deviation required by the scheduler
+        latents = latents * self.scheduler.init_noise_sigma
+        return latents
+
+    @torch.no_grad()
+    def __call__(
+        self,
+        prompt: Union[str, List[str]] = None,
+        height: Optional[int] = 256,
+        width: Optional[int] = 1024,
+        num_inference_steps: int = 100,
+        guidance_scale: float = 7.5,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: float = 0.0,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        output_type: Optional[str] = "pt",
+        return_dict: bool = True,
+        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+        callback_steps: int = 1,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        guidance_rescale: float = 0.0,
+        duration: Optional[float] = 10,
+    ):
+       
+        # 0. Default height and width to unet
+        height = height or self.unet.config.sample_size * self.vae_scale_factor
+        width = width or self.unet.config.sample_size * self.vae_scale_factor
+        audio_length = int(duration * 16000)
+
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(
+            prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+        )
+       
+
+        # 2. Define call parameters
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        device = self._execution_device
+        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+        # corresponds to doing no classifier free guidance.
+        do_classifier_free_guidance = guidance_scale > 1.0
+
+        # 3. Encode input prompt
+        prompt_embeds = self._encode_prompt(
+            prompt,
+            device,
+            num_images_per_prompt,
+            do_classifier_free_guidance,
+            negative_prompt,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds
+        )
+        
+        # 4. Prepare timesteps
+        self.scheduler.set_timesteps(num_inference_steps, device=device)
+        timesteps = self.scheduler.timesteps
+
+        # 5. Prepare latent variables
+        num_channels_latents = self.unet.config.in_channels
+        latents = self.prepare_latents(
+            batch_size * num_images_per_prompt,
+            num_channels_latents,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            latents,
+        )
+
+        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 7. Denoising loop
+        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            for i, t in enumerate(timesteps):
+                # expand the latents if we are doing classifier free guidance
+                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+                # predict the noise residual
+                noise_pred = self.unet(
+                    latent_model_input,
+                    t,
+                    encoder_hidden_states=prompt_embeds,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    return_dict=False,
+                )[0]
+
+                # perform guidance
+                if do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                if do_classifier_free_guidance and guidance_rescale > 0.0:
+                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+                # call the callback, if provided
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+                    if callback is not None and i % callback_steps == 0:
+                        callback(i, t, latents)
+
+        if not output_type == "latent":
+            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+        else:
+            image = latents
+            has_nsfw_concept = None
+
+        if has_nsfw_concept is None:
+            do_denormalize = [True] * image.shape[0]
+        else:
+            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+        # Offload last model to CPU
+        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+            self.final_offload_hook.offload()
+
+
+        # Generate audio
+        spectrograms, audios = [], []
+        for img in image:
+            spectrogram = denormalize_spectrogram(img)
+            audio = self.vocoder.inference(spectrogram, lengths=audio_length)[0]
+            audios.append(audio)
+            spectrograms.append(spectrogram)
+
+        # Convert to PIL
+        images = pt_to_numpy(image)    
+        images = numpy_to_pil(images)
+        images = [image_add_color(image) for image in images]
+
+        if not return_dict:
+            return (images, audios, spectrograms)
+                    
+
+        return PipelineOutput(images=images, audios=audios, spectrograms=spectrograms)
+
+def retrieve_timesteps(
+    scheduler,
+    num_inference_steps: Optional[int] = None,
+    device: Optional[Union[str, torch.device]] = None,
+    timesteps: Optional[List[int]] = None,
+    **kwargs,
+):
+    """
+    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+    Args:
+        scheduler (`SchedulerMixin`):
+            The scheduler to get timesteps from.
+        num_inference_steps (`int`):
+            The number of diffusion steps used when generating samples with a pre-trained model. If used,
+            `timesteps` must be `None`.
+        device (`str` or `torch.device`, *optional*):
+            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+        timesteps (`List[int]`, *optional*):
+                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+                must be `None`.
+
+    Returns:
+        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+        second element is the number of inference steps.
+    """
+    if timesteps is not None:
+        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accepts_timesteps:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" timestep schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    else:
+        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+    return timesteps, num_inference_steps
+
+class AuffusionNoAdapterPipeline(
+    DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
+):
+    r"""
+    Pipeline for text-to-image generation using Stable Diffusion.
+
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+    implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+    The pipeline also inherits the following loading methods:
+        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+        - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+        - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+        text_encoder ([`~transformers.CLIPTextModel`]):
+            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+        tokenizer ([`~transformers.CLIPTokenizer`]):
+            A `CLIPTokenizer` to tokenize text.
+        unet ([`UNet2DConditionModel`]):
+            A `UNet2DConditionModel` to denoise the encoded image latents.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+        safety_checker ([`StableDiffusionSafetyChecker`]):
+            Classification module that estimates whether generated images could be considered offensive or harmful.
+            Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+            about a model's potential harms.
+        feature_extractor ([`~transformers.CLIPImageProcessor`]):
+            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+    """
+
+    model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+    _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+    _exclude_from_cpu_offload = ["safety_checker"]
+    _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+    def __init__(
+        self,
+        vae: AutoencoderKL,
+        text_encoder: CLIPTextModel,
+        tokenizer: CLIPTokenizer,
+        unet: UNet2DConditionModel,
+        scheduler: KarrasDiffusionSchedulers,
+        safety_checker: StableDiffusionSafetyChecker,
+        feature_extractor: CLIPImageProcessor,
+        image_encoder: CLIPVisionModelWithProjection = None,
+        requires_safety_checker: bool = True,
+    ):
+        super().__init__()
+
+        if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+            deprecation_message = (
+                f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+                f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+                "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+                " in future versions. If you have downloaded this checkpoint from the HF中国镜像站 Hub,"
+                " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+                " file"
+            )
+            deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+            new_config = dict(scheduler.config)
+            new_config["steps_offset"] = 1
+            scheduler._internal_dict = FrozenDict(new_config)
+
+        if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+            deprecation_message = (
+                f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+                " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+                " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+                " future versions. If you have downloaded this checkpoint from the HF中国镜像站 Hub, it would be very"
+                " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+            )
+            deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+            new_config = dict(scheduler.config)
+            new_config["clip_sample"] = False
+            scheduler._internal_dict = FrozenDict(new_config)
+
+        if safety_checker is None and requires_safety_checker:
+            logger.warning(
+                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+                " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+                " results in services or applications open to the public. Both the diffusers team and HF中国镜像站"
+                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+            )
+
+        if safety_checker is not None and feature_extractor is None:
+            raise ValueError(
+                "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+                " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+            )
+
+        is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+            version.parse(unet.config._diffusers_version).base_version
+        ) < version.parse("0.9.0.dev0")
+        is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+            deprecation_message = (
+                "The configuration file of the unet has set the default `sample_size` to smaller than"
+                " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+                " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+                " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+                " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+                " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+                " in the config might lead to incorrect results in future versions. If you have downloaded this"
+                " checkpoint from the HF中国镜像站 Hub, it would be very nice if you could open a Pull request for"
+                " the `unet/config.json` file"
+            )
+            deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+            new_config = dict(unet.config)
+            new_config["sample_size"] = 64
+            unet._internal_dict = FrozenDict(new_config)
+
+        self.register_modules(
+            vae=vae,
+            text_encoder=text_encoder,
+            tokenizer=tokenizer,
+            unet=unet,
+            scheduler=scheduler,
+            safety_checker=safety_checker,
+            feature_extractor=feature_extractor,
+            image_encoder=image_encoder,
+        )
+        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+        self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+    def enable_vae_slicing(self):
+        r"""
+        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+        """
+        self.vae.enable_slicing()
+
+    def disable_vae_slicing(self):
+        r"""
+        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+        computing decoding in one step.
+        """
+        self.vae.disable_slicing()
+
+    def enable_vae_tiling(self):
+        r"""
+        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+        processing larger images.
+        """
+        self.vae.enable_tiling()
+
+    def disable_vae_tiling(self):
+        r"""
+        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+        computing decoding in one step.
+        """
+        self.vae.disable_tiling()
+
+    def _encode_prompt(
+        self,
+        prompt,
+        device,
+        num_images_per_prompt,
+        do_classifier_free_guidance,
+        negative_prompt=None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        lora_scale: Optional[float] = None,
+        **kwargs,
+    ):
+        deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+        deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+        prompt_embeds_tuple = self.encode_prompt(
+            prompt=prompt,
+            device=device,
+            num_images_per_prompt=num_images_per_prompt,
+            do_classifier_free_guidance=do_classifier_free_guidance,
+            negative_prompt=negative_prompt,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            lora_scale=lora_scale,
+            **kwargs,
+        )
+
+        # concatenate for backwards comp
+        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+        return prompt_embeds
+
+    def encode_prompt(
+        self,
+        prompt,
+        device,
+        num_images_per_prompt,
+        do_classifier_free_guidance,
+        negative_prompt=None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        lora_scale: Optional[float] = None,
+        clip_skip: Optional[int] = None,
+    ):
+        r"""
+        Encodes the prompt into text encoder hidden states.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                prompt to be encoded
+            device: (`torch.device`):
+                torch device
+            num_images_per_prompt (`int`):
+                number of images that should be generated per prompt
+            do_classifier_free_guidance (`bool`):
+                whether to use classifier free guidance or not
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            lora_scale (`float`, *optional*):
+                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+            clip_skip (`int`, *optional*):
+                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+                the output of the pre-final layer will be used for computing the prompt embeddings.
+        """
+        # set lora scale so that monkey patched LoRA
+        # function of text encoder can correctly access it
+        if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+            self._lora_scale = lora_scale
+
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        if prompt_embeds is None:
+            # textual inversion: procecss multi-vector tokens if necessary
+            if isinstance(self, TextualInversionLoaderMixin):
+                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+            text_inputs = self.tokenizer(
+                prompt,
+                padding="max_length",
+                max_length=self.tokenizer.model_max_length,
+                truncation=True,
+                return_tensors="pt",
+            )
+            text_input_ids = text_inputs.input_ids
+            untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+                text_input_ids, untruncated_ids
+            ):
+                removed_text = self.tokenizer.batch_decode(
+                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+                )
+                logger.warning(
+                    "The following part of your input was truncated because CLIP can only handle sequences up to"
+                    f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+                )
+
+            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+                attention_mask = text_inputs.attention_mask.to(device)
+            else:
+                attention_mask = None
+
+            if clip_skip is None:
+                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+                prompt_embeds = prompt_embeds[0]
+            else:
+                prompt_embeds = self.text_encoder(
+                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+                )
+                # Access the `hidden_states` first, that contains a tuple of
+                # all the hidden states from the encoder layers. Then index into
+                # the tuple to access the hidden states from the desired layer.
+                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+                # We also need to apply the final LayerNorm here to not mess with the
+                # representations. The `last_hidden_states` that we typically use for
+                # obtaining the final prompt representations passes through the LayerNorm
+                # layer.
+                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+        if self.text_encoder is not None:
+            prompt_embeds_dtype = self.text_encoder.dtype
+        elif self.unet is not None:
+            prompt_embeds_dtype = self.unet.dtype
+        else:
+            prompt_embeds_dtype = prompt_embeds.dtype
+
+        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+        bs_embed, seq_len, _ = prompt_embeds.shape
+        # duplicate text embeddings for each generation per prompt, using mps friendly method
+        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+        # get unconditional embeddings for classifier free guidance
+        if do_classifier_free_guidance and negative_prompt_embeds is None:
+            uncond_tokens: List[str]
+            if negative_prompt is None:
+                uncond_tokens = [""] * batch_size
+            elif prompt is not None and type(prompt) is not type(negative_prompt):
+                raise TypeError(
+                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+                    f" {type(prompt)}."
+                )
+            elif isinstance(negative_prompt, str):
+                uncond_tokens = [negative_prompt]
+            elif batch_size != len(negative_prompt):
+                raise ValueError(
+                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                    " the batch size of `prompt`."
+                )
+            else:
+                uncond_tokens = negative_prompt
+
+            # textual inversion: procecss multi-vector tokens if necessary
+            if isinstance(self, TextualInversionLoaderMixin):
+                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+            max_length = prompt_embeds.shape[1]
+            uncond_input = self.tokenizer(
+                uncond_tokens,
+                padding="max_length",
+                max_length=max_length,
+                truncation=True,
+                return_tensors="pt",
+            )
+
+            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+                attention_mask = uncond_input.attention_mask.to(device)
+            else:
+                attention_mask = None
+
+            negative_prompt_embeds = self.text_encoder(
+                uncond_input.input_ids.to(device),
+                attention_mask=attention_mask,
+            )
+            negative_prompt_embeds = negative_prompt_embeds[0]
+
+        if do_classifier_free_guidance:
+            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+            seq_len = negative_prompt_embeds.shape[1]
+
+            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+        return prompt_embeds, negative_prompt_embeds
+
+    def prepare_ip_adapter_image_embeds(
+        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+    ):
+        if ip_adapter_image_embeds is None:
+            if not isinstance(ip_adapter_image, list):
+                ip_adapter_image = [ip_adapter_image]
+
+            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+                raise ValueError(
+                    f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+                )
+
+            image_embeds = []
+            for single_ip_adapter_image, image_proj_layer in zip(
+                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+            ):
+                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+                single_image_embeds, single_negative_image_embeds = self.encode_image(
+                    single_ip_adapter_image, device, 1, output_hidden_state
+                )
+                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+                single_negative_image_embeds = torch.stack(
+                    [single_negative_image_embeds] * num_images_per_prompt, dim=0
+                )
+
+                if do_classifier_free_guidance:
+                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+                    single_image_embeds = single_image_embeds.to(device)
+
+                image_embeds.append(single_image_embeds)
+        else:
+            repeat_dims = [1]
+            image_embeds = []
+            for single_image_embeds in ip_adapter_image_embeds:
+                if do_classifier_free_guidance:
+                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+                    single_image_embeds = single_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+                    )
+                    single_negative_image_embeds = single_negative_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+                    )
+                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+                else:
+                    single_image_embeds = single_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+                    )
+                image_embeds.append(single_image_embeds)
+
+        return image_embeds 
+
+    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+        dtype = next(self.image_encoder.parameters()).dtype
+
+        if not isinstance(image, torch.Tensor):
+            image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+        image = image.to(device=device, dtype=dtype)
+        if output_hidden_states:
+            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+            uncond_image_enc_hidden_states = self.image_encoder(
+                torch.zeros_like(image), output_hidden_states=True
+            ).hidden_states[-2]
+            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+                num_images_per_prompt, dim=0
+            )
+            return image_enc_hidden_states, uncond_image_enc_hidden_states
+        else:
+            image_embeds = self.image_encoder(image).image_embeds
+            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+            uncond_image_embeds = torch.zeros_like(image_embeds)
+
+            return image_embeds, uncond_image_embeds
+
+    def run_safety_checker(self, image, device, dtype):
+        if self.safety_checker is None:
+            has_nsfw_concept = None
+        else:
+            if torch.is_tensor(image):
+                feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+            else:
+                feature_extractor_input = self.image_processor.numpy_to_pil(image)
+            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+            image, has_nsfw_concept = self.safety_checker(
+                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+            )
+        return image, has_nsfw_concept
+
+    def decode_latents(self, latents):
+        deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+        deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
+
+        latents = 1 / self.vae.config.scaling_factor * latents
+        image = self.vae.decode(latents, return_dict=False)[0]
+        image = (image / 2 + 0.5).clamp(0, 1)
+        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+        return image
+
+    def prepare_extra_step_kwargs(self, generator, eta):
+        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+        # and should be between [0, 1]
+
+        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        extra_step_kwargs = {}
+        if accepts_eta:
+            extra_step_kwargs["eta"] = eta
+
+        # check if the scheduler accepts generator
+        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        if accepts_generator:
+            extra_step_kwargs["generator"] = generator
+        return extra_step_kwargs
+
+    def check_inputs(
+        self,
+        prompt,
+        height,
+        width,
+        callback_steps,
+        negative_prompt=None,
+        prompt_embeds=None,
+        negative_prompt_embeds=None,
+        callback_on_step_end_tensor_inputs=None,
+    ):
+        if height % 8 != 0 or width % 8 != 0:
+            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+            raise ValueError(
+                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+                f" {type(callback_steps)}."
+            )
+        if callback_on_step_end_tensor_inputs is not None and not all(
+            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+        ):
+            raise ValueError(
+                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+            )
+
+        if prompt is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt is None and prompt_embeds is None:
+            raise ValueError(
+                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+            )
+        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+        if negative_prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if prompt_embeds is not None and negative_prompt_embeds is not None:
+            if prompt_embeds.shape != negative_prompt_embeds.shape:
+                raise ValueError(
+                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+                    f" {negative_prompt_embeds.shape}."
+                )
+
+    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+        if isinstance(generator, list) and len(generator) != batch_size:
+            raise ValueError(
+                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+            )
+
+        if latents is None:
+            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+        else:
+            latents = latents.to(device)
+
+        # scale the initial noise by the standard deviation required by the scheduler
+        latents = latents * self.scheduler.init_noise_sigma
+        return latents
+
+    def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+        r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+        The suffixes after the scaling factors represent the stages where they are being applied.
+
+        Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+        that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+        Args:
+            s1 (`float`):
+                Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+                mitigate "oversmoothing effect" in the enhanced denoising process.
+            s2 (`float`):
+                Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+                mitigate "oversmoothing effect" in the enhanced denoising process.
+            b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+            b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+        """
+        if not hasattr(self, "unet"):
+            raise ValueError("The pipeline must have `unet` for using FreeU.")
+        self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+    def disable_freeu(self):
+        """Disables the FreeU mechanism if enabled."""
+        self.unet.disable_freeu()
+
+    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
+    def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
+        """
+        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+        key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+        <Tip warning={true}>
+
+        This API is 🧪 experimental.
+
+        </Tip>
+
+        Args:
+            unet (`bool`, defaults to `True`): To apply fusion on the UNet.
+            vae (`bool`, defaults to `True`): To apply fusion on the VAE.
+        """
+        self.fusing_unet = False
+        self.fusing_vae = False
+
+        if unet:
+            self.fusing_unet = True
+            self.unet.fuse_qkv_projections()
+            self.unet.set_attn_processor(FusedAttnProcessor2_0())
+
+        if vae:
+            if not isinstance(self.vae, AutoencoderKL):
+                raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
+
+            self.fusing_vae = True
+            self.vae.fuse_qkv_projections()
+            self.vae.set_attn_processor(FusedAttnProcessor2_0())
+
+    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
+    def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
+        """Disable QKV projection fusion if enabled.
+
+        <Tip warning={true}>
+
+        This API is 🧪 experimental.
+
+        </Tip>
+
+        Args:
+            unet (`bool`, defaults to `True`): To apply fusion on the UNet.
+            vae (`bool`, defaults to `True`): To apply fusion on the VAE.
+
+        """
+        if unet:
+            if not self.fusing_unet:
+                logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
+            else:
+                self.unet.unfuse_qkv_projections()
+                self.fusing_unet = False
+
+        if vae:
+            if not self.fusing_vae:
+                logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
+            else:
+                self.vae.unfuse_qkv_projections()
+                self.fusing_vae = False
+
+    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+        """
+        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+        Args:
+            timesteps (`torch.Tensor`):
+                generate embedding vectors at these timesteps
+            embedding_dim (`int`, *optional*, defaults to 512):
+                dimension of the embeddings to generate
+            dtype:
+                data type of the generated embeddings
+
+        Returns:
+            `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+        """
+        assert len(w.shape) == 1
+        w = w * 1000.0
+
+        half_dim = embedding_dim // 2
+        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+        emb = w.to(dtype)[:, None] * emb[None, :]
+        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+        if embedding_dim % 2 == 1:  # zero pad
+            emb = torch.nn.functional.pad(emb, (0, 1))
+        assert emb.shape == (w.shape[0], embedding_dim)
+        return emb
+
+    @property
+    def guidance_scale(self):
+        return self._guidance_scale
+
+    @property
+    def guidance_rescale(self):
+        return self._guidance_rescale
+
+    @property
+    def clip_skip(self):
+        return self._clip_skip
+
+    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+    # corresponds to doing no classifier free guidance.
+    @property
+    def do_classifier_free_guidance(self):
+        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+    @property
+    def cross_attention_kwargs(self):
+        return self._cross_attention_kwargs
+
+    @property
+    def num_timesteps(self):
+        return self._num_timesteps
+
+    @property
+    def interrupt(self):
+        return self._interrupt
+
+    @torch.no_grad()
+    def __call__(
+        self,
+        prompt: Union[str, List[str]] = None,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        num_inference_steps: int = 50,
+        timesteps: List[int] = None,
+        guidance_scale: float = 7.5,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: float = 0.0,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        ip_adapter_image: Optional[PipelineImageInput] = None,
+        ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        guidance_rescale: float = 0.0,
+        clip_skip: Optional[int] = None,
+        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+        **kwargs,
+    ):
+        r"""
+        The call function to the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+                The height in pixels of the generated image.
+            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+                The width in pixels of the generated image.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            timesteps (`List[int]`, *optional*):
+                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+                passed will be used. Must be in descending order.
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                A higher guidance scale value encourages the model to generate images closely linked to the text
+                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+                generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor is generated by sampling using the supplied random `generator`.
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+                provided, text embeddings are generated from the `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+                plain tuple.
+            cross_attention_kwargs (`dict`, *optional*):
+                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+            guidance_rescale (`float`, *optional*, defaults to 0.0):
+                Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
+                Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+                using zero terminal SNR.
+            clip_skip (`int`, *optional*):
+                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+                the output of the pre-final layer will be used for computing the prompt embeddings.
+            callback_on_step_end (`Callable`, *optional*):
+                A function that calls at the end of each denoising steps during the inference. The function is called
+                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+                `callback_on_step_end_tensor_inputs`.
+            callback_on_step_end_tensor_inputs (`List`, *optional*):
+                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+                `._callback_tensor_inputs` attribute of your pipeline class.
+
+        Examples:
+
+        Returns:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+                otherwise a `tuple` is returned where the first element is a list with the generated images and the
+                second element is a list of `bool`s indicating whether the corresponding generated image contains
+                "not-safe-for-work" (nsfw) content.
+        """
+
+        callback = kwargs.pop("callback", None)
+        callback_steps = kwargs.pop("callback_steps", None)
+
+        if callback is not None:
+            deprecate(
+                "callback",
+                "1.0.0",
+                "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+            )
+        if callback_steps is not None:
+            deprecate(
+                "callback_steps",
+                "1.0.0",
+                "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+            )
+
+        # 0. Default height and width to unet
+        height = height or self.unet.config.sample_size * self.vae_scale_factor
+        width = width or self.unet.config.sample_size * self.vae_scale_factor
+        # to deal with lora scaling and other possible forward hooks
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(
+            prompt,
+            height,
+            width,
+            callback_steps,
+            negative_prompt,
+            prompt_embeds,
+            negative_prompt_embeds,
+            callback_on_step_end_tensor_inputs,
+        )
+
+        self._guidance_scale = guidance_scale
+        self._guidance_rescale = guidance_rescale
+        self._clip_skip = clip_skip
+        self._cross_attention_kwargs = cross_attention_kwargs
+        self._interrupt = False
+
+        # 2. Define call parameters
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        device = self._execution_device
+
+        # 3. Encode input prompt
+        lora_scale = (
+            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+        )
+
+        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+            prompt,
+            device,
+            num_images_per_prompt,
+            self.do_classifier_free_guidance,
+            negative_prompt,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            lora_scale=lora_scale,
+            clip_skip=self.clip_skip,
+        )
+
+        # For classifier free guidance, we need to do two forward passes.
+        # Here we concatenate the unconditional and text embeddings into a single batch
+        # to avoid doing two forward passes
+        if self.do_classifier_free_guidance:
+            if prompt_embeds.shape != negative_prompt_embeds.shape:
+                tmp_embeds = negative_prompt_embeds.clone()
+                tmp_embeds[:,0:1,:] = prompt_embeds
+                prompt_embeds = tmp_embeds
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+        # TODO
+        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+            image_embeds = self.prepare_ip_adapter_image_embeds(
+                ip_adapter_image,
+                ip_adapter_image_embeds,
+                device,
+                batch_size * num_images_per_prompt,
+                self.do_classifier_free_guidance,
+            )
+        # if ip_adapter_image is not None:
+        #     if self.unet.multi_frames_condition:
+        #         output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, VideoProjModel) else True
+        #     else:
+        #         output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+        #     # NOTE: ip_adapter_image shold be list with len() == 50
+        #     image_embeds, negative_image_embeds = self.encode_image(
+        #         ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+        #     )
+        #     # import ipdb; ipdb.set_trace()
+        #     image_embeds = image_embeds.unsqueeze(0)
+        #     negative_image_embeds = negative_image_embeds.unsqueeze(0)
+        #     if not self.unet.multi_frames_condition:
+        #         image_embeds = torch.mean(image_embeds, dim=1, keepdim=False)
+        #         negative_image_embeds = negative_image_embeds[:,0, ...]
+
+        #     if self.do_classifier_free_guidance:
+        #         image_embeds = torch.cat([negative_image_embeds, image_embeds])
+
+        # 4. Prepare timesteps
+        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+        # 5. Prepare latent variables
+        num_channels_latents = self.unet.config.in_channels
+        latents = self.prepare_latents(
+            batch_size * num_images_per_prompt,
+            num_channels_latents,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            latents,
+        )
+
+        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 6.1 Add image embeds for IP-Adapter
+        added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
+
+        # 6.2 Optionally get Guidance Scale Embedding
+        timestep_cond = None
+        if self.unet.config.time_cond_proj_dim is not None:
+            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+            timestep_cond = self.get_guidance_scale_embedding(
+                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+            ).to(device=device, dtype=latents.dtype)
+
+        # 7. Denoising loop
+        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+        self._num_timesteps = len(timesteps)
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            for i, t in enumerate(timesteps):
+                if self.interrupt:
+                    continue
+
+                # expand the latents if we are doing classifier free guidance
+                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+                # predict the noise residual
+                noise_pred = self.unet(
+                    latent_model_input,
+                    t,
+                    encoder_hidden_states=prompt_embeds,
+                    timestep_cond=timestep_cond,
+                    cross_attention_kwargs=self.cross_attention_kwargs,
+                    added_cond_kwargs=added_cond_kwargs,
+                    return_dict=False,
+                )[0]
+
+                # perform guidance
+                if self.do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+                if callback_on_step_end is not None:
+                    callback_kwargs = {}
+                    for k in callback_on_step_end_tensor_inputs:
+                        callback_kwargs[k] = locals()[k]
+                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                    latents = callback_outputs.pop("latents", latents)
+                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+                # call the callback, if provided
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+                    if callback is not None and i % callback_steps == 0:
+                        step_idx = i // getattr(self.scheduler, "order", 1)
+                        callback(step_idx, t, latents)
+
+        if not output_type == "latent":
+            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+                0
+            ]
+            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+        else:
+            image = latents
+            has_nsfw_concept = None
+
+        if has_nsfw_concept is None:
+            do_denormalize = [True] * image.shape[0]
+        else:
+            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+        # Offload all models
+        self.maybe_free_model_hooks()
+
+        if not return_dict:
+            return (image, has_nsfw_concept)
+
+        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
\ No newline at end of file
diff --git a/foleycrafter/pipelines/pipeline_controlnet.py b/foleycrafter/pipelines/pipeline_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..11f1de506080f840224d9111be642082e5ad5f5c
--- /dev/null
+++ b/foleycrafter/pipelines/pipeline_controlnet.py
@@ -0,0 +1,1340 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+    USE_PEFT_BACKEND,
+    deprecate,
+    logging,
+    replace_example_docstring,
+    scale_lora_layers,
+    unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+
+from foleycrafter.models.auffusion_unet import UNet2DConditionModel
+from foleycrafter.models.auffusion.loaders.ip_adapter import IPAdapterMixin
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+    Examples:
+        ```py
+        >>> # !pip install opencv-python transformers accelerate
+        >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
+        >>> from diffusers.utils import load_image
+        >>> import numpy as np
+        >>> import torch
+
+        >>> import cv2
+        >>> from PIL import Image
+
+        >>> # download an image
+        >>> image = load_image(
+        ...     "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
+        ... )
+        >>> image = np.array(image)
+
+        >>> # get canny image
+        >>> image = cv2.Canny(image, 100, 200)
+        >>> image = image[:, :, None]
+        >>> image = np.concatenate([image, image, image], axis=2)
+        >>> canny_image = Image.fromarray(image)
+
+        >>> # load control net and stable diffusion v1-5
+        >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+        >>> pipe = StableDiffusionControlNetPipeline.from_pretrained(
+        ...     "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+        ... )
+
+        >>> # speed up diffusion process with faster scheduler and memory optimization
+        >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+        >>> # remove following line if xformers is not installed
+        >>> pipe.enable_xformers_memory_efficient_attention()
+
+        >>> pipe.enable_model_cpu_offload()
+
+        >>> # generate image
+        >>> generator = torch.manual_seed(0)
+        >>> image = pipe(
+        ...     "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image
+        ... ).images[0]
+        ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+    scheduler,
+    num_inference_steps: Optional[int] = None,
+    device: Optional[Union[str, torch.device]] = None,
+    timesteps: Optional[List[int]] = None,
+    **kwargs,
+):
+    """
+    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+    Args:
+        scheduler (`SchedulerMixin`):
+            The scheduler to get timesteps from.
+        num_inference_steps (`int`):
+            The number of diffusion steps used when generating samples with a pre-trained model. If used,
+            `timesteps` must be `None`.
+        device (`str` or `torch.device`, *optional*):
+            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+        timesteps (`List[int]`, *optional*):
+                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+                must be `None`.
+
+    Returns:
+        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+        second element is the number of inference steps.
+    """
+    if timesteps is not None:
+        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accepts_timesteps:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" timestep schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    else:
+        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+    return timesteps, num_inference_steps
+
+class StableDiffusionControlNetPipeline(
+    DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
+):
+    r"""
+    Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
+
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+    implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+    The pipeline also inherits the following loading methods:
+        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+        - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+        - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+        text_encoder ([`~transformers.CLIPTextModel`]):
+            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+        tokenizer ([`~transformers.CLIPTokenizer`]):
+            A `CLIPTokenizer` to tokenize text.
+        unet ([`UNet2DConditionModel`]):
+            A `UNet2DConditionModel` to denoise the encoded image latents.
+        controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
+            Provides additional conditioning to the `unet` during the denoising process. If you set multiple
+            ControlNets as a list, the outputs from each ControlNet are added together to create one combined
+            additional conditioning.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+        safety_checker ([`StableDiffusionSafetyChecker`]):
+            Classification module that estimates whether generated images could be considered offensive or harmful.
+            Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+            about a model's potential harms.
+        feature_extractor ([`~transformers.CLIPImageProcessor`]):
+            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+    """
+
+    model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+    _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+    _exclude_from_cpu_offload = ["safety_checker"]
+    _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+    def __init__(
+        self,
+        vae: AutoencoderKL,
+        text_encoder: CLIPTextModel,
+        tokenizer: CLIPTokenizer,
+        unet: UNet2DConditionModel,
+        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+        scheduler: KarrasDiffusionSchedulers,
+        safety_checker: StableDiffusionSafetyChecker,
+        feature_extractor: CLIPImageProcessor,
+        image_encoder: CLIPVisionModelWithProjection = None,
+        requires_safety_checker: bool = True,
+    ):
+        super().__init__()
+
+        if safety_checker is None and requires_safety_checker:
+            logger.warning(
+                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+                " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+                " results in services or applications open to the public. Both the diffusers team and HF中国镜像站"
+                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+            )
+
+        if safety_checker is not None and feature_extractor is None:
+            raise ValueError(
+                "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+                " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+            )
+
+        if isinstance(controlnet, (list, tuple)):
+            controlnet = MultiControlNetModel(controlnet)
+
+        self.register_modules(
+            vae=vae,
+            text_encoder=text_encoder,
+            tokenizer=tokenizer,
+            unet=unet,
+            controlnet=controlnet,
+            scheduler=scheduler,
+            safety_checker=safety_checker,
+            feature_extractor=feature_extractor,
+            image_encoder=image_encoder,
+        )
+        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+        self.control_image_processor = VaeImageProcessor(
+            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+        )
+        self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+    def enable_vae_slicing(self):
+        r"""
+        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+        """
+        self.vae.enable_slicing()
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+    def disable_vae_slicing(self):
+        r"""
+        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+        computing decoding in one step.
+        """
+        self.vae.disable_slicing()
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+    def enable_vae_tiling(self):
+        r"""
+        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+        processing larger images.
+        """
+        self.vae.enable_tiling()
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+    def disable_vae_tiling(self):
+        r"""
+        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+        computing decoding in one step.
+        """
+        self.vae.disable_tiling()
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+    def _encode_prompt(
+        self,
+        prompt,
+        device,
+        num_images_per_prompt,
+        do_classifier_free_guidance,
+        negative_prompt=None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        lora_scale: Optional[float] = None,
+        **kwargs,
+    ):
+        deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+        deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+        prompt_embeds_tuple = self.encode_prompt(
+            prompt=prompt,
+            device=device,
+            num_images_per_prompt=num_images_per_prompt,
+            do_classifier_free_guidance=do_classifier_free_guidance,
+            negative_prompt=negative_prompt,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            lora_scale=lora_scale,
+            **kwargs,
+        )
+
+        # concatenate for backwards comp
+        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+        return prompt_embeds
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
+    def encode_prompt(
+        self,
+        prompt,
+        device,
+        num_images_per_prompt,
+        do_classifier_free_guidance,
+        negative_prompt=None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        lora_scale: Optional[float] = None,
+        clip_skip: Optional[int] = None,
+    ):
+        r"""
+        Encodes the prompt into text encoder hidden states.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                prompt to be encoded
+            device: (`torch.device`):
+                torch device
+            num_images_per_prompt (`int`):
+                number of images that should be generated per prompt
+            do_classifier_free_guidance (`bool`):
+                whether to use classifier free guidance or not
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            lora_scale (`float`, *optional*):
+                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+            clip_skip (`int`, *optional*):
+                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+                the output of the pre-final layer will be used for computing the prompt embeddings.
+        """
+        # set lora scale so that monkey patched LoRA
+        # function of text encoder can correctly access it
+        if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+            self._lora_scale = lora_scale
+
+            # dynamically adjust the LoRA scale
+            if not USE_PEFT_BACKEND:
+                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+            else:
+                scale_lora_layers(self.text_encoder, lora_scale)
+
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        if prompt_embeds is None:
+            # textual inversion: procecss multi-vector tokens if necessary
+            if isinstance(self, TextualInversionLoaderMixin):
+                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+            text_inputs = self.tokenizer(
+                prompt,
+                padding="max_length",
+                max_length=self.tokenizer.model_max_length,
+                truncation=True,
+                return_tensors="pt",
+            )
+            text_input_ids = text_inputs.input_ids
+            untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+                text_input_ids, untruncated_ids
+            ):
+                removed_text = self.tokenizer.batch_decode(
+                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+                )
+                logger.warning(
+                    "The following part of your input was truncated because CLIP can only handle sequences up to"
+                    f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+                )
+
+            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+                attention_mask = text_inputs.attention_mask.to(device)
+            else:
+                attention_mask = None
+
+            if clip_skip is None:
+                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+                prompt_embeds = prompt_embeds[0]
+            else:
+                prompt_embeds = self.text_encoder(
+                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+                )
+                # Access the `hidden_states` first, that contains a tuple of
+                # all the hidden states from the encoder layers. Then index into
+                # the tuple to access the hidden states from the desired layer.
+                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+                # We also need to apply the final LayerNorm here to not mess with the
+                # representations. The `last_hidden_states` that we typically use for
+                # obtaining the final prompt representations passes through the LayerNorm
+                # layer.
+                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+        if self.text_encoder is not None:
+            prompt_embeds_dtype = self.text_encoder.dtype
+        elif self.unet is not None:
+            prompt_embeds_dtype = self.unet.dtype
+        else:
+            prompt_embeds_dtype = prompt_embeds.dtype
+
+        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+        bs_embed, seq_len, _ = prompt_embeds.shape
+        # duplicate text embeddings for each generation per prompt, using mps friendly method
+        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+        # get unconditional embeddings for classifier free guidance
+        if do_classifier_free_guidance and negative_prompt_embeds is None:
+            uncond_tokens: List[str]
+            if negative_prompt is None:
+                uncond_tokens = [""] * batch_size
+            elif prompt is not None and type(prompt) is not type(negative_prompt):
+                raise TypeError(
+                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+                    f" {type(prompt)}."
+                )
+            elif isinstance(negative_prompt, str):
+                uncond_tokens = [negative_prompt]
+            elif batch_size != len(negative_prompt):
+                raise ValueError(
+                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                    " the batch size of `prompt`."
+                )
+            else:
+                uncond_tokens = negative_prompt
+
+            # textual inversion: procecss multi-vector tokens if necessary
+            if isinstance(self, TextualInversionLoaderMixin):
+                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+            max_length = prompt_embeds.shape[1]
+            uncond_input = self.tokenizer(
+                uncond_tokens,
+                padding="max_length",
+                max_length=max_length,
+                truncation=True,
+                return_tensors="pt",
+            )
+
+            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+                attention_mask = uncond_input.attention_mask.to(device)
+            else:
+                attention_mask = None
+
+            negative_prompt_embeds = self.text_encoder(
+                uncond_input.input_ids.to(device),
+                attention_mask=attention_mask,
+            )
+            negative_prompt_embeds = negative_prompt_embeds[0]
+
+        if do_classifier_free_guidance:
+            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+            seq_len = negative_prompt_embeds.shape[1]
+
+            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+        if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
+            # Retrieve the original scale by scaling back the LoRA layers
+            unscale_lora_layers(self.text_encoder, lora_scale)
+
+        return prompt_embeds, negative_prompt_embeds
+
+    def prepare_ip_adapter_image_embeds(
+        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+    ):
+        if ip_adapter_image_embeds is None:
+            if not isinstance(ip_adapter_image, list):
+                ip_adapter_image = [ip_adapter_image]
+
+            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+                raise ValueError(
+                    f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+                )
+
+            image_embeds = []
+            for single_ip_adapter_image, image_proj_layer in zip(
+                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+            ):
+                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+                single_image_embeds, single_negative_image_embeds = self.encode_image(
+                    single_ip_adapter_image, device, 1, output_hidden_state
+                )
+                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+                single_negative_image_embeds = torch.stack(
+                    [single_negative_image_embeds] * num_images_per_prompt, dim=0
+                )
+
+                if do_classifier_free_guidance:
+                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+                    single_image_embeds = single_image_embeds.to(device)
+
+                image_embeds.append(single_image_embeds)
+        else:
+            repeat_dims = [1]
+            image_embeds = []
+            for single_image_embeds in ip_adapter_image_embeds:
+                if do_classifier_free_guidance:
+                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+                    single_image_embeds = single_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+                    )
+                    single_negative_image_embeds = single_negative_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+                    )
+                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+                else:
+                    single_image_embeds = single_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+                    )
+                image_embeds.append(single_image_embeds)
+
+        return image_embeds 
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+        dtype = next(self.image_encoder.parameters()).dtype
+
+        if not isinstance(image, torch.Tensor):
+            image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+        image = image.to(device=device, dtype=dtype)
+        if output_hidden_states:
+            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+            uncond_image_enc_hidden_states = self.image_encoder(
+                torch.zeros_like(image), output_hidden_states=True
+            ).hidden_states[-2]
+            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+                num_images_per_prompt, dim=0
+            )
+            return image_enc_hidden_states, uncond_image_enc_hidden_states
+        else:
+            image_embeds = self.image_encoder(image).image_embeds
+            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+            uncond_image_embeds = torch.zeros_like(image_embeds)
+
+            return image_embeds, uncond_image_embeds
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+    def run_safety_checker(self, image, device, dtype):
+        if self.safety_checker is None:
+            has_nsfw_concept = None
+        else:
+            if torch.is_tensor(image):
+                feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+            else:
+                feature_extractor_input = self.image_processor.numpy_to_pil(image)
+            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+            image, has_nsfw_concept = self.safety_checker(
+                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+            )
+        return image, has_nsfw_concept
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+    def decode_latents(self, latents):
+        deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+        deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
+
+        latents = 1 / self.vae.config.scaling_factor * latents
+        image = self.vae.decode(latents, return_dict=False)[0]
+        image = (image / 2 + 0.5).clamp(0, 1)
+        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+        return image
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+    def prepare_extra_step_kwargs(self, generator, eta):
+        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+        # and should be between [0, 1]
+
+        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        extra_step_kwargs = {}
+        if accepts_eta:
+            extra_step_kwargs["eta"] = eta
+
+        # check if the scheduler accepts generator
+        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        if accepts_generator:
+            extra_step_kwargs["generator"] = generator
+        return extra_step_kwargs
+
+    def check_inputs(
+        self,
+        prompt,
+        image,
+        callback_steps,
+        negative_prompt=None,
+        prompt_embeds=None,
+        negative_prompt_embeds=None,
+        controlnet_conditioning_scale=1.0,
+        control_guidance_start=0.0,
+        control_guidance_end=1.0,
+        callback_on_step_end_tensor_inputs=None,
+    ):
+        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+            raise ValueError(
+                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+                f" {type(callback_steps)}."
+            )
+
+        if callback_on_step_end_tensor_inputs is not None and not all(
+            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+        ):
+            raise ValueError(
+                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+            )
+
+        if prompt is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt is None and prompt_embeds is None:
+            raise ValueError(
+                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+            )
+        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+        if negative_prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if prompt_embeds is not None and negative_prompt_embeds is not None:
+            if prompt_embeds.shape != negative_prompt_embeds.shape:
+                raise ValueError(
+                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+                    f" {negative_prompt_embeds.shape}."
+                )
+
+        # `prompt` needs more sophisticated handling when there are multiple
+        # conditionings.
+        if isinstance(self.controlnet, MultiControlNetModel):
+            if isinstance(prompt, list):
+                logger.warning(
+                    f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
+                    " prompts. The conditionings will be fixed across the prompts."
+                )
+
+        # Check `image`
+        is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+            self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+        )
+        if (
+            isinstance(self.controlnet, ControlNetModel)
+            or is_compiled
+            and isinstance(self.controlnet._orig_mod, ControlNetModel)
+        ):
+            self.check_image(image, prompt, prompt_embeds)
+        elif (
+            isinstance(self.controlnet, MultiControlNetModel)
+            or is_compiled
+            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+        ):
+            if not isinstance(image, list):
+                raise TypeError("For multiple controlnets: `image` must be type `list`")
+
+            # When `image` is a nested list:
+            # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+            elif any(isinstance(i, list) for i in image):
+                raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
+            elif len(image) != len(self.controlnet.nets):
+                raise ValueError(
+                    f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
+                )
+
+            for image_ in image:
+                self.check_image(image_, prompt, prompt_embeds)
+        else:
+            assert False
+
+        # Check `controlnet_conditioning_scale`
+        if (
+            isinstance(self.controlnet, ControlNetModel)
+            or is_compiled
+            and isinstance(self.controlnet._orig_mod, ControlNetModel)
+        ):
+            if not isinstance(controlnet_conditioning_scale, float):
+                raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+        elif (
+            isinstance(self.controlnet, MultiControlNetModel)
+            or is_compiled
+            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+        ):
+            if isinstance(controlnet_conditioning_scale, list):
+                if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+                    raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
+            elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+                self.controlnet.nets
+            ):
+                raise ValueError(
+                    "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+                    " the same length as the number of controlnets"
+                )
+        else:
+            assert False
+
+        if not isinstance(control_guidance_start, (tuple, list)):
+            control_guidance_start = [control_guidance_start]
+
+        if not isinstance(control_guidance_end, (tuple, list)):
+            control_guidance_end = [control_guidance_end]
+
+        if len(control_guidance_start) != len(control_guidance_end):
+            raise ValueError(
+                f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+            )
+
+        if isinstance(self.controlnet, MultiControlNetModel):
+            if len(control_guidance_start) != len(self.controlnet.nets):
+                raise ValueError(
+                    f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+                )
+
+        for start, end in zip(control_guidance_start, control_guidance_end):
+            if start >= end:
+                raise ValueError(
+                    f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+                )
+            if start < 0.0:
+                raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+            if end > 1.0:
+                raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+    def check_image(self, image, prompt, prompt_embeds):
+        image_is_pil = isinstance(image, PIL.Image.Image)
+        image_is_tensor = isinstance(image, torch.Tensor)
+        image_is_np = isinstance(image, np.ndarray)
+        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+        if (
+            not image_is_pil
+            and not image_is_tensor
+            and not image_is_np
+            and not image_is_pil_list
+            and not image_is_tensor_list
+            and not image_is_np_list
+        ):
+            raise TypeError(
+                f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+            )
+
+        if image_is_pil:
+            image_batch_size = 1
+        else:
+            image_batch_size = len(image)
+
+        if prompt is not None and isinstance(prompt, str):
+            prompt_batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            prompt_batch_size = len(prompt)
+        elif prompt_embeds is not None:
+            prompt_batch_size = prompt_embeds.shape[0]
+
+        if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+            raise ValueError(
+                f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+            )
+
+    def prepare_image(
+        self,
+        image,
+        width,
+        height,
+        batch_size,
+        num_images_per_prompt,
+        device,
+        dtype,
+        do_classifier_free_guidance=False,
+        guess_mode=False,
+    ):
+        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+        image_batch_size = image.shape[0]
+
+        if image_batch_size == 1:
+            repeat_by = batch_size
+        else:
+            # image batch size is the same as prompt batch size
+            repeat_by = num_images_per_prompt
+
+        image = image.repeat_interleave(repeat_by, dim=0)
+
+        image = image.to(device=device, dtype=dtype)
+
+        if do_classifier_free_guidance and not guess_mode:
+            image = torch.cat([image] * 2)
+
+        return image
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+        if isinstance(generator, list) and len(generator) != batch_size:
+            raise ValueError(
+                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+            )
+
+        if latents is None:
+            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+        else:
+            latents = latents.to(device)
+
+        # scale the initial noise by the standard deviation required by the scheduler
+        latents = latents * self.scheduler.init_noise_sigma
+        return latents
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
+    def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+        r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+        The suffixes after the scaling factors represent the stages where they are being applied.
+
+        Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+        that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+        Args:
+            s1 (`float`):
+                Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+                mitigate "oversmoothing effect" in the enhanced denoising process.
+            s2 (`float`):
+                Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+                mitigate "oversmoothing effect" in the enhanced denoising process.
+            b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+            b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+        """
+        if not hasattr(self, "unet"):
+            raise ValueError("The pipeline must have `unet` for using FreeU.")
+        self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
+    def disable_freeu(self):
+        """Disables the FreeU mechanism if enabled."""
+        self.unet.disable_freeu()
+
+    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+        """
+        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+        Args:
+            timesteps (`torch.Tensor`):
+                generate embedding vectors at these timesteps
+            embedding_dim (`int`, *optional*, defaults to 512):
+                dimension of the embeddings to generate
+            dtype:
+                data type of the generated embeddings
+
+        Returns:
+            `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+        """
+        assert len(w.shape) == 1
+        w = w * 1000.0
+
+        half_dim = embedding_dim // 2
+        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+        emb = w.to(dtype)[:, None] * emb[None, :]
+        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+        if embedding_dim % 2 == 1:  # zero pad
+            emb = torch.nn.functional.pad(emb, (0, 1))
+        assert emb.shape == (w.shape[0], embedding_dim)
+        return emb
+
+    @property
+    def guidance_scale(self):
+        return self._guidance_scale
+
+    @property
+    def clip_skip(self):
+        return self._clip_skip
+
+    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+    # corresponds to doing no classifier free guidance.
+    @property
+    def do_classifier_free_guidance(self):
+        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+    @property
+    def cross_attention_kwargs(self):
+        return self._cross_attention_kwargs
+
+    @property
+    def num_timesteps(self):
+        return self._num_timesteps
+
+    @torch.no_grad()
+    @replace_example_docstring(EXAMPLE_DOC_STRING)
+    def __call__(
+        self,
+        prompt: Union[str, List[str]] = None,
+        image: PipelineImageInput = None,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        num_inference_steps: int = 50,
+        timesteps: List[int] = None,
+        guidance_scale: float = 7.5,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: float = 0.0,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        ip_adapter_image: Optional[PipelineImageInput] = None,
+        ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+        controlnet_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
+        guess_mode: bool = False,
+        control_guidance_start: Union[float, List[float]] = 0.0,
+        control_guidance_end: Union[float, List[float]] = 1.0,
+        clip_skip: Optional[int] = None,
+        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+        **kwargs,
+    ):
+        r"""
+        The call function to the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+            image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+                    `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+                The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+                specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
+                accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
+                and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
+                `init`, images must be passed as a list such that each element of the list can be correctly batched for
+                input to a single ControlNet.
+            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+                The height in pixels of the generated image.
+            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+                The width in pixels of the generated image.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            timesteps (`List[int]`, *optional*):
+                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+                passed will be used. Must be in descending order.
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                A higher guidance scale value encourages the model to generate images closely linked to the text
+                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+                generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor is generated by sampling using the supplied random `generator`.
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+                provided, text embeddings are generated from the `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+                plain tuple.
+            callback (`Callable`, *optional*):
+                A function that calls every `callback_steps` steps during inference. The function is called with the
+                following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+            callback_steps (`int`, *optional*, defaults to 1):
+                The frequency at which the `callback` function is called. If not specified, the callback is called at
+                every step.
+            cross_attention_kwargs (`dict`, *optional*):
+                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+                The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+                to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+                the corresponding scale as a list.
+            guess_mode (`bool`, *optional*, defaults to `False`):
+                The ControlNet encoder tries to recognize the content of the input image even if you remove all
+                prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
+            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+                The percentage of total steps at which the ControlNet starts applying.
+            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+                The percentage of total steps at which the ControlNet stops applying.
+            clip_skip (`int`, *optional*):
+                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+                the output of the pre-final layer will be used for computing the prompt embeddings.
+            callback_on_step_end (`Callable`, *optional*):
+                A function that calls at the end of each denoising steps during the inference. The function is called
+                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+                `callback_on_step_end_tensor_inputs`.
+            callback_on_step_end_tensor_inputs (`List`, *optional*):
+                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+                `._callback_tensor_inputs` attribute of your pipeine class.
+
+        Examples:
+
+        Returns:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+                otherwise a `tuple` is returned where the first element is a list with the generated images and the
+                second element is a list of `bool`s indicating whether the corresponding generated image contains
+                "not-safe-for-work" (nsfw) content.
+        """
+
+        callback = kwargs.pop("callback", None)
+        callback_steps = kwargs.pop("callback_steps", None)
+
+        if callback is not None:
+            deprecate(
+                "callback",
+                "1.0.0",
+                "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+            )
+        if callback_steps is not None:
+            deprecate(
+                "callback_steps",
+                "1.0.0",
+                "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+            )
+
+        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+        # align format for control guidance
+        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+            control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+            control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+            control_guidance_start, control_guidance_end = (
+                mult * [control_guidance_start],
+                mult * [control_guidance_end],
+            )
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(
+            prompt,
+            image,
+            callback_steps,
+            negative_prompt,
+            prompt_embeds,
+            negative_prompt_embeds,
+            controlnet_conditioning_scale,
+            control_guidance_start,
+            control_guidance_end,
+            callback_on_step_end_tensor_inputs,
+        )
+
+        self._guidance_scale = guidance_scale
+        self._clip_skip = clip_skip
+        self._cross_attention_kwargs = cross_attention_kwargs
+
+        # 2. Define call parameters
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        device = self._execution_device
+
+        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+        global_pool_conditions = (
+            controlnet.config.global_pool_conditions
+            if isinstance(controlnet, ControlNetModel)
+            else controlnet.nets[0].config.global_pool_conditions
+        )
+        guess_mode = guess_mode or global_pool_conditions
+
+        # 3. Encode input prompt
+        text_encoder_lora_scale = (
+            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+        )
+        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+            prompt,
+            device,
+            num_images_per_prompt,
+            self.do_classifier_free_guidance,
+            negative_prompt,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            lora_scale=text_encoder_lora_scale,
+            clip_skip=self.clip_skip,
+        )
+        # For classifier free guidance, we need to do two forward passes.
+        # Here we concatenate the unconditional and text embeddings into a single batch
+        # to avoid doing two forward passes
+        if self.do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+        
+        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+            image_embeds = self.prepare_ip_adapter_image_embeds(
+                ip_adapter_image,
+                ip_adapter_image_embeds,
+                device,
+                batch_size * num_images_per_prompt,
+                self.do_classifier_free_guidance,
+            )
+            
+        # 4. Prepare image
+        if isinstance(controlnet, ControlNetModel):
+            image = self.prepare_image(
+                image=image,
+                width=width,
+                height=height,
+                batch_size=batch_size * num_images_per_prompt,
+                num_images_per_prompt=num_images_per_prompt,
+                device=device,
+                dtype=controlnet.dtype,
+                do_classifier_free_guidance=self.do_classifier_free_guidance,
+                guess_mode=guess_mode,
+            )
+            height, width = image.shape[-2:]
+        elif isinstance(controlnet, MultiControlNetModel):
+            images = []
+
+            for image_ in image:
+                image_ = self.prepare_image(
+                    image=image_,
+                    width=width,
+                    height=height,
+                    batch_size=batch_size * num_images_per_prompt,
+                    num_images_per_prompt=num_images_per_prompt,
+                    device=device,
+                    dtype=controlnet.dtype,
+                    do_classifier_free_guidance=self.do_classifier_free_guidance,
+                    guess_mode=guess_mode,
+                )
+
+                images.append(image_)
+
+            image = images
+            height, width = image[0].shape[-2:]
+        else:
+            assert False
+
+        # 5. Prepare timesteps
+        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+        self._num_timesteps = len(timesteps)
+
+        # 6. Prepare latent variables
+        num_channels_latents = self.unet.config.in_channels
+        latents = self.prepare_latents(
+            batch_size * num_images_per_prompt,
+            num_channels_latents,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            latents,
+        )
+
+        # 6.5 Optionally get Guidance Scale Embedding
+        timestep_cond = None
+        if self.unet.config.time_cond_proj_dim is not None:
+            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+            timestep_cond = self.get_guidance_scale_embedding(
+                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+            ).to(device=device, dtype=latents.dtype)
+
+        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 7.1 Add image embeds for IP-Adapter
+        added_cond_kwargs = {"image_embeds": image_embeds} if image_embeds is not None else None
+
+        # 7.2 Create tensor stating which controlnets to keep
+        controlnet_keep = []
+        for i in range(len(timesteps)):
+            keeps = [
+                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+                for s, e in zip(control_guidance_start, control_guidance_end)
+            ]
+            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
+
+        # 8. Denoising loop
+        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+        is_unet_compiled = is_compiled_module(self.unet)
+        is_controlnet_compiled = is_compiled_module(self.controlnet)
+        is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            for i, t in enumerate(timesteps):
+                # Relevant thread:
+                # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
+                if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
+                    torch._inductor.cudagraph_mark_step_begin()
+                # expand the latents if we are doing classifier free guidance
+                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+                # controlnet(s) inference
+                if guess_mode and self.do_classifier_free_guidance:
+                    # Infer ControlNet only for the conditional batch.
+                    control_model_input = latents
+                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+                else:
+                    control_model_input = latent_model_input
+                    controlnet_prompt_embeds = prompt_embeds
+
+                if isinstance(controlnet_keep[i], list):
+                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+                else:
+                    controlnet_cond_scale = controlnet_conditioning_scale
+                    if isinstance(controlnet_cond_scale, list):
+                        controlnet_cond_scale = controlnet_cond_scale[0]
+                    cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+                down_block_res_samples, mid_block_res_sample = self.controlnet(
+                    control_model_input,
+                    t,
+                    encoder_hidden_states=controlnet_prompt_embeds,
+                    controlnet_cond=image,
+                    conditioning_scale=cond_scale,
+                    guess_mode=guess_mode,
+                    return_dict=False,
+                )
+
+                if guess_mode and self.do_classifier_free_guidance:
+                    # Infered ControlNet only for the conditional batch.
+                    # To apply the output of ControlNet to both the unconditional and conditional batches,
+                    # add 0 to the unconditional batch to keep it unchanged.
+                    down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+                    mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+                # predict the noise residual
+                noise_pred = self.unet(
+                    latent_model_input,
+                    t,
+                    encoder_hidden_states=prompt_embeds,
+                    timestep_cond=timestep_cond,
+                    cross_attention_kwargs=self.cross_attention_kwargs,
+                    down_block_additional_residuals=down_block_res_samples,
+                    mid_block_additional_residual=mid_block_res_sample,
+                    added_cond_kwargs=added_cond_kwargs,
+                    return_dict=False,
+                )[0]
+
+                # perform guidance
+                if self.do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+                if callback_on_step_end is not None:
+                    callback_kwargs = {}
+                    for k in callback_on_step_end_tensor_inputs:
+                        callback_kwargs[k] = locals()[k]
+                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                    latents = callback_outputs.pop("latents", latents)
+                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+                # call the callback, if provided
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+                    if callback is not None and i % callback_steps == 0:
+                        step_idx = i // getattr(self.scheduler, "order", 1)
+                        callback(step_idx, t, latents)
+
+        # If we do sequential model offloading, let's offload unet and controlnet
+        # manually for max memory savings
+        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+            self.unet.to("cpu")
+            self.controlnet.to("cpu")
+            torch.cuda.empty_cache()
+
+        if not output_type == "latent":
+            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+                0
+            ]
+            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+        else:
+            image = latents
+            has_nsfw_concept = None
+
+        if has_nsfw_concept is None:
+            do_denormalize = [True] * image.shape[0]
+        else:
+            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+        # Offload all models
+        self.maybe_free_model_hooks()
+
+        if not return_dict:
+            return (image, has_nsfw_concept)
+
+        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/foleycrafter/utils/audio_to_mel_af.py b/foleycrafter/utils/audio_to_mel_af.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0335eba4637457ca78ff5990f86b085bef49f59
--- /dev/null
+++ b/foleycrafter/utils/audio_to_mel_af.py
@@ -0,0 +1,181 @@
+import numpy as np
+from PIL import Image
+
+import math
+import os
+import random
+import torch
+import json
+import torch.utils.data
+import numpy as np
+import librosa
+from librosa.util import normalize
+from scipy.io.wavfile import read
+from librosa.filters import mel as librosa_mel_fn
+
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+
+MAX_WAV_VALUE = 32768.0
+
+
+def load_wav(full_path):
+    sampling_rate, data = read(full_path)
+    return data, sampling_rate
+
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+    return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+    return np.exp(x) / C
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+    return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+    return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+    output = dynamic_range_compression_torch(magnitudes)
+    return output
+
+
+def spectral_de_normalize_torch(magnitudes):
+    output = dynamic_range_decompression_torch(magnitudes)
+    return output
+
+
+mel_basis = {}
+hann_window = {}
+
+
+def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
+    if torch.min(y) < -1.:
+        print('min value is ', torch.min(y))
+    if torch.max(y) > 1.:
+        print('max value is ', torch.max(y))
+
+    global mel_basis, hann_window
+    if fmax not in mel_basis:
+        mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
+        mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
+        hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
+
+    y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
+    y = y.squeeze(1)
+
+    # complex tensor as default, then use view_as_real for future pytorch compatibility
+    spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
+                      center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
+    spec = torch.view_as_real(spec)
+    spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
+
+    spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
+    spec = spectral_normalize_torch(spec)
+
+    return spec
+
+
+def spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
+    if torch.min(y) < -1.:
+        print('min value is ', torch.min(y))
+    if torch.max(y) > 1.:
+        print('max value is ', torch.max(y))
+
+    global hann_window
+    hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
+
+    y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
+    y = y.squeeze(1)
+
+    # complex tensor as default, then use view_as_real for future pytorch compatibility
+    spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
+                      center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
+    spec = torch.view_as_real(spec)
+    spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
+
+    return spec
+
+
+def normalize_spectrogram(
+    spectrogram: torch.Tensor,
+    max_value: float = 200, 
+    min_value: float = 1e-5, 
+    power: float = 1., 
+    inverse: bool = False
+) -> torch.Tensor:
+    # Rescale to 0-1
+    max_value = np.log(max_value) # 5.298317366548036
+    min_value = np.log(min_value) # -11.512925464970229
+
+    assert spectrogram.max() <= max_value and spectrogram.min() >= min_value
+
+    data = (spectrogram - min_value) / (max_value - min_value)
+
+    # Invert
+    if inverse:
+        data = 1 - data
+
+    # Apply the power curve
+    data = torch.pow(data, power)  
+    
+    # 1D -> 3D
+    data = data.unsqueeze(1)
+    # data = data.repeat(1, 3, 1, 1)
+    # (b f) (h w) c -> b f (h w) c -> b t (h w) c -> b t (h' w') c 
+
+    # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner
+    data = torch.flip(data, [1])
+
+    return data
+
+def denormalize_spectrogram(
+    data: torch.Tensor,
+    max_value: float = 200, 
+    min_value: float = 1e-5, 
+    power: float = 1, 
+    inverse: bool = False,
+) -> torch.Tensor:
+    
+    max_value = np.log(max_value)
+    min_value = np.log(min_value)
+
+    # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner
+    data = torch.flip(data, [1])
+
+    assert len(data.shape) == 3, "Expected 3 dimensions, got {}".format(len(data.shape))
+    
+    if data.shape[0] == 1:
+        data = data.repeat(3, 1, 1)
+        
+    assert data.shape[0] == 3, "Expected 3 channels, got {}".format(data.shape[0])
+    data = data[0]
+
+    # Reverse the power curve
+    data = torch.pow(data, 1 / power)
+
+    # Invert
+    if inverse:
+        data = 1 - data
+
+    # Rescale to max value
+    spectrogram = data * (max_value - min_value) + min_value
+
+    return spectrogram
+
+
+def get_mel_spectrogram_from_audio(audio):
+    # for auffusion 
+    spec = mel_spectrogram(audio, n_fft=2048, num_mels=256, sampling_rate=16000, hop_size=160, win_size=1024, fmin=0, fmax=8000, center=False)
+
+    # for audioldm
+    # spec = mel_spectrogram(audio, n_fft=1024, num_mels=64, sampling_rate=16000, hop_size=160, win_size=1024, fmin=0, fmax=8000, center=False)
+    spec = normalize_spectrogram(spec)
+    return spec
\ No newline at end of file
diff --git a/foleycrafter/utils/converter.py b/foleycrafter/utils/converter.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ecfaa22c7f17e024b7b7d0e142f4ab5785e13eb
--- /dev/null
+++ b/foleycrafter/utils/converter.py
@@ -0,0 +1,398 @@
+# Copy from https://github.com/happylittlecat2333/Auffusion/blob/main/converter.py
+import numpy as np
+from PIL import Image
+
+import math
+import os
+import random
+import torch
+import json
+import torch.utils.data
+import numpy as np
+import librosa
+# from librosa.util import normalize
+from scipy.io.wavfile import read
+from librosa.filters import mel as librosa_mel_fn
+
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+
+MAX_WAV_VALUE = 32768.0
+
+
+def load_wav(full_path):
+    sampling_rate, data = read(full_path)
+    return data, sampling_rate
+
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+    return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+    return np.exp(x) / C
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+    return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+    return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+    output = dynamic_range_compression_torch(magnitudes)
+    return output
+
+
+def spectral_de_normalize_torch(magnitudes):
+    output = dynamic_range_decompression_torch(magnitudes)
+    return output
+
+
+mel_basis = {}
+hann_window = {}
+
+
+def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
+    if torch.min(y) < -1.:
+        print('min value is ', torch.min(y))
+    if torch.max(y) > 1.:
+        print('max value is ', torch.max(y))
+
+    global mel_basis, hann_window
+    if fmax not in mel_basis:
+        mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
+        mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
+        hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
+
+    y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
+    y = y.squeeze(1)
+
+    # complex tensor as default, then use view_as_real for future pytorch compatibility
+    spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
+                      center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
+    spec = torch.view_as_real(spec)
+    spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
+
+    spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
+    spec = spectral_normalize_torch(spec)
+
+    return spec
+
+
+def spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
+    if torch.min(y) < -1.:
+        print('min value is ', torch.min(y))
+    if torch.max(y) > 1.:
+        print('max value is ', torch.max(y))
+
+    global hann_window
+    hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
+
+    y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
+    y = y.squeeze(1)
+
+    # complex tensor as default, then use view_as_real for future pytorch compatibility
+    spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
+                      center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
+    spec = torch.view_as_real(spec)
+    spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
+
+    return spec
+
+
+def normalize_spectrogram(
+    spectrogram: torch.Tensor,
+    max_value: float = 200, 
+    min_value: float = 1e-5, 
+    power: float = 1., 
+    inverse: bool = False
+) -> torch.Tensor:
+    
+    # Rescale to 0-1
+    max_value = np.log(max_value) # 5.298317366548036
+    min_value = np.log(min_value) # -11.512925464970229
+
+    assert spectrogram.max() <= max_value and spectrogram.min() >= min_value
+
+    data = (spectrogram - min_value) / (max_value - min_value)
+
+    # Invert
+    if inverse:
+        data = 1 - data
+
+    # Apply the power curve
+    data = torch.pow(data, power)  
+    
+    # 1D -> 3D
+    data = data.repeat(3, 1, 1)
+
+    # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner
+    data = torch.flip(data, [1])
+
+    return data
+
+
+
+def denormalize_spectrogram(
+    data: torch.Tensor,
+    max_value: float = 200, 
+    min_value: float = 1e-5, 
+    power: float = 1, 
+    inverse: bool = False,
+) -> torch.Tensor:
+    
+    max_value = np.log(max_value)
+    min_value = np.log(min_value)
+
+    # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner
+    data = torch.flip(data, [1])
+
+    assert len(data.shape) == 3, "Expected 3 dimensions, got {}".format(len(data.shape))
+    
+    if data.shape[0] == 1:
+        data = data.repeat(3, 1, 1)
+        
+    assert data.shape[0] == 3, "Expected 3 channels, got {}".format(data.shape[0])
+    data = data[0]
+
+    # Reverse the power curve
+    data = torch.pow(data, 1 / power)
+
+    # Invert
+    if inverse:
+        data = 1 - data
+
+    # Rescale to max value
+    spectrogram = data * (max_value - min_value) + min_value
+
+    return spectrogram
+
+
+def get_mel_spectrogram_from_audio(audio, device="cpu"):
+    audio = audio / MAX_WAV_VALUE
+    audio = librosa.util.normalize(audio) * 0.95
+    # print(' >>> normalize done <<< ')
+        
+    audio = torch.FloatTensor(audio)
+    audio = audio.unsqueeze(0)    
+
+    waveform = audio.to(device)
+    spec = mel_spectrogram(waveform, n_fft=2048, num_mels=256, sampling_rate=16000, hop_size=160, win_size=1024, fmin=0, fmax=8000, center=False)
+    return audio, spec
+
+
+
+LRELU_SLOPE = 0.1
+MAX_WAV_VALUE = 32768.0
+
+
+class AttrDict(dict):
+    def __init__(self, *args, **kwargs):
+        super(AttrDict, self).__init__(*args, **kwargs)
+        self.__dict__ = self
+
+
+def get_config(config_path):
+    config = json.loads(open(config_path).read())
+    config = AttrDict(config)
+    return config
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def apply_weight_norm(m):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        weight_norm(m)
+
+
+def get_padding(kernel_size, dilation=1):
+    return int((kernel_size*dilation - dilation)/2)
+
+
+class ResBlock1(torch.nn.Module):
+    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super(ResBlock1, self).__init__()
+        self.h = h
+        self.convs1 = nn.ModuleList([
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+                               padding=get_padding(kernel_size, dilation[0]))),
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+                               padding=get_padding(kernel_size, dilation[1]))),
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
+                               padding=get_padding(kernel_size, dilation[2])))
+        ])
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList([
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+                               padding=get_padding(kernel_size, 1))),
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+                               padding=get_padding(kernel_size, 1))),
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+                               padding=get_padding(kernel_size, 1)))
+        ])
+        self.convs2.apply(init_weights)
+
+    def forward(self, x):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            xt = c1(xt)
+            xt = F.leaky_relu(xt, LRELU_SLOPE)
+            xt = c2(xt)
+            x = xt + x
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs1:
+            remove_weight_norm(l)
+        for l in self.convs2:
+            remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
+        super(ResBlock2, self).__init__()
+        self.h = h
+        self.convs = nn.ModuleList([
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+                               padding=get_padding(kernel_size, dilation[0]))),
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+                               padding=get_padding(kernel_size, dilation[1])))
+        ])
+        self.convs.apply(init_weights)
+
+    def forward(self, x):
+        for c in self.convs:
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            xt = c(xt)
+            x = xt + x
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs:
+            remove_weight_norm(l)
+
+
+
+class Generator(torch.nn.Module):
+    def __init__(self, h):
+        super(Generator, self).__init__()
+        self.h = h
+        self.num_kernels = len(h.resblock_kernel_sizes)
+        self.num_upsamples = len(h.upsample_rates)
+        self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)) # change: 80 --> 512
+        resblock = ResBlock1 if h.resblock == '1' else ResBlock2
+
+        self.ups = nn.ModuleList()
+        for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
+            if (k-u) % 2 == 0:
+                self.ups.append(weight_norm(
+                    ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
+                                    k, u, padding=(k-u)//2)))
+            else:
+                self.ups.append(weight_norm(
+                    ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
+                                    k, u, padding=(k-u)//2+1, output_padding=1)))
+            
+            # self.ups.append(weight_norm(
+            #     ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
+            #                     k, u, padding=(k-u)//2)))
+            
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = h.upsample_initial_channel//(2**(i+1))
+            for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
+                self.resblocks.append(resblock(h, ch, k, d))
+
+        self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+        self.ups.apply(init_weights)
+        self.conv_post.apply(init_weights)
+
+    def forward(self, x):
+        x = self.conv_pre(x)
+        for i in range(self.num_upsamples):
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            x = self.ups[i](x)
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.resblocks[i*self.num_kernels+j](x)
+                else:
+                    xs += self.resblocks[i*self.num_kernels+j](x)
+            x = xs / self.num_kernels
+        x = F.leaky_relu(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.ups:
+            remove_weight_norm(l)
+        for l in self.resblocks:
+            l.remove_weight_norm()
+        remove_weight_norm(self.conv_pre)
+        remove_weight_norm(self.conv_post)
+
+    @classmethod
+    def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None):
+        if subfolder is not None:
+            pretrained_model_name_or_path = os.path.join(pretrained_model_name_or_path, subfolder)
+        config_path = os.path.join(pretrained_model_name_or_path, "config.json")
+        ckpt_path = os.path.join(pretrained_model_name_or_path, "vocoder.pt")
+
+        config = get_config(config_path)
+        vocoder = cls(config)
+
+        state_dict_g = torch.load(ckpt_path)
+        vocoder.load_state_dict(state_dict_g["generator"])
+        vocoder.eval()
+        vocoder.remove_weight_norm()
+        return vocoder    
+    
+    
+    @torch.no_grad()
+    def inference(self, mels, lengths=None):
+        self.eval()
+        with torch.no_grad():
+            wavs = self(mels).squeeze(1)
+
+        wavs = (wavs.cpu().numpy() * MAX_WAV_VALUE).astype("int16")
+
+        if lengths is not None:
+            wavs = wavs[:, :lengths]
+
+        return wavs
+    
+def normalize(images):
+    """
+    Normalize an image array to [-1,1].
+    """
+    if images.min() >= 0:
+        return 2.0 * images - 1.0
+    else:
+        return images
+
+def pad_spec(spec, spec_length, pad_value=0, random_crop=True): # spec: [3, mel_dim, spec_len]
+    assert spec_length % 8 == 0, "spec_length must be divisible by 8"
+    if spec.shape[-1] < spec_length:
+        # pad spec to spec_length
+        spec = F.pad(spec, (0, spec_length - spec.shape[-1]), value=pad_value)
+    else:
+        # random crop
+        if random_crop:
+            start = random.randint(0, spec.shape[-1] - spec_length)
+            spec = spec[:, :, start:start+spec_length]
+        else:
+            spec = spec[:, :, :spec_length]
+    return spec
\ No newline at end of file
diff --git a/foleycrafter/utils/spec_to_mel.py b/foleycrafter/utils/spec_to_mel.py
new file mode 100644
index 0000000000000000000000000000000000000000..b77358dd8ae5af3473da0c8f25834ab2c2596a27
--- /dev/null
+++ b/foleycrafter/utils/spec_to_mel.py
@@ -0,0 +1,403 @@
+import torch
+import torchaudio
+import torch.nn.functional as F
+import numpy as np
+from scipy.signal import get_window
+import librosa.util as librosa_util
+from librosa.util import pad_center, tiny
+from librosa.filters import mel as librosa_mel_fn
+import io
+# spectrogram to mel
+
+class STFT(torch.nn.Module):
+    """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
+
+    def __init__(self, filter_length, hop_length, win_length, window="hann"):
+        super(STFT, self).__init__()
+        self.filter_length = filter_length
+        self.hop_length = hop_length
+        self.win_length = win_length
+        self.window = window
+        self.forward_transform = None
+        scale = self.filter_length / self.hop_length
+        fourier_basis = np.fft.fft(np.eye(self.filter_length))
+
+        cutoff = int((self.filter_length / 2 + 1))
+        fourier_basis = np.vstack(
+            [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
+        )
+
+        forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
+        inverse_basis = torch.FloatTensor(
+            np.linalg.pinv(scale * fourier_basis).T[:, None, :]
+        )
+
+        if window is not None:
+            assert filter_length >= win_length
+            # get window and zero center pad it to filter_length
+            fft_window = get_window(window, win_length, fftbins=True)
+            fft_window = pad_center(fft_window, filter_length)
+            fft_window = torch.from_numpy(fft_window).float()
+
+            # window the bases
+            forward_basis *= fft_window
+            inverse_basis *= fft_window
+
+        self.register_buffer("forward_basis", forward_basis.float())
+        self.register_buffer("inverse_basis", inverse_basis.float())
+
+    def transform(self, input_data):
+        num_batches = input_data.size(0)
+        num_samples = input_data.size(1)
+
+        self.num_samples = num_samples
+
+        # similar to librosa, reflect-pad the input
+        input_data = input_data.view(num_batches, 1, num_samples)
+        input_data = F.pad(
+            input_data.unsqueeze(1),
+            (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
+            mode="reflect",
+        )
+        input_data = input_data.squeeze(1)
+
+        forward_transform = F.conv1d(
+            input_data,
+            torch.autograd.Variable(self.forward_basis, requires_grad=False),
+            stride=self.hop_length,
+            padding=0,
+        ).cpu()
+
+        cutoff = int((self.filter_length / 2) + 1)
+        real_part = forward_transform[:, :cutoff, :]
+        imag_part = forward_transform[:, cutoff:, :]
+
+        magnitude = torch.sqrt(real_part**2 + imag_part**2)
+        phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
+
+        return magnitude, phase
+
+    def inverse(self, magnitude, phase):
+        recombine_magnitude_phase = torch.cat(
+            [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
+        )
+
+        inverse_transform = F.conv_transpose1d(
+            recombine_magnitude_phase,
+            torch.autograd.Variable(self.inverse_basis, requires_grad=False),
+            stride=self.hop_length,
+            padding=0,
+        )
+
+        if self.window is not None:
+            window_sum = window_sumsquare(
+                self.window,
+                magnitude.size(-1),
+                hop_length=self.hop_length,
+                win_length=self.win_length,
+                n_fft=self.filter_length,
+                dtype=np.float32,
+            )
+            # remove modulation effects
+            approx_nonzero_indices = torch.from_numpy(
+                np.where(window_sum > tiny(window_sum))[0]
+            )
+            window_sum = torch.autograd.Variable(
+                torch.from_numpy(window_sum), requires_grad=False
+            )
+            window_sum = window_sum
+            inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
+                approx_nonzero_indices
+            ]
+
+            # scale by hop ratio
+            inverse_transform *= float(self.filter_length) / self.hop_length
+
+        inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
+        inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
+
+        return inverse_transform
+
+    def forward(self, input_data):
+        self.magnitude, self.phase = self.transform(input_data)
+        reconstruction = self.inverse(self.magnitude, self.phase)
+        return reconstruction
+    
+def window_sumsquare(
+    window,
+    n_frames,
+    hop_length,
+    win_length,
+    n_fft,
+    dtype=np.float32,
+    norm=None,
+):
+    """
+    # from librosa 0.6
+    Compute the sum-square envelope of a window function at a given hop length.
+
+    This is used to estimate modulation effects induced by windowing
+    observations in short-time fourier transforms.
+
+    Parameters
+    ----------
+    window : string, tuple, number, callable, or list-like
+        Window specification, as in `get_window`
+
+    n_frames : int > 0
+        The number of analysis frames
+
+    hop_length : int > 0
+        The number of samples to advance between frames
+
+    win_length : [optional]
+        The length of the window function.  By default, this matches `n_fft`.
+
+    n_fft : int > 0
+        The length of each analysis frame.
+
+    dtype : np.dtype
+        The data type of the output
+
+    Returns
+    -------
+    wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
+        The sum-squared envelope of the window function
+    """
+    if win_length is None:
+        win_length = n_fft
+
+    n = n_fft + hop_length * (n_frames - 1)
+    x = np.zeros(n, dtype=dtype)
+
+    # Compute the squared window at the desired length
+    win_sq = get_window(window, win_length, fftbins=True)
+    win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
+    win_sq = librosa_util.pad_center(win_sq, n_fft)
+
+    # Fill the envelope
+    for i in range(n_frames):
+        sample = i * hop_length
+        x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
+    return x
+
+
+def griffin_lim(magnitudes, stft_fn, n_iters=30):
+    """
+    PARAMS
+    ------
+    magnitudes: spectrogram magnitudes
+    stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
+    """
+
+    angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
+    angles = angles.astype(np.float32)
+    angles = torch.autograd.Variable(torch.from_numpy(angles))
+    signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
+
+    for i in range(n_iters):
+        _, angles = stft_fn.transform(signal)
+        signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
+    return signal
+
+def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
+    """
+    PARAMS
+    ------
+    C: compression factor
+    """
+    return normalize_fun(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+    """
+    PARAMS
+    ------
+    C: compression factor used to compress
+    """
+    return torch.exp(x) / C 
+class TacotronSTFT(torch.nn.Module):
+    def __init__(
+        self,
+        filter_length,
+        hop_length,
+        win_length,
+        n_mel_channels,
+        sampling_rate,
+        mel_fmin,
+        mel_fmax,
+    ):
+        super(TacotronSTFT, self).__init__()
+        self.n_mel_channels = n_mel_channels
+        self.sampling_rate = sampling_rate
+        self.stft_fn = STFT(filter_length, hop_length, win_length)
+        mel_basis = librosa_mel_fn(
+            sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
+        )
+        mel_basis = torch.from_numpy(mel_basis).float()
+        self.register_buffer("mel_basis", mel_basis)
+
+    def spectral_normalize(self, magnitudes, normalize_fun):
+        output = dynamic_range_compression(magnitudes, normalize_fun)
+        return output
+
+    def spectral_de_normalize(self, magnitudes):
+        output = dynamic_range_decompression(magnitudes)
+        return output
+
+    def mel_spectrogram(self, y, normalize_fun=torch.log):
+        """Computes mel-spectrograms from a batch of waves
+        PARAMS
+        ------
+        y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
+
+        RETURNS
+        -------
+        mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
+        """
+        assert torch.min(y.data) >= -1, torch.min(y.data)
+        assert torch.max(y.data) <= 1, torch.max(y.data)
+
+        magnitudes, phases = self.stft_fn.transform(y)
+        magnitudes = magnitudes.data
+        mel_output = torch.matmul(self.mel_basis, magnitudes)
+        mel_output = self.spectral_normalize(mel_output, normalize_fun)
+        energy = torch.norm(magnitudes, dim=1)
+
+        log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
+
+        return mel_output, log_magnitudes, energy
+    
+def pad_wav(waveform, segment_length):
+    waveform_length = waveform.shape[-1]
+    assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
+    if segment_length is None or waveform_length == segment_length:
+        return waveform
+    elif waveform_length > segment_length:
+        return waveform[:,:segment_length]
+    elif waveform_length < segment_length:
+        temp_wav = np.zeros((1, segment_length))
+        temp_wav[:, :waveform_length] = waveform
+    return temp_wav
+
+def normalize_wav(waveform):
+    waveform = waveform - np.mean(waveform)
+    waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
+    return waveform * 0.5
+
+def _pad_spec(fbank, target_length=1024):
+    n_frames = fbank.shape[0]
+    p = target_length - n_frames
+    # cut and pad
+    if p > 0:
+        m = torch.nn.ZeroPad2d((0, 0, 0, p))
+        fbank = m(fbank)
+    elif p < 0:
+        fbank = fbank[0:target_length, :]
+
+    if fbank.size(-1) % 2 != 0:
+        fbank = fbank[..., :-1]
+
+    return fbank
+
+def get_mel_from_wav(audio, _stft):
+    audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
+    audio = torch.autograd.Variable(audio, requires_grad=False)
+    melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
+    melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
+    log_magnitudes_stft = (
+        torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32)
+    )
+    energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
+    return melspec, log_magnitudes_stft, energy 
+
+def read_wav_file_io(bytes):
+    # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
+    waveform, sr = torchaudio.load(bytes, format='mp4')  # Faster!!!
+    waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
+    # waveform = waveform.numpy()[0, ...]
+    # waveform = normalize_wav(waveform)
+    # waveform = waveform[None, ...]
+
+    # waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
+    # waveform = 0.5 * waveform
+    
+    return waveform
+
+def load_audio(bytes, sample_rate=16000):
+    waveform, sr = torchaudio.load(bytes, format='mp4')
+    waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
+    return waveform
+
+def read_wav_file(filename):
+    # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
+    waveform, sr = torchaudio.load(filename)  # Faster!!!
+    waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
+    waveform = waveform.numpy()[0, ...]
+    waveform = normalize_wav(waveform)
+    waveform = waveform[None, ...]
+    
+    waveform = waveform / np.max(np.abs(waveform))
+    waveform = 0.5 * waveform
+    
+    return waveform
+
+def norm_wav_tensor(waveform: torch.FloatTensor):
+    waveform = waveform.numpy()[0, ...]
+    waveform = normalize_wav(waveform)
+    waveform = waveform[None, ...]
+    waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
+    waveform = 0.5 * waveform
+    return waveform
+
+def wav_to_fbank(filename, target_length=1024, fn_STFT=None):
+    if fn_STFT is None:
+        fn_STFT = TacotronSTFT(
+            1024, # filter_length
+            160, # hop_length
+            1024, # win_length
+            64, # n_mel
+            16000, # sample_rate
+            0, # fmin
+            8000, # fmax
+        )
+
+    # mixup
+    waveform = read_wav_file(filename, target_length * 160)  # hop size is 160
+
+    waveform = waveform[0, ...]
+    waveform = torch.FloatTensor(waveform)
+
+    fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
+
+    fbank = torch.FloatTensor(fbank.T)
+    log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T)
+
+    fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
+        log_magnitudes_stft, target_length
+    )
+
+    return fbank, log_magnitudes_stft, waveform
+
+def wav_tensor_to_fbank(waveform, target_length=512, fn_STFT=None):
+    if fn_STFT is None:
+        fn_STFT = TacotronSTFT(
+            1024, # filter_length
+            160, # hop_length
+            1024, # win_length
+            256, # n_mel
+            16000, # sample_rate
+            0, # fmin
+            8000, # fmax
+        ) # In practice used
+
+    fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
+
+    fbank = torch.FloatTensor(fbank.T) 
+    log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T)
+
+    fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
+        log_magnitudes_stft, target_length
+    )
+
+    return fbank
\ No newline at end of file
diff --git a/foleycrafter/utils/util.py b/foleycrafter/utils/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd135cc41f2f7619deed8cf58ac016eb02faa2ab
--- /dev/null
+++ b/foleycrafter/utils/util.py
@@ -0,0 +1,1696 @@
+import torch
+import torchvision
+import torchaudio
+import torchvision.transforms as transforms
+from diffusers import UNet2DConditionModel, ControlNetModel
+from foleycrafter.pipelines.pipeline_controlnet import StableDiffusionControlNetPipeline
+from foleycrafter.pipelines.auffusion_pipeline import AuffusionNoAdapterPipeline, Generator
+from foleycrafter.models.auffusion_unet import UNet2DConditionModel as af_UNet2DConditionModel
+from diffusers.models import AutoencoderKLTemporalDecoder, AutoencoderKL
+from diffusers.schedulers import EulerDiscreteScheduler, DDIMScheduler, PNDMScheduler, KarrasDiffusionSchedulers
+from diffusers.utils.import_utils import is_xformers_available
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection,\
+    SpeechT5HifiGan, ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast,\
+    CLIPTextModel, CLIPTokenizer
+import glob
+from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip, VideoClip
+from moviepy.audio.AudioClip import AudioArrayClip
+import numpy as np
+from safetensors import safe_open
+import random 
+from typing import Union, Optional
+import decord
+import os
+import os.path as osp
+import imageio
+import soundfile as sf
+from PIL import Image, ImageOps
+import torch.distributed as dist
+import io
+from omegaconf import OmegaConf
+import json
+
+from dataclasses import dataclass
+from enum import Enum
+import typing as T
+import warnings
+import pydub
+from scipy.io import wavfile
+
+from einops import rearrange
+
+def zero_rank_print(s):
+    if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s, flush=True)
+
+def build_foleycrafter(
+    pretrained_model_name_or_path: str="auffusion/auffusion-full-no-adapter",
+) -> StableDiffusionControlNetPipeline:
+    vae               = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
+    unet              = af_UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') 
+    scheduler         = PNDMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
+    tokenizer         = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
+    text_encoder      = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
+
+    controlnet = ControlNetModel.from_unet(unet, conditioning_channels=1)
+
+    pipe = StableDiffusionControlNetPipeline(
+        vae=vae,
+        controlnet=controlnet,
+        unet=unet,
+        scheduler=scheduler,
+        tokenizer=tokenizer,
+        text_encoder=text_encoder,
+        feature_extractor=None,
+        safety_checker=None,
+        requires_safety_checker=False,
+    )
+
+    return pipe
+
+def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
+    if len(videos.shape) == 4:
+        videos = videos.unsqueeze(0)
+    videos = rearrange(videos, "b c t h w -> t b c h w")
+    outputs = []
+    for x in videos:
+        x = torchvision.utils.make_grid(x, nrow=n_rows)
+        x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+        if rescale:
+            x = (x + 1.0) / 2.0  # -1,1 -> 0,1
+        x = torch.clamp((x * 255), 0, 255).numpy().astype(np.uint8)
+        outputs.append(x)
+    os.makedirs(os.path.dirname(path), exist_ok=True)
+    imageio.mimsave(path, outputs, fps=fps)
+
+def save_videos_from_pil_list(videos: list, path: str, fps=7):
+    for i in range(len(videos)):
+        videos[i] = ImageOps.scale(videos[i], 255)
+ 
+    imageio.mimwrite(path, videos, fps=fps)
+
+
+def seed_everything(seed: int) -> None:
+    r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`,
+    :obj:`numpy` and :python:`Python`.
+
+    Args:
+        seed (int): The desired seed.
+    """
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+
+def get_video_frames(video: np.ndarray, num_frames: int=200):
+    video_length = video.shape[0]
+    video_idx = np.linspace(0, video_length-1, num_frames, dtype=int)
+    video = video[video_idx, ...]
+    return video
+
+def random_audio_video_clip(audio: np.ndarray, video: np.ndarray, fps:float, \
+                            sample_rate:int=16000, duration:int=5, num_frames: int=20):
+    """
+        Random sample video clips with duration
+    """
+    video_length = video.shape[0]
+    audio_length = audio.shape[-1]
+    av_duration  = int(video_length / fps)
+    assert av_duration >= duration,\
+    f"video duration {av_duration} is less than {duration}"
+
+    # random sample start time
+    start_time  = random.uniform(0, av_duration - duration)
+    end_time    = start_time + duration
+
+    start_idx, end_idx = start_time / av_duration, end_time / av_duration
+
+    video_start_frame, video_end_frame\
+                       = video_length * start_idx, video_length * end_idx
+    audio_start_frame, audio_end_frame\
+                       = audio_length * start_idx, audio_length * end_idx
+
+    # print(f"time_idx : {start_time}:{end_time}")
+    # print(f"video_idx: {video_start_frame}:{video_end_frame}")
+    # print(f"audio_idx: {audio_start_frame}:{audio_end_frame}")
+
+    audio_idx = np.linspace(audio_start_frame, audio_end_frame, sample_rate * duration, dtype=int)
+    video_idx = np.linspace(video_start_frame, video_end_frame, num_frames, dtype=int)
+
+    audio = audio[..., audio_idx]
+    video = video[video_idx, ...]
+    
+    return audio, video
+
+def get_full_indices(reader: Union[decord.VideoReader, decord.AudioReader])\
+    -> np.ndarray:
+    if isinstance(reader, decord.VideoReader):
+        return np.linspace(0, len(reader) - 1, len(reader), dtype=int)
+    elif isinstance(reader, decord.AudioReader):
+        return np.linspace(0, reader.shape[-1] - 1, reader.shape[-1], dtype=int)
+
+def get_frames(video_path:str, onset_list, frame_nums=1024):
+    video = decord.VideoReader(video_path)
+    video_frame = len(video)
+
+    frames_list = []
+    for start, end in onset_list:
+        video_start = int(start / frame_nums * video_frame)
+        video_end   = int(end   / frame_nums * video_frame)
+
+        frames_list.extend(range(video_start, video_end))
+    frames = video.get_batch(frames_list).asnumpy()
+    return frames
+
+def get_frames_in_video(video_path:str, onset_list, frame_nums=1024, audio_length_in_s=10):
+    # this function consider the video length
+    video = decord.VideoReader(video_path)
+    video_frame = len(video)
+    duration = video_frame / video.get_avg_fps()
+    frames_list = []
+    video_onset_list  = []
+    for start, end in onset_list:
+        if int(start / frame_nums * duration) >= audio_length_in_s:
+            continue
+        video_start = int(start / audio_length_in_s * duration / frame_nums * video_frame)
+        if video_start >= video_frame:
+            continue
+        video_end   = int(end   / audio_length_in_s * duration / frame_nums * video_frame)
+        video_onset_list.append([int(start / audio_length_in_s * duration), int(end / audio_length_in_s * duration)])
+        frames_list.extend(range(video_start, video_end))
+    frames = video.get_batch(frames_list).asnumpy()
+    return frames, video_onset_list
+
+def save_multimodal(video, audio, output_path, audio_fps:int=16000, video_fps:int=8, remove_audio:bool=True):
+    imgs = [img for img in video]
+    # if audio.shape[0] == 1 or audio.shape[0] == 2:
+    #     audio = audio.T #[len, channel]
+    # audio = np.repeat(audio, 2, axis=1)
+    output_dir = osp.dirname(output_path)
+    try:
+        wavfile.write(osp.join(output_dir, "audio.wav"), audio_fps, audio)
+    except:
+        sf.write(osp.join(output_dir, "audio.wav"), audio, audio_fps)
+    audio_clip = AudioFileClip(osp.join(output_dir, "audio.wav"))
+    # audio_clip = AudioArrayClip(audio, fps=audio_fps)
+    video_clip = ImageSequenceClip(imgs, fps=video_fps)
+    video_clip = video_clip.set_audio(audio_clip)
+    video_clip.write_videofile(output_path, video_fps, audio=True, audio_fps=audio_fps)
+    if remove_audio:
+        os.remove(osp.join(output_dir, "audio.wav"))
+    return
+
+def save_multimodal_by_frame(video, audio, output_path, audio_fps:int=16000):
+    imgs = [img for img in video]
+    # if audio.shape[0] == 1 or audio.shape[0] == 2:
+    #     audio = audio.T #[len, channel]
+    # audio = np.repeat(audio, 2, axis=1)
+    # output_dir = osp.dirname(output_path)
+    output_dir = output_path
+    wavfile.write(osp.join(output_dir, "audio.wav"), audio_fps, audio)
+    audio_clip = AudioFileClip(osp.join(output_dir, "audio.wav"))
+    # audio_clip = AudioArrayClip(audio, fps=audio_fps)
+    os.makedirs(osp.join(output_dir, 'frames'), exist_ok=True)
+    for num, img in enumerate(imgs):
+        if isinstance(img, np.ndarray):
+            img = Image.fromarray(img.astype(np.uint8))
+        img.save(osp.join(output_dir, 'frames', f"{num}.jpg"))
+    return
+
+def sanity_check(data: dict, save_path: str="sanity_check", batch_size: int=4, sample_rate: int=16000):
+    video_path = osp.join(save_path, 'video')
+    audio_path = osp.join(save_path, 'audio')
+    av_path    = osp.join(save_path, 'av')
+
+    video, audio, text = data['pixel_values'], data['audio'], data['text']
+    video = (video / 2 + 0.5).clamp(0, 1)
+
+    zero_rank_print(f"Saving {text} audio: {audio[0].shape} video: {video[0].shape}")
+
+    for bsz in range(batch_size):
+        os.makedirs(video_path, exist_ok=True)
+        os.makedirs(audio_path, exist_ok=True)
+        os.makedirs(av_path, exist_ok=True)
+        # save_videos_grid(video[bsz:bsz+1,...], f"{osp.join(video_path, str(bsz) + '.mp4')}")
+        bsz_audio = audio[bsz,...].permute(1, 0).cpu().numpy()
+        bsz_video = video_tensor_to_np(video[bsz, ...])
+        sf.write(f"{osp.join(audio_path, str(bsz) + '.wav')}", bsz_audio, sample_rate)
+        save_multimodal(bsz_video, bsz_audio, osp.join(av_path, str(bsz) + '.mp4'))
+
+def video_tensor_to_np(video: torch.Tensor, rescale: bool=True, scale: bool=False):
+    if scale:
+        video = (video / 2 + 0.5).clamp(0, 1)
+    # c f h w -> f h w c
+    if video.shape[0] == 3:
+        video = video.permute(1, 2, 3, 0).detach().cpu().numpy()
+    elif video.shape[1] == 3:
+        video = video.permute(0, 2, 3, 1).detach().cpu().numpy()
+    if rescale:
+        video = video * 255
+    return video
+
+def composite_audio_video(video: str, audio: str, path:str, video_fps:int=7, audio_sample_rate:int=16000):
+    video = decord.VideoReader(video)
+    audio = decord.AudioReader(audio, sample_rate=audio_sample_rate)
+    audio = audio.get_batch(get_full_indices(audio)).asnumpy()
+    video = video.get_batch(get_full_indices(video)).asnumpy()
+    save_multimodal(video, audio, path, audio_fps=audio_sample_rate, video_fps=video_fps)
+    return
+
+# for video pipeline
+def append_dims(x, target_dims):
+    """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+    dims_to_append = target_dims - x.ndim
+    if dims_to_append < 0:
+        raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+    return x[(...,) + (None,) * dims_to_append]
+
+def resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
+    h, w = input.shape[-2:]
+    factors = (h / size[0], w / size[1])
+
+    # First, we have to determine sigma
+    # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
+    sigmas = (
+        max((factors[0] - 1.0) / 2.0, 0.001),
+        max((factors[1] - 1.0) / 2.0, 0.001),
+    )
+
+    # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
+    # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
+    # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
+    ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
+
+    # Make sure it is odd
+    if (ks[0] % 2) == 0:
+        ks = ks[0] + 1, ks[1]
+
+    if (ks[1] % 2) == 0:
+        ks = ks[0], ks[1] + 1
+
+    input = _gaussian_blur2d(input, ks, sigmas)
+
+    output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
+    return output
+
+def _gaussian_blur2d(input, kernel_size, sigma):
+    if isinstance(sigma, tuple):
+        sigma = torch.tensor([sigma], dtype=input.dtype)
+    else:
+        sigma = sigma.to(dtype=input.dtype)
+
+    ky, kx = int(kernel_size[0]), int(kernel_size[1])
+    bs = sigma.shape[0]
+    kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
+    kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
+    out_x = _filter2d(input, kernel_x[..., None, :])
+    out = _filter2d(out_x, kernel_y[..., None])
+
+    return out
+
+def _filter2d(input, kernel):
+    # prepare kernel
+    b, c, h, w = input.shape
+    tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
+
+    tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
+
+    height, width = tmp_kernel.shape[-2:]
+
+    padding_shape: list[int] = _compute_padding([height, width])
+    input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
+
+    # kernel and input tensor reshape to align element-wise or batch-wise params
+    tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
+    input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
+
+    # convolve the tensor with the kernel.
+    output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
+
+    out = output.view(b, c, h, w)
+    return out
+
+
+def _gaussian(window_size: int, sigma):
+    if isinstance(sigma, float):
+        sigma = torch.tensor([[sigma]])
+
+    batch_size = sigma.shape[0]
+
+    x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
+
+    if window_size % 2 == 0:
+        x = x + 0.5
+
+    gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
+
+    return gauss / gauss.sum(-1, keepdim=True)
+
+def _compute_padding(kernel_size):
+    """Compute padding tuple."""
+    # 4 or 6 ints:  (padding_left, padding_right,padding_top,padding_bottom)
+    # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
+    if len(kernel_size) < 2:
+        raise AssertionError(kernel_size)
+    computed = [k - 1 for k in kernel_size]
+
+    # for even kernels we need to do asymmetric padding :(
+    out_padding = 2 * len(kernel_size) * [0]
+
+    for i in range(len(kernel_size)):
+        computed_tmp = computed[-(i + 1)]
+
+        pad_front = computed_tmp // 2
+        pad_rear = computed_tmp - pad_front
+
+        out_padding[2 * i + 0] = pad_front
+        out_padding[2 * i + 1] = pad_rear
+
+    return out_padding
+
+def print_gpu_memory_usage(info: str, cuda_id:int=0):
+
+    print(f">>> {info} <<<")
+    reserved = torch.cuda.memory_reserved(cuda_id) / 1024 ** 3
+    used     = torch.cuda.memory_allocated(cuda_id) / 1024 ** 3
+
+    print("total: ", reserved, "G")
+    print("used: ", used, "G")
+    print("available: ", reserved - used, "G")
+
+# use for dsp mel2spec
+@dataclass(frozen=True)
+class SpectrogramParams:
+    """
+    Parameters for the conversion from audio to spectrograms to images and back.
+
+    Includes helpers to convert to and from EXIF tags, allowing these parameters to be stored
+    within spectrogram images.
+
+    To understand what these parameters do and to customize them, read `spectrogram_converter.py`
+    and the linked torchaudio documentation.
+    """
+
+    # Whether the audio is stereo or mono
+    stereo: bool = False
+
+    # FFT parameters
+    sample_rate: int = 44100
+    step_size_ms: int = 10
+    window_duration_ms: int = 100
+    padded_duration_ms: int = 400
+
+    # Mel scale parameters
+    num_frequencies: int = 200
+    # TODO(hayk): Set these to [20, 20000] for newer models
+    min_frequency: int = 0
+    max_frequency: int = 10000
+    mel_scale_norm: T.Optional[str] = None
+    mel_scale_type: str = "htk"
+    max_mel_iters: int = 200
+
+    # Griffin Lim parameters
+    num_griffin_lim_iters: int = 32
+
+    # Image parameterization
+    power_for_image: float = 0.25
+
+    class ExifTags(Enum):
+        """
+        Custom EXIF tags for the spectrogram image.
+        """
+
+        SAMPLE_RATE = 11000
+        STEREO = 11005
+        STEP_SIZE_MS = 11010
+        WINDOW_DURATION_MS = 11020
+        PADDED_DURATION_MS = 11030
+
+        NUM_FREQUENCIES = 11040
+        MIN_FREQUENCY = 11050
+        MAX_FREQUENCY = 11060
+
+        POWER_FOR_IMAGE = 11070
+        MAX_VALUE = 11080
+
+    @property
+    def n_fft(self) -> int:
+        """
+        The number of samples in each STFT window, with padding.
+        """
+        return int(self.padded_duration_ms / 1000.0 * self.sample_rate)
+
+    @property
+    def win_length(self) -> int:
+        """
+        The number of samples in each STFT window.
+        """
+        return int(self.window_duration_ms / 1000.0 * self.sample_rate)
+
+    @property
+    def hop_length(self) -> int:
+        """
+        The number of samples between each STFT window.
+        """
+        return int(self.step_size_ms / 1000.0 * self.sample_rate)
+
+    def to_exif(self) -> T.Dict[int, T.Any]:
+        """
+        Return a dictionary of EXIF tags for the current values.
+        """
+        return {
+            self.ExifTags.SAMPLE_RATE.value: self.sample_rate,
+            self.ExifTags.STEREO.value: self.stereo,
+            self.ExifTags.STEP_SIZE_MS.value: self.step_size_ms,
+            self.ExifTags.WINDOW_DURATION_MS.value: self.window_duration_ms,
+            self.ExifTags.PADDED_DURATION_MS.value: self.padded_duration_ms,
+            self.ExifTags.NUM_FREQUENCIES.value: self.num_frequencies,
+            self.ExifTags.MIN_FREQUENCY.value: self.min_frequency,
+            self.ExifTags.MAX_FREQUENCY.value: self.max_frequency,
+            self.ExifTags.POWER_FOR_IMAGE.value: float(self.power_for_image),
+        } 
+
+class SpectrogramImageConverter:
+    """
+    Convert between spectrogram images and audio segments.
+
+    This is a wrapper around SpectrogramConverter that additionally converts from spectrograms
+    to images and back. The real audio processing lives in SpectrogramConverter.
+    """
+
+    def __init__(self, params: SpectrogramParams, device: str = "cuda"):
+        self.p = params
+        self.device = device
+        self.converter = SpectrogramConverter(params=params, device=device)
+
+    def spectrogram_image_from_audio(
+        self,
+        segment: pydub.AudioSegment,
+    ) -> Image.Image:
+        """
+        Compute a spectrogram image from an audio segment.
+
+        Args:
+            segment: Audio segment to convert
+
+        Returns:
+            Spectrogram image (in pillow format)
+        """
+        assert int(segment.frame_rate) == self.p.sample_rate, "Sample rate mismatch"
+
+        if self.p.stereo:
+            if segment.channels == 1:
+                print("WARNING: Mono audio but stereo=True, cloning channel")
+                segment = segment.set_channels(2)
+            elif segment.channels > 2:
+                print("WARNING: Multi channel audio, reducing to stereo")
+                segment = segment.set_channels(2)
+        else:
+            if segment.channels > 1:
+                print("WARNING: Stereo audio but stereo=False, setting to mono")
+                segment = segment.set_channels(1)
+
+        spectrogram = self.converter.spectrogram_from_audio(segment)
+
+        image = image_from_spectrogram(
+            spectrogram,
+            power=self.p.power_for_image,
+        )
+
+        # Store conversion params in exif metadata of the image
+        exif_data = self.p.to_exif()
+        exif_data[SpectrogramParams.ExifTags.MAX_VALUE.value] = float(np.max(spectrogram))
+        exif = image.getexif()
+        exif.update(exif_data.items())
+
+        return image
+
+    def audio_from_spectrogram_image(
+        self,
+        image: Image.Image,
+        apply_filters: bool = True,
+        max_value: float = 30e6,
+    ) -> pydub.AudioSegment:
+        """
+        Reconstruct an audio segment from a spectrogram image.
+
+        Args:
+            image: Spectrogram image (in pillow format)
+            apply_filters: Apply post-processing to improve the reconstructed audio
+            max_value: Scaled max amplitude of the spectrogram. Shouldn't matter.
+        """
+        spectrogram = spectrogram_from_image(
+            image,
+            max_value=max_value,
+            power=self.p.power_for_image,
+            stereo=self.p.stereo,
+        )
+
+        segment = self.converter.audio_from_spectrogram(
+            spectrogram,
+            apply_filters=apply_filters,
+        )
+
+        return segment
+    
+def image_from_spectrogram(spectrogram: np.ndarray, power: float = 0.25) -> Image.Image:
+    """
+    Compute a spectrogram image from a spectrogram magnitude array.
+
+    This is the inverse of spectrogram_from_image, except for discretization error from
+    quantizing to uint8.
+
+    Args:
+        spectrogram: (channels, frequency, time)
+        power: A power curve to apply to the spectrogram to preserve contrast
+
+    Returns:
+        image: (frequency, time, channels)
+    """
+    # Rescale to 0-1
+    max_value = np.max(spectrogram)
+    data = spectrogram / max_value
+
+    # Apply the power curve
+    data = np.power(data, power)
+
+    # Rescale to 0-255
+    data = data * 255
+
+    # Invert
+    data = 255 - data
+
+    # Convert to uint8
+    data = data.astype(np.uint8)
+
+    # Munge channels into a PIL image
+    if data.shape[0] == 1:
+        # TODO(hayk): Do we want to write single channel to disk instead?
+        image = Image.fromarray(data[0], mode="L").convert("RGB")
+    elif data.shape[0] == 2:
+        data = np.array([np.zeros_like(data[0]), data[0], data[1]]).transpose(1, 2, 0)
+        image = Image.fromarray(data, mode="RGB")
+    else:
+        raise NotImplementedError(f"Unsupported number of channels: {data.shape[0]}")
+
+    # Flip Y
+    image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
+
+    return image
+
+
+def spectrogram_from_image(
+    image: Image.Image,
+    power: float = 0.25,
+    stereo: bool = False,
+    max_value: float = 30e6,
+) -> np.ndarray:
+    """
+    Compute a spectrogram magnitude array from a spectrogram image.
+
+    This is the inverse of image_from_spectrogram, except for discretization error from
+    quantizing to uint8.
+
+    Args:
+        image: (frequency, time, channels)
+        power: The power curve applied to the spectrogram
+        stereo: Whether the spectrogram encodes stereo data
+        max_value: The max value of the original spectrogram. In practice doesn't matter.
+
+    Returns:
+        spectrogram: (channels, frequency, time)
+    """
+    # Convert to RGB if single channel
+    if image.mode in ("P", "L"):
+        image = image.convert("RGB")
+
+    # Flip Y
+    image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
+
+    # Munge channels into a numpy array of (channels, frequency, time)
+    data = np.array(image).transpose(2, 0, 1)
+    if stereo:
+        # Take the G and B channels as done in image_from_spectrogram
+        data = data[[1, 2], :, :]
+    else:
+        data = data[0:1, :, :]
+
+    # Convert to floats
+    data = data.astype(np.float32)
+
+    # Invert
+    data = 255 - data
+
+    # Rescale to 0-1
+    data = data / 255
+
+    # Reverse the power curve
+    data = np.power(data, 1 / power)
+
+    # Rescale to max value
+    data = data * max_value
+
+    return data
+
+class SpectrogramConverter:
+    """
+    Convert between audio segments and spectrogram tensors using torchaudio.
+
+    In this class a "spectrogram" is defined as a (batch, time, frequency) tensor with float values
+    that represent the amplitude of the frequency at that time bucket (in the frequency domain).
+    Frequencies are given in the perceptul Mel scale defined by the params. A more specific term
+    used in some functions is "mel amplitudes".
+
+    The spectrogram computed from `spectrogram_from_audio` is complex valued, but it only
+    returns the amplitude, because the phase is chaotic and hard to learn. The function
+    `audio_from_spectrogram` is an approximate inverse of `spectrogram_from_audio`, which
+    approximates the phase information using the Griffin-Lim algorithm.
+
+    Each channel in the audio is treated independently, and the spectrogram has a batch dimension
+    equal to the number of channels in the input audio segment.
+
+    Both the Griffin Lim algorithm and the Mel scaling process are lossy.
+
+    For more information, see https://pytorch.org/audio/stable/transforms.html
+    """
+
+    def __init__(self, params: SpectrogramParams, device: str = "cuda"):
+        self.p = params
+
+        self.device = check_device(device)
+
+        if device.lower().startswith("mps"):
+            warnings.warn(
+                "WARNING: MPS does not support audio operations, falling back to CPU for them",
+                stacklevel=2,
+            )
+            self.device = "cpu"
+
+        # https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html
+        self.spectrogram_func = torchaudio.transforms.Spectrogram(
+            n_fft=params.n_fft,
+            hop_length=params.hop_length,
+            win_length=params.win_length,
+            pad=0,
+            window_fn=torch.hann_window,
+            power=None,
+            normalized=False,
+            wkwargs=None,
+            center=True,
+            pad_mode="reflect",
+            onesided=True,
+        ).to(self.device)
+
+        # https://pytorch.org/audio/stable/generated/torchaudio.transforms.GriffinLim.html
+        self.inverse_spectrogram_func = torchaudio.transforms.GriffinLim(
+            n_fft=params.n_fft,
+            n_iter=params.num_griffin_lim_iters,
+            win_length=params.win_length,
+            hop_length=params.hop_length,
+            window_fn=torch.hann_window,
+            power=1.0,
+            wkwargs=None,
+            momentum=0.99,
+            length=None,
+            rand_init=True,
+        ).to(self.device)
+
+        # https://pytorch.org/audio/stable/generated/torchaudio.transforms.MelScale.html
+        self.mel_scaler = torchaudio.transforms.MelScale(
+            n_mels=params.num_frequencies,
+            sample_rate=params.sample_rate,
+            f_min=params.min_frequency,
+            f_max=params.max_frequency,
+            n_stft=params.n_fft // 2 + 1,
+            norm=params.mel_scale_norm,
+            mel_scale=params.mel_scale_type,
+        ).to(self.device)
+
+        # https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseMelScale.html
+        self.inverse_mel_scaler = torchaudio.transforms.InverseMelScale(
+            n_stft=params.n_fft // 2 + 1,
+            n_mels=params.num_frequencies,
+            sample_rate=params.sample_rate,
+            f_min=params.min_frequency,
+            f_max=params.max_frequency,
+            # max_iter=params.max_mel_iters, # for higher verson of torchaudio
+            # tolerance_loss=1e-5, # for higher verson of torchaudio
+            # tolerance_change=1e-8, # for higher verson of torchaudio
+            # sgdargs=None, # for higher verson of torchaudio
+            norm=params.mel_scale_norm,
+            mel_scale=params.mel_scale_type,
+        ).to(self.device)
+
+    def spectrogram_from_audio(
+        self,
+        audio: pydub.AudioSegment,
+    ) -> np.ndarray:
+        """
+        Compute a spectrogram from an audio segment.
+
+        Args:
+            audio: Audio segment which must match the sample rate of the params
+
+        Returns:
+            spectrogram: (channel, frequency, time)
+        """
+        assert int(audio.frame_rate) == self.p.sample_rate, "Audio sample rate must match params"
+
+        # Get the samples as a numpy array in (batch, samples) shape
+        waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()])
+
+        # Convert to floats if necessary
+        if waveform.dtype != np.float32:
+            waveform = waveform.astype(np.float32)
+
+        waveform_tensor = torch.from_numpy(waveform).to(self.device)
+        amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor)
+        return amplitudes_mel.cpu().numpy()
+
+    def audio_from_spectrogram(
+        self,
+        spectrogram: np.ndarray,
+        apply_filters: bool = True,
+    ) -> pydub.AudioSegment:
+        """
+        Reconstruct an audio segment from a spectrogram.
+
+        Args:
+            spectrogram: (batch, frequency, time)
+            apply_filters: Post-process with normalization and compression
+
+        Returns:
+            audio: Audio segment with channels equal to the batch dimension
+        """
+        # Move to device
+        amplitudes_mel = torch.from_numpy(spectrogram).to(self.device)
+
+        # Reconstruct the waveform
+        waveform = self.waveform_from_mel_amplitudes(amplitudes_mel)
+
+        # Convert to audio segment
+        segment = audio_from_waveform(
+            samples=waveform.cpu().numpy(),
+            sample_rate=self.p.sample_rate,
+            # Normalize the waveform to the range [-1, 1]
+            normalize=True,
+        )
+
+        # Optionally apply post-processing filters
+        if apply_filters:
+            segment = apply_filters_func(
+                segment,
+                compression=False,
+            )
+
+        return segment
+
+    def mel_amplitudes_from_waveform(
+        self,
+        waveform: torch.Tensor,
+    ) -> torch.Tensor:
+        """
+        Torch-only function to compute Mel-scale amplitudes from a waveform.
+
+        Args:
+            waveform: (batch, samples)
+
+        Returns:
+            amplitudes_mel: (batch, frequency, time)
+        """
+        # Compute the complex-valued spectrogram
+        spectrogram_complex = self.spectrogram_func(waveform)
+
+        # Take the magnitude
+        amplitudes = torch.abs(spectrogram_complex)
+
+        # Convert to mel scale
+        return self.mel_scaler(amplitudes)
+
+    def waveform_from_mel_amplitudes(
+        self,
+        amplitudes_mel: torch.Tensor,
+    ) -> torch.Tensor:
+        """
+        Torch-only function to approximately reconstruct a waveform from Mel-scale amplitudes.
+
+        Args:
+            amplitudes_mel: (batch, frequency, time)
+
+        Returns:
+            waveform: (batch, samples)
+        """
+        # Convert from mel scale to linear
+        amplitudes_linear = self.inverse_mel_scaler(amplitudes_mel)
+
+        # Run the approximate algorithm to compute the phase and recover the waveform
+        return self.inverse_spectrogram_func(amplitudes_linear)
+    
+def check_device(device: str, backup: str = "cpu") -> str:
+    """
+    Check that the device is valid and available. If not,
+    """
+    cuda_not_found = device.lower().startswith("cuda") and not torch.cuda.is_available()
+    mps_not_found = device.lower().startswith("mps") and not torch.backends.mps.is_available()
+
+    if cuda_not_found or mps_not_found:
+        warnings.warn(f"WARNING: {device} is not available, using {backup} instead.", stacklevel=3)
+        return backup
+
+    return device
+
+def audio_from_waveform(
+    samples: np.ndarray, sample_rate: int, normalize: bool = False
+) -> pydub.AudioSegment:
+    """
+    Convert a numpy array of samples of a waveform to an audio segment.
+
+    Args:
+        samples: (channels, samples) array
+    """
+    # Normalize volume to fit in int16
+    if normalize:
+        samples *= np.iinfo(np.int16).max / np.max(np.abs(samples))
+
+    # Transpose and convert to int16
+    samples = samples.transpose(1, 0)
+    samples = samples.astype(np.int16)
+
+    # Write to the bytes of a WAV file
+    wav_bytes = io.BytesIO()
+    wavfile.write(wav_bytes, sample_rate, samples)
+    wav_bytes.seek(0)
+
+    # Read into pydub
+    return pydub.AudioSegment.from_wav(wav_bytes)
+
+
+def apply_filters_func(segment: pydub.AudioSegment, compression: bool = False) -> pydub.AudioSegment:
+    """
+    Apply post-processing filters to the audio segment to compress it and
+    keep at a -10 dBFS level.
+    """
+    # TODO(hayk): Come up with a principled strategy for these filters and experiment end-to-end.
+    # TODO(hayk): Is this going to make audio unbalanced between sequential clips?
+
+    if compression:
+        segment = pydub.effects.normalize(
+            segment,
+            headroom=0.1,
+        )
+
+        segment = segment.apply_gain(-10 - segment.dBFS)
+
+        # TODO(hayk): This is quite slow, ~1.7 seconds on a beefy CPU
+        segment = pydub.effects.compress_dynamic_range(
+            segment,
+            threshold=-20.0,
+            ratio=4.0,
+            attack=5.0,
+            release=50.0,
+        )
+
+    desired_db = -12
+    segment = segment.apply_gain(desired_db - segment.dBFS)
+
+    segment = pydub.effects.normalize(
+        segment,
+        headroom=0.1,
+    )
+
+    return segment
+
+def shave_segments(path, n_shave_prefix_segments=1):
+    """
+    Removes segments. Positive values shave the first segments, negative shave the last segments.
+    """
+    if n_shave_prefix_segments >= 0:
+        return ".".join(path.split(".")[n_shave_prefix_segments:])
+    else:
+        return ".".join(path.split(".")[:n_shave_prefix_segments])
+
+
+def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
+    """
+    Updates paths inside resnets to the new naming scheme (local renaming)
+    """
+    mapping = []
+    for old_item in old_list:
+        new_item = old_item.replace("in_layers.0", "norm1")
+        new_item = new_item.replace("in_layers.2", "conv1")
+
+        new_item = new_item.replace("out_layers.0", "norm2")
+        new_item = new_item.replace("out_layers.3", "conv2")
+
+        new_item = new_item.replace("emb_layers.1", "time_emb_proj")
+        new_item = new_item.replace("skip_connection", "conv_shortcut")
+
+        new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+        mapping.append({"old": old_item, "new": new_item})
+
+    return mapping
+
+
+def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
+    """
+    Updates paths inside resnets to the new naming scheme (local renaming)
+    """
+    mapping = []
+    for old_item in old_list:
+        new_item = old_item
+
+        new_item = new_item.replace("nin_shortcut", "conv_shortcut")
+        new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+        mapping.append({"old": old_item, "new": new_item})
+
+    return mapping
+
+
+def renew_attention_paths(old_list, n_shave_prefix_segments=0):
+    """
+    Updates paths inside attentions to the new naming scheme (local renaming)
+    """
+    mapping = []
+    for old_item in old_list:
+        new_item = old_item
+
+        #         new_item = new_item.replace('norm.weight', 'group_norm.weight')
+        #         new_item = new_item.replace('norm.bias', 'group_norm.bias')
+
+        #         new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
+        #         new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
+
+        #         new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+        mapping.append({"old": old_item, "new": new_item})
+
+    return mapping
+
+
+def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
+    """
+    Updates paths inside attentions to the new naming scheme (local renaming)
+    """
+    mapping = []
+    for old_item in old_list:
+        new_item = old_item
+
+        new_item = new_item.replace("norm.weight", "group_norm.weight")
+        new_item = new_item.replace("norm.bias", "group_norm.bias")
+
+        new_item = new_item.replace("q.weight", "to_q.weight")
+        new_item = new_item.replace("q.bias", "to_q.bias")
+
+        new_item = new_item.replace("k.weight", "to_k.weight")
+        new_item = new_item.replace("k.bias", "to_k.bias")
+
+        new_item = new_item.replace("v.weight", "to_v.weight")
+        new_item = new_item.replace("v.bias", "to_v.bias")
+
+        new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
+        new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
+
+        new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+        mapping.append({"old": old_item, "new": new_item})
+    return mapping
+
+
+def assign_to_checkpoint(
+    paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
+):
+    """
+    This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
+    attention layers, and takes into account additional replacements that may arise.
+
+    Assigns the weights to the new checkpoint.
+    """
+    assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
+
+    # Splits the attention layers into three variables.
+    if attention_paths_to_split is not None:
+        for path, path_map in attention_paths_to_split.items():
+            old_tensor = old_checkpoint[path]
+            channels = old_tensor.shape[0] // 3
+
+            target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
+
+            num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
+
+            old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
+            query, key, value = old_tensor.split(channels // num_heads, dim=1)
+
+            checkpoint[path_map["query"]] = query.reshape(target_shape)
+            checkpoint[path_map["key"]] = key.reshape(target_shape)
+            checkpoint[path_map["value"]] = value.reshape(target_shape)
+
+    for path in paths:
+        new_path = path["new"]
+
+        # These have already been assigned
+        if attention_paths_to_split is not None and new_path in attention_paths_to_split:
+            continue
+
+        # Global renaming happens here
+        new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
+        new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
+        new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
+
+        if additional_replacements is not None:
+            for replacement in additional_replacements:
+                new_path = new_path.replace(replacement["old"], replacement["new"])
+
+        # proj_attn.weight has to be converted from conv 1D to linear
+        if "proj_attn.weight" in new_path:
+            checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
+        elif 'to_out.0.weight' in new_path:
+            checkpoint[new_path] = old_checkpoint[path['old']].squeeze()
+        elif any([qkv in new_path for qkv in ['to_q', 'to_k', 'to_v']]):
+            checkpoint[new_path] = old_checkpoint[path['old']].squeeze()
+        else:
+            checkpoint[new_path] = old_checkpoint[path["old"]]
+
+
+def conv_attn_to_linear(checkpoint):
+    keys = list(checkpoint.keys())
+    attn_keys = ["query.weight", "key.weight", "value.weight"]
+    for key in keys:
+        if ".".join(key.split(".")[-2:]) in attn_keys:
+            if checkpoint[key].ndim > 2:
+                checkpoint[key] = checkpoint[key][:, :, 0, 0]
+        elif "proj_attn.weight" in key:
+            if checkpoint[key].ndim > 2:
+                checkpoint[key] = checkpoint[key][:, :, 0]
+
+
+def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
+    """
+    Creates a config for the diffusers based on the config of the LDM model.
+    """
+    if controlnet:
+        unet_params = original_config.model.params.control_stage_config.params
+    else:
+        unet_params = original_config.model.params.unet_config.params
+
+    vae_params = original_config.model.params.first_stage_config.params.ddconfig
+
+    block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
+
+    down_block_types = []
+    resolution = 1
+    for i in range(len(block_out_channels)):
+        block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
+        down_block_types.append(block_type)
+        if i != len(block_out_channels) - 1:
+            resolution *= 2
+
+    up_block_types = []
+    for i in range(len(block_out_channels)):
+        block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
+        up_block_types.append(block_type)
+        resolution //= 2
+
+    vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
+
+    head_dim = unet_params.num_heads if "num_heads" in unet_params else None
+    use_linear_projection = (
+        unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
+    )
+    if use_linear_projection:
+        # stable diffusion 2-base-512 and 2-768
+        if head_dim is None:
+            head_dim = [5, 10, 20, 20]
+
+    class_embed_type = None
+    projection_class_embeddings_input_dim = None
+
+    if "num_classes" in unet_params:
+        if unet_params.num_classes == "sequential":
+            class_embed_type = "projection"
+            assert "adm_in_channels" in unet_params
+            projection_class_embeddings_input_dim = unet_params.adm_in_channels
+        else:
+            raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
+
+    config = {
+        "sample_size": image_size // vae_scale_factor,
+        "in_channels": unet_params.in_channels,
+        "down_block_types": tuple(down_block_types),
+        "block_out_channels": tuple(block_out_channels),
+        "layers_per_block": unet_params.num_res_blocks,
+        "cross_attention_dim": unet_params.context_dim,
+        "attention_head_dim": head_dim,
+        "use_linear_projection": use_linear_projection,
+        "class_embed_type": class_embed_type,
+        "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
+    }
+
+    if not controlnet:
+        config["out_channels"] = unet_params.out_channels
+        config["up_block_types"] = tuple(up_block_types)
+
+    return config
+
+
+def create_vae_diffusers_config(original_config, image_size: int):
+    """
+    Creates a config for the diffusers based on the config of the LDM model.
+    """
+    vae_params = original_config.model.params.first_stage_config.params.ddconfig
+    _ = original_config.model.params.first_stage_config.params.embed_dim
+
+    block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
+    down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
+    up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
+
+    config = {
+        "sample_size": image_size,
+        "in_channels": vae_params.in_channels,
+        "out_channels": vae_params.out_ch,
+        "down_block_types": tuple(down_block_types),
+        "up_block_types": tuple(up_block_types),
+        "block_out_channels": tuple(block_out_channels),
+        "latent_channels": vae_params.z_channels,
+        "layers_per_block": vae_params.num_res_blocks,
+    }
+    return config
+
+
+def create_diffusers_schedular(original_config):
+    schedular = DDIMScheduler(
+        num_train_timesteps=original_config.model.params.timesteps,
+        beta_start=original_config.model.params.linear_start,
+        beta_end=original_config.model.params.linear_end,
+        beta_schedule="scaled_linear",
+    )
+    return schedular
+
+def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
+    """
+    Takes a state dict and a config, and returns a converted checkpoint.
+    """
+
+    # extract state_dict for UNet
+    unet_state_dict = {}
+    keys = list(checkpoint.keys())
+
+    if controlnet:
+        unet_key = "control_model."
+    else:
+        unet_key = "model.diffusion_model."
+
+    # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
+    if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
+        print(f"Checkpoint {path} has both EMA and non-EMA weights.")
+        print(
+            "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
+            " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
+        )
+        for key in keys:
+            if key.startswith("model.diffusion_model"):
+                flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
+                unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
+    else:
+        if sum(k.startswith("model_ema") for k in keys) > 100:
+            print(
+                "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
+                " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
+            )
+
+        for key in keys:
+            if key.startswith(unet_key):
+                unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
+
+    new_checkpoint = {}
+
+    new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+    new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+    new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+    new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+    if config["class_embed_type"] is None:
+        # No parameters to port
+        ...
+    elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
+        new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
+        new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
+        new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
+        new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
+    else:
+        raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
+
+    new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+    new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+    if not controlnet:
+        new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+        new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+        new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+        new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+    # Retrieves the keys for the input blocks only
+    num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+    input_blocks = {
+        layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
+        for layer_id in range(num_input_blocks)
+    }
+
+    # Retrieves the keys for the middle blocks only
+    num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+    middle_blocks = {
+        layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
+        for layer_id in range(num_middle_blocks)
+    }
+
+    # Retrieves the keys for the output blocks only
+    num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+    output_blocks = {
+        layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
+        for layer_id in range(num_output_blocks)
+    }
+
+    for i in range(1, num_input_blocks):
+        block_id = (i - 1) // (config["layers_per_block"] + 1)
+        layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+        resnets = [
+            key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
+        ]
+        attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+
+        if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+            new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+                f"input_blocks.{i}.0.op.weight"
+            )
+            new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+                f"input_blocks.{i}.0.op.bias"
+            )
+
+        paths = renew_resnet_paths(resnets)
+        meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
+        assign_to_checkpoint(
+            paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+        )
+
+        if len(attentions):
+            paths = renew_attention_paths(attentions)
+            meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
+            assign_to_checkpoint(
+                paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+            )
+
+    resnet_0 = middle_blocks[0]
+    attentions = middle_blocks[1]
+    resnet_1 = middle_blocks[2]
+
+    resnet_0_paths = renew_resnet_paths(resnet_0)
+    assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
+
+    resnet_1_paths = renew_resnet_paths(resnet_1)
+    assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
+
+    attentions_paths = renew_attention_paths(attentions)
+    meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+    assign_to_checkpoint(
+        attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+    )
+
+    for i in range(num_output_blocks):
+        block_id = i // (config["layers_per_block"] + 1)
+        layer_in_block_id = i % (config["layers_per_block"] + 1)
+        output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+        output_block_list = {}
+
+        for layer in output_block_layers:
+            layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+            if layer_id in output_block_list:
+                output_block_list[layer_id].append(layer_name)
+            else:
+                output_block_list[layer_id] = [layer_name]
+
+        if len(output_block_list) > 1:
+            resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
+            attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
+
+            resnet_0_paths = renew_resnet_paths(resnets)
+            paths = renew_resnet_paths(resnets)
+
+            meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
+            assign_to_checkpoint(
+                paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+            )
+
+            output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
+            if ["conv.bias", "conv.weight"] in output_block_list.values():
+                index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+                new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+                    f"output_blocks.{i}.{index}.conv.weight"
+                ]
+                new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+                    f"output_blocks.{i}.{index}.conv.bias"
+                ]
+
+                # Clear attentions as they have been attributed above.
+                if len(attentions) == 2:
+                    attentions = []
+
+            if len(attentions):
+                paths = renew_attention_paths(attentions)
+                meta_path = {
+                    "old": f"output_blocks.{i}.1",
+                    "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
+                }
+                assign_to_checkpoint(
+                    paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+                )
+        else:
+            resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
+            for path in resnet_0_paths:
+                old_path = ".".join(["output_blocks", str(i), path["old"]])
+                new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
+
+                new_checkpoint[new_path] = unet_state_dict[old_path]
+
+    if controlnet:
+        # conditioning embedding
+
+        orig_index = 0
+
+        new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
+            f"input_hint_block.{orig_index}.weight"
+        )
+        new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
+            f"input_hint_block.{orig_index}.bias"
+        )
+
+        orig_index += 2
+
+        diffusers_index = 0
+
+        while diffusers_index < 6:
+            new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
+                f"input_hint_block.{orig_index}.weight"
+            )
+            new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
+                f"input_hint_block.{orig_index}.bias"
+            )
+            diffusers_index += 1
+            orig_index += 2
+
+        new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
+            f"input_hint_block.{orig_index}.weight"
+        )
+        new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
+            f"input_hint_block.{orig_index}.bias"
+        )
+
+        # down blocks
+        for i in range(num_input_blocks):
+            new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
+            new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
+
+        # mid block
+        new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
+        new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
+
+    return new_checkpoint
+
+
+def convert_ldm_vae_checkpoint(checkpoint, config, only_decoder=False, only_encoder=False):
+    # extract state dict for VAE
+    vae_state_dict = {}
+    vae_key = "first_stage_model."
+    keys = list(checkpoint.keys())
+    for key in keys:
+        if key.startswith(vae_key):
+            vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
+
+    new_checkpoint = {}
+
+    new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
+    new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
+    new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
+    new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
+    new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
+    new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
+
+    new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
+    new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
+    new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
+    new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
+    new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
+    new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
+
+    new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
+    new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
+    new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
+    new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
+
+    # Retrieves the keys for the encoder down blocks only
+    num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
+    down_blocks = {
+        layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+    }
+
+    # Retrieves the keys for the decoder up blocks only
+    num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
+    up_blocks = {
+        layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
+    }
+
+    for i in range(num_down_blocks):
+        resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
+
+        if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
+            new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
+                f"encoder.down.{i}.downsample.conv.weight"
+            )
+            new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
+                f"encoder.down.{i}.downsample.conv.bias"
+            )
+
+        paths = renew_vae_resnet_paths(resnets)
+        meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
+        assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+    mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
+    num_mid_res_blocks = 2
+    for i in range(1, num_mid_res_blocks + 1):
+        resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
+
+        paths = renew_vae_resnet_paths(resnets)
+        meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+        assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+    mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
+    paths = renew_vae_attention_paths(mid_attentions)
+    meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+    conv_attn_to_linear(new_checkpoint)
+
+    for i in range(num_up_blocks):
+        block_id = num_up_blocks - 1 - i
+        resnets = [
+            key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
+        ]
+
+        if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
+            new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
+                f"decoder.up.{block_id}.upsample.conv.weight"
+            ]
+            new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
+                f"decoder.up.{block_id}.upsample.conv.bias"
+            ]
+
+        paths = renew_vae_resnet_paths(resnets)
+        meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
+        assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+    mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
+    num_mid_res_blocks = 2
+    for i in range(1, num_mid_res_blocks + 1):
+        resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
+
+        paths = renew_vae_resnet_paths(resnets)
+        meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+        assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+    mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
+    paths = renew_vae_attention_paths(mid_attentions)
+    meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+    conv_attn_to_linear(new_checkpoint)
+
+    if only_decoder:
+        new_checkpoint = {k: v for k, v in new_checkpoint.items() if k.startswith('decoder') or k.startswith('post_quant')}
+    elif only_encoder:
+        new_checkpoint = {k: v for k, v in new_checkpoint.items() if k.startswith('encoder') or k.startswith('quant')}
+
+    return new_checkpoint
+
+def convert_ldm_clip_checkpoint(checkpoint):
+    keys = list(checkpoint.keys())
+
+    text_model_dict = {}
+    for key in keys:
+        if key.startswith("cond_stage_model.transformer"):
+            text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
+
+    return text_model_dict
+
+def convert_lora_model_level(state_dict, unet, text_encoder=None, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
+    """convert lora in model level instead of pipeline leval
+    """
+
+    visited = []
+
+    # directly update weight in diffusers model
+    for key in state_dict:
+        # it is suggested to print out the key, it usually will be something like below
+        # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
+
+        # as we have set the alpha beforehand, so just skip
+        if ".alpha" in key or key in visited:
+            continue
+
+        if "text" in key:
+            layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
+            assert text_encoder is not None, (
+                'text_encoder must be passed since lora contains text encoder layers')
+            curr_layer = text_encoder
+        else:
+            layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
+            curr_layer = unet
+
+        # find the target layer
+        temp_name = layer_infos.pop(0)
+        while len(layer_infos) > -1:
+            try:
+                curr_layer = curr_layer.__getattr__(temp_name)
+                if len(layer_infos) > 0:
+                    temp_name = layer_infos.pop(0)
+                elif len(layer_infos) == 0:
+                    break
+            except Exception:
+                if len(temp_name) > 0:
+                    temp_name += "_" + layer_infos.pop(0)
+                else:
+                    temp_name = layer_infos.pop(0)
+
+        pair_keys = []
+        if "lora_down" in key:
+            pair_keys.append(key.replace("lora_down", "lora_up"))
+            pair_keys.append(key)
+        else:
+            pair_keys.append(key)
+            pair_keys.append(key.replace("lora_up", "lora_down"))
+
+        # update weight
+        # NOTE: load lycon, meybe have bugs :(
+        if 'conv_in' in pair_keys[0]:
+            weight_up = state_dict[pair_keys[0]].to(torch.float32)
+            weight_down = state_dict[pair_keys[1]].to(torch.float32)
+            weight_up = weight_up.view(weight_up.size(0), -1)
+            weight_down = weight_down.view(weight_down.size(0), -1)
+            shape = [e for e in curr_layer.weight.data.shape]
+            shape[1] = 4
+            curr_layer.weight.data[:, :4, ...] += alpha * (weight_up @ weight_down).view(*shape)
+        elif 'conv' in pair_keys[0]:
+            weight_up = state_dict[pair_keys[0]].to(torch.float32)
+            weight_down = state_dict[pair_keys[1]].to(torch.float32)
+            weight_up = weight_up.view(weight_up.size(0), -1)
+            weight_down = weight_down.view(weight_down.size(0), -1)
+            shape = [e for e in curr_layer.weight.data.shape]
+            curr_layer.weight.data += alpha * (weight_up @ weight_down).view(*shape)
+        elif len(state_dict[pair_keys[0]].shape) == 4:
+            weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
+            weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
+            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
+        else:
+            weight_up = state_dict[pair_keys[0]].to(torch.float32)
+            weight_down = state_dict[pair_keys[1]].to(torch.float32)
+            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
+
+        # update visited list
+        for item in pair_keys:
+            visited.append(item)
+
+    return unet, text_encoder
+
+def denormalize_spectrogram(
+    data: torch.Tensor,
+    max_value: float = 200, 
+    min_value: float = 1e-5, 
+    power: float = 1, 
+    inverse: bool = False,
+) -> torch.Tensor:
+    
+    max_value = np.log(max_value)
+    min_value = np.log(min_value)
+
+    # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner
+    data = torch.flip(data, [1])
+
+    assert len(data.shape) == 3, "Expected 3 dimensions, got {}".format(len(data.shape))
+    
+    if data.shape[0] == 1:
+        data = data.repeat(3, 1, 1)
+        
+    assert data.shape[0] == 3, "Expected 3 channels, got {}".format(data.shape[0])
+    data = data[0]
+
+    # Reverse the power curve
+    data = torch.pow(data, 1 / power)
+
+    # Invert
+    if inverse:
+        data = 1 - data
+
+    # Rescale to max value
+    spectrogram = data * (max_value - min_value) + min_value
+
+    return spectrogram
+
+class ToTensor1D(torchvision.transforms.ToTensor):
+
+    def __call__(self, tensor: np.ndarray):
+        tensor_2d = super(ToTensor1D, self).__call__(tensor[..., np.newaxis])
+
+        return tensor_2d.squeeze_(0)
+
+def scale(old_value, old_min, old_max, new_min, new_max):
+    old_range = (old_max - old_min)
+    new_range = (new_max - new_min)
+    new_value = (((old_value - old_min) * new_range) / old_range) + new_min
+
+    return new_value
+
+def read_frames_with_moviepy(video_path, max_frame_nums=None):
+    clip = VideoFileClip(video_path)
+    duration = clip.duration
+    frames = []
+    for frame in clip.iter_frames():
+        frames.append(frame) 
+    if max_frame_nums is not None:
+        frames_idx = np.linspace(0, len(frames) - 1, max_frame_nums, dtype=int)
+    return np.array(frames)[frames_idx,...], duration
+
+def read_frames_with_moviepy_resample(video_path, save_path):
+    vision_transform_list = [
+        transforms.Resize((128, 128)),
+        transforms.CenterCrop((112, 112)),
+        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+    ]
+    video_transform = transforms.Compose(vision_transform_list)
+    os.makedirs(save_path, exist_ok=True)
+    command = f'ffmpeg -v quiet -y -i \"{video_path}\" -f image2 -vf \"scale=-1:360,fps=15\" -qscale:v 3 \"{save_path}\"/frame%06d.jpg'
+    os.system(command)
+    frame_list = glob.glob(f'{save_path}/*.jpg')
+    frame_list.sort()
+    convert_tensor = transforms.ToTensor()
+    frame_list = [convert_tensor(np.array(Image.open(frame))) for frame in frame_list]
+    imgs = torch.stack(frame_list, dim=0)
+    imgs = video_transform(imgs)
+    imgs = imgs.permute(1, 0, 2, 3)
+    return imgs
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..80258daac773e340a742f01ce92f402f57194cbb
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,21 @@
+decord==0.6.0
+diffusers==0.20.0
+einops==0.7.0
+imageio==2.27.0
+ipdb==0.13.13
+librosa==0.9.2
+moviepy==1.0.3
+numpy==1.23.5
+omegaconf==2.3.0
+opencv_python==4.8.0.76
+Pillow==10.2.0
+pydub==0.25.1
+safetensors==0.3.3
+scipy==1.12.0
+soundfile==0.12.1
+torch==2.1.2
+torchaudio==2.1.2
+torchvision==0.16.2
+tqdm==4.65.0
+transformers==4.32.1
+xformers==0.0.23.post1