import os

import gradio as gr
import pandas as pd
import plotly.express as px
from huggingface_hub.file_download import hf_hub_download


from src.utils import process_model_name, process_model_arch
from src.assets.css_html_js import custom_css
from src.assets.text_content import (
    TITLE,
    ABOUT_TEXT,
    INTRODUCTION_TEXT,
    EXAMPLE_CONFIG_TEXT,
    CITATION_BUTTON_LABEL,
    CITATION_BUTTON_TEXT,
)

HF_TOKEN = os.environ.get("HF_TOKEN", None)
LOGO_URL = "https://huggingface.co/spaces/optimum/llm-perf-leaderboard/resolve/main/huggy_bench.png"
LLM_PERF_DATASET_REPO = "optimum/llm-perf-dataset"
MACHINE_TO_HARDWARE = {"hf-dgx-01": "A100-80GB 🖥️"}
ALL_COLUMNS_MAPPING = {
    "Model": "Model 🤗",
    "Arch": "Arch 🏛️",
    "Size": "Params (B) 📏",
    # deployment settings
    "backend.name": "Backend 🏭",
    "backend.torch_dtype": "Dtype 📥",
    "optimization": "Optimization 🛠️",
    "quantization": "Quantization 🗜️",
    # measurements
    "Score": "Open LLM Score (%) ⬆️",
    "decode.throughput(tokens/s)": "Decode Throughput (tokens/s) ⬆️",
    "generate.throughput(tokens/s)": "E2E Throughput (tokens/s) ⬆️",
    "forward.latency(s)": "Prefill Latency (s) ⬇️",
    "generate.latency(s)": "E2E Latency (s) ⬇️",
    "generate.max_memory_allocated(MB)": "Allocated Memory (MB) ⬇️",
    "generate.max_memory_reserved(MB)": "Reserved Memory (MB) ⬇️",
    "generate.max_memory_used(MB)": "Used Memory (MB) ⬇️",
    "generate.energy_consumption(tokens/kWh)": "Energy (tokens/kWh) ⬇️",
}
SORTING_COLUMN = ["Score", "generate.throughput(tokens/s)"]
SORTING_ASCENDING = [False, True]
ALL_COLUMNS_DATATYPES = [
    # open llm
    "markdown",
    "markdown",
    "number",
    # deployment settings
    "str",
    "str",
    "str",
    "str",
    # measurements
    "number",
    "number",
    "number",
    "number",
    "number",
    "number",
    "number",
    "number",
    "number",
    "number",
]
# download data
hf_hub_download(
    repo_id="optimum/llm-perf-dataset",
    filename="open-llm.csv",
    local_dir="dataset",
    repo_type="dataset",
    token=HF_TOKEN,
)
OPEN_LLM = pd.read_csv("dataset/open-llm.csv")

MACHINE_TO_DATAFRAME = {}
for machine in MACHINE_TO_HARDWARE:
    hf_hub_download(
        repo_id="optimum/llm-perf-dataset",
        filename=f"{machine}/full-report.csv",
        local_dir="dataset",
        repo_type="dataset",
        token=HF_TOKEN,
    )
    MACHINE_TO_DATAFRAME[machine] = pd.read_csv(f"dataset/{machine}/full-report.csv")


def get_benchmark_df(machine="hf-dgx-01"):
    # merge on model
    llm_perf = MACHINE_TO_DATAFRAME[machine].copy()
    merged_df = OPEN_LLM.merge(llm_perf, left_on="Model", right_on="model")
    # transpose energy consumption
    merged_df["generate.energy_consumption(tokens/kWh)"] = (
        1 / merged_df["generate.energy_consumption(kWh/token)"].fillna(1)
    ).astype(int)
    # fix nan values
    merged_df.loc[
        merged_df["generate.energy_consumption(tokens/kWh)"] == 1,
        "generate.energy_consumption(tokens/kWh)",
    ] = pd.NA
    # add optimization column
    merged_df["optimization"] = merged_df[
        ["backend.to_bettertransformer", "backend.use_flash_attention_2"]
    ].apply(
        lambda x: "BetterTransformer"
        if x["backend.to_bettertransformer"]
        else ("FlashAttentionV2" if x["backend.use_flash_attention_2"] else "None"),
        axis=1,
    )
    # add quantization scheme
    merged_df["quantization"] = merged_df["backend.quantization_scheme"].apply(
        lambda x: "BnB.4bit" if x == "bnb" else ("GPTQ.4bit" if x == "gptq" else "None")
    )
    # add decode throughput
    merged_df["decode.throughput(tokens/s)"] = (
        1000 / (merged_df["generate.latency(s)"] - merged_df["forward.latency(s)"])
    ).round(2)
    # sort by metric
    merged_df.sort_values(by=SORTING_COLUMN, ascending=SORTING_ASCENDING, inplace=True)
    # filter columns
    merged_df = merged_df[list(ALL_COLUMNS_MAPPING.keys())]
    # rename columns
    merged_df.rename(columns=ALL_COLUMNS_MAPPING, inplace=True)

    return merged_df


