salsarra commited on
Commit
067a7ae
·
verified ·
1 Parent(s): 2232a70

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +332 -0
app.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tensorflow as tf
3
+ from tf_keras import models, layers
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, TFAutoModelForQuestionAnswering
5
+ import gradio as gr
6
+ import re
7
+
8
+ # Check if GPU is available and use it if possible
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+ # Load the models and tokenizers
12
+ qa_model_name = 'salsarra/ConfliBERT-QA'
13
+ qa_model = TFAutoModelForQuestionAnswering.from_pretrained(qa_model_name)
14
+ qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
15
+
16
+ ner_model_name = 'eventdata-utd/conflibert-named-entity-recognition'
17
+ ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name).to(device)
18
+ ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name)
19
+
20
+ clf_model_name = 'eventdata-utd/conflibert-binary-classification'
21
+ clf_model = AutoModelForSequenceClassification.from_pretrained(clf_model_name).to(device)
22
+ clf_tokenizer = AutoTokenizer.from_pretrained(clf_model_name)
23
+
24
+ multi_clf_model_name = 'eventdata-utd/conflibert-satp-relevant-multilabel'
25
+ multi_clf_model = AutoModelForSequenceClassification.from_pretrained(multi_clf_model_name).to(device)
26
+ multi_clf_tokenizer = AutoTokenizer.from_pretrained(multi_clf_model_name)
27
+
28
+ # Define the class names for text classification
29
+ class_names = ['Negative', 'Positive']
30
+ multi_class_names = ["Armed Assault", "Bombing or Explosion", "Kidnapping", "Other"] # Updated labels
31
+
32
+ # Define the NER labels and colors
33
+ ner_labels = {
34
+ 'Organisation': 'blue',
35
+ 'Person': 'red',
36
+ 'Location': 'green',
37
+ 'Quantity': 'orange',
38
+ 'Weapon': 'purple',
39
+ 'Nationality': 'cyan',
40
+ 'Temporal': 'magenta',
41
+ 'DocumentReference': 'brown',
42
+ 'MilitaryPlatform': 'yellow',
43
+ 'Money': 'pink'
44
+ }
45
+
46
+ def handle_error_message(e, default_limit=512):
47
+ error_message = str(e)
48
+ pattern = re.compile(r"The size of tensor a \((\d+)\) must match the size of tensor b \((\d+)\)")
49
+ match = pattern.search(error_message)
50
+ if match:
51
+ number_1, number_2 = match.groups()
52
+ return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>"
53
+ pattern_qa = re.compile(r"indices\[0,(\d+)\] = \d+ is not in \[0, (\d+)\)")
54
+ match_qa = pattern_qa.search(error_message)
55
+ if match_qa:
56
+ number_1, number_2 = match_qa.groups()
57
+ return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>"
58
+ return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size is larger than model limits of {default_limit}</span>"
59
+
60
+ # Define the functions for each task
61
+ def question_answering(context, question):
62
+ try:
63
+ inputs = qa_tokenizer(question, context, return_tensors='tf', truncation=True)
64
+ outputs = qa_model(inputs)
65
+ answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0]
66
+ answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1
67
+ answer = qa_tokenizer.convert_tokens_to_string(qa_tokenizer.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]))
68
+ return f"<span style='color: green; font-weight: bold;'>{answer}</span>"
69
+ except Exception as e:
70
+ return handle_error_message(e)
71
+
72
+ def replace_unk(tokens):
73
+ return [token.replace('[UNK]', "'") for token in tokens]
74
+
75
+ def named_entity_recognition(text):
76
+ try:
77
+ inputs = ner_tokenizer(text, return_tensors='pt', truncation=True)
78
+ with torch.no_grad():
79
+ outputs = ner_model(**inputs)
80
+ ner_results = outputs.logits.argmax(dim=2).squeeze().tolist()
81
+ tokens = ner_tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze().tolist())
82
+ tokens = replace_unk(tokens)
83
+ entities = []
84
+ seen_labels = set()
85
+ for i in range(len(tokens)):
86
+ token = tokens[i]
87
+ label = ner_model.config.id2label[ner_results[i]].split('-')[-1]
88
+ if token.startswith('##'):
89
+ if entities:
90
+ entities[-1][0] += token[2:]
91
+ else:
92
+ entities.append([token, label])
93
+ if label != 'O':
94
+ seen_labels.add(label)
95
+
96
+ highlighted_text = ""
97
+ for token, label in entities:
98
+ color = ner_labels.get(label, 'black')
99
+ if label != 'O':
100
+ highlighted_text += f"<span style='color: {color}; font-weight: bold;'>{token}</span> "
101
+ else:
102
+ highlighted_text += f"{token} "
103
+
104
+ legend = "<div><strong>NER Tags Found:</strong><ul style='list-style-type: disc; padding-left: 20px;'>"
105
+ for label in seen_labels:
106
+ color = ner_labels.get(label, 'black')
107
+ legend += f"<li style='color: {color}; font-weight: bold; display: inline; margin-right: 10px;'>{label}</li>"
108
+ legend += "</ul></div>"
109
+
110
+ return f"<div>{highlighted_text}</div>{legend}"
111
+ except Exception as e:
112
+ return handle_error_message(e)
113
+
114
+ def text_classification(text):
115
+ try:
116
+ inputs = clf_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device)
117
+ with torch.no_grad():
118
+ outputs = clf_model(**inputs)
119
+ logits = outputs.logits.squeeze().tolist()
120
+ predicted_class = torch.argmax(outputs.logits, dim=1).item()
121
+ confidence = torch.softmax(outputs.logits, dim=1).max().item() * 100
122
+
123
+ if predicted_class == 1: # Positive class
124
+ result = f"<span style='color: green; font-weight: bold;'>Positive: The text is related to conflict, violence, or politics. (Confidence: {confidence:.2f}%)</span>"
125
+ else: # Negative class
126
+ result = f"<span style='color: red; font-weight: bold;'>Negative: The text is not related to conflict, violence, or politics. (Confidence: {confidence:.2f}%)</span>"
127
+ return result
128
+ except Exception as e:
129
+ return handle_error_message(e)
130
+
131
+ def multilabel_classification(text):
132
+ try:
133
+ inputs = multi_clf_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device)
134
+ with torch.no_grad():
135
+ outputs = multi_clf_model(**inputs)
136
+ predicted_classes = torch.sigmoid(outputs.logits).squeeze().tolist()
137
+ if len(predicted_classes) != len(multi_class_names):
138
+ return f"Error: Number of predicted classes ({len(predicted_classes)}) does not match number of class names ({len(multi_class_names)})."
139
+
140
+ results = []
141
+ for i in range(len(predicted_classes)):
142
+ confidence = predicted_classes[i] * 100
143
+ if predicted_classes[i] >= 0.5:
144
+ results.append(f"<span style='color: green; font-weight: bold;'>{multi_class_names[i]} (Confidence: {confidence:.2f}%)</span>")
145
+ else:
146
+ results.append(f"<span style='color: red; font-weight: bold;'>{multi_class_names[i]} (Confidence: {confidence:.2f}%)</span>")
147
+
148
+ return " / ".join(results)
149
+ except Exception as e:
150
+ return handle_error_message(e)
151
+
152
+ # Define the Gradio interface
153
+ def chatbot(task, text=None, context=None, question=None):
154
+ if task == "Question Answering":
155
+ if context and question:
156
+ return question_answering(context, question)
157
+ else:
158
+ return "Please provide both context and question for the Question Answering task."
159
+ elif task == "Named Entity Recognition":
160
+ if text:
161
+ return named_entity_recognition(text)
162
+ else:
163
+ return "Please provide text for the Named Entity Recognition task."
164
+ elif task == "Text Classification":
165
+ if text:
166
+ return text_classification(text)
167
+ else:
168
+ return "Please provide text for the Text Classification task."
169
+ elif task == "Multilabel Classification":
170
+ if text:
171
+ return multilabel_classification(text)
172
+ else:
173
+ return "Please provide text for the Multilabel Classification task."
174
+ else:
175
+ return "Please select a valid task."
176
+
177
+ css = """
178
+ body {
179
+ background-color: #f0f8ff;
180
+ font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
181
+ }
182
+
183
+ h1 {
184
+ color: #2e8b57;
185
+ text-align: center;
186
+ font-size: 2em;
187
+ }
188
+
189
+ h2 {
190
+ color: #ff8c00;
191
+ text-align: center;
192
+ font-size: 1.5em;
193
+ }
194
+
195
+ .gradio-container {
196
+ max-width: 100%;
197
+ margin: 10px auto;
198
+ padding: 10px;
199
+ background-color: #ffffff;
200
+ border-radius: 10px;
201
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
202
+ }
203
+
204
+ .gr-input, .gr-output {
205
+ background-color: #ffffff;
206
+ border: 1px solid #ddd;
207
+ border-radius: 5px;
208
+ padding: 10px;
209
+ font-size: 1em;
210
+ }
211
+
212
+ .gr-title {
213
+ font-size: 1.5em;
214
+ font-weight: bold;
215
+ color: #2e8b57;
216
+ margin-bottom: 10px;
217
+ text-align: center;
218
+ }
219
+
220
+ .gr-description {
221
+ font-size: 1.2em;
222
+ color: #ff8c00;
223
+ margin-bottom: 10px;
224
+ text-align: center;
225
+ }
226
+
227
+ .header {
228
+ display: flex;
229
+ justify-content: space-between;
230
+ align-items: center;
231
+ padding: 10px;
232
+ flex-wrap: wrap;
233
+ }
234
+
235
+ .header-title-left, .header-title-right, .header-title-center {
236
+ flex: 1 1 30%;
237
+ text-align: center;
238
+ }
239
+
240
+ .header-title-left a, .header-title-center a, .header-title-right a {
241
+ color: inherit;
242
+ text-decoration: none;
243
+ font-size: 1em;
244
+ display: block;
245
+ }
246
+
247
+ .header-title-center a {
248
+ font-size: 4em; /* Increased font size */
249
+ font-weight: bold; /* Made text bold */
250
+ color: darkorange; /* Darker orange color */
251
+ }
252
+
253
+ .header-title-left a {
254
+ color: green; /* Changed color to green */
255
+ font-weight: bold; /* Made text bold */
256
+ font-size: 1.3em; /* Increased font size */
257
+ }
258
+
259
+ .header-title-right a {
260
+ color: green; /* Changed color to green */
261
+ font-weight: bold; /* Made text bold */
262
+ font-size: 1.3em; /* Increased font size */
263
+ }
264
+
265
+ .gr-button {
266
+ background-color: #ff8c00;
267
+ color: white;
268
+ border: none;
269
+ padding: 10px 20px;
270
+ font-size: 1em;
271
+ border-radius: 5px;
272
+ cursor: pointer;
273
+ }
274
+
275
+ .gr-button:hover {
276
+ background-color: #ff4500;
277
+ }
278
+
279
+ .footer {
280
+ text-align: center;
281
+ margin-top: 10px;
282
+ font-size: 0.9em;
283
+ color: #666;
284
+ width: 100%;
285
+ }
286
+
287
+ .footer a {
288
+ color: #2e8b57;
289
+ font-weight: bold;
290
+ text-decoration: none;
291
+ }
292
+
293
+ .footer a:hover {
294
+ text-decoration: underline;
295
+ }
296
+ """
297
+
298
+ with gr.Blocks(css=css) as demo:
299
+ with gr.Row(elem_id="header"):
300
+ gr.Markdown("<div class='header-title-left'><a href='https://eventdata.utdallas.edu/'>UTD Event Data</a></div>", elem_id="header-title-left")
301
+ gr.Markdown("<div class='header-title-center'><a href='https://eventdata.utdallas.edu/conflibert/'>ConfliBERT</a></div>", elem_id="header-title-center")
302
+ gr.Markdown("<div class='header-title-right'><a href='https://www.utdallas.edu/'>University of Texas at Dallas</a></div>", elem_id="header-title-right")
303
+
304
+ gr.Markdown("Select a task and provide the necessary inputs.")
305
+
306
+ task = gr.Dropdown(choices=["Question Answering", "Named Entity Recognition", "Text Classification", "Multilabel Classification"], label="Select Task")
307
+
308
+ with gr.Row():
309
+ text_input = gr.Textbox(lines=5, placeholder="Enter the text here...", label="Text")
310
+ context_input = gr.Textbox(lines=5, placeholder="Enter the context here...", label="Context", visible=False)
311
+ question_input = gr.Textbox(lines=2, placeholder="Enter your question here...", label="Question", visible=False)
312
+
313
+ output = gr.HTML(label="Output")
314
+
315
+ def update_inputs(task):
316
+ if task == "Question Answering":
317
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
318
+ else:
319
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
320
+
321
+ task.change(fn=update_inputs, inputs=task, outputs=[text_input, context_input, question_input])
322
+
323
+ def chatbot_interface(task, text, context, question):
324
+ result = chatbot(task, text, context, question)
325
+ return result
326
+
327
+ submit_button = gr.Button("Submit", elem_id="gr-button")
328
+ submit_button.click(fn=chatbot_interface, inputs=[task, text_input, context_input, question_input], outputs=output)
329
+
330
+ gr.Markdown("<div class='footer'>Developed By: <a href='https://www.linkedin.com/in/sultan-alsarra-phd-56977a63/' target='_blank'>Sultan Alsarra</a></div>")
331
+
332
+ demo.launch(share=True)