import gradio as gr
import pandas as pd
import plotly.express as px


QUANT_DATA = [
    # open llm
    "Model 🤗",
    "Arch 🏛️",
    "DType 📥",
    "Backend 🏭",
    "Params (B)",
    "Open LLM Score (%)",
    # deployment settings
    "DType 📥",
    "Backend 🏭",
    "Optimization 🛠️",
    "Quantization 🗜️",
    "Optimization 🛠️ Custom Kernel",
    "Quantization 🗜️ Custom Kernel",
    # primary measurements
    "Prefill Latency (s)",
    "Prefill Latency (s) Custom Kernel",
    "Decode Throughput (tokens/s)",
    "Decode Throughput (tokens/s) Custom Kernel",
    # speedups
    "Prefill Latency Speedup (%)",
    "Decode Throughput Speedup (%)",
]


def get_quant_df(llm_perf_df):
    copy_df = llm_perf_df.copy()
    # seperate vanilla GPTQ experiments from Custom Kernel experiments
    vanilla_df = copy_df[
        (copy_df["Backend 🏭"] == "pytorch") &
        (copy_df["Quantization 🗜️"] == "None") &
        (copy_df["Optimization 🛠️"] == "None") &
        (copy_df["DType 📥"] == "float16")
    ]
    exllamav1_df = copy_df[(copy_df["Quantization 🗜️"] == "GPTQ.4bit+ExllamaV1")]
    exllamav2_df = copy_df[(copy_df["Quantization 🗜️"] == "GPTQ.4bit+ExllamaV2")]
    gemm_df = copy_df[(copy_df["Quantization 🗜️"] == "AWQ.4bit+GEMM")]
    gemv_df = copy_df[(copy_df["Quantization 🗜️"] == "AWQ.4bit+GEMV")]
    # merge the three dataframes
    exllamav1_df = pd.merge(
        vanilla_df,
        exllamav1_df,
        on=["Model 🤗"],
        suffixes=["", " Custom Kernel"],
    )
    exllamav2_df = pd.merge(
        vanilla_df,
        exllamav2_df,
        on=["Model 🤗"],
        suffixes=["", " Custom Kernel"],
    )
    gemm_df = pd.merge(
        vanilla_df,
        gemm_df,
        on=["Model 🤗"],
        suffixes=["", " Custom Kernel"],
    )
    gemv_df = pd.merge(
        vanilla_df,
        gemv_df,
        on=["Model 🤗"],
        suffixes=["", " Custom Kernel"],
    )
    # concat the two dataframes row-wise
    quant_df = pd.concat([exllamav1_df, exllamav2_df, gemm_df, gemv_df])
    # compute speedups
    quant_df["Prefill Latency Speedup (%)"] = (
        (quant_df["Prefill Latency (s)"] / quant_df["Prefill Latency (s) Custom Kernel"]) * 100
    ).round(2) - 100
    quant_df["Decode Throughput Speedup (%)"] = (
        (quant_df["Decode Throughput (tokens/s) Custom Kernel"] / quant_df["Decode Throughput (tokens/s)"]) * 100
    ).round(2) - 100
    # filter speedups > 1000%
    quant_df = quant_df[quant_df["Prefill Latency Speedup (%)"] < 1000]
    quant_df = quant_df[quant_df["Decode Throughput Speedup (%)"] < 1000]

    return quant_df


def get_quant_decode_fig(llm_perf_df):
    quant_df = get_quant_df(llm_perf_df)
    # plot
    decode_fig = px.box(
        quant_df,
        x="Arch 🏛️",
        y="Decode Throughput Speedup (%)",
        color_discrete_sequence=px.colors.qualitative.Light24,
        custom_data=QUANT_DATA,
        color="Quantization 🗜️ Custom Kernel",
        points="all",
    )
    # add hover data
    decode_fig.update_traces(
        hovertemplate="<br>".join([f"<b>{column}:</b> %{{customdata[{i}]}}" for i, column in enumerate(QUANT_DATA)])
    )
    # add layout
    decode_fig.update_layout(
        title={
            "text": "Decode Throughput Speedup per Architecture",
            "y": 0.95,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
        },
        xaxis_title="LLM Architecture",
        yaxis_title="Decode Speedup (%)",
        legend_title="Quantization Scheme",
        width=1200,
        height=600,
    )

    return decode_fig


def get_quant_prefill_fig(llm_perf_df):
    quant_df = get_quant_df(llm_perf_df)
    # plot
    prefill_fig = px.box(
        quant_df,
        x="Arch 🏛️",
        y="Prefill Latency Speedup (%)",
        color_discrete_sequence=px.colors.qualitative.Light24,
        custom_data=QUANT_DATA,
        color="Quantization 🗜️ Custom Kernel",
        points="all",
    )
    # add hover data
    prefill_fig.update_traces(
        hovertemplate="<br>".join([f"<b>{column}:</b> %{{customdata[{i}]}}" for i, column in enumerate(QUANT_DATA)])
    )
    # add layout
    prefill_fig.update_layout(
        title={
            "text": "Prefill Latency Speedup per Architecture",
            "y": 0.95,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
        },
        xaxis_title="LLM Architecture",
        yaxis_title="Prefill Speedup (%)",
        legend_title="Quantization Scheme",
        width=1200,
        height=600,
    )

    return prefill_fig


def create_quant_plots(llm_perf_df):
    # descriptive text
    gr.HTML("👆 Hover over the points 👆 for additional information.", elem_id="text")
    # get figures
    prefill_fig = get_quant_prefill_fig(llm_perf_df)
    decode_fig = get_quant_decode_fig(llm_perf_df)

    # create plots
    prefill_plot = gr.components.Plot(value=prefill_fig, elem_id="plot", show_label=False)
    decode_plot = gr.components.Plot(value=decode_fig, elem_id="plot", show_label=False)

    return prefill_plot, decode_plot