Mayuri commited on
Commit
566a959
·
verified ·
1 Parent(s): 2a5630b

Rename main_v3.py to app.py

Browse files
Files changed (1) hide show
  1. main_v3.py → app.py +140 -140
main_v3.py → app.py RENAMED
@@ -1,140 +1,140 @@
1
- import gradio as gr
2
- import argparse
3
- import os
4
-
5
- import pandas as pd
6
- from PIL import Image
7
- import numpy as np
8
- import torch as th
9
- from torchvision import transforms
10
-
11
- import diffusers
12
- from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, LCMScheduler
13
- import gc
14
- from safetensors import safe_open
15
-
16
- from models import SAR2OptUNetv3
17
- from utils import update_args_from_yaml, safe_load
18
-
19
- transform_sar = transforms.Compose([
20
- transforms.ToTensor(),
21
- transforms.Resize((256, 256)),
22
- transforms.Normalize((0.5), (0.5)),
23
- ])
24
- AVAILABLE_MODELS = {
25
- "Sen12:LCM-Model": "models/model.safetensors",
26
- "Sen12:Org-Model": "models/model_org.safetensors",
27
- }
28
-
29
- device = th.device('cuda:0' if th.cuda.is_available() else 'cpu')
30
-
31
- def safe_load(model_path):
32
- assert "safetensors" in model_path
33
- state_dict = {}
34
- with safe_open(model_path, framework="pt", device="cpu") as f:
35
- for k in f.keys():
36
- state_dict[k] = f.get_tensor(k)
37
- return state_dict
38
-
39
- unet_model = SAR2OptUNetv3(
40
- sample_size=256,
41
- in_channels=4,
42
- out_channels=3,
43
- layers_per_block=2,
44
- block_out_channels=(128, 128, 256, 256, 512, 512),
45
- down_block_types=(
46
- "DownBlock2D",
47
- "DownBlock2D",
48
- "DownBlock2D",
49
- "DownBlock2D",
50
- "AttnDownBlock2D",
51
- "DownBlock2D",
52
- ),
53
- up_block_types=(
54
- "UpBlock2D",
55
- "AttnUpBlock2D",
56
- "UpBlock2D",
57
- "UpBlock2D",
58
- "UpBlock2D",
59
- "UpBlock2D",
60
- ),
61
- )
62
-
63
- print('load unet safetensos done!')
64
- lcm_scheduler = LCMScheduler(num_train_timesteps=1000)
65
-
66
- unet_model.to(device)
67
- unet_model.eval()
68
-
69
- model_kwargs = {}
70
-
71
-
72
- def predict(condition, nums_step, model_name):
73
- unet_checkpoint = AVAILABLE_MODELS[model_name]
74
- unet_model.load_state_dict(safe_load(unet_checkpoint), strict=True)
75
- unet_model.eval().to(device)
76
- with th.no_grad():
77
- lcm_scheduler.set_timesteps(nums_step, device=device)
78
- timesteps = lcm_scheduler.timesteps
79
- pred_latent = th.randn(size=[1, 3, 256, 256], device=device)
80
- condition = condition.convert("L")
81
- condition = transform_sar(condition)
82
- condition = th.unsqueeze(condition, 0)
83
- condition = condition.to(device)
84
- for timestep in timesteps:
85
- latent_to_pred = th.cat((pred_latent, condition), dim=1)
86
- model_pred = unet_model(latent_to_pred, timestep)
87
- pred_latent, denoised = lcm_scheduler.step(
88
- model_output=model_pred,
89
- timestep=timestep,
90
- sample=pred_latent,
91
- return_dict=False)
92
- sample = denoised.cpu()
93
-
94
- sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
95
- sample = sample.permute(0, 2, 3, 1)
96
- sample = sample.contiguous()
97
- sample = sample.cpu().numpy()
98
- sample = sample.squeeze(0)
99
- sample = Image.fromarray(sample)
100
- return sample
101
-
102
-
103
- demo = gr.Interface(
104
- fn=predict,
105
- inputs=[gr.Image(type="pil"),
106
- gr.Slider(1, 1000),
107
- gr.Dropdown(
108
- choices=list(AVAILABLE_MODELS.keys()),
109
- value=list(AVAILABLE_MODELS.keys())[0],
110
- label="Choose the Model"),],
111
- # gr.Radio(["Sent", "GF3"], label="Model", info="Which model to you want to use?"), ],
112
- outputs=gr.Image(type="pil"),
113
- examples=[
114
- [os.path.join(os.path.dirname(__file__), "sar_1.png"), 8, "Sen12:LCM-Model"],
115
- [os.path.join(os.path.dirname(__file__), "sar_2.png"), 16, "Sen12:LCM-Model"],
116
- [os.path.join(os.path.dirname(__file__), "sar_3.png"), 500, "Sen12:Org-Model"],
117
- [os.path.join(os.path.dirname(__file__), "sar_4.png"), 1000, "Sen12:Org-Model"],
118
- ],
119
- title="SAR to Optical Image🚀",
120
- description="""
121
- # 🎯 Instruction
122
- This is a project that converts SAR images into optical images, based on conditional diffusion.
123
-
124
- Input a SAR image, and its corresponding optical image will be obtained.
125
-
126
- ## 📢 Inputs
127
- - `condition`: the SAR image that you want to transfer.
128
- - `timestep_respacing`: the number of iteration steps when inference.
129
-
130
- ## 🎉 Outputs
131
- - The corresponding optical image.
132
-
133
- **Paper** : [Guided Diffusion for Image Generation](https://arxiv.org/abs/2105.05233)
134
-
135
- **Github** : https://github.com/Coordi777/Conditional_SAR2OPT
136
- """
137
- )
138
-
139
- if __name__ == "__main__":
140
- demo.launch(server_port=16006)
 
