Update app.py
Browse files
app.py
CHANGED
@@ -23,13 +23,14 @@ from huggingface_hub import list_datasets, HfApi, hf_hub_download
|
|
23 |
ANNOTATION_CONFIG_FILE = "annotation_config.json"
|
24 |
OUTPUT_FILE_PATH = "dataset.jsonl"
|
25 |
|
26 |
-
|
27 |
|
28 |
def load_llm_config():
|
29 |
params = load_params()
|
30 |
return (
|
31 |
params.get('PROVIDER', ''),
|
32 |
params.get('BASE_URL', ''),
|
|
|
33 |
params.get('WORKSPACE', ''),
|
34 |
params.get('API_KEY', ''),
|
35 |
params.get('max_tokens', 2048),
|
@@ -41,10 +42,11 @@ def load_llm_config():
|
|
41 |
|
42 |
|
43 |
|
44 |
-
def save_llm_config(provider, base_url, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty):
|
45 |
save_params({
|
46 |
'PROVIDER': provider,
|
47 |
'BASE_URL': base_url,
|
|
|
48 |
'WORKSPACE': workspace,
|
49 |
'API_KEY': api_key,
|
50 |
'max_tokens': max_tokens,
|
@@ -56,6 +58,8 @@ def save_llm_config(provider, base_url, workspace, api_key, max_tokens, temperat
|
|
56 |
return "LLM configuration saved successfully"
|
57 |
|
58 |
|
|
|
|
|
59 |
|
60 |
|
61 |
def load_annotation_config():
|
@@ -493,7 +497,7 @@ def update_chat_context(row_data, index, total, quality, high_quality_tags, low_
|
|
493 |
|
494 |
|
495 |
|
496 |
-
async def run_generate_dataset(num_workers, num_generations, output_file_path,
|
497 |
if loaded_dataset is None:
|
498 |
return "Error: No dataset loaded. Please load a dataset before generating.", ""
|
499 |
|
@@ -501,7 +505,7 @@ async def run_generate_dataset(num_workers, num_generations, output_file_path, l
|
|
501 |
for _ in range(num_generations):
|
502 |
topic_selected = random.choice(TOPICS)
|
503 |
system_message_selected = random.choice(SYSTEM_MESSAGES_VODALUS)
|
504 |
-
data = await generate_data(topic_selected, PROMPT_1, system_message_selected, output_file_path,
|
505 |
if data:
|
506 |
generated_data.append(json.dumps(data))
|
507 |
|
@@ -621,6 +625,13 @@ def load_dataset_wrapper(dataset_name, split):
|
|
621 |
dataset, message = load_huggingface_dataset(dataset_name, split)
|
622 |
return dataset, message
|
623 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
624 |
|
625 |
def get_popular_datasets():
|
626 |
return [
|
@@ -913,10 +924,12 @@ with demo:
|
|
913 |
with gr.Tab("LLM Configuration"):
|
914 |
with gr.Row():
|
915 |
provider = gr.Dropdown(choices=["local-model", "anything-llm"], label="LLM Provider")
|
916 |
-
base_url = gr.Textbox(label="Base URL (for local model)")
|
917 |
with gr.Row():
|
918 |
-
|
919 |
-
|
|
|
|
|
|
|
920 |
|
921 |
with gr.Accordion("Advanced Options", open=False):
|
922 |
with gr.Row():
|
@@ -1045,18 +1058,18 @@ with demo:
|
|
1045 |
|
1046 |
start_generation_btn.click(
|
1047 |
run_generate_dataset,
|
1048 |
-
inputs=[num_workers, num_generations, output_file_path,
|
1049 |
outputs=[generation_status, generation_output]
|
1050 |
)
|
1051 |
|
1052 |
demo.load(
|
1053 |
load_llm_config,
|
1054 |
-
outputs=[provider, base_url, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty]
|
1055 |
)
|
1056 |
|
1057 |
save_llm_config_btn.click(
|
1058 |
save_llm_config,
|
1059 |
-
inputs=[provider, base_url, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty],
|
1060 |
outputs=[llm_config_status]
|
1061 |
)
|
1062 |
|
@@ -1071,28 +1084,34 @@ with demo:
|
|
1071 |
outputs=[chatbot]
|
1072 |
)
|
1073 |
|
1074 |
-
|
1075 |
-
|
1076 |
-
inputs=[
|
1077 |
-
outputs=[
|
1078 |
)
|
1079 |
|
1080 |
-
|
1081 |
-
|
1082 |
-
|
1083 |
-
|
1084 |
-
)
|
1085 |
|
1086 |
-
|
1087 |
-
|
1088 |
-
|
1089 |
-
|
1090 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
1091 |
|
1092 |
# Modify the start_generation_btn.click to include the loaded dataset
|
1093 |
start_generation_btn.click(
|
1094 |
run_generate_dataset,
|
1095 |
-
inputs=[num_workers, num_generations, output_file_path,
|
1096 |
outputs=[generation_status, generation_output]
|
1097 |
)
|
1098 |
|
|
|
23 |
ANNOTATION_CONFIG_FILE = "annotation_config.json"
|
24 |
OUTPUT_FILE_PATH = "dataset.jsonl"
|
25 |
|
26 |
+
llm_provider_state = State("")
|
27 |
|
28 |
def load_llm_config():
|
29 |
params = load_params()
|
30 |
return (
|
31 |
params.get('PROVIDER', ''),
|
32 |
params.get('BASE_URL', ''),
|
33 |
+
params.get('MODEL', ''), # Add this line
|
34 |
params.get('WORKSPACE', ''),
|
35 |
params.get('API_KEY', ''),
|
36 |
params.get('max_tokens', 2048),
|
|
|
42 |
|
43 |
|
44 |
|
45 |
+
def save_llm_config(provider, base_url, model, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty):
|
46 |
save_params({
|
47 |
'PROVIDER': provider,
|
48 |
'BASE_URL': base_url,
|
49 |
+
'MODEL': model, # Add this line
|
50 |
'WORKSPACE': workspace,
|
51 |
'API_KEY': api_key,
|
52 |
'max_tokens': max_tokens,
|
|
|
58 |
return "LLM configuration saved successfully"
|
59 |
|
60 |
|
61 |
+
def update_model_visibility(provider):
|
62 |
+
return gr.update(visible=provider in ["local-model", "openai"])
|
63 |
|
64 |
|
65 |
def load_annotation_config():
|
|
|
497 |
|
498 |
|
499 |
|
500 |
+
async def run_generate_dataset(num_workers, num_generations, output_file_path, llm_provider, dataset):
|
501 |
if loaded_dataset is None:
|
502 |
return "Error: No dataset loaded. Please load a dataset before generating.", ""
|
503 |
|
|
|
505 |
for _ in range(num_generations):
|
506 |
topic_selected = random.choice(TOPICS)
|
507 |
system_message_selected = random.choice(SYSTEM_MESSAGES_VODALUS)
|
508 |
+
data = await generate_data(topic_selected, PROMPT_1, system_message_selected, output_file_path, llm_provider)
|
509 |
if data:
|
510 |
generated_data.append(json.dumps(data))
|
511 |
|
|
|
625 |
dataset, message = load_huggingface_dataset(dataset_name, split)
|
626 |
return dataset, message
|
627 |
|
628 |
+
def update_field_visibility(provider):
|
629 |
+
if provider == "local-model":
|
630 |
+
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
|
631 |
+
elif provider == "anything-llm":
|
632 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
|
633 |
+
else:
|
634 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
635 |
|
636 |
def get_popular_datasets():
|
637 |
return [
|
|
|
924 |
with gr.Tab("LLM Configuration"):
|
925 |
with gr.Row():
|
926 |
provider = gr.Dropdown(choices=["local-model", "anything-llm"], label="LLM Provider")
|
|
|
927 |
with gr.Row():
|
928 |
+
base_url = gr.Textbox(label="Base URL (for local model)", visible=False)
|
929 |
+
model = gr.Textbox(label="Model (for local model)", visible=False)
|
930 |
+
with gr.Row():
|
931 |
+
workspace = gr.Textbox(label="Workspace (for AnythingLLM)", visible=False)
|
932 |
+
api_key = gr.Textbox(label="API Key (for AnythingLLM)", visible=False)
|
933 |
|
934 |
with gr.Accordion("Advanced Options", open=False):
|
935 |
with gr.Row():
|
|
|
1058 |
|
1059 |
start_generation_btn.click(
|
1060 |
run_generate_dataset,
|
1061 |
+
inputs=[num_workers, num_generations, output_file_path, llm_provider, dataset],
|
1062 |
outputs=[generation_status, generation_output]
|
1063 |
)
|
1064 |
|
1065 |
demo.load(
|
1066 |
load_llm_config,
|
1067 |
+
outputs=[provider, base_url, model, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty]
|
1068 |
)
|
1069 |
|
1070 |
save_llm_config_btn.click(
|
1071 |
save_llm_config,
|
1072 |
+
inputs=[provider, base_url, model, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty],
|
1073 |
outputs=[llm_config_status]
|
1074 |
)
|
1075 |
|
|
|
1084 |
outputs=[chatbot]
|
1085 |
)
|
1086 |
|
1087 |
+
provider.change(
|
1088 |
+
lambda x: x,
|
1089 |
+
inputs=[provider],
|
1090 |
+
outputs=[llm_provider_state]
|
1091 |
)
|
1092 |
|
1093 |
+
# search_button.click(
|
1094 |
+
# search_huggingface_datasets,
|
1095 |
+
# inputs=[dataset_search],
|
1096 |
+
# outputs=[dataset_results, dataset_input]
|
1097 |
+
# )
|
1098 |
|
1099 |
+
# dataset_results.change(
|
1100 |
+
# lambda choice: choice,
|
1101 |
+
# inputs=[dataset_results],
|
1102 |
+
# outputs=[dataset_input]
|
1103 |
+
# )
|
1104 |
+
|
1105 |
+
# load_dataset_button.click(
|
1106 |
+
# load_dataset_wrapper,
|
1107 |
+
# inputs=[dataset_input, dataset_split],
|
1108 |
+
# outputs=[loaded_dataset, dataset_status]
|
1109 |
+
# )
|
1110 |
|
1111 |
# Modify the start_generation_btn.click to include the loaded dataset
|
1112 |
start_generation_btn.click(
|
1113 |
run_generate_dataset,
|
1114 |
+
inputs=[num_workers, num_generations, output_file_path, llm_provider_state],
|
1115 |
outputs=[generation_status, generation_output]
|
1116 |
)
|
1117 |
|