def get_benchmark_table(bench_df):
    copy_df = bench_df.copy()
    # transform
    copy_df["Model 🤗"] = copy_df["Model 🤗"].apply(process_model_name)
    copy_df["Arch 🏛️"] = copy_df["Arch 🏛️"].apply(process_model_arch)
    # process quantization
    copy_df["Open LLM Score (%) ⬆️"] = copy_df.apply(
        lambda x: f"{x['Open LLM Score (%) ⬆️']}**"
        if x["Quantization 🗜️"] in ["BnB.4bit", "GPTQ.4bit"]
        else x["Open LLM Score (%) ⬆️"],
        axis=1,
    )
    return copy_df


def get_benchmark_chart(bench_df):
    copy_df = bench_df.copy()
    # transform
    copy_df["Arch 🏛️"] = copy_df["Arch 🏛️"].apply(process_model_arch)
    # plot
    fig = px.scatter(
        copy_df,
        y="Open LLM Score (%) ⬆️",
        x="E2E Latency (s) ⬇️",
        size="Allocated Memory (MB) ⬇️",
        color="Arch 🏛️",
        custom_data=list(ALL_COLUMNS_MAPPING.values()),
        color_discrete_sequence=px.colors.qualitative.Light24,
    )
    fig.update_layout(
        title={
            "text": "Latency vs. Score vs. Memory",
            "y": 0.95,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
        },
        xaxis_title="Per 1000 Tokens Latency (s)",
        yaxis_title="Open LLM Score (%)",
        legend_title="LLM Architecture",
        width=1200,
        height=600,
    )
    fig.update_traces(
        hovertemplate="<br>".join(
            [
                f"<b>{column}:</b> %{{customdata[{i}]}}"
                for i, column in enumerate(ALL_COLUMNS_MAPPING.values())
            ]
        )
    )
    return fig


def filter_query(
    text,
    backends,
    datatypes,
    optimizations,
    quantizations,
    score,
    memory,
    machine,
):
    raw_df = get_benchmark_df(machine=machine)
    filtered_df = raw_df[
        raw_df["Model 🤗"].str.contains(text, case=False)
        & raw_df["Backend 🏭"].isin(backends)
        & raw_df["Dtype 📥"].isin(datatypes)
        & raw_df["Optimization 🛠️"].isin(optimizations)
        & raw_df["Quantization 🗜️"].isin(quantizations)
        & (raw_df["Open LLM Score (%) ⬆️"] >= score)
        & (raw_df["Allocated Memory (MB) ⬇️"] <= memory)
    ]
    filtered_table = get_benchmark_table(filtered_df)
    filtered_chart = get_benchmark_chart(filtered_df)
    return filtered_table, filtered_chart


