Mayuri commited on
Commit
2a5630b
·
verified ·
1 Parent(s): c566753

Upload 10 files

Browse files
Files changed (10) hide show
  1. lcm.py +117 -0
  2. main_v3.py +140 -0
  3. models.py +402 -0
  4. models/model.safetensors +3 -0
  5. models/model_org.safetensors +3 -0
  6. sar_1.png +0 -0
  7. sar_2.png +0 -0
  8. sar_3.png +0 -0
  9. sar_4.png +0 -0
  10. utils.py +347 -0
lcm.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 首先,确保安装了必要的库
2
+ # 你可以使用以下命令安装:
3
+ # pip install gradio diffusers transformers torch
4
+
5
+ import gradio as gr
6
+ from diffusers import StableDiffusionPipeline
7
+ import torch
8
+ from PIL import Image
9
+ import requests
10
+ from io import BytesIO
11
+
12
+ # 定义可用的扩散模型列表
13
+ AVAILABLE_MODELS = {
14
+ "Stable Diffusion v1.4": "CompVis/stable-diffusion-v1-4",
15
+ "Stable Diffusion v1.5": "runwayml/stable-diffusion-v1-5",
16
+ "Stable Diffusion 2.1": "stabilityai/stable-diffusion-2-1",
17
+ # 你可以根据需要添加更多模型
18
+ }
19
+
20
+ # 示例图片的URL列表
21
+ SAMPLE_IMAGES = {
22
+ "风景": "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/samples/landscape.jpg",
23
+ "人像": "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/samples/portrait.jpg",
24
+ "动物": "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/samples/animal.jpg",
25
+ }
26
+
27
+ # 使用缓存来存储已加载的模型,以避免重复加载
28
+ model_cache = {}
29
+
30
+ def load_model(model_name):
31
+ if model_name in model_cache:
32
+ return model_cache[model_name]
33
+ else:
34
+ model_id = AVAILABLE_MODELS[model_name]
35
+ pipe = StableDiffusionPipeline.from_pretrained(
36
+ model_id,
37
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
38
+ )
39
+ pipe = pipe.to("cuda") if torch.cuda.is_available() else pipe.to("cpu")
40
+ model_cache[model_name] = pipe
41
+ return pipe
42
+
43
+ def process_image(model_name, input_image, sample_choice):
44
+ # 如果用户选择使用示例图片,则下载示例图片
45
+ if sample_choice != "上传图片":
46
+ url = SAMPLE_IMAGES.get(sample_choice, SAMPLE_IMAGES["风景"])
47
+ response = requests.get(url)
48
+ input_image = Image.open(BytesIO(response.content)).convert("RGB")
49
+
50
+ # 加载所选模型
51
+ pipe = load_model(model_name)
52
+
53
+ # 生成图像(这里以文本提示为例,可以根据实际模型功能调整)
54
+ prompt = "A transformed version of the input image."
55
+
56
+ with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"):
57
+ generated_image = pipe(prompt=prompt, init_image=input_image, strength=0.8).images[0]
58
+
59
+ return input_image, generated_image
60
+
61
+ # 定义 Gradio 接口
62
+ def main():
63
+ with gr.Blocks() as demo:
64
+ gr.Markdown("# Diffusers 扩散模型展示页面")
65
+ gr.Markdown("选择一个模型,上传一张图片或选择一个示例图片,然后点击转换按钮查看结果。")
66
+
67
+ with gr.Row():
68
+ model_dropdown = gr.Dropdown(
69
+ choices=list(AVAILABLE_MODELS.keys()),
70
+ value=list(AVAILABLE_MODELS.keys())[0],
71
+ label="选择模型"
72
+ )
73
+
74
+ with gr.Row():
75
+ sample_radio = gr.Radio(
76
+ choices=["上传图片"] + list(SAMPLE_IMAGES.keys()),
77
+ value="上传图片",
78
+ label="选择图片来源"
79
+ )
80
+
81
+ with gr.Row():
82
+ input_image = gr.Image(
83
+ type="pil",
84
+ label="上传图片",
85
+ visible=False
86
+ )
87
+ sample_image = gr.Image(
88
+ type="pil",
89
+ label="示例图片",
90
+ visible=False
91
+ )
92
+
93
+ # 根据用户选择显示上传或示例图片
94
+ def toggle_image(choice):
95
+ return {
96
+ "input_image": gr.update(visible=(choice == "上传图片")),
97
+ "sample_image": gr.update(visible=(choice != "上传图片"))
98
+ }
99
+
100
+ sample_radio.change(toggle_image, inputs=sample_radio, outputs=[input_image, sample_image])
101
+
102
+ convert_button = gr.Button("转换")
103
+
104
+ with gr.Row():
105
+ original_output = gr.Image(label="原图")
106
+ generated_output = gr.Image(label="生成图")
107
+
108
+ convert_button.click(
109
+ process_image,
110
+ inputs=[model_dropdown, input_image, sample_radio],
111
+ outputs=[original_output, generated_output]
112
+ )
113
+
114
+ demo.launch(server_port=16006)
115
+
116
+ if __name__ == "__main__":
117
+ main()
main_v3.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import argparse
3
+ import os
4
+
5
+ import pandas as pd
6
+ from PIL import Image
7
+ import numpy as np
8
+ import torch as th
9
+ from torchvision import transforms
10
+
11
+ import diffusers
12
+ from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, LCMScheduler
13
+ import gc
14
+ from safetensors import safe_open
15
+
16
+ from models import SAR2OptUNetv3
17
+ from utils import update_args_from_yaml, safe_load
18
+
19
+ transform_sar = transforms.Compose([
20
+ transforms.ToTensor(),
21
+ transforms.Resize((256, 256)),
22
+ transforms.Normalize((0.5), (0.5)),
23
+ ])
24
+ AVAILABLE_MODELS = {
25
+ "Sen12:LCM-Model": "models/model.safetensors",
26
+ "Sen12:Org-Model": "models/model_org.safetensors",
27
+ }
28
+
29
+ device = th.device('cuda:0' if th.cuda.is_available() else 'cpu')
30
+
31
+ def safe_load(model_path):
32
+ assert "safetensors" in model_path
33
+ state_dict = {}
34
+ with safe_open(model_path, framework="pt", device="cpu") as f:
35
+ for k in f.keys():
36
+ state_dict[k] = f.get_tensor(k)
37
+ return state_dict
38
+
39
+ unet_model = SAR2OptUNetv3(
40
+ sample_size=256,
41
+ in_channels=4,
42
+ out_channels=3,
43
+ layers_per_block=2,
44
+ block_out_channels=(128, 128, 256, 256, 512, 512),
45
+ down_block_types=(
46
+ "DownBlock2D",
47
+ "DownBlock2D",
48
+ "DownBlock2D",
49
+ "DownBlock2D",
50
+ "AttnDownBlock2D",
51
+ "DownBlock2D",
52
+ ),
53
+ up_block_types=(
54
+ "UpBlock2D",
55
+ "AttnUpBlock2D",
56
+ "UpBlock2D",
57
+ "UpBlock2D",
58
+ "UpBlock2D",
59
+ "UpBlock2D",
60
+ ),
61
+ )
62
+
63
+ print('load unet safetensos done!')
64
+ lcm_scheduler = LCMScheduler(num_train_timesteps=1000)
65
+
66
+ unet_model.to(device)
67
+ unet_model.eval()
68
+
69
+ model_kwargs = {}
70
+
71
+
72
+ def predict(condition, nums_step, model_name):
73
+ unet_checkpoint = AVAILABLE_MODELS[model_name]
74
+ unet_model.load_state_dict(safe_load(unet_checkpoint), strict=True)
75
+ unet_model.eval().to(device)
76
+ with th.no_grad():
77
+ lcm_scheduler.set_timesteps(nums_step, device=device)
78
+ timesteps = lcm_scheduler.timesteps
79
+ pred_latent = th.randn(size=[1, 3, 256, 256], device=device)
80
+ condition = condition.convert("L")
81
+ condition = transform_sar(condition)
82
+ condition = th.unsqueeze(condition, 0)
83
+ condition = condition.to(device)
84
+ for timestep in timesteps:
85
+ latent_to_pred = th.cat((pred_latent, condition), dim=1)
86
+ model_pred = unet_model(latent_to_pred, timestep)
87
+ pred_latent, denoised = lcm_scheduler.step(
88
+ model_output=model_pred,
89
+ timestep=timestep,
90
+ sample=pred_latent,
91
+ return_dict=False)
92
+ sample = denoised.cpu()
93
+
94
+ sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
95
+ sample = sample.permute(0, 2, 3, 1)
96
+ sample = sample.contiguous()
97
+ sample = sample.cpu().numpy()
98
+ sample = sample.squeeze(0)
99
+ sample = Image.fromarray(sample)
100
+ return sample
101
+
102
+
103
+ demo = gr.Interface(
104
+ fn=predict,
105
+ inputs=[gr.Image(type="pil"),
106
+ gr.Slider(1, 1000),
107
+ gr.Dropdown(
108
+ choices=list(AVAILABLE_MODELS.keys()),
109
+ value=list(AVAILABLE_MODELS.keys())[0],
110
+ label="Choose the Model"),],
111
+ # gr.Radio(["Sent", "GF3"], label="Model", info="Which model to you want to use?"), ],
112
+ outputs=gr.Image(type="pil"),
113
+ examples=[
114
+ [os.path.join(os.path.dirname(__file__), "sar_1.png"), 8, "Sen12:LCM-Model"],
115
+ [os.path.join(os.path.dirname(__file__), "sar_2.png"), 16, "Sen12:LCM-Model"],
116
+ [os.path.join(os.path.dirname(__file__), "sar_3.png"), 500, "Sen12:Org-Model"],
117
+ [os.path.join(os.path.dirname(__file__), "sar_4.png"), 1000, "Sen12:Org-Model"],
118
+ ],
119
+ title="SAR to Optical Image🚀",
120
+ description="""
121
+ # 🎯 Instruction
122
+ This is a project that converts SAR images into optical images, based on conditional diffusion.
123
+
124
+ Input a SAR image, and its corresponding optical image will be obtained.
125
+
126
+ ## 📢 Inputs
127
+ - `condition`: the SAR image that you want to transfer.
128
+ - `timestep_respacing`: the number of iteration steps when inference.
129
+
130
+ ## 🎉 Outputs
131
+ - The corresponding optical image.
132
+
133
+ **Paper** : [Guided Diffusion for Image Generation](https://arxiv.org/abs/2105.05233)
134
+
135
+ **Github** : https://github.com/Coordi777/Conditional_SAR2OPT
136
+ """
137
+ )
138
+
139
+ if __name__ == "__main__":
140
+ demo.launch(server_port=16006)
models.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionPipeline
2
+ from diffusers import AutoencoderKL, UNet2DConditionModel, UNet2DModel
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import os
7
+ import json
8
+
9
+
10
+ class SAR2OptUNet(UNet2DConditionModel):
11
+
12
+ def forward(self, sample, timestep, encoder_hidden_states, timestep_cond, cross_attention_kwargs,
13
+ added_cond_kwargs):
14
+ default_overall_up_factor = 2 ** self.num_upsamplers
15
+ forward_upsample_size = False
16
+ upsample_size = None
17
+
18
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
19
+ forward_upsample_size = True
20
+
21
+ timesteps = timestep
22
+ if not torch.is_tensor(timesteps):
23
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
24
+ # This would be a good case for the `match` statement (Python 3.10+)
25
+ is_mps = sample.device.type == "mps"
26
+ if isinstance(timestep, float):
27
+ dtype = torch.float32 if is_mps else torch.float64
28
+ else:
29
+ dtype = torch.int32 if is_mps else torch.int64
30
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
31
+ elif len(timesteps.shape) == 0:
32
+ timesteps = timesteps[None].to(sample.device)
33
+
34
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
35
+ timesteps = timesteps.expand(sample.shape[0])
36
+
37
+ t_emb = self.time_proj(timesteps)
38
+ t_emb = t_emb.to(dtype=sample.dtype)
39
+
40
+ emb = self.time_embedding(t_emb, timestep_cond)
41
+ aug_emb = None
42
+
43
+ if added_cond_kwargs is not None:
44
+ if 'sar' in added_cond_kwargs:
45
+ image_embs = added_cond_kwargs.get("image_embeds")
46
+ aug_emb = self.add_embedding(image_embs)
47
+ else:
48
+ raise ValueError(
49
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
50
+ )
51
+
52
+ emb = emb + aug_emb if aug_emb is not None else emb
53
+ if self.time_embed_act is not None:
54
+ emb = self.time_embed_act(emb)
55
+ # 2. pre-process
56
+ sample = self.conv_in(sample)
57
+
58
+ # 3. down
59
+ down_block_res_samples = (sample,)
60
+
61
+ for downsample_block in self.down_blocks:
62
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
63
+ sample, res_samples = downsample_block(
64
+ hidden_states=sample,
65
+ temb=emb,
66
+ encoder_hidden_states=encoder_hidden_states,
67
+ attention_mask=None,
68
+ cross_attention_kwargs=cross_attention_kwargs,
69
+ encoder_attention_mask=None,
70
+ )
71
+ else:
72
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
73
+
74
+ down_block_res_samples += res_samples
75
+
76
+ # 4. mid
77
+ if self.mid_block is not None:
78
+ sample = self.mid_block(
79
+ sample,
80
+ emb,
81
+ encoder_hidden_states=encoder_hidden_states,
82
+ attention_mask=None,
83
+ cross_attention_kwargs=cross_attention_kwargs,
84
+ encoder_attention_mask=None,
85
+ )
86
+
87
+ # 5. up
88
+ for i, upsample_block in enumerate(self.up_blocks):
89
+ is_final_block = i == len(self.up_blocks) - 1
90
+
91
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
92
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
93
+
94
+ # if we have not reached the final block and need to forward the
95
+ # upsample size, we do it here
96
+ if not is_final_block and forward_upsample_size:
97
+ upsample_size = down_block_res_samples[-1].shape[2:]
98
+
99
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
100
+ sample = upsample_block(
101
+ hidden_states=sample,
102
+ temb=emb,
103
+ res_hidden_states_tuple=res_samples,
104
+ encoder_hidden_states=encoder_hidden_states,
105
+ cross_attention_kwargs=cross_attention_kwargs,
106
+ upsample_size=upsample_size,
107
+ attention_mask=None,
108
+ encoder_attention_mask=None,
109
+ )
110
+ else:
111
+ sample = upsample_block(
112
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
113
+ )
114
+
115
+ # 6. post-process
116
+ if self.conv_norm_out:
117
+ sample = self.conv_norm_out(sample)
118
+ sample = self.conv_act(sample)
119
+ sample = self.conv_out(sample)
120
+
121
+ return sample
122
+
123
+ class SAREncoder(nn.Module):
124
+ def __init__(self,in_channels,ngf=50):
125
+ super(SAREncoder, self).__init__()
126
+ self.ngf = ngf
127
+ self.encoder = nn.Sequential(
128
+ # Encoder 1
129
+ nn.Conv2d(in_channels=in_channels, out_channels=self.ngf, kernel_size=3, stride=1, padding=1),
130
+ nn.BatchNorm2d(self.ngf),
131
+ nn.LeakyReLU(0.2, inplace=True),
132
+
133
+ # Encoder 2
134
+ nn.Conv2d(in_channels=self.ngf, out_channels=self.ngf * 2, kernel_size=3, stride=2, padding=1),# half
135
+ nn.BatchNorm2d(self.ngf * 2),
136
+ nn.LeakyReLU(0.2, inplace=True),
137
+
138
+ # Encoder 3
139
+ nn.Conv2d(in_channels=self.ngf * 2, out_channels=self.ngf * 4, kernel_size=3, stride=2, padding=1),# half
140
+ nn.BatchNorm2d(self.ngf * 4),
141
+ nn.LeakyReLU(0.2, inplace=True),
142
+
143
+ # Encoder 4
144
+ nn.Conv2d(in_channels=self.ngf * 4, out_channels=self.ngf * 5, kernel_size=3, stride=2, padding=1),# half
145
+ nn.BatchNorm2d(self.ngf * 5),
146
+ nn.LeakyReLU(0.2, inplace=True),
147
+
148
+ )
149
+
150
+ def forward(self, x):
151
+ bz = x.shape[0]
152
+ out = self.encoder(x).reshape(bz, -1, 1280)
153
+ return out
154
+
155
+
156
+ class SAR2OptUNetv2(UNet2DConditionModel):
157
+ def __init__(self, *args, **kwargs):
158
+ super().__init__(*args,**kwargs)
159
+ in_channels = 1
160
+ self.ngf = 2
161
+ self.sar_encoder = nn.Sequential(
162
+ # Encoder 1
163
+ nn.Conv2d(in_channels=in_channels, out_channels=self.ngf, kernel_size=3, stride=1, padding=1),
164
+ nn.BatchNorm2d(self.ngf),
165
+ nn.LeakyReLU(0.2, inplace=True),
166
+
167
+ # Encoder 2
168
+ nn.Conv2d(in_channels=self.ngf, out_channels=self.ngf * 2, kernel_size=3, stride=2, padding=1),# half
169
+ nn.BatchNorm2d(self.ngf * 2),
170
+ nn.LeakyReLU(0.2, inplace=True),
171
+
172
+ # Encoder 3
173
+ nn.Conv2d(in_channels=self.ngf * 2, out_channels=self.ngf * 4, kernel_size=3, stride=2, padding=1),# half
174
+ nn.BatchNorm2d(self.ngf * 4),
175
+ nn.LeakyReLU(0.2, inplace=True),
176
+
177
+ # Encoder 4
178
+ nn.Conv2d(in_channels=self.ngf * 4, out_channels=self.ngf * 5, kernel_size=3, stride=2, padding=1),# half
179
+ nn.BatchNorm2d(self.ngf * 5),
180
+ nn.LeakyReLU(0.2, inplace=True),
181
+
182
+ )
183
+
184
+ def forward(self, sample, timestep, sar_image=None,
185
+ encoder_hidden_states=None,
186
+ timestep_cond=None, cross_attention_kwargs=None,
187
+ added_cond_kwargs=None):
188
+
189
+ if encoder_hidden_states is None:
190
+ assert sar_image is not None
191
+ bz = sample.shape[0]
192
+ encoder_hidden_states = self.sar_encoder(sar_image).reshape(bz, -1, 1280)
193
+
194
+ default_overall_up_factor = 2 ** self.num_upsamplers
195
+ forward_upsample_size = False
196
+ upsample_size = None
197
+
198
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
199
+ forward_upsample_size = True
200
+
201
+ timesteps = timestep
202
+ if not torch.is_tensor(timesteps):
203
+ is_mps = sample.device.type == "mps"
204
+ if isinstance(timestep, float):
205
+ dtype = torch.float32 if is_mps else torch.float64
206
+ else:
207
+ dtype = torch.int32 if is_mps else torch.int64
208
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
209
+ elif len(timesteps.shape) == 0:
210
+ timesteps = timesteps[None].to(sample.device)
211
+
212
+ timesteps = timesteps.expand(sample.shape[0])
213
+
214
+ t_emb = self.time_proj(timesteps)
215
+ t_emb = t_emb.to(dtype=sample.dtype)
216
+
217
+ emb = self.time_embedding(t_emb, timestep_cond)
218
+ aug_emb = None
219
+
220
+ if added_cond_kwargs is not None:
221
+ if 'sar' in added_cond_kwargs:
222
+ image_embs = added_cond_kwargs.get("image_embeds")
223
+ aug_emb = self.add_embedding(image_embs)
224
+ else:
225
+ raise ValueError(
226
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
227
+ )
228
+
229
+ emb = emb + aug_emb if aug_emb is not None else emb
230
+ if self.time_embed_act is not None:
231
+ emb = self.time_embed_act(emb)
232
+ # 2. pre-process
233
+ sample = self.conv_in(sample)
234
+
235
+ # 3. down
236
+ down_block_res_samples = (sample,)
237
+
238
+ for downsample_block in self.down_blocks:
239
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
240
+ sample, res_samples = downsample_block(
241
+ hidden_states=sample,
242
+ temb=emb,
243
+ encoder_hidden_states=encoder_hidden_states,
244
+ attention_mask=None,
245
+ cross_attention_kwargs=cross_attention_kwargs,
246
+ encoder_attention_mask=None,
247
+ )
248
+ else:
249
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
250
+
251
+ down_block_res_samples += res_samples
252
+
253
+ # 4. mid
254
+ if self.mid_block is not None:
255
+ sample = self.mid_block(
256
+ sample,
257
+ emb,
258
+ encoder_hidden_states=encoder_hidden_states,
259
+ attention_mask=None,
260
+ cross_attention_kwargs=cross_attention_kwargs,
261
+ encoder_attention_mask=None,
262
+ )
263
+
264
+ # 5. up
265
+ for i, upsample_block in enumerate(self.up_blocks):
266
+ is_final_block = i == len(self.up_blocks) - 1
267
+
268
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
269
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
270
+
271
+ # if we have not reached the final block and need to forward the
272
+ # upsample size, we do it here
273
+ if not is_final_block and forward_upsample_size:
274
+ upsample_size = down_block_res_samples[-1].shape[2:]
275
+
276
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
277
+ sample = upsample_block(
278
+ hidden_states=sample,
279
+ temb=emb,
280
+ res_hidden_states_tuple=res_samples,
281
+ encoder_hidden_states=encoder_hidden_states,
282
+ cross_attention_kwargs=cross_attention_kwargs,
283
+ upsample_size=upsample_size,
284
+ attention_mask=None,
285
+ encoder_attention_mask=None,
286
+ )
287
+ else:
288
+ sample = upsample_block(
289
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
290
+ )
291
+
292
+ # 6. post-process
293
+ if self.conv_norm_out:
294
+ sample = self.conv_norm_out(sample)
295
+ sample = self.conv_act(sample)
296
+ sample = self.conv_out(sample)
297
+
298
+ return sample
299
+
300
+
301
+
302
+ class SAR2OptUNetv3(UNet2DModel):
303
+ def __init__(self, *args, **kwargs):
304
+ super().__init__(*args,**kwargs)
305
+
306
+ def forward(self, sample, timestep):
307
+ if self.config.center_input_sample:
308
+ sample = 2 * sample - 1.0
309
+ # 1. time
310
+ timesteps = timestep
311
+ if not torch.is_tensor(timesteps):
312
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
313
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
314
+ timesteps = timesteps[None].to(sample.device)
315
+
316
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
317
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
318
+
319
+ t_emb = self.time_proj(timesteps)
320
+ t_emb = t_emb.to(dtype=self.dtype)
321
+ emb = self.time_embedding(t_emb)
322
+
323
+ # 2. pre-process
324
+ skip_sample = sample
325
+ sample = self.conv_in(sample)
326
+
327
+ # 3. down
328
+ down_block_res_samples = (sample,)
329
+ for downsample_block in self.down_blocks:
330
+ if hasattr(downsample_block, "skip_conv"):
331
+ sample, res_samples, skip_sample = downsample_block(
332
+ hidden_states=sample, temb=emb, skip_sample=skip_sample
333
+ )
334
+ else:
335
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
336
+
337
+ down_block_res_samples += res_samples
338
+
339
+ # 4. mid
340
+ sample = self.mid_block(sample, emb)
341
+
342
+ # 5. up
343
+ skip_sample = None
344
+ for upsample_block in self.up_blocks:
345
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
346
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
347
+
348
+ if hasattr(upsample_block, "skip_conv"):
349
+ sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
350
+ else:
351
+ sample = upsample_block(sample, res_samples, emb)
352
+
353
+ # 6. post-process
354
+ sample = self.conv_norm_out(sample)
355
+ sample = self.conv_act(sample)
356
+ sample = self.conv_out(sample)
357
+
358
+ if skip_sample is not None:
359
+ sample += skip_sample
360
+
361
+ if self.config.time_embedding_type == "fourier":
362
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
363
+ sample = sample / timesteps
364
+
365
+ return sample
366
+
367
+
368
+
369
+
370
+
371
+ # 3*64*64
372
+ if __name__ == '__main__':
373
+ model = SAR2OptUNetv2(
374
+ sample_size=256,
375
+ in_channels=3,
376
+ out_channels=3,
377
+ layers_per_block=2,
378
+ block_out_channels=(128, 128, 256, 256, 512, 512),
379
+ down_block_types=(
380
+ "DownBlock2D",
381
+ "DownBlock2D",
382
+ "DownBlock2D",
383
+ "DownBlock2D",
384
+ "AttnDownBlock2D",
385
+ "DownBlock2D",
386
+ ),
387
+ up_block_types=(
388
+ "UpBlock2D",
389
+ "AttnUpBlock2D",
390
+ "UpBlock2D",
391
+ "UpBlock2D",
392
+ "UpBlock2D",
393
+ "UpBlock2D",
394
+ ),
395
+ )
396
+ model.to("cuda")
397
+ opt_image = torch.randn(8, 3, 256, 256).to("cuda")
398
+ sar_image = torch.randn(8, 1, 256, 256).to("cuda")
399
+
400
+ timestep = torch.tensor(1.0)
401
+ re = model(opt_image, timestep, sar_image , None, None, None)
402
+ print(re.shape)
models/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34833bcdbebf7767daa0015ca6bc0a0c444c68d84fad6f7aa96a10f1653cf1d7
3
+ size 454745716
models/model_org.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:788ed3e1601923a5245e430b89ff3522c3ab8c46b928d8a1275778a27cf2f8cf
3
+ size 454745716
sar_1.png ADDED
sar_2.png ADDED
sar_3.png ADDED
sar_4.png ADDED
utils.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ from safetensors import safe_open
3
+ import torch
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Union, List
6
+
7
+ def update_args_from_yaml(group, args, parser):
8
+ for key, value in group.items():
9
+ if isinstance(value, dict):
10
+ update_args_from_yaml(value, args, parser)
11
+ else:
12
+ if value == 'None' or value == 'null':
13
+ value = None
14
+ else:
15
+ arg_type = next((action.type for action in parser._actions if action.dest == key), str)
16
+
17
+ if arg_type is ast.literal_eval:
18
+ pass
19
+ elif arg_type is not None and not isinstance(value, arg_type):
20
+ try:
21
+ value = arg_type(value)
22
+ except ValueError as e:
23
+ raise ValueError(f"Cannot convert {key} to {arg_type}: {e}")
24
+
25
+ setattr(args, key, value)
26
+
27
+
28
+ def safe_load(model_path):
29
+ assert "safetensors" in model_path
30
+ state_dict = {}
31
+ with safe_open(model_path, framework="pt", device="cpu") as f:
32
+ for k in f.keys():
33
+ state_dict[k] = f.get_tensor(k)
34
+ return state_dict
35
+
36
+
37
+ @dataclass
38
+ class DDIMSchedulerStepOutput:
39
+ prev_sample: torch.Tensor # x_{t-1}
40
+ pred_original_sample: Optional[torch.Tensor] = None # x0
41
+
42
+
43
+ @dataclass
44
+ class DDIMSchedulerConversionOutput:
45
+ pred_epsilon: torch.Tensor
46
+ pred_original_sample: torch.Tensor
47
+ pred_velocity: torch.Tensor
48
+
49
+
50
+ class DDIMScheduler:
51
+ prediction_types = ["epsilon", "sample", "v_prediction"]
52
+
53
+ def __init__(
54
+ self,
55
+ num_train_timesteps: int,
56
+ num_inference_timesteps: int,
57
+ betas: torch.Tensor,
58
+ set_alpha_to_one: bool = True,
59
+ set_inference_timesteps_from_pure_noise: bool = True,
60
+ inference_timesteps: Union[str, List[int]] = "trailing",
61
+ device: Optional[Union[str, torch.device]] = None,
62
+ dtype: torch.dtype = torch.float32,
63
+ skip_step:bool = False,
64
+ original_inference_step: int=20,
65
+ steps_offset: int=0,
66
+
67
+ ):
68
+ assert num_train_timesteps > 0
69
+ assert num_train_timesteps >= num_inference_timesteps
70
+ assert num_train_timesteps == betas.size(0)
71
+ assert betas.ndim == 1
72
+ # self.user_name = user_name
73
+ # self.run_time = Recorder.format_time()
74
+ # self.task_name = 'AutoAIGC_%s' % str(self.run_time)
75
+ self.module_name = 'AutoAIGC'
76
+ self.config_list = {"num_train_timesteps": num_train_timesteps,
77
+ "num_inference_timesteps": num_inference_timesteps,
78
+ "betas": betas,
79
+ "set_alpha_to_one": set_alpha_to_one,
80
+ "set_inference_timesteps_from_pure_noise": set_inference_timesteps_from_pure_noise,
81
+ "inference_timesteps": inference_timesteps}
82
+ self.module_info = str(self.config_list)
83
+
84
+ # self.upload_logger(user_name=user_name)
85
+
86
+ device = device or betas.device
87
+
88
+ self.num_train_timesteps = num_train_timesteps
89
+ self.num_inference_steps = num_inference_timesteps
90
+ self.steps_offset = steps_offset
91
+
92
+ self.betas = betas # .to(device=device, dtype=dtype)
93
+ self.alphas = 1.0 - self.betas
94
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
95
+ self.final_alpha_cumprod = torch.tensor(1.0, device=device, dtype=dtype) if set_alpha_to_one else self.alphas_cumprod[0]
96
+
97
+ if isinstance(inference_timesteps, torch.Tensor):
98
+ assert len(inference_timesteps) == num_inference_timesteps
99
+ self.timesteps = inference_timesteps.cpu().numpy().tolist()
100
+ elif set_inference_timesteps_from_pure_noise:
101
+ if inference_timesteps == "trailing":
102
+ # [999, 949, 899, 849, 799, 749, 699, 649, 599, 549, 499, 449, 399, 349, 299, 249, 199, 149, 99, 49]
103
+ if skip_step: # ?
104
+ original_timesteps = torch.arange(num_train_timesteps - 1, -1, -num_train_timesteps / original_inference_step, device=device).round().int().tolist()
105
+ skipping_step = len(original_timesteps) // num_inference_timesteps
106
+ self.timesteps = original_timesteps[::skipping_step][:num_inference_timesteps]
107
+ else: # [999, 899, 799, 699, 599, 499, 399, 299, 199, 99]
108
+ self.timesteps = torch.arange(num_train_timesteps - 1, -1, -num_train_timesteps / num_inference_timesteps, device=device).round().int().tolist()
109
+ elif inference_timesteps == "linspace":
110
+ # Fixed DDIM timestep. Make sure the timestep starts from 999.
111
+ # Example 20 steps:
112
+ # [999, 946, 894, 841, 789, 736, 684, 631, 578, 526, 473, 421, 368, 315, 263, 210, 158, 105, 53, 0]
113
+ # [999, 888, 777, 666, 555, 444, 333, 222, 111, 0]
114
+ self.timesteps = torch.linspace(0, num_train_timesteps - 1, num_inference_timesteps, device=device).round().int().flip(0).tolist()
115
+ elif inference_timesteps == "leading":
116
+ step_ratio = num_train_timesteps // num_inference_timesteps
117
+ # # creates integer timesteps by multiplying by ratio
118
+ # # casting to int to avoid issues when num_inference_step is power of 3
119
+ self.timesteps = torch.arange(0, num_inference_timesteps).mul(step_ratio).round().flip(dims=[0]) #.clone().long()
120
+ # self.timesteps += self.steps_offset
121
+
122
+ # Original SD and DDIM paper may have a bug: <https://github.com/huggingface/diffusers/issues/2585>
123
+ # The inference timestep does not start from 999.
124
+ # Example 20 steps:
125
+ # [950, 900, 850, 800, 750, 700, 650, 600, 550, 500, 450, 400, 350, 300, 250, 200, 150, 100, 50, 0]
126
+ # [ 900, 800, 700, 600, 500, 400, 300, 200, 100, 0]
127
+ # self.timesteps = torch.arange(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps, device=self.device, dtype=torch.int).flip(0)
128
+ # self.timesteps = list(reversed(range(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps)))
129
+ else:
130
+ raise NotImplementedError
131
+
132
+ elif inference_timesteps == "leading":
133
+ # Original SD and DDIM paper may have a bug: <https://github.com/huggingface/diffusers/issues/2585>
134
+ # The inference timestep does not start from 999.
135
+ # Example 20 steps:
136
+ # [950, 900, 850, 800, 750, 700, 650, 600, 550, 500, 450, 400, 350, 300, 250, 200, 150, 100, 50, 0]
137
+ # [ 900, 800, 700, 600, 500, 400, 300, 200, 100, 0]
138
+ # self.timesteps = torch.arange(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps, device=self.device, dtype=torch.int).flip(0)
139
+ self.timesteps = list(reversed(range(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps)))
140
+
141
+ else:
142
+ self.timesteps = list(reversed(range(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps)))
143
+ # raise NotImplementedError
144
+
145
+ self.to(device=device)
146
+
147
+
148
+ def to(self, device):
149
+ self.betas = self.betas.to(device)
150
+ self.alphas_cumprod = self.alphas_cumprod.to(device)
151
+ self.final_alpha_cumprod = self.final_alpha_cumprod.to(device)
152
+ # self.timesteps = self.timesteps.to(device)
153
+ return self
154
+
155
+ def step(
156
+ self,
157
+ model_output: torch.Tensor,
158
+ model_output_type: str,
159
+ timestep: Union[torch.Tensor, int],
160
+ sample: torch.Tensor,
161
+ eta: float = 0.0,
162
+ clip_sample: bool = False,
163
+ dynamic_threshold: Optional[float] = None,
164
+ variance_noise: Optional[torch.Tensor] = None,
165
+ ) -> DDIMSchedulerStepOutput:
166
+ # 1. get previous step value (t-1)
167
+ if isinstance(timestep, int):
168
+ # 1. get previous step value (t-1)
169
+ idx = self.timesteps.index(timestep)
170
+ prev_timestep = self.timesteps[idx + 1] if idx < self.num_inference_steps - 1 else None
171
+
172
+ # 2. compute alphas, betas
173
+ alpha_prod_t = self.alphas_cumprod[timestep]
174
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep is not None else self.final_alpha_cumprod
175
+ beta_prod_t = 1 - alpha_prod_t
176
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
177
+ else:
178
+ timesteps = torch.tensor(self.timesteps).to(timestep.device)
179
+ idx = timestep.reshape(-1, 1).eq(timesteps.reshape(1, -1)).nonzero()[:, 1] # 找到 timestep 在 timesteps 中的索引 idx
180
+ # 根据idx找到idx+1对应的timesteps元素,也就是下一个时间步。如果idx+1超出了timesteps的长度,它会被限制在self.num_inference_steps - 1
181
+ prev_timestep = timesteps[idx.add(1).clamp_max(self.num_inference_steps - 1)]
182
+
183
+ assert (prev_timestep is not None)
184
+ # 2. compute alphas, betas
185
+ alpha_prod_t = self.alphas_cumprod[timestep]
186
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]
187
+ alpha_prod_t_prev = torch.where(prev_timestep < 0, self.final_alpha_cumprod, alpha_prod_t_prev)
188
+ beta_prod_t = 1 - alpha_prod_t
189
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
190
+
191
+ bs = timestep.size(0)
192
+ alpha_prod_t = alpha_prod_t.view(bs, 1, 1, 1)
193
+ alpha_prod_t_prev = alpha_prod_t_prev.view(bs, 1, 1, 1)
194
+ beta_prod_t = beta_prod_t.view(bs, 1, 1, 1)
195
+ beta_prod_t_prev = beta_prod_t_prev.view(bs, 1, 1, 1)
196
+
197
+ # # 2. compute alphas, betas
198
+ # alpha_prod_t = self.alphas_cumprod[timestep]
199
+ # alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep is not None else self.final_alpha_cumprod
200
+ # beta_prod_t = 1 - alpha_prod_t
201
+ # beta_prod_t_prev = 1 - alpha_prod_t_prev
202
+ # rcfg
203
+ self.stock_alpha_prod_t_prev = alpha_prod_t_prev
204
+ self.stock_beta_prod_t_prev = beta_prod_t_prev
205
+
206
+ # rcfg
207
+ self.stock_alpha_prod_t_prev = alpha_prod_t_prev
208
+ self.stock_beta_prod_t_prev = beta_prod_t_prev
209
+
210
+ # 3. compute predicted original sample from predicted noise also called
211
+ model_output_conversion = self.convert_output(model_output, model_output_type, sample, timestep)
212
+ pred_original_sample = model_output_conversion.pred_original_sample
213
+ pred_epsilon = model_output_conversion.pred_epsilon
214
+
215
+ # 4. Clip or threshold "predicted x_0"
216
+ if clip_sample:
217
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
218
+ pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon
219
+
220
+ if dynamic_threshold is not None:
221
+ # Dynamic thresholding in https://arxiv.org/abs/2205.11487
222
+ dynamic_max_val = pred_original_sample \
223
+ .flatten(1) \
224
+ .abs() \
225
+ .float() \
226
+ .quantile(dynamic_threshold, dim=1) \
227
+ .type_as(pred_original_sample) \
228
+ .clamp_min(1) \
229
+ .view(-1, *([1] * (pred_original_sample.ndim - 1)))
230
+ pred_original_sample = pred_original_sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
231
+ pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon
232
+
233
+ # 5. compute variance: "sigma_t(η)" -> see formula (16) from https://arxiv.org/pdf/2010.02502.pdf
234
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
235
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
236
+ std_dev_t = eta * variance ** (0.5)
237
+
238
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
239
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
240
+
241
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
242
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
243
+
244
+ # 8. add "random noise" if needed.
245
+ if eta > 0:
246
+ if variance_noise is None:
247
+ variance_noise = torch.randn_like(model_output)
248
+ prev_sample = prev_sample + std_dev_t * variance_noise
249
+
250
+ return DDIMSchedulerStepOutput(
251
+ prev_sample=prev_sample, # x_{t-1}
252
+ pred_original_sample=pred_original_sample # x0
253
+ )
254
+
255
+ def add_noise(
256
+ self,
257
+ original_samples: torch.Tensor,
258
+ noise: torch.Tensor,
259
+ timesteps: Union[torch.Tensor, int],
260
+ replace_noise=True
261
+ ) -> torch.Tensor:
262
+ alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (original_samples.ndim - 1)))
263
+ if replace_noise:
264
+ indices = (timesteps == 999).nonzero()
265
+ if indices.numel() > 0:
266
+ alpha_prod_t[indices] = 0
267
+ return alpha_prod_t ** (0.5) * original_samples + (1 - alpha_prod_t) ** (0.5) * noise
268
+
269
+ def add_noise_lcm(
270
+ self,
271
+ original_samples: torch.Tensor,
272
+ noise: torch.Tensor,
273
+ timestep: Union[torch.Tensor, int],
274
+ ) -> torch.Tensor:
275
+ if isinstance(timestep, int):
276
+ # 1. get previous step value (t-1)
277
+ idx = self.timesteps.index(timestep)
278
+ prev_timestep = self.timesteps[idx + 1] if idx < self.num_inference_steps - 1 else None
279
+
280
+ # 2. compute alphas, betas
281
+ alpha_prod_t = self.alphas_cumprod[timestep]
282
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep is not None else self.final_alpha_cumprod
283
+ beta_prod_t = 1 - alpha_prod_t
284
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
285
+ else:
286
+ timesteps = torch.tensor(self.timesteps).to(timestep.device)
287
+ idx = timestep.reshape(-1, 1).eq(timesteps.reshape(1, -1)).nonzero()[:, 1] # 找到 timestep 在 timesteps 中的索引 idx
288
+ prev_timestep = timesteps[idx.add(1).clamp_max(self.num_inference_steps - 1)]
289
+
290
+ assert (prev_timestep is not None)
291
+ # 2. compute alphas, betas
292
+ alpha_prod_t = self.alphas_cumprod[timestep]
293
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]
294
+ alpha_prod_t_prev = torch.where(prev_timestep < 0, self.final_alpha_cumprod, alpha_prod_t_prev)
295
+ beta_prod_t = 1 - alpha_prod_t
296
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
297
+
298
+ bs = timestep.size(0)
299
+ alpha_prod_t = alpha_prod_t.view(bs, 1, 1, 1)
300
+ alpha_prod_t_prev = alpha_prod_t_prev.view(bs, 1, 1, 1)
301
+ beta_prod_t = beta_prod_t.view(bs, 1, 1, 1)
302
+ beta_prod_t_prev = beta_prod_t_prev.view(bs, 1, 1, 1)
303
+
304
+ alpha_prod_t_prev = alpha_prod_t_prev.reshape(-1, *([1] * (original_samples.ndim - 1)))
305
+ return alpha_prod_t_prev ** (0.5) * original_samples + (1 - alpha_prod_t_prev) ** (0.5) * noise
306
+
307
+
308
+ def convert_output(
309
+ self,
310
+ model_output: torch.Tensor,
311
+ model_output_type: str,
312
+ sample: torch.Tensor,
313
+ timesteps: Union[torch.Tensor, int]
314
+ ) -> DDIMSchedulerConversionOutput:
315
+ assert model_output_type in self.prediction_types
316
+
317
+ alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (sample.ndim - 1)))
318
+ beta_prod_t = 1 - alpha_prod_t
319
+
320
+ if model_output_type == "epsilon":
321
+ pred_epsilon = model_output
322
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * pred_epsilon) / alpha_prod_t ** (0.5)
323
+ pred_velocity = alpha_prod_t ** (0.5) * pred_epsilon - (1 - alpha_prod_t) ** (0.5) * pred_original_sample
324
+ elif model_output_type == "sample":
325
+ pred_original_sample = model_output
326
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
327
+ pred_velocity = alpha_prod_t ** (0.5) * pred_epsilon - (1 - alpha_prod_t) ** (0.5) * pred_original_sample
328
+ elif model_output_type == "v_prediction":
329
+ pred_velocity = model_output
330
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
331
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
332
+ else:
333
+ raise ValueError("Unknown prediction type")
334
+
335
+ return DDIMSchedulerConversionOutput(
336
+ pred_epsilon=pred_epsilon,
337
+ pred_original_sample=pred_original_sample,
338
+ pred_velocity=pred_velocity)
339
+
340
+ def get_velocity(
341
+ self,
342
+ sample: torch.Tensor,
343
+ noise: torch.Tensor,
344
+ timesteps: torch.Tensor
345
+ ) -> torch.FloatTensor:
346
+ alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (sample.ndim - 1)))
347
+ return alpha_prod_t ** (0.5) * noise - (1 - alpha_prod_t) ** (0.5) * sample