Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,8 @@ from PIL import Image
|
|
10 |
from torch.utils.data import DataLoader
|
11 |
from tqdm import tqdm
|
12 |
|
|
|
|
|
13 |
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
14 |
|
15 |
|
@@ -30,42 +32,50 @@ def encode_image_to_base64(image):
|
|
30 |
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
31 |
|
32 |
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
"""Calls OpenAI's GPT-4o-mini with the query and image data."""
|
35 |
|
36 |
if api_key and api_key.startswith("sk"):
|
37 |
try:
|
38 |
from openai import OpenAI
|
39 |
-
|
40 |
-
base64_images = [encode_image_to_base64(image[0]) for image in images]
|
41 |
client = OpenAI(api_key=api_key.strip())
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
Give detailed and extensive answers, only containing info in the pages you are given.
|
47 |
-
You can answer using information contained in plots and figures if necessary.
|
48 |
-
Answer in the same language as the query.
|
49 |
-
|
50 |
-
Query: {query}
|
51 |
-
PDF pages:
|
52 |
"""
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
response = client.chat.completions.create(
|
55 |
model="gpt-4o-mini",
|
56 |
messages=[
|
57 |
{
|
58 |
"role": "user",
|
59 |
-
"content":
|
60 |
-
{
|
61 |
-
"type": "text",
|
62 |
-
"text": PROMPT.format(query=query)
|
63 |
-
}] + [{
|
64 |
-
"type": "image_url",
|
65 |
-
"image_url": {
|
66 |
-
"url": f"data:image/jpeg;base64,{im}"
|
67 |
-
},
|
68 |
-
} for im in base64_images]
|
69 |
}
|
70 |
],
|
71 |
max_tokens=500,
|
@@ -77,7 +87,7 @@ def query_gpt4o_mini(query, images, api_key):
|
|
77 |
return "Enter your OpenAI API key to get a custom response"
|
78 |
|
79 |
|
80 |
-
def search(query: str, ds, images, k, api_key):
|
81 |
k = min(k, len(ds))
|
82 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
83 |
if device != model.device:
|
@@ -95,7 +105,9 @@ def search(query: str, ds, images, k, api_key):
|
|
95 |
|
96 |
results = []
|
97 |
for idx in top_k_indices:
|
98 |
-
|
|
|
|
|
99 |
|
100 |
# Generate response from GPT-4o-mini
|
101 |
ai_response = query_gpt4o_mini(query, results, api_key)
|
@@ -103,22 +115,62 @@ def search(query: str, ds, images, k, api_key):
|
|
103 |
return results, ai_response
|
104 |
|
105 |
|
106 |
-
def index(files, ds):
|
107 |
print("Converting files")
|
108 |
-
images = convert_files(files)
|
109 |
print(f"Files converted with {len(images)} images.")
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
images = []
|
|
|
|
|
116 |
for f in files:
|
117 |
-
|
|
|
|
|
|
|
118 |
|
119 |
if len(images) >= 500:
|
120 |
raise gr.Error("The number of images in the dataset should be less than 500.")
|
121 |
-
return images
|
122 |
|
123 |
|
124 |
def index_gpu(images, ds):
|
@@ -141,7 +193,7 @@ def index_gpu(images, ds):
|
|
141 |
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
|
142 |
embeddings_doc = model(**batch_doc)
|
143 |
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
144 |
-
return
|
145 |
|
146 |
|
147 |
|
@@ -166,6 +218,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
166 |
api_key = gr.Textbox(placeholder="Enter your OpenAI KEY here (optional)", label="API key")
|
167 |
embeds = gr.State(value=[])
|
168 |
imgs = gr.State(value=[])
|
|
|
169 |
|
170 |
with gr.Column(scale=3):
|
171 |
gr.Markdown("## 2️⃣ Search")
|
@@ -178,8 +231,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
178 |
output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
|
179 |
output_text = gr.Textbox(label="AI Response", placeholder="Generated response based on retrieved documents")
|
180 |
|
181 |
-
convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
|
182 |
-
search_button.click(search, inputs=[query, embeds, imgs, k, api_key], outputs=[output_gallery, output_text])
|
183 |
|
184 |
if __name__ == "__main__":
|
185 |
-
demo.queue(max_size=5).launch(debug=True)
|
|
|
10 |
from torch.utils.data import DataLoader
|
11 |
from tqdm import tqdm
|
12 |
|
13 |
+
from pqdm.processes import pqdm
|
14 |
+
|
15 |
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
16 |
|
17 |
|
|
|
32 |
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
33 |
|
34 |
|
35 |
+
DEFAULT_SYSTEM_PROMPT = """
|
36 |
+
You are a smart assistant designed to answer questions about a PDF document.
|
37 |
+
You are given relevant information in the form of PDF pages preceded by their metadata (PDF title, page number, surrounding context).
|
38 |
+
Use them to construct a short response to the question, and cite your sources (page number, pdf title).
|
39 |
+
If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
|
40 |
+
Give detailed and extensive answers, only containing info in the pages you are given.
|
41 |
+
You can answer using information contained in plots and figures if necessary.
|
42 |
+
Answer in the same language as the query.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def query_gpt4o_mini(query, images, api_key, system_prompt=DEFAULT_SYSTEM_PROMPT):
|
46 |
"""Calls OpenAI's GPT-4o-mini with the query and image data."""
|
47 |
|
48 |
if api_key and api_key.startswith("sk"):
|
49 |
try:
|
50 |
from openai import OpenAI
|
51 |
+
|
|
|
52 |
client = OpenAI(api_key=api_key.strip())
|
53 |
+
prompt = f"""
|
54 |
+
{system_prompt}
|
55 |
+
Query: {query}
|
56 |
+
PDF pages:
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
"""
|
58 |
+
|
59 |
+
messages = [{"type": "text", "text": prompt}]
|
60 |
+
for im, capt in images:
|
61 |
+
if capt is not None:
|
62 |
+
messages.append({
|
63 |
+
"type": "text",
|
64 |
+
"text": capt
|
65 |
+
})
|
66 |
+
messages.append({
|
67 |
+
"type": "image_url",
|
68 |
+
"image_url": {
|
69 |
+
"url": f"data:image/jpeg;base64,{encode_image_to_base64(im)}"
|
70 |
+
},
|
71 |
+
})
|
72 |
+
|
73 |
response = client.chat.completions.create(
|
74 |
model="gpt-4o-mini",
|
75 |
messages=[
|
76 |
{
|
77 |
"role": "user",
|
78 |
+
"content": messages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
}
|
80 |
],
|
81 |
max_tokens=500,
|
|
|
87 |
return "Enter your OpenAI API key to get a custom response"
|
88 |
|
89 |
|
90 |
+
def search(query: str, ds, images, metadatas, k, api_key):
|
91 |
k = min(k, len(ds))
|
92 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
93 |
if device != model.device:
|
|
|
105 |
|
106 |
results = []
|
107 |
for idx in top_k_indices:
|
108 |
+
img = images[idx]
|
109 |
+
meta = metadatas[idx]
|
110 |
+
results.append((img, f"Document: {meta['title']}, Page: {meta['page']}, Context: {meta['context']}"))
|
111 |
|
112 |
# Generate response from GPT-4o-mini
|
113 |
ai_response = query_gpt4o_mini(query, results, api_key)
|
|
|
115 |
return results, ai_response
|
116 |
|
117 |
|
118 |
+
def index(files, ds, api_key):
|
119 |
print("Converting files")
|
120 |
+
images, metadatas = convert_files(files, api_key)
|
121 |
print(f"Files converted with {len(images)} images.")
|
122 |
+
ds = index_gpu(images, ds)
|
123 |
+
print(f"Indexed {len(ds)} images.")
|
124 |
+
return f"Uploaded and converted {len(images)} pages", ds, images, metadatas
|
125 |
+
|
126 |
+
DEFAULT_CONTEXT_PROMPT = """
|
127 |
+
You are a smart assistant designed to extract context of PDF pages.
|
128 |
+
Give detailed and extensive answers, only containing info in the pages you are given.
|
129 |
+
You can answer using information contained in plots and figures if necessary.
|
130 |
+
Answer in the same language as the query.
|
131 |
+
"""
|
132 |
+
|
133 |
+
def extract_context(images, api_key, window=10):
|
134 |
+
"""Extracts context from images."""
|
135 |
+
prompt = "Give the general context about these pages."
|
136 |
+
window_contexts = []
|
137 |
+
|
138 |
+
args = [(prompt, (images[max(i-window+1, 0):i+1], None), api_key, DEFAULT_CONTEXT_PROMPT)
|
139 |
+
for i in range(0, len(images), window)]
|
140 |
+
window_contexts = pqdm(args, query_gpt4o_mini, n_jobs=8)
|
141 |
+
|
142 |
+
# for i in tqdm(range(0, len(images), window), desc="Extracting context", total=len(images)//window):
|
143 |
+
# window_images = images[max(i-window+1, 0):i+1]
|
144 |
+
# window_images = [(image, None) for image in window_images]
|
145 |
+
# window_contexts.append(query_gpt4o_mini(prompt, window_images, api_key, system_prompt=DEFAULT_CONTEXT_PROMPT))
|
146 |
+
|
147 |
+
contexts = []
|
148 |
+
for i in range(len(images)):
|
149 |
+
context = window_contexts[i//window]
|
150 |
+
contexts.append(context)
|
151 |
+
|
152 |
+
assert len(contexts) == len(images)
|
153 |
+
return contexts
|
154 |
+
|
155 |
+
def extract_metadata(file, images, api_key, window=10):
|
156 |
+
"""Extracts metadata from pdfs. Extract page number, file name, and authors."""
|
157 |
+
title = file.split("/")[-1]
|
158 |
+
contexts = extract_context(images, api_key, window=window)
|
159 |
+
return [{"title": title, "page": i+1, "context": contexts[i]} for i in range(len(images))]
|
160 |
+
|
161 |
+
def convert_files(files, api_key):
|
162 |
images = []
|
163 |
+
metadatas = []
|
164 |
+
|
165 |
for f in files:
|
166 |
+
file_images = convert_from_path(f, thread_count=4)
|
167 |
+
file_metadatas = extract_metadata(f, file_images, api_key)
|
168 |
+
images.extend(file_images)
|
169 |
+
metadatas.extend(file_metadatas)
|
170 |
|
171 |
if len(images) >= 500:
|
172 |
raise gr.Error("The number of images in the dataset should be less than 500.")
|
173 |
+
return images, metadatas
|
174 |
|
175 |
|
176 |
def index_gpu(images, ds):
|
|
|
193 |
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
|
194 |
embeddings_doc = model(**batch_doc)
|
195 |
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
196 |
+
return ds
|
197 |
|
198 |
|
199 |
|
|
|
218 |
api_key = gr.Textbox(placeholder="Enter your OpenAI KEY here (optional)", label="API key")
|
219 |
embeds = gr.State(value=[])
|
220 |
imgs = gr.State(value=[])
|
221 |
+
metadatas = gr.State(value=[])
|
222 |
|
223 |
with gr.Column(scale=3):
|
224 |
gr.Markdown("## 2️⃣ Search")
|
|
|
231 |
output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
|
232 |
output_text = gr.Textbox(label="AI Response", placeholder="Generated response based on retrieved documents")
|
233 |
|
234 |
+
convert_button.click(index, inputs=[file, embeds, api_key], outputs=[message, embeds, imgs, metadatas])
|
235 |
+
search_button.click(search, inputs=[query, embeds, imgs, metadatas, k, api_key], outputs=[output_gallery, output_text])
|
236 |
|
237 |
if __name__ == "__main__":
|
238 |
+
demo.queue(max_size=5).launch(debug=True, server_name="0.0.0.0")
|