Spaces:
Runtime error
Runtime error
| import gradio as gr # type: ignore | |
| import plotly.express as px # type: ignore | |
| from backend.data import load_cot_data | |
| from backend.envs import API, REPO_ID, TOKEN | |
| logo1_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/AI2_Logo_Square.png" | |
| logo2_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/logo_logikon_notext_withborder.png" | |
| LOGOS = f'<div style="display: flex; justify-content: center;"><a href="https://allenai.org/"><img src="{logo1_url}" alt="AI2" style="width: 30vw; min-width: 20px; max-width: 60px;"></a> <a href="https://logikon.ai"><img src="{logo2_url}" alt="Logikon AI" style="width: 30vw; min-width: 20px; max-width: 60px; margin-left: 10px;"></a></div>' | |
| TITLE = f'<h1 align="center" id="space-title"> Open CoT Dashboard</h1> {LOGOS}' | |
| INTRODUCTION_TEXT = """ | |
| Baseline accuracies and marginal accuracy gains for specific models and CoT regimes from the [Open CoT Leaderboard](https://huggingface.co/spaces/logikon/open_cot_leaderboard). | |
| """ | |
| def restart_space(): | |
| API.restart_space(repo_id=REPO_ID, token=TOKEN) | |
| try: | |
| df_cot_err, df_cot_regimes = load_cot_data() | |
| except Exception as err: | |
| print(err) | |
| # sleep for 10 seconds before restarting the space | |
| import time | |
| time.sleep(10) | |
| restart_space() | |
| def plot_evals_init(model_id, regex_model_filter, plotly_mode, request: gr.Request): | |
| if request and "model" in request.query_params: | |
| model_param = request.query_params["model"] | |
| if model_param in df_cot_err.model.to_list(): | |
| model_id = model_param | |
| return plot_evals(model_id, regex_model_filter, plotly_mode) | |
| def plot_evals(model_id, regex_model_filter, plotly_mode): | |
| df = df_cot_err.copy() | |
| df["selected"] = df_cot_err.model.apply(lambda x: "selected" if x==model_id else "-") | |
| try: | |
| df_filter = df.model.str.contains(regex_model_filter) | |
| except Exception as err: | |
| gr.Warning("Failed to apply regex filter", duration=4) | |
| print("Failed to apply regex filter" + err) | |
| df_filter = df.model.str.contains(".*") | |
| df = df[df_filter | df.selected.eq("selected")] | |
| #df.sort_values(["selected", "model"], inplace=True, ascending=True) # has currently no effect with px.scatter | |
| template = "plotly_dark" if plotly_mode=="dark" else "plotly" | |
| fig = px.scatter(df, x="base accuracy", y="marginal acc. gain", color="selected", symbol="model", | |
| facet_col="task", facet_col_wrap=3, | |
| category_orders={"selected": ["selected", "-"]}, | |
| color_discrete_sequence=["Orange", "Gray"], | |
| template=template, | |
| error_y="acc_gain-err", hover_data=['model', "cot accuracy"], | |
| width=1200, height=700) | |
| fig.update_layout( | |
| title={"automargin": True}, | |
| ) | |
| return fig, model_id | |
| def styled_model_table_init(model_id, request: gr.Request): | |
| if request and "model" in request.query_params: | |
| model_param = request.query_params["model"] | |
| if model_param in df_cot_regimes.model.to_list(): | |
| model_id = model_param | |
| return styled_model_table(model_id) | |
| def styled_model_table(model_id): | |
| def make_pretty(styler): | |
| styler.hide(axis="index") | |
| styler.format(precision=1), | |
| styler.background_gradient( | |
| axis=None, | |
| subset=["acc_base", "acc_cot"], | |
| vmin=20, vmax=100, cmap="YlGnBu" | |
| ) | |
| styler.background_gradient( | |
| axis=None, | |
| subset=["acc_gain"], | |
| vmin=-20, vmax=20, cmap="coolwarm" | |
| ) | |
| styler.set_table_styles({ | |
| 'task': [{'selector': '', | |
| 'props': [('font-weight', 'bold')]}], | |
| 'B': [{'selector': 'td', | |
| 'props': 'color: blue;'}] | |
| }, overwrite=False) | |
| return styler | |
| df_cot_model = df_cot_regimes[df_cot_regimes.model.eq(model_id)][['task', 'cot_chain', 'best_of', | |
| 'temperature', 'top_k', 'top_p', 'acc_base', 'acc_cot', 'acc_gain']] | |
| df_cot_model = df_cot_model \ | |
| .rename(columns={"temperature": "temp"}) \ | |
| .replace({'cot_chain': 'ReflectBeforeRun'}, "Reflect") \ | |
| .sort_values(["task", "cot_chain"]) \ | |
| .reset_index(drop=True) | |
| return df_cot_model.style.pipe(make_pretty) | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.HTML(TITLE) | |
| gr.Markdown(INTRODUCTION_TEXT) | |
| with gr.Row(): | |
| selected_model = gr.Dropdown(list(df_cot_err.model.unique()), value="allenai/tulu-2-70b", label="Model", info="with performance details below", scale=2) | |
| regex_model_filter = gr.Textbox(".*", label="Regex", info="to filter models shown in plots", scale=2) | |
| plotly_mode = gr.Radio(["dark","light"], value="light", label="Theme", info="of plots", scale=1) | |
| submit = gr.Button("Update", scale=1) | |
| table = gr.DataFrame() | |
| plot = gr.Plot(label="evals") | |
| submit.click(plot_evals, [selected_model, regex_model_filter, plotly_mode], [plot, selected_model]) | |
| submit.click(styled_model_table, selected_model, table) | |
| demo.load(plot_evals_init, [selected_model, regex_model_filter, plotly_mode], [plot, selected_model]) | |
| demo.load(styled_model_table_init, selected_model, table) | |
| demo.launch() |