salsarra commited on
Commit
b3a5224
·
verified ·
1 Parent(s): ca48898

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -0
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import gradio as gr
4
+ import re
5
+
6
+ # Check if GPU is available and use it if possible
7
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
+
9
+ # Load the model and tokenizer
10
+ mme_model_name = 'sperkins2116/ConfliBERT-BC-MMEs'
11
+ mme_model = AutoModelForSequenceClassification.from_pretrained(mme_model_name).to(device)
12
+ mme_tokenizer = AutoTokenizer.from_pretrained(mme_model_name)
13
+
14
+ # Define the class names for text classification
15
+ class_names = ['Negative', 'Positive']
16
+
17
+ def handle_error_message(e, default_limit=512):
18
+ error_message = str(e)
19
+ pattern = re.compile(r"The size of tensor a \((\d+)\) must match the size of tensor b \((\d+)\)")
20
+ match = pattern.search(error_message)
21
+ if match:
22
+ number_1, number_2 = match.groups()
23
+ return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>"
24
+ return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size is larger than model limits of {default_limit}</span>"
25
+
26
+ def mme_classification(text):
27
+ try:
28
+ inputs = mme_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device)
29
+ with torch.no_grad():
30
+ outputs = mme_model(**inputs)
31
+ logits = outputs.logits.squeeze().tolist()
32
+ predicted_class = torch.argmax(outputs.logits, dim=1).item()
33
+ confidence = torch.softmax(outputs.logits, dim=1).max().item() * 100
34
+
35
+ if predicted_class == 1: # Positive class
36
+ result = f"<span style='color: green; font-weight: bold;'>Positive: The text contains evidence of a multinational military exercise. (Confidence: {confidence:.2f}%)</span>"
37
+ else: # Negative class
38
+ result = f"<span style='color: red; font-weight: bold;'>Negative: The text does not contain evidence of a multinational military exercise. (Confidence: {confidence:.2f}%)</span>"
39
+ return result
40
+ except Exception as e:
41
+ return handle_error_message(e)
42
+
43
+ # Define the Gradio interface
44
+ def chatbot(text):
45
+ return mme_classification(text)
46
+
47
+ css = """
48
+ body {
49
+ background-color: #f0f8ff;
50
+ font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
51
+ color: black; /* Ensure text is visible in dark mode */
52
+ }
53
+
54
+ h1 {
55
+ color: #2e8b57;
56
+ text-align: center;
57
+ font-size: 2em;
58
+ }
59
+
60
+ h2 {
61
+ color: #ff8c00;
62
+ text-align: center;
63
+ font-size: 1.5em;
64
+ }
65
+
66
+ .gradio-container {
67
+ max-width: 100%;
68
+ margin: 10px auto;
69
+ padding: 10px;
70
+ background-color: #ffffff;
71
+ border-radius: 10px;
72
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
73
+ }
74
+
75
+ .gr-input, .gr-output {
76
+ background-color: #ffffff;
77
+ border: 1px solid #ddd;
78
+ border-radius: 5px;
79
+ padding: 10px;
80
+ font-size: 1em;
81
+ color: black; /* Ensure text is visible in dark mode */
82
+ }
83
+
84
+ .gr-title {
85
+ font-size: 1.5em;
86
+ font-weight: bold;
87
+ color: #2e8b57;
88
+ margin-bottom: 10px;
89
+ text-align: center;
90
+ }
91
+
92
+ .gr-description {
93
+ font-size: 1.2em;
94
+ color: #ff8c00;
95
+ margin-bottom: 10px;
96
+ text-align: center;
97
+ }
98
+
99
+ .header {
100
+ display: flex;
101
+ justify-content: center;
102
+ align-items: center;
103
+ padding: 10px;
104
+ flex-wrap: wrap;
105
+ }
106
+
107
+ .header-title-center a {
108
+ font-size: 4em; /* Increased font size */
109
+ font-weight: bold; /* Made text bold */
110
+ color: darkorange; /* Darker orange color */
111
+ text-align: center;
112
+ display: block;
113
+ }
114
+
115
+ .gr-button {
116
+ background-color: #ff8c00;
117
+ color: white;
118
+ border: none;
119
+ padding: 10px 20px;
120
+ font-size: 1em;
121
+ border-radius: 5px;
122
+ cursor: pointer;
123
+ }
124
+
125
+ .gr-button:hover {
126
+ background-color: #ff4500;
127
+ }
128
+
129
+ .footer {
130
+ text-align: center;
131
+ margin-top: 10px;
132
+ font-size: 0.9em; /* Updated font size */
133
+ color: black; /* Ensure text is visible in dark mode */
134
+ width: 100%;
135
+ }
136
+
137
+ .footer a {
138
+ color: #2e8b57;
139
+ font-weight: bold;
140
+ text-decoration: none;
141
+ }
142
+
143
+ .footer a:hover {
144
+ text-decoration: underline;
145
+ }
146
+
147
+ .footer .inline {
148
+ display: inline;
149
+ color: black; /* Ensure text is visible in dark mode */
150
+ }
151
+ """
152
+
153
+ with gr.Blocks(css=css) as demo:
154
+ with gr.Row(elem_id="header"):
155
+ gr.Markdown("<div class='header-title-center'><a href='https://eventdata.utdallas.edu/conflibert/'>ConfliBERT-MME</a></div>", elem_id="header-title-center")
156
+
157
+ gr.Markdown("<span style='color: black;'>Provide the text for MME Classification.</span>")
158
+
159
+ text_input = gr.Textbox(lines=5, placeholder="Enter the text here...", label="Text")
160
+
161
+ output = gr.HTML(label="Output")
162
+
163
+ submit_button = gr.Button("Submit", elem_id="gr-button")
164
+ submit_button.click(fn=chatbot, inputs=text_input, outputs=output)
165
+
166
+ gr.Markdown("<div class='footer'><a href='https://eventdata.utdallas.edu/'>UTD Event Data</a> | <a href='https://www.utdallas.edu/'>University of Texas at Dallas</a> | <a href='https://www.wvu.edu/'>West Virginia University</a></div>")
167
+ gr.Markdown("<div class='footer'><span class='inline'>Developed By: <a href='https://www.linkedin.com/in/sultan-alsarra-phd-56977a63/' target='_blank'>Sultan Alsarra</a> | Finetuned By: Spencer Perkins</span></div>")
168
+
169
+ demo.launch(share=True)