# Demo interface
demo = gr.Blocks(css=custom_css)
with demo:
    # logo
    gr.HTML(f'<img src="{LOGO_URL}">', elem_classes="logo")
    # leaderboard title
    gr.HTML(TITLE)
    # introduction text
    gr.Markdown(INTRODUCTION_TEXT, elem_classes="descriptive-text")

    with gr.Tabs(elem_classes="leaderboard-tabs"):
        machine_placeholders = {}
        machine_tables = {}
        machine_plots = {}
        ####################### HARDWARE TABS #######################
        for i, (machine, hardware) in enumerate(MACHINE_TO_HARDWARE.items()):
            # dummy placeholder of the machine name
            machine_placeholders[machine] = gr.Textbox(value=machine, visible=False)

            with gr.TabItem(hardware, id=i):
                with gr.Tabs(elem_classes="machine-tabs"):
                    # placeholder for full dataframe
                    machine_df = get_benchmark_df(machine=machine)
                    with gr.TabItem("Leaderboard 🏅", id=0):
                        gr.HTML(
                            "👉 Scroll to the right 👉 for additional columns.",
                            elem_id="descriptive-text",
                        )
                        # Original leaderboard table
                        machine_tables[machine] = gr.components.Dataframe(
                            value=get_benchmark_table(machine_df),
                            headers=list(ALL_COLUMNS_MAPPING.values()),
                            datatype=ALL_COLUMNS_DATATYPES,
                            elem_id="machine-table",
                        )
                    with gr.TabItem("Plot 📊", id=1):
                        gr.HTML(
                            "👆 Hover over the points 👆 for additional information.",
                            elem_id="descriptive-text",
                        )
                        # Original leaderboard plot
                        machine_plots[machine] = gr.components.Plot(
                            value=get_benchmark_chart(machine_df),
                            elem_id="machine-plot",
                            show_label=False,
                        )

        ###################### CONTROL PANEL #######################
        with gr.TabItem("Control Panel 🎛️", id=2):
            gr.HTML(
                "Use this control panel to filter the leaderboard's table and plot.",  # noqa: E501
                elem_id="descriptive-text",
            )
            with gr.Row():
                with gr.Column():
                    search_bar = gr.Textbox(
                        label="Model 🤗",
                        info="🔍 Search for a model name",
                        elem_id="search-bar",
                    )
            with gr.Row():
                with gr.Column(scale=1):
                    score_slider = gr.Slider(
                        label="Open LLM Score (%) 📈",
                        info="🎚️ Slide to minimum Open LLM score",
                        value=0,
                        elem_id="threshold-slider",
                    )
                with gr.Column(scale=1):
                    memory_slider = gr.Slider(
                        label="Peak Memory (MB) 📈",
                        info="🎚️ Slide to maximum Peak Memory",
                        minimum=0,
                        maximum=80 * 1024,
                        value=80 * 1024,
                        elem_id="memory-slider",
                    )
                with gr.Column(scale=1):
                    backend_checkboxes = gr.CheckboxGroup(
                        label="Backends 🏭",
                        choices=["pytorch", "onnxruntime"],
                        value=["pytorch", "onnxruntime"],
                        info="☑️ Select the backends",
                        elem_id="backend-checkboxes",
                    )
            with gr.Row():
                with gr.Column(scale=1):
                    datatype_checkboxes = gr.CheckboxGroup(
                        label="Load Dtypes 📥",
                        choices=["float32", "float16"],
                        value=["float32", "float16"],
                        info="☑️ Select the load dtypes",
                        elem_id="dtype-checkboxes",
                    )
                with gr.Column(scale=1):
                    optimization_checkboxes = gr.CheckboxGroup(
                        label="Optimizations 🛠️",
                        choices=["None", "BetterTransformer", "FlashAttentionV2"],
                        value=["None", "BetterTransformer", "FlashAttentionV2"],
                        info="☑️ Select the optimization",
                        elem_id="optimization-checkboxes",
                    )
                with gr.Column(scale=1):
                    quantization_checkboxes = gr.CheckboxGroup(
                        label="Quantizations 🗜️",
                        choices=["None", "BnB.4bit", "GPTQ.4bit"],
                        value=["None", "BnB.4bit", "GPTQ.4bit"],
                        info="☑️ Select the quantization schemes",
                        elem_id="quantization-checkboxes",
                    )
            with gr.Row():
                filter_button = gr.Button(
                    value="Filter 🚀",
                    elem_id="filter-button",
                )
            for machine in MACHINE_TO_HARDWARE:
                filter_button.click(
                    filter_query,
                    [
                        search_bar,
                        backend_checkboxes,
                        datatype_checkboxes,
                        optimization_checkboxes,
                        quantization_checkboxes,
                        score_slider,
                        memory_slider,
                        machine_placeholders[machine],
                    ],
                    [machine_tables[machine], machine_plots[machine]],
                )

        ####################### ABOUT TAB #######################
        with gr.TabItem("About 📖", id=3):
            gr.HTML(ABOUT_TEXT, elem_classes="descriptive-text")
            gr.Markdown(EXAMPLE_CONFIG_TEXT, elem_classes="descriptive-text")

    ####################### CITATION #######################
    with gr.Row():
        with gr.Accordion("📙 Citation", open=False):
            citation_button = gr.Textbox(
                value=CITATION_BUTTON_TEXT,
                label=CITATION_BUTTON_LABEL,
                elem_id="citation-button",
                show_copy_button=True,
            )

# Launch demo
demo.queue().launch()