1
+ import gradio as gr
2
+ import argparse
3
+ import os
4
+
5
+ import pandas as pd
6
+ from PIL import Image
7
+ import numpy as np
8
+ import torch as th
9
+ from torchvision import transforms
10
+
11
+ import diffusers
12
+ from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, LCMScheduler
13
+ import gc
14
+ from safetensors import safe_open
15
+
16
+ from models import SAR2OptUNetv3
17
+ from utils import update_args_from_yaml, safe_load
18
+
19
+ transform_sar = transforms.Compose([
20
+ transforms.ToTensor(),
21
+ transforms.Resize((256, 256)),
22
+ transforms.Normalize((0.5), (0.5)),
23
+ ])
24
+ AVAILABLE_MODELS = {
25
+ "Sen12:LCM-Model": "models/model.safetensors",
26
+ "Sen12:Org-Model": "models/model_org.safetensors",
27
+ }
28
+
29
+ device = th.device('cuda:0' if th.cuda.is_available() else 'cpu')
30
+
31
+ def safe_load(model_path):
32
+ assert "safetensors" in model_path
33
+ state_dict = {}
34
+ with safe_open(model_path, framework="pt", device="cpu") as f:
35
+ for k in f.keys():
36
+ state_dict[k] = f.get_tensor(k)
37
+ return state_dict
38
+
39
+ unet_model = SAR2OptUNetv3(
40
+ sample_size=256,
41
+ in_channels=4,
42
+ out_channels=3,
43
+ layers_per_block=2,
44
+ block_out_channels=(128, 128, 256, 256, 512, 512),
45
+ down_block_types=(
46
+ "DownBlock2D",
47
+ "DownBlock2D",
48
+ "DownBlock2D",
49
+ "DownBlock2D",
50
+ "AttnDownBlock2D",
51
+ "DownBlock2D",
52
+ ),
53
+ up_block_types=(
54
+ "UpBlock2D",
55
+ "AttnUpBlock2D",
56
+ "UpBlock2D",
57
+ "UpBlock2D",
58
+ "UpBlock2D",
59
+ "UpBlock2D",
60
+ ),
61
+ )
62
+
63
+ print('load unet safetensos done!')
64
+ lcm_scheduler = LCMScheduler(num_train_timesteps=1000)
65
+
66
+ unet_model.to(device)
67
+ unet_model.eval()
68
+
69
+ model_kwargs = {}
70
+
71
+
72
+ def predict(condition, nums_step, model_name):
73
+ unet_checkpoint = AVAILABLE_MODELS[model_name]
74
+ unet_model.load_state_dict(safe_load(unet_checkpoint), strict=True)
75
+ unet_model.eval().to(device)
76
+ with th.no_grad():
77
+ lcm_scheduler.set_timesteps(nums_step, device=device)
78
+ timesteps = lcm_scheduler.timesteps
79
+ pred_latent = th.randn(size=[1, 3, 256, 256], device=device)
80
+ condition = condition.convert("L")
81
+ condition = transform_sar(condition)
82
+ condition = th.unsqueeze(condition, 0)
83
+ condition = condition.to(device)
84
+ for timestep in timesteps:
85
+ latent_to_pred = th.cat((pred_latent, condition), dim=1)
86
+ model_pred = unet_model(latent_to_pred, timestep)
87
+ pred_latent, denoised = lcm_scheduler.step(
88
+ model_output=model_pred,
89
+ timestep=timestep,
90
+ sample=pred_latent,
91
+ return_dict=False)
92
+ sample = denoised.cpu()
93
+
94
+ sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
95
+ sample = sample.permute(0, 2, 3, 1)
96
+ sample = sample.contiguous()
97
+ sample = sample.cpu().numpy()
98
+ sample = sample.squeeze(0)
99
+ sample = Image.fromarray(sample)
100
+ return sample
101
+
102
+
103
+ demo = gr.Interface(
104
+ fn=predict,
105
+ inputs=[gr.Image(type="pil"),
106
+ gr.Slider(1, 1000),
107
+ gr.Dropdown(
108
+ choices=list(AVAILABLE_MODELS.keys()),
109
+ value=list(AVAILABLE_MODELS.keys())[0],
110
+ label="Choose the Model"),],
111
+ # gr.Radio(["Sent", "GF3"], label="Model", info="Which model to you want to use?"), ],
112
+ outputs=gr.Image(type="pil"),
113
+ examples=[
114
+ [os.path.join(os.path.dirname(__file__), "sar_1.png"), 8, "Sen12:LCM-Model"],
115
+ [os.path.join(os.path.dirname(__file__), "sar_2.png"), 16, "Sen12:LCM-Model"],
116
+ [os.path.join(os.path.dirname(__file__), "sar_3.png"), 500, "Sen12:Org-Model"],
117
+ [os.path.join(os.path.dirname(__file__), "sar_4.png"), 1000, "Sen12:Org-Model"],
118
+ ],
119
+ title="SAR to Optical Image🚀",
120
+ description="""
121
+ # 🎯 Instruction
122
+ This is a project that converts SAR images into optical images, based on conditional diffusion.
123
+
124
+ Input a SAR image, and its corresponding optical image will be obtained.
125
+
126
+ ## 📢 Inputs
127
+ - `condition`: the SAR image that you want to transfer.
128
+ - `timestep_respacing`: the number of iteration steps when inference.
129
+
130
+ ## 🎉 Outputs
131
+ - The corresponding optical image.
132
+
133
+ **Paper** : [Guided Diffusion for Image Generation](https://arxiv.org/abs/2105.05233)
134
+
135
+ **Github** : https://github.com/Coordi777/Conditional_SAR2OPT
136
+ """
137
+ )
138
+
139
+ if __name__ == "__main__":
140
+ demo.launch()