muqtasid87 commited on
Commit
a8c5cb4
·
verified ·
1 Parent(s): 2a35a01

Update project/app_florence.py

Browse files
Files changed (1) hide show
  1. project/app_florence.py +222 -222
project/app_florence.py CHANGED
@@ -1,223 +1,223 @@
1
- import streamlit as st
2
- from transformers import (
3
- AutoModelForCausalLM,
4
- AutoProcessor
5
- )
6
- import torch
7
- from PIL import Image
8
- import time
9
- import os
10
- import matplotlib.pyplot as plt
11
- import matplotlib.patches as patches
12
- import io
13
- import numpy as np
14
-
15
-
16
- @st.cache_resource
17
- def load_model():
18
- """Load the model and processor (cached to prevent reloading)"""
19
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
20
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
21
-
22
- model = AutoModelForCausalLM.from_pretrained(
23
- "microsoft/Florence-2-large-ft",
24
- torch_dtype=torch_dtype,
25
- trust_remote_code=True
26
- ).to(device)
27
- processor = AutoProcessor.from_pretrained(
28
- "microsoft/Florence-2-large-ft",
29
- trust_remote_code=True
30
- )
31
- return model, processor, device, torch_dtype
32
-
33
- def draw_bounding_boxes(image, bboxes, labels):
34
- """Draw bounding boxes and labels on the image"""
35
- # Convert PIL image to numpy array
36
- img_array = np.array(image)
37
-
38
- # Create figure and axis
39
- fig, ax = plt.subplots()
40
- ax.imshow(img_array)
41
-
42
- # Add each bounding box and label
43
- for bbox, label in zip(bboxes, labels):
44
- x, y, x2, y2 = bbox
45
- width = x2 - x
46
- height = y2 - y
47
-
48
- # Create rectangle patch
49
- rect = patches.Rectangle(
50
- (x, y), width, height,
51
- linewidth=2,
52
- edgecolor='red',
53
- facecolor='none'
54
- )
55
- ax.add_patch(rect)
56
-
57
- # Add label above the box
58
- plt.text(
59
- x, y-5,
60
- label,
61
- color='red',
62
- fontsize=12,
63
- bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=0)
64
- )
65
-
66
- # Remove axes
67
- plt.axis('off')
68
-
69
- # Convert plot to image
70
- buf = io.BytesIO()
71
- plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
72
- plt.close()
73
- buf.seek(0)
74
- return Image.open(buf)
75
-
76
- def process_image(image, text_input, model, processor, device, torch_dtype):
77
- """Process the image and return the model's output"""
78
- start_time = time.time()
79
-
80
- task_prompt = "<CAPTION_TO_PHRASE_GROUNDING>"
81
- prompt = task_prompt + text_input if text_input else task_prompt
82
-
83
- inputs = processor(
84
- text=prompt,
85
- images=image,
86
- return_tensors="pt"
87
- ).to(device, torch_dtype)
88
-
89
- generated_ids = model.generate(
90
- input_ids=inputs["input_ids"],
91
- pixel_values=inputs["pixel_values"],
92
- max_new_tokens=2048,
93
- num_beams=3
94
- )
95
-
96
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
97
- parsed_answer = processor.post_process_generation(
98
- generated_text,
99
- task=task_prompt,
100
- image_size=(image.width, image.height)
101
- )
102
-
103
- inference_time = time.time() - start_time
104
-
105
- # Create annotated image
106
- result = parsed_answer[task_prompt]
107
- annotated_image = draw_bounding_boxes(
108
- image,
109
- result['bboxes'],
110
- result['labels']
111
- )
112
-
113
- return result, inference_time, annotated_image
114
-
115
- def main():
116
- # Compact header
117
- st.markdown("<h1 style='font-size: 24px;'>🔍 Image Analysis with Florence-2</h1>", unsafe_allow_html=True)
118
-
119
- # Load model and processor
120
- with st.spinner("Loading model... This might take a minute."):
121
- model, processor, device, torch_dtype = load_model()
122
-
123
- # Initialize session state
124
- if 'selected_image' not in st.session_state:
125
- st.session_state.selected_image = None
126
- if 'result' not in st.session_state:
127
- st.session_state.result = None
128
- if 'inference_time' not in st.session_state:
129
- st.session_state.inference_time = None
130
- if 'annotated_image' not in st.session_state:
131
- st.session_state.annotated_image = None
132
-
133
- # Main content area
134
- col1, col2, col3 = st.columns([1, 1.5, 1])
135
-
136
- with col1:
137
- # Input method selection
138
- input_option = st.radio("Choose input method:", ["Use example image", "Upload image"], label_visibility="collapsed")
139
-
140
- if input_option == "Upload image":
141
- uploaded_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"], label_visibility="collapsed")
142
- image_source = uploaded_file
143
- if uploaded_file:
144
- st.session_state.selected_image = uploaded_file
145
- else:
146
- image_source = st.session_state.selected_image
147
-
148
- # Default prompt and analysis section
149
- default_prompt = "What type of vehicle is this?"
150
- prompt = st.text_area("Enter prompt:", value=default_prompt, height=100)
151
-
152
- analyze_col1, analyze_col2 = st.columns([1, 2])
153
- with analyze_col1:
154
- analyze_button = st.button("Analyze Image", use_container_width=True, disabled=image_source is None)
155
-
156
- # Display selected image and results
157
- if image_source:
158
- try:
159
- if isinstance(image_source, str):
160
- image = Image.open(image_source).convert("RGB")
161
- else:
162
- image = Image.open(image_source).convert("RGB")
163
- st.image(image, caption="Selected Image", width=300)
164
- except Exception as e:
165
- st.error(f"Error loading image: {str(e)}")
166
-
167
- # Analysis results
168
- if analyze_button and image_source:
169
- with st.spinner("Analyzing..."):
170
- try:
171
- result, inference_time, annotated_image = process_image(image, prompt, model, processor, device, torch_dtype)
172
- st.session_state.result = result
173
- st.session_state.inference_time = inference_time
174
- st.session_state.annotated_image = annotated_image
175
- except Exception as e:
176
- st.error(f"Error: {str(e)}")
177
-
178
- if st.session_state.result:
179
- st.success("Analysis Complete!")
180
-
181
- # Display the annotated image
182
- st.image(st.session_state.annotated_image, caption="Analyzed Image with Detections", use_container_width=True)
183
-
184
- # Display raw results and inference time
185
- st.markdown("**Raw Results:**")
186
- st.json(st.session_state.result)
187
- st.markdown(f"*Inference time: {st.session_state.inference_time:.2f} seconds*")
188
-
189
- # Example images section
190
- if input_option == "Use example image":
191
- st.markdown("### Example Images")
192
- example_images = [f for f in os.listdir("images") if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
193
-
194
- if example_images:
195
- # Create grid of images
196
- cols = st.columns(4) # Adjust number of columns as needed
197
- for idx, img_name in enumerate(example_images):
198
- with cols[idx % 4]:
199
- img_path = os.path.join("images", img_name)
200
- img = Image.open(img_path)
201
- img.thumbnail((150, 150))
202
-
203
- # Make image clickable
204
- if st.button(
205
- "📷",
206
- key=f"img_{idx}",
207
- help=img_name,
208
- use_container_width=True
209
- ):
210
- st.session_state.selected_image = img_path
211
- st.rerun()
212
-
213
- # Display image with conditional styling
214
- st.image(
215
- img,
216
- caption=img_name,
217
- use_container_width=True,
218
- )
219
- else:
220
- st.error("No example images found in the 'images' directory")
221
-
222
- if __name__ == "__main__":
223
  main()
 
1
+ import streamlit as st
2
+ from transformers import (
3
+ AutoModelForCausalLM,
4
+ AutoProcessor
5
+ )
6
+ import torch
7
+ from PIL import Image
8
+ import time
9
+ import os
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.patches as patches
12
+ import io
13
+ import numpy as np
14
+
15
+
16
+ @st.cache_resource
17
+ def load_model():
18
+ """Load the model and processor (cached to prevent reloading)"""
19
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
20
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
21
+
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ "microsoft/Florence-2-large-ft",
24
+ torch_dtype=torch_dtype,
25
+ trust_remote_code=True
26
+ ).to(device)
27
+ processor = AutoProcessor.from_pretrained(
28
+ "microsoft/Florence-2-large-ft",
29
+ trust_remote_code=True
30
+ )
31
+ return model, processor, device, torch_dtype
32
+
33
+ def draw_bounding_boxes(image, bboxes, labels):
34
+ """Draw bounding boxes and labels on the image"""
35
+ # Convert PIL image to numpy array
36
+ img_array = np.array(image)
37
+
38
+ # Create figure and axis
39
+ fig, ax = plt.subplots()
40
+ ax.imshow(img_array)
41
+
42
+ # Add each bounding box and label
43
+ for bbox, label in zip(bboxes, labels):
44
+ x, y, x2, y2 = bbox
45
+ width = x2 - x
46
+ height = y2 - y
47
+
48
+ # Create rectangle patch
49
+ rect = patches.Rectangle(
50
+ (x, y), width, height,
51
+ linewidth=2,
52
+ edgecolor='red',
53
+ facecolor='none'
54
+ )
55
+ ax.add_patch(rect)
56
+
57
+ # Add label above the box
58
+ plt.text(
59
+ x, y-5,
60
+ label,
61
+ color='red',
62
+ fontsize=12,
63
+ bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=0)
64
+ )
65
+
66
+ # Remove axes
67
+ plt.axis('off')
68
+
69
+ # Convert plot to image
70
+ buf = io.BytesIO()
71
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
72
+ plt.close()
73
+ buf.seek(0)
74
+ return Image.open(buf)
75
+
76
+ def process_image(image, text_input, model, processor, device, torch_dtype):
77
+ """Process the image and return the model's output"""
78
+ start_time = time.time()
79
+
80
+ task_prompt = "<CAPTION_TO_PHRASE_GROUNDING>"
81
+ prompt = task_prompt + text_input if text_input else task_prompt
82
+
83
+ inputs = processor(
84
+ text=prompt,
85
+ images=image,
86
+ return_tensors="pt"
87
+ ).to(device, torch_dtype)
88
+
89
+ generated_ids = model.generate(
90
+ input_ids=inputs["input_ids"],
91
+ pixel_values=inputs["pixel_values"],
92
+ max_new_tokens=2048,
93
+ num_beams=3
94
+ )
95
+
96
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
97
+ parsed_answer = processor.post_process_generation(
98
+ generated_text,
99
+ task=task_prompt,
100
+ image_size=(image.width, image.height)
101
+ )
102
+
103
+ inference_time = time.time() - start_time
104
+
105
+ # Create annotated image
106
+ result = parsed_answer[task_prompt]
107
+ annotated_image = draw_bounding_boxes(
108
+ image,
109
+ result['bboxes'],
110
+ result['labels']
111
+ )
112
+
113
+ return result, inference_time, annotated_image
114
+
115
+ def main():
116
+ # Compact header
117
+ st.markdown("<h1 style='font-size: 24px;'>🔍 Image Analysis with Florence-2</h1>", unsafe_allow_html=True)
118
+
119
+ # Load model and processor
120
+ with st.spinner("Loading model... This might take a minute."):
121
+ model, processor, device, torch_dtype = load_model()
122
+
123
+ # Initialize session state
124
+ if 'selected_image' not in st.session_state:
125
+ st.session_state.selected_image = None
126
+ if 'result' not in st.session_state:
127
+ st.session_state.result = None
128
+ if 'inference_time' not in st.session_state:
129
+ st.session_state.inference_time = None
130
+ if 'annotated_image' not in st.session_state:
131
+ st.session_state.annotated_image = None
132
+
133
+ # Main content area
134
+ col1, col2, col3 = st.columns([1, 1.5, 1])
135
+
136
+ with col1:
137
+ # Input method selection
138
+ input_option = st.radio("Choose input method:", ["Use example image", "Upload image"], label_visibility="collapsed")
139
+
140
+ if input_option == "Upload image":
141
+ uploaded_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"], label_visibility="collapsed")
142
+ image_source = uploaded_file
143
+ if uploaded_file:
144
+ st.session_state.selected_image = uploaded_file
145
+ else:
146
+ image_source = st.session_state.selected_image
147
+
148
+ # Default prompt and analysis section
149
+ default_prompt = "<output from qwen2 eg. bus>"
150
+ prompt = st.text_area("Enter prompt:", value=default_prompt, height=100)
151
+
152
+ analyze_col1, analyze_col2 = st.columns([1, 2])
153
+ with analyze_col1:
154
+ analyze_button = st.button("Analyze Image", use_container_width=True, disabled=image_source is None)
155
+
156
+ # Display selected image and results
157
+ if image_source:
158
+ try:
159
+ if isinstance(image_source, str):
160
+ image = Image.open(image_source).convert("RGB")
161
+ else:
162
+ image = Image.open(image_source).convert("RGB")
163
+ st.image(image, caption="Selected Image", width=300)
164
+ except Exception as e:
165
+ st.error(f"Error loading image: {str(e)}")
166
+
167
+ # Analysis results
168
+ if analyze_button and image_source:
169
+ with st.spinner("Analyzing..."):
170
+ try:
171
+ result, inference_time, annotated_image = process_image(image, prompt, model, processor, device, torch_dtype)
172
+ st.session_state.result = result
173
+ st.session_state.inference_time = inference_time
174
+ st.session_state.annotated_image = annotated_image
175
+ except Exception as e:
176
+ st.error(f"Error: {str(e)}")
177
+
178
+ if st.session_state.result:
179
+ st.success("Analysis Complete!")
180
+
181
+ # Display the annotated image
182
+ st.image(st.session_state.annotated_image, caption="Analyzed Image with Detections", use_container_width=True)
183
+
184
+ # Display raw results and inference time
185
+ st.markdown("**Raw Results:**")
186
+ st.json(st.session_state.result)
187
+ st.markdown(f"*Inference time: {st.session_state.inference_time:.2f} seconds*")
188
+
189
+ # Example images section
190
+ if input_option == "Use example image":
191
+ st.markdown("### Example Images")
192
+ example_images = [f for f in os.listdir("images") if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
193
+
194
+ if example_images:
195
+ # Create grid of images
196
+ cols = st.columns(4) # Adjust number of columns as needed
197
+ for idx, img_name in enumerate(example_images):
198
+ with cols[idx % 4]:
199
+ img_path = os.path.join("images", img_name)
200
+ img = Image.open(img_path)
201
+ img.thumbnail((150, 150))
202
+
203
+ # Make image clickable
204
+ if st.button(
205
+ "📷",
206
+ key=f"img_{idx}",
207
+ help=img_name,
208
+ use_container_width=True
209
+ ):
210
+ st.session_state.selected_image = img_path
211
+ st.rerun()
212
+
213
+ # Display image with conditional styling
214
+ st.image(
215
+ img,
216
+ caption=img_name,
217
+ use_container_width=True,
218
+ )
219
+ else:
220
+ st.error("No example images found in the 'images' directory")
221
+
222
+ if __name__ == "__main__":
223
  main()