manu commited on
Commit
4f3a756
·
verified ·
1 Parent(s): 068f2e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -39
app.py CHANGED
@@ -10,6 +10,8 @@ from PIL import Image
10
  from torch.utils.data import DataLoader
11
  from tqdm import tqdm
12
 
 
 
13
  from colpali_engine.models import ColQwen2, ColQwen2Processor
14
 
15
 
@@ -30,42 +32,50 @@ def encode_image_to_base64(image):
30
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
31
 
32
 
33
- def query_gpt4o_mini(query, images, api_key):
 
 
 
 
 
 
 
 
 
 
34
  """Calls OpenAI's GPT-4o-mini with the query and image data."""
35
 
36
  if api_key and api_key.startswith("sk"):
37
  try:
38
  from openai import OpenAI
39
-
40
- base64_images = [encode_image_to_base64(image[0]) for image in images]
41
  client = OpenAI(api_key=api_key.strip())
42
- PROMPT = """
43
- You are a smart assistant designed to answer questions about a PDF document.
44
- You are given relevant information in the form of PDF pages. Use them to construct a short response to the question, and cite your sources (page numbers, etc).
45
- If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
46
- Give detailed and extensive answers, only containing info in the pages you are given.
47
- You can answer using information contained in plots and figures if necessary.
48
- Answer in the same language as the query.
49
-
50
- Query: {query}
51
- PDF pages:
52
  """
