File size: 1,972 Bytes
ab5f5f1
 
 
 
 
 
76b423c
0232cf1
dc685a9
76b423c
 
ab5f5f1
0232cf1
 
 
 
76b423c
ab5f5f1
 
 
 
 
 
5345cba
ab5f5f1
 
5345cba
0232cf1
ab5f5f1
0232cf1
ab5f5f1
 
 
 
 
76b423c
 
 
 
ab5f5f1
 
 
 
57896bb
ab5f5f1
 
5345cba
 
ab5f5f1
76b423c
ab5f5f1
 
 
 
 
5345cba
4f5bf6c
ab5f5f1
 
 
 
 
 
0232cf1
4f5bf6c
ab5f5f1
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
import plotly.express as px

SCORE_MEMORY_LATENCY_DATA = [
    "Model 🤗",
    "Backend 🏭",
    "Precision 📥",
    "Params (B)",
    "Quantization 🗜️",
    "Attention 👁️",
    "Kernel ⚛️",
    "Open LLM Score (%)",
    "Prefill (s)",
    "Decode (tokens/s)",
    "Memory (MB)",
    "End-to-End (s)",
    "Architecture 🏛️",
]


def get_lat_score_mem_fig(llm_perf_df):
    copy_df = llm_perf_df.copy()
    # plot
    # filter nan memory
    fig = px.scatter(
        copy_df,
        size="Memory (MB)",
        x="End-to-End (s)",
        y="Open LLM Score (%)",
        color="Architecture 🏛️",
        custom_data=SCORE_MEMORY_LATENCY_DATA,
        color_discrete_sequence=px.colors.qualitative.Light24,
    )
    fig.update_traces(
        hovertemplate="<br>".join(
            [
                f"<b>{column}:</b> %{{customdata[{i}]}}"
                for i, column in enumerate(SCORE_MEMORY_LATENCY_DATA)
            ]
        )
    )
    fig.update_layout(
        title={
            "text": "Latency vs. Score vs. Memory",
            "xanchor": "center",
            "yanchor": "top",
            "y": 0.95,
            "x": 0.5,
        },
        xaxis_title="Time To Generate 64 Tokens (s)",
        yaxis_title="Open LLM Score (%)",
        legend_title="LLM Architecture",
        width=1200,
        height=600,
    )
    # update x range with 95 percentile of
    fig.update_xaxes(range=[-0.5, copy_df["End-to-End (s)"].quantile(0.90)])

    return fig


def create_lat_score_mem_plot(llm_perf_df):
    # descriptive text
    gr.HTML("👆 Hover over the points 👆 for additional information. ", elem_id="text")
    gr.HTML("📊 We only show the top 90% LLMs based on latency ⌛", elem_id="text")
    # get figure
    fig = get_lat_score_mem_fig(llm_perf_df)
    # create plot
    plot = gr.components.Plot(
        value=fig,
        elem_id="plot",
        show_label=False,
    )

    return plot