sashasax commited on
Commit
717c1f9
·
1 Parent(s): 2ac1f2a

initial commit

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ import PIL
6
+ from PIL import Image
7
+ import os
8
+ from typing import Tuple
9
+
10
+
11
+ def setup_model(device: torch.device) -> Tuple[torch.nn.Module, int]:
12
+ image_size = 384
13
+ model = torch.hub.load('alexsax/omnidata_models', 'surface_normal_dpt_hybrid_384')
14
+ model.to(device)
15
+ model.eval()
16
+
17
+ return model, image_size
18
+
19
+ def setup_transforms(image_size: int) -> transforms.Compose:
20
+ return transforms.Compose([
21
+ transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR),
22
+ transforms.CenterCrop(image_size),
23
+ transforms.ToTensor(),
24
+ ])
25
+
26
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
+ model, image_size = setup_model(device)
28
+ trans_totensor = setup_transforms(image_size)
29
+
30
+ def estimate_surface_normal(input_image: PIL.Image.Image) -> PIL.Image.Image:
31
+ with torch.no_grad():
32
+ img_tensor = trans_totensor(input_image)[:3].unsqueeze(0).to(device)
33
+
34
+ if img_tensor.shape[1] == 1:
35
+ img_tensor = img_tensor.repeat_interleave(3, 1)
36
+
37
+ output = model(img_tensor).clamp(min=0, max=1)
38
+ output_image = transforms.ToPILImage()(output[0])
39
+
40
+ return output_image
41
+
42
+ iface = gr.Interface(
43
+ fn=estimate_surface_normal,
44
+ inputs=gr.Image(type="pil"),
45
+ outputs=gr.Image(type="pil"),
46
+ title="Monocular Surface Normal Estimation: Omnidata DPT-Hybrid",
47
+ description="Upload an image to estimate monocular surface normals.",
48
+ examples=[
49
+ "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/test1_rgb.png?raw=true",
50
+ "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test2.png?raw=true",
51
+ "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test3.png?raw=true",
52
+ "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test4.png?raw=true",
53
+ "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test5.png?raw=true",
54
+ "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test6.png?raw=true",
55
+ "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test7.png?raw=true",
56
+ "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test8.png?raw=true",
57
+ "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test9.png?raw=true",
58
+ "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test10.png?raw=true",
59
+ ],
60
+ )
61
+
62
+ if __name__ == "__main__":
63
+ iface.launch()