nielsr HF staff commited on
Commit
c62a436
·
1 Parent(s): d1536f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForCausalLM, BlipForConditionalGeneration
3
  import torch
4
 
5
  torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
@@ -11,6 +11,9 @@ git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
11
  blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
12
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
13
 
 
 
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
  git_model.to(device)
@@ -31,7 +34,9 @@ def generate_captions(image):
31
 
32
  caption_blip = generate_caption(blip_processor, blip_model, image)
33
 
34
- return caption_git, caption_blip
 
 
35
 
36
 
37
  examples = [["cats.jpg"], ["stop_sign.png"]]
@@ -42,7 +47,7 @@ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2102.033
42
 
43
  interface = gr.Interface(fn=generate_captions,
44
  inputs=gr.inputs.Image(type="pil"),
45
- outputs=[gr.outputs.Textbox(label="Caption generated by GIT"), gr.outputs.Textbox(label="Caption generated by BLIP")],
46
  examples=examples,
47
  title=title,
48
  description=description,
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration
3
  import torch
4
 
5
  torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
 
11
  blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
12
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
13
 
14
+ vitgpt_processor = AutoImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
15
+ vitgpt_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
16
+
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
  git_model.to(device)
 
34
 
35
  caption_blip = generate_caption(blip_processor, blip_model, image)
36
 
37
+ caption_vitgpt = generate_caption(vitgpt_processor, vitgpt_model, image)
38
+
39
+ return caption_git, caption_blip, caption_vitgpt
40
 
41
 
42
  examples = [["cats.jpg"], ["stop_sign.png"]]
 
47
 
48
  interface = gr.Interface(fn=generate_captions,
49
  inputs=gr.inputs.Image(type="pil"),
50
+ outputs=[gr.outputs.Textbox(label="Caption generated by GIT"), gr.outputs.Textbox(label="Caption generated by BLIP"), gr.outputs.Textbox(label="Caption generated by ViT+GPT-2")],
51
  examples=examples,
52
  title=title,
53
  description=description,