magilogi commited on
Commit
aa9beda
β€’
1 Parent(s): b83aee7
Files changed (1) hide show
  1. app.py +31 -2
app.py CHANGED
@@ -28,6 +28,9 @@ explanation_data = {
28
  explanation_df = pd.DataFrame(explanation_data)
29
 
30
  df = pd.read_csv("data/csv/models_data.csv")
 
 
 
31
 
32
  filter_mapping = {
33
  "all": "all",
@@ -65,6 +68,17 @@ def create_scatter_plot(df, x_col, y_col, title, x_title, y_title):
65
  fig.update_traces(marker=dict(size=10), selector=dict(mode='markers'))
66
  return fig
67
 
 
 
 
 
 
 
 
 
 
 
 
68
  def create_bar_plot(df, col, title):
69
  sorted_df = df.sort_values(by=col, ascending=True)
70
  fig = px.bar(sorted_df,
@@ -109,14 +123,20 @@ with gr.Blocks(css="custom.css") as demo:
109
  value=create_bar_plot(df, "medqa_diff", "Impact of Generic2Brand swap on MedQA Accuracy"),
110
  elem_id="bar2"
111
  )
 
 
 
 
112
 
113
  with gr.Row():
114
  gr.Markdown(""" """)
115
 
 
 
116
  with gr.Tabs(elem_classes="tab-buttons"):
117
  with gr.TabItem("πŸ” Evaluation table"):
118
  with gr.Column():
119
- with gr.Accordion("➑️ Filter by Column", open=False):
120
  shown_columns = gr.CheckboxGroup(
121
  choices=df.columns.tolist(),
122
  value=df.columns.tolist(),
@@ -227,7 +247,16 @@ with gr.Blocks(css="custom.css") as demo:
227
  label="Explanation of Scores"
228
  )
229
 
230
-
 
 
 
 
 
 
 
 
 
231
 
232
 
233
 
 
28
  explanation_df = pd.DataFrame(explanation_data)
29
 
30
  df = pd.read_csv("data/csv/models_data.csv")
31
+ df['average_g2b'] = df[['medmcqa_g2b', 'medqa_4options_g2b']].mean(axis=1).round(2)
32
+ df['average_orginal_acc'] = df[['medmcqa_orig_filtered', 'medqa_4options_orig_filtered']].mean(axis=1).round(2)
33
+ df['average_diff'] = df[['medmcqa_diff', 'medqa_diff']].mean(axis=1).round(2)
34
 
35
  filter_mapping = {
36
  "all": "all",
 
68
  fig.update_traces(marker=dict(size=10), selector=dict(mode='markers'))
69
  return fig
70
 
71
+ def create_lm_plot(df, x_col, y_col, title, x_title, y_title):
72
+ fig = px.scatter(df, x=x_col, y=y_col, color='Model', title=title, color_discrete_sequence=px.colors.sequential.solar, trendline='ols')
73
+
74
+ fig.update_layout(
75
+ xaxis_title=x_title,
76
+ yaxis_title=y_title,
77
+ legend_title_text='Model'
78
+ )
79
+ fig.update_traces(marker=dict(size=10), selector=dict(mode='markers'))
80
+ return fig
81
+
82
  def create_bar_plot(df, col, title):
83
  sorted_df = df.sort_values(by=col, ascending=True)
84
  fig = px.bar(sorted_df,
 
123
  value=create_bar_plot(df, "medqa_diff", "Impact of Generic2Brand swap on MedQA Accuracy"),
124
  elem_id="bar2"
125
  )
126
+
127
+
128
+
129
+
130
 
131
  with gr.Row():
132
  gr.Markdown(""" """)
133
 
134
+ default_visible_columns = ['T', 'Model', 'average_original_acc', 'average_g2b','average_diff']
135
+
136
  with gr.Tabs(elem_classes="tab-buttons"):
137
  with gr.TabItem("πŸ” Evaluation table"):
138
  with gr.Column():
139
+ with gr.Accordion("➑️ See All Columns", open=False):
140
  shown_columns = gr.CheckboxGroup(
141
  choices=df.columns.tolist(),
142
  value=df.columns.tolist(),
 
247
  label="Explanation of Scores"
248
  )
249
 
250
+ with gr.Row():
251
+ bar3 = gr.Plot(
252
+ value=create_bar_plot(df, "b4bqa", "Which LLMs are best at matching brand names to generic drug names? (Results from custom task)"),
253
+ elem_id="bar3"
254
+ )
255
+
256
+ with gr.Row():
257
+ scatter_g2b = gr.Plot(
258
+ value=create_lm_plot(df, "b4bqa", "average_g2b", "Does that matching accuracy correlate with biomedical task robustness?", "b4bqa", "average_diff"),
259
+ )
260
 
261
 
262