import torch
import tensorflow as tf
from tf_keras import models, layers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, TFAutoModelForQuestionAnswering
import gradio as gr
import re

# Check if GPU is available and use it if possible
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the model and tokenizer
mme_model_name = 'sperkins2116/ConfliBERT-BC-MMEs'
mme_model = AutoModelForSequenceClassification.from_pretrained(mme_model_name).to(device)
mme_tokenizer = AutoTokenizer.from_pretrained(mme_model_name)

# Define the class names for text classification
class_names = ['Negative', 'Positive']

def handle_error_message(e, default_limit=512):
    error_message = str(e)
    pattern = re.compile(r"The size of tensor a \((\d+)\) must match the size of tensor b \((\d+)\)")
    match = pattern.search(error_message)
    if match:
        number_1, number_2 = match.groups()
        return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>"
    return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size is larger than model limits of {default_limit}</span>"

def mme_classification(text):
    try:
        inputs = mme_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device)
        with torch.no_grad():
            outputs = mme_model(**inputs)
        logits = outputs.logits.squeeze().tolist()
        predicted_class = torch.argmax(outputs.logits, dim=1).item()
        confidence = torch.softmax(outputs.logits, dim=1).max().item() * 100

        if predicted_class == 1:  # Positive class
            result = f"<span style='color: green; font-weight: bold;'>Positive: The text contains evidence of a multinational military exercise. (Confidence: {confidence:.2f}%)</span>"
        else:  # Negative class
            result = f"<span style='color: red; font-weight: bold;'>Negative: The text does not contain evidence of a multinational military exercise. (Confidence: {confidence:.2f}%)</span>"
        return result
    except Exception as e:
        return handle_error_message(e)

# Define the Gradio interface
def chatbot(text):
    return mme_classification(text)

css = """
body {
    background-color: #f0f8ff;
    font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
    color: black; /* Ensure text is visible in dark mode */
}

h1 {
    color: #2e8b57;
    text-align: center;
    font-size: 2em;
}

h2 {
    color: #ff8c00;
    text-align: center;
    font-size: 1.5em;
}

.gradio-container {
    max-width: 100%;
    margin: 10px auto;
    padding: 10px;
    background-color: #ffffff;
    border-radius: 10px;
    box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
}

.gr-input, .gr-output {
    background-color: #ffffff;
    border: 1px solid #ddd;
    border-radius: 5px;
    padding: 10px;
    font-size: 1em;
    color: black; /* Ensure text is visible in dark mode */
}

.gr-title {
    font-size: 1.5em;
    font-weight: bold;
    color: #2e8b57;
    margin-bottom: 10px;
    text-align: center;
}

.gr-description {
    font-size: 1.2em;
    color: #ff8c00;
    margin-bottom: 10px;
    text-align: center;
}

.header {
    display: flex;
    justify-content: center;
    align-items: center;
    padding: 10px;
    flex-wrap: wrap;
}

.header-title-center a {
    font-size: 4em;  /* Increased font size */
    font-weight: bold;  /* Made text bold */
    color: darkorange;  /* Darker orange color */
    text-align: center;
    display: block;
}

.gr-button {
    background-color: #ff8c00;
    color: white;
    border: none;
    padding: 10px 20px;
    font-size: 1em;
    border-radius: 5px;
    cursor: pointer;
}

.gr-button:hover {
    background-color: #ff4500;
}

.footer {
    text-align: center;
    margin-top: 10px;
    font-size: 0.9em;  /* Updated font size */
    color: black; /* Ensure text is visible in dark mode */
    width: 100%;
}

.footer a {
    color: #2e8b57;
    font-weight: bold;
    text-decoration: none;
}

.footer a:hover {
    text-decoration: underline;
}

.footer .inline {
    display: inline;
    color: black; /* Ensure text is visible in dark mode */
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Row(elem_id="header"):
        gr.Markdown("<div class='header-title-center'><a href='https://eventdata.utdallas.edu/conflibert/'>ConfliBERT-MME</a></div>", elem_id="header-title-center")
    
    gr.Markdown("<span style='color: black;'>Provide the text for MME Classification.</span>")
    
    text_input = gr.Textbox(lines=5, placeholder="Enter the text here...", label="Text")
    
    output = gr.HTML(label="Output")
    
    submit_button = gr.Button("Submit", elem_id="gr-button")
    submit_button.click(fn=chatbot, inputs=text_input, outputs=output)
    
    gr.Markdown("<div class='footer'><a href='https://eventdata.utdallas.edu/'>UTD Event Data</a> | <a href='https://www.utdallas.edu/'>University of Texas at Dallas</a> | <a href='https://www.wvu.edu/'>West Virginia University</a></div>")
    gr.Markdown("<div class='footer'><span class='inline'>Developed By: <a href='https://www.linkedin.com/in/sultan-alsarra-phd-56977a63/' target='_blank'>Sultan Alsarra</a> | Finetuned By: Spencer Perkins</span></div>")

demo.launch(share=True)