CK42 commited on
Commit
7af13f4
1 Parent(s): e4c082d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests.exceptions
3
+ from huggingface_hub import HfApi, hf_hub_download
4
+ from huggingface_hub.repocard import metadata_load
5
+
6
+ app = gr.Blocks()
7
+
8
+ def load_agent(model_id_1, model_id_2):
9
+ """
10
+ This function load the agent's video and results
11
+ :return: video_path
12
+ """
13
+ # Load the metrics
14
+ metadata_1 = get_metadata(model_id_1)
15
+
16
+ # Get the accuracy
17
+ results_1 = parse_metrics_accuracy(metadata_1)
18
+
19
+ # Load the video
20
+ video_path_1 = hf_hub_download(model_id_1, filename="replay.mp4")
21
+
22
+ # Load the metrics
23
+ metadata_2 = get_metadata(model_id_2)
24
+
25
+ # Get the accuracy
26
+ results_2 = parse_metrics_accuracy(metadata_2)
27
+
28
+ # Load the video
29
+ video_path_2 = hf_hub_download(model_id_2, filename="replay.mp4")
30
+
31
+ return model_id_1, video_path_1, results_1, model_id_2, video_path_2, results_2
32
+
33
+ def parse_metrics_accuracy(meta):
34
+ if "model-index" not in meta:
35
+ return None
36
+ result = meta["model-index"][0]["results"]
37
+ metrics = result[0]["metrics"]
38
+ accuracy = metrics[0]["value"]
39
+ return accuracy
40
+
41
+ def get_metadata(model_id):
42
+ """
43
+ Get the metadata of the model repo
44
+ :param model_id:
45
+ :return: metadata
46
+ """
47
+ try:
48
+ readme_path = hf_hub_download(model_id, filename="README.md")
49
+ metadata = metadata_load(readme_path)
50
+ print(metadata)
51
+ return metadata
52
+ except requests.exceptions.HTTPError:
53
+ return None
54
+
55
+
56
+
57
+
58
+ with app:
59
+ gr.Markdown(
60
+ """
61
+ # Compare Deep Reinforcement Learning Agents 🤖
62
+
63
+ Type two models id you want to compare or check examples below.
64
+ """)
65
+ with gr.Row():
66
+ model1_input = gr.Textbox(label="Model 1")
67
+ model2_input = gr.Textbox(label="Model 2")
68
+ with gr.Row():
69
+ app_button = gr.Button("Compare models")
70
+ with gr.Row():
71
+ with gr.Column():
72
+ model1_name = gr.Markdown()
73
+ model1_video_output = gr.Video()
74
+ model1_score_output = gr.Textbox(label="Mean Reward +/- Std Reward")
75
+ with gr.Column():
76
+ model2_name = gr.Markdown()
77
+ model2_video_output = gr.Video()
78
+ model2_score_output = gr.Textbox(label="Mean Reward +/- Std Reward")
79
+
80
+ app_button.click(load_agent, inputs=[model1_input, model2_input], outputs=[model1_name, model1_video_output, model1_score_output, model2_name, model2_video_output, model2_score_output])
81
+
82
+ examples = gr.Examples(examples=[["sb3/a2c-AntBulletEnv-v0","sb3/ppo-AntBulletEnv-v0"],
83
+ ["ThomasSimonini/a2c-AntBulletEnv-v0", "sb3/a2c-AntBulletEnv-v0"],
84
+ ["sb3/dqn-SpaceInvadersNoFrameskip-v4", "sb3/a2c-SpaceInvadersNoFrameskip-v4"],
85
+ ["ThomasSimonini/ppo-QbertNoFrameskip-v4","sb3/ppo-QbertNoFrameskip-v4"]],
86
+ inputs=[model1_input, model2_input])
87
+
88
+
89
+ app.launch()