53
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  response = client.chat.completions.create(
55
  model="gpt-4o-mini",
56
  messages=[
57
  {
58
  "role": "user",
59
- "content": [
60
- {
61
- "type": "text",
62
- "text": PROMPT.format(query=query)
63
- }] + [{
64
- "type": "image_url",
65
- "image_url": {
66
- "url": f"data:image/jpeg;base64,{im}"
67
- },
68
- } for im in base64_images]
69
  }
70
  ],
71
  max_tokens=500,
@@ -77,7 +87,7 @@ def query_gpt4o_mini(query, images, api_key):
77
  return "Enter your OpenAI API key to get a custom response"
78
 
79
 
80
- def search(query: str, ds, images, k, api_key):
81
  k = min(k, len(ds))
82
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
83
  if device != model.device:
@@ -95,7 +105,9 @@ def search(query: str, ds, images, k, api_key):
95
 
96
  results = []
97
  for idx in top_k_indices:
98
- results.append((images[idx], f"Page {idx}"))
 
 
99
 
100
  # Generate response from GPT-4o-mini
101
  ai_response = query_gpt4o_mini(query, results, api_key)
@@ -103,22 +115,62 @@ def search(query: str, ds, images, k, api_key):
103
  return results, ai_response
104
 
105
 
106
- def index(files, ds):
107
  print("Converting files")
108
- images = convert_files(files)
109
  print(f"Files converted with {len(images)} images.")
110
- return index_gpu(images, ds)
111
-
112
-
113
-
114
- def convert_files(files):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  images = []
 
 
116
  for f in files:
117
- images.extend(convert_from_path(f, thread_count=4))
 
 
 
118
 
119
  if len(images) >= 500:
120
  raise gr.Error("The number of images in the dataset should be less than 500.")
121
- return images
122
 
123
 
124
  def index_gpu(images, ds):
@@ -141,7 +193,7 @@ def index_gpu(images, ds):
141
  batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
142
  embeddings_doc = model(**batch_doc)
143
  ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
144
- return f"Uploaded and converted {len(images)} pages", ds, images
145
 
146
 
147
 
@@ -166,6 +218,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
166
  api_key = gr.Textbox(placeholder="Enter your OpenAI KEY here (optional)", label="API key")
167
  embeds = gr.State(value=[])
168
  imgs = gr.State(value=[])
 
169
 
170
  with gr.Column(scale=3):
171
  gr.Markdown("## 2️⃣ Search")
@@ -178,8 +231,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
178
  output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
179
  output_text = gr.Textbox(label="AI Response", placeholder="Generated response based on retrieved documents")
180
 
181
- convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
182
- search_button.click(search, inputs=[query, embeds, imgs, k, api_key], outputs=[output_gallery, output_text])
183
 
184
  if __name__ == "__main__":
185
- demo.queue(max_size=5).launch(debug=True)
 
10
  from torch.utils.data import DataLoader
11
  from tqdm import tqdm
12
 
13
+ from pqdm.processes import pqdm
14
+
15
  from colpali_engine.models import ColQwen2, ColQwen2Processor
16
 
17
 
 
32
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
33
 
34
 
35
+ DEFAULT_SYSTEM_PROMPT = """
36
+ You are a smart assistant designed to answer questions about a PDF document.
37
+ You are given relevant information in the form of PDF pages preceded by their metadata (PDF title, page number, surrounding context).
38
+ Use them to construct a short response to the question, and cite your sources (page number, pdf title).
39
+ If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
40
+ Give detailed and extensive answers, only containing info in the pages you are given.
41
+ You can answer using information contained in plots and figures if necessary.
42
+ Answer in the same language as the query.
43
+ """
44
+
45
+ def query_gpt4o_mini(query, images, api_key, system_prompt=DEFAULT_SYSTEM_PROMPT):
46
  """Calls OpenAI's GPT-4o-mini with the query and image data."""
47
 
48
  if api_key and api_key.startswith("sk"):
49
  try:
50
  from openai import OpenAI
51
+
 
52
  client = OpenAI(api_key=api_key.strip())
53
+ prompt = f"""
54
+ {system_prompt}
55
+ Query: {query}
56
+ PDF pages:
 
 
 
 
 
 
57
  """
58
+
59
+ messages = [{"type": "text", "text": prompt}]
60
+ for im, capt in images:
61
+ if capt is not None:
62
+ messages.append({
63
+ "type": "text",
64
+ "text": capt
65
+ })
66
+ messages.append({
67
+ "type": "image_url",
68
+ "image_url": {
69
+ "url": f"data:image/jpeg;base64,{encode_image_to_base64(im)}"
70
+ },
71
+ })
72
+
73
  response = client.chat.completions.create(
74
  model="gpt-4o-mini",
75
  messages=[
76
  {
77
  "role": "user",
78
+ "content": messages
 
 
 
 
 
 
 
 
 
79
  }
80
  ],
81
  max_tokens=500,
 
87
  return "Enter your OpenAI API key to get a custom response"
88
 
89
 
90
+ def search(query: str, ds, images, metadatas, k, api_key):
91
  k = min(k, len(ds))
92
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
93
  if device != model.device:
 
105
 
106
  results = []
107
  for idx in top_k_indices:
108
+ img = images[idx]
109
+ meta = metadatas[idx]
110
+ results.append((img, f"Document: {meta['title']}, Page: {meta['page']}, Context: {meta['context']}"))
111
 
112
  # Generate response from GPT-4o-mini
113
  ai_response = query_gpt4o_mini(query, results, api_key)
 
115
  return results, ai_response
116
 
117
 
118
+ def index(files, ds, api_key):
119
  print("Converting files")
120
+ images, metadatas = convert_files(files, api_key)
121
  print(f"Files converted with {len(images)} images.")
122
+ ds = index_gpu(images, ds)
123
+ print(f"Indexed {len(ds)} images.")
124
+ return f"Uploaded and converted {len(images)} pages", ds, images, metadatas
125
+
126
+ DEFAULT_CONTEXT_PROMPT = """
127
+ You are a smart assistant designed to extract context of PDF pages.
128
+ Give detailed and extensive answers, only containing info in the pages you are given.
129
+ You can answer using information contained in plots and figures if necessary.
130
+ Answer in the same language as the query.
131
+ """
132
+
133
+ def extract_context(images, api_key, window=10):
134
+ """Extracts context from images."""
135
+ prompt = "Give the general context about these pages."
136
+ window_contexts = []
137
+
138
+ args = [(prompt, (images[max(i-window+1, 0):i+1], None), api_key, DEFAULT_CONTEXT_PROMPT)
139
+ for i in range(0, len(images), window)]
140
+ window_contexts = pqdm(args, query_gpt4o_mini, n_jobs=8)
141
+
142
+ # for i in tqdm(range(0, len(images), window), desc="Extracting context", total=len(images)//window):
143
+ # window_images = images[max(i-window+1, 0):i+1]
144
+ # window_images = [(image, None) for image in window_images]
145
+ # window_contexts.append(query_gpt4o_mini(prompt, window_images, api_key, system_prompt=DEFAULT_CONTEXT_PROMPT))
146
+
147
+ contexts = []
148
+ for i in range(len(images)):
149
+ context = window_contexts[i//window]
150
+ contexts.append(context)
151
+
152
+ assert len(contexts) == len(images)
153
+ return contexts
154
+
155
+ def extract_metadata(file, images, api_key, window=10):
156
+ """Extracts metadata from pdfs. Extract page number, file name, and authors."""
157
+ title = file.split("/")[-1]
158
+ contexts = extract_context(images, api_key, window=window)
159
+ return [{"title": title, "page": i+1, "context": contexts[i]} for i in range(len(images))]
160
+
161
+ def convert_files(files, api_key):
162
  images = []
163
+ metadatas = []
164
+
165
  for f in files:
166
+ file_images = convert_from_path(f, thread_count=4)
167
+ file_metadatas = extract_metadata(f, file_images, api_key)
168
+ images.extend(file_images)
169
+ metadatas.extend(file_metadatas)
170
 
171
  if len(images) >= 500:
172
  raise gr.Error("The number of images in the dataset should be less than 500.")
173
+ return images, metadatas
174
 
175
 
176
  def index_gpu(images, ds):
 
193
  batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
194
  embeddings_doc = model(**batch_doc)
195
  ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
196
+ return ds
197
 
198
 
199
 
 
218
  api_key = gr.Textbox(placeholder="Enter your OpenAI KEY here (optional)", label="API key")
219
  embeds = gr.State(value=[])
220
  imgs = gr.State(value=[])
221
+ metadatas = gr.State(value=[])
222
 
223
  with gr.Column(scale=3):
224
  gr.Markdown("## 2️⃣ Search")
 
231
  output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
232
  output_text = gr.Textbox(label="AI Response", placeholder="Generated response based on retrieved documents")
233
 
234
+ convert_button.click(index, inputs=[file, embeds, api_key], outputs=[message, embeds, imgs, metadatas])
235
+ search_button.click(search, inputs=[query, embeds, imgs, metadatas, k, api_key], outputs=[output_gallery, output_text])
236
 
237
  if __name__ == "__main__":
238
+ demo.queue(max_size=5).launch(debug=True, server_name="0.0.0.0")