howard-hou commited on
Commit
1d2fc64
·
verified ·
1 Parent(s): 7fabc1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -17
app.py CHANGED
@@ -147,32 +147,28 @@ def pil_image_to_base64(pil_image):
147
  base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
148
  return base64_image
149
 
150
- image_cache = {}
151
  ln0_weight = model.w['blocks.0.ln0.weight'].to(torch.float32).to(device)
152
  ln0_bias = model.w['blocks.0.ln0.bias'].to(torch.float32).to(device)
153
- def compute_image_state(image):
154
- base64_image = pil_image_to_base64(image)
155
- if base64_image in image_cache:
156
- image_state = image_cache[base64_image]
157
- else:
158
- image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
159
- image = image.to(device)
160
- image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
161
- # apply layer norm to image feature, very important
162
- image_features = F.layer_norm(image_features,
163
- (image_features.shape[-1],),
164
- weight=ln0_weight,
165
- bias=ln0_bias)
166
- _, image_state = model.forward(embs=image_features, state=None)
167
- image_cache[base64_image] = image_state
168
  return image_state
169
 
170
  def chatbot(image, question):
171
  if image is None:
172
  yield "Please upload an image."
173
  return
174
- image_state = compute_image_state(image)
175
  input_text = generate_prompt(question)
 
 
176
  for output in generate(input_text, image_state):
177
  yield output
178
 
 
147
  base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
148
  return base64_image
149
 
150
+
151
  ln0_weight = model.w['blocks.0.ln0.weight'].to(torch.float32).to(device)
152
  ln0_bias = model.w['blocks.0.ln0.bias'].to(torch.float32).to(device)
153
+ def compute_image_state(image, prefix_tokens):
154
+ image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
155
+ image = image.to(device)
156
+ image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
157
+ # apply layer norm to image feature, very important
158
+ image_features = F.layer_norm(image_features,
159
+ (image_features.shape[-1],),
160
+ weight=ln0_weight,
161
+ bias=ln0_bias)
162
+ _, image_state = model.forward(tokens=prefix_tokens, embs=image_features, state=None)
 
 
 
 
 
163
  return image_state
164
 
165
  def chatbot(image, question):
166
  if image is None:
167
  yield "Please upload an image."
168
  return
 
169
  input_text = generate_prompt(question)
170
+ prefix_tokens = pipeline.encode(input_text)[-ctx_limit:]
171
+ image_state = compute_image_state(image, prefix_tokens)
172
  for output in generate(input_text, image_state):
173
  yield output
174