BeTaLabs commited on
Commit
6bdbc55
1 Parent(s): 78d99a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -3
app.py CHANGED
@@ -14,12 +14,17 @@ import random
14
  from params import load_params, save_params
15
  import pandas as pd
16
  import csv
 
 
 
17
 
18
 
19
 
20
  ANNOTATION_CONFIG_FILE = "annotation_config.json"
21
  OUTPUT_FILE_PATH = "dataset.jsonl"
22
 
 
 
23
  def load_llm_config():
24
  params = load_params()
25
  return (
@@ -34,6 +39,8 @@ def load_llm_config():
34
  params.get('presence_penalty', 0.0)
35
  )
36
 
 
 
37
  def save_llm_config(provider, base_url, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty):
38
  save_params({
39
  'PROVIDER': provider,
@@ -49,6 +56,8 @@ def save_llm_config(provider, base_url, workspace, api_key, max_tokens, temperat
49
  return "LLM configuration saved successfully"
50
 
51
 
 
 
52
  def load_annotation_config():
53
  try:
54
  with open(ANNOTATION_CONFIG_FILE, 'r') as f:
@@ -92,6 +101,8 @@ def load_annotation_config():
92
  }
93
 
94
 
 
 
95
  def load_csv_dataset(file_path):
96
  data = []
97
  with open(file_path, 'r') as f:
@@ -100,20 +111,28 @@ def load_csv_dataset(file_path):
100
  data.append(row)
101
  return data
102
 
 
 
103
  def load_txt_dataset(file_path):
104
  with open(file_path, 'r') as f:
105
  return [{"content": line.strip()} for line in f if line.strip()]
106
 
 
 
107
  def save_annotation_config(config):
108
  with open(ANNOTATION_CONFIG_FILE, 'w') as f:
109
  json.dump(config, f, indent=2)
110
 
 
 
111
  def load_jsonl_dataset(file_path):
112
  if not os.path.exists(file_path):
113
  return []
114
  with open(file_path, 'r') as f:
115
  return [json.loads(line.strip()) for line in f if line.strip()]
116
 
 
 
117
  def load_dataset(file):
118
  if file is None:
119
  return "", 0, 0, "No file uploaded", "3", [], [], [], ""
@@ -136,6 +155,8 @@ def load_dataset(file):
136
  first_row = json.dumps(data[0], indent=2)
137
  return first_row, 0, len(data), f"Row: 1/{len(data)}", "3", [], [], [], ""
138
 
 
 
139
  def save_row(file_path, index, row_data):
140
  file_extension = file_path.split('.')[-1].lower()
141
 
@@ -150,6 +171,8 @@ def save_row(file_path, index, row_data):
150
 
151
  return f"Row {index} saved successfully"
152
 
 
 
153
  def save_jsonl_row(file_path, index, row_data):
154
  with open(file_path, 'r') as f:
155
  lines = f.readlines()
@@ -159,6 +182,8 @@ def save_jsonl_row(file_path, index, row_data):
159
  with open(file_path, 'w') as f:
160
  f.writelines(lines)
161
 
 
 
162
  def save_csv_row(file_path, index, row_data):
163
  df = pd.read_csv(file_path)
164
  row_dict = json.loads(row_data)
@@ -166,6 +191,8 @@ def save_csv_row(file_path, index, row_data):
166
  df.at[index, col] = value
167
  df.to_csv(file_path, index=False)
168
 
 
 
169
  def save_txt_row(file_path, index, row_data):
170
  with open(file_path, 'r') as f:
171
  lines = f.readlines()
@@ -176,6 +203,8 @@ def save_txt_row(file_path, index, row_data):
176
  with open(file_path, 'w') as f:
177
  f.writelines(lines)
178
 
 
 
179
  def get_row(file_path, index):
180
  data = load_jsonl_dataset(file_path)
181
  if not data:
@@ -184,6 +213,8 @@ def get_row(file_path, index):
184
  return json.dumps(data[index], indent=2), len(data)
185
  return "", len(data)
186
 
 
 
187
  def json_to_markdown(json_str):
188
  try:
189
  data = json.loads(json_str)
@@ -192,6 +223,8 @@ def json_to_markdown(json_str):
192
  except json.JSONDecodeError:
193
  return "Error: Invalid JSON format"
194
 
 
 
195
  def markdown_to_json(markdown_str):
196
  sections = re.split(r'#\s+(System|Instruction|Response)\s*\n', markdown_str)
197
  if len(sections) != 7: # Should be: ['', 'System', content, 'Instruction', content, 'Response', content]
@@ -204,10 +237,14 @@ def markdown_to_json(markdown_str):
204
  }
205
  return json.dumps(json_data, indent=2)
206
 
 
 
207
  def navigate_rows(file_path: str, current_index: int, direction: Literal["prev", "next"], metadata_config):
208
  new_index = max(0, current_index + (-1 if direction == "prev" else 1))
209
  return load_and_show_row(file_path, new_index, metadata_config)
210
 
 
 
211
  def load_and_show_row(file_path, index, metadata_config):
212
  row_data, total = get_row(file_path, index)
213
  if not row_data:
@@ -229,6 +266,8 @@ def load_and_show_row(file_path, index, metadata_config):
229
  return (row_data, index, total, f"Row: {index + 1}/{total}", quality,
230
  high_quality_tags, low_quality_tags, toxic_tags, other)
231
 
 
 
232
  def save_row_with_metadata(file_path, index, row_data, config, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
233
  data = json.loads(row_data)
234
  metadata = {
@@ -248,6 +287,8 @@ def save_row_with_metadata(file_path, index, row_data, config, quality, high_qua
248
  data["metadata"] = metadata
249
  return save_row(file_path, index, json.dumps(data))
250
 
 
 
251
  def update_annotation_ui(config):
252
  quality_choices = [(item["value"], item["label"]) for item in config["quality_scale"]["scale"]]
253
  quality_label = gr.Radio(
@@ -271,6 +312,8 @@ def update_annotation_ui(config):
271
 
272
  return quality_label, *tag_components, other_description
273
 
 
 
274
  def load_config_to_ui(config):
275
  return (
276
  config["quality_scale"]["name"],
@@ -280,6 +323,8 @@ def load_config_to_ui(config):
280
  [[field["name"], field["description"]] for field in config["free_text_fields"]]
281
  )
282
 
 
 
283
  def save_config_from_ui(name, description, scale, categories, fields, topics, all_topics_text):
284
  if all_topics_text.visible:
285
  topics_list = [topic.strip() for topic in all_topics_text.split("\n") if topic.strip()]
@@ -299,6 +344,8 @@ def save_config_from_ui(name, description, scale, categories, fields, topics, al
299
  save_annotation_config(new_config)
300
  return "Configuration saved successfully", new_config
301
 
 
 
302
  # Add this new function to generate the preview
303
  def generate_preview(row_data, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
304
  try:
@@ -321,6 +368,8 @@ def generate_preview(row_data, quality, high_quality_tags, low_quality_tags, tox
321
  except json.JSONDecodeError:
322
  return "Error: Invalid JSON in the current row data"
323
 
 
 
324
  def load_dataset_config():
325
  params = load_params()
326
  with open("system_messages.py", "r") as f:
@@ -347,6 +396,8 @@ def load_dataset_config():
347
  params.get('presence_penalty', 0.0)
348
  )
349
 
 
 
350
  def edit_all_topics_func(topics):
351
  topics_list = [topic[0] for topic in topics]
352
  jsonl_rows = "\n".join([json.dumps({"topic": topic}) for topic in topics_list])
@@ -356,6 +407,8 @@ def edit_all_topics_func(topics):
356
  gr.update(visible=True)
357
  )
358
 
 
 
359
  def update_topics_from_text(text):
360
  try:
361
  # Try parsing as JSONL
@@ -366,6 +419,8 @@ def update_topics_from_text(text):
366
 
367
  return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
368
 
 
 
369
  def save_dataset_config(system_messages, prompt_1, topics, max_tokens, temperature, top_p, frequency_penalty, presence_penalty):
370
  # Save VODALUS_SYSTEM_MESSAGE to system_messages.py
371
  with open("system_messages.py", "w") as f:
@@ -426,6 +481,7 @@ def chat_with_llm(message, history):
426
  print(f"Error in chat_with_llm: {str(e)}")
427
  return history + [[message, f"Error: {str(e)}"]]
428
 
 
429
  def update_chat_context(row_data, index, total, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
430
  context = f"""Current app state:
431
  Row: {index + 1}/{total}
@@ -440,12 +496,16 @@ def update_chat_context(row_data, index, total, quality, high_quality_tags, low_
440
  return [[None, context]]
441
 
442
 
443
- async def run_generate_dataset(num_workers, num_generations, output_file_path):
 
 
 
 
444
  generated_data = []
445
  for _ in range(num_generations):
446
  topic_selected = random.choice(TOPICS)
447
  system_message_selected = random.choice(SYSTEM_MESSAGES_VODALUS)
448
- data = await generate_data(topic_selected, PROMPT_1, system_message_selected, output_file_path)
449
  if data:
450
  generated_data.append(json.dumps(data))
451
 
@@ -456,15 +516,21 @@ async def run_generate_dataset(num_workers, num_generations, output_file_path):
456
 
457
  return f"Generated {num_generations} entries and saved to {output_file_path}", "\n".join(generated_data[:5]) + "\n..."
458
 
 
 
459
  def add_topic_row(data):
460
  if isinstance(data, pd.DataFrame):
461
  return pd.concat([data, pd.DataFrame({"Topic": ["New Topic"]})], ignore_index=True)
462
  else:
463
  return data + [["New Topic"]]
464
 
 
 
465
  def remove_last_topic_row(data):
466
  return data[:-1] if len(data) > 1 else data
467
 
 
 
468
  def edit_all_topics_func(topics):
469
  topics_list = [topic[0] for topic in topics]
470
  jsonl_rows = "\n".join([json.dumps({"topic": topic}) for topic in topics_list])
@@ -474,6 +540,8 @@ def edit_all_topics_func(topics):
474
  gr.update(visible=True)
475
  )
476
 
 
 
477
  def update_topics_from_text(text):
478
  try:
479
  # Try parsing as JSONL
@@ -484,6 +552,8 @@ def update_topics_from_text(text):
484
 
485
  return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
486
 
 
 
487
  def update_topics_from_text(text):
488
  try:
489
  # Try parsing as JSONL
@@ -494,6 +564,82 @@ def update_topics_from_text(text):
494
 
495
  return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
  css = """
498
  body, #root {
499
  margin: 0;
@@ -740,6 +886,20 @@ with demo:
740
  with gr.Row():
741
  save_dataset_config_btn = gr.Button("Save Dataset Configuration", variant="primary")
742
  dataset_config_status = gr.Textbox(label="Status")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
743
 
744
 
745
  with gr.Tab("Dataset Generation"):
@@ -889,7 +1049,7 @@ with demo:
889
 
890
  start_generation_btn.click(
891
  run_generate_dataset,
892
- inputs=[num_workers, num_generations, output_file_path],
893
  outputs=[generation_status, generation_output]
894
  )
895
 
@@ -915,6 +1075,30 @@ with demo:
915
  outputs=[chatbot]
916
  )
917
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
918
 
919
  demo.load(
920
  lambda: (
 
14
  from params import load_params, save_params
15
  import pandas as pd
16
  import csv
17
+ from datasets import load_dataset
18
+ from huggingface_hub import list_datasets, HfApi, hf_hub_download
19
+
20
 
21
 
22
 
23
  ANNOTATION_CONFIG_FILE = "annotation_config.json"
24
  OUTPUT_FILE_PATH = "dataset.jsonl"
25
 
26
+
27
+
28
  def load_llm_config():
29
  params = load_params()
30
  return (
 
39
  params.get('presence_penalty', 0.0)
40
  )
41
 
42
+
43
+
44
  def save_llm_config(provider, base_url, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty):
45
  save_params({
46
  'PROVIDER': provider,
 
56
  return "LLM configuration saved successfully"
57
 
58
 
59
+
60
+
61
  def load_annotation_config():
62
  try:
63
  with open(ANNOTATION_CONFIG_FILE, 'r') as f:
 
101
  }
102
 
103
 
104
+
105
+
106
  def load_csv_dataset(file_path):
107
  data = []
108
  with open(file_path, 'r') as f:
 
111
  data.append(row)
112
  return data
113
 
114
+
115
+
116
  def load_txt_dataset(file_path):
117
  with open(file_path, 'r') as f:
118
  return [{"content": line.strip()} for line in f if line.strip()]
119
 
120
+
121
+
122
  def save_annotation_config(config):
123
  with open(ANNOTATION_CONFIG_FILE, 'w') as f:
124
  json.dump(config, f, indent=2)
125
 
126
+
127
+
128
  def load_jsonl_dataset(file_path):
129
  if not os.path.exists(file_path):
130
  return []
131
  with open(file_path, 'r') as f:
132
  return [json.loads(line.strip()) for line in f if line.strip()]
133
 
134
+
135
+
136
  def load_dataset(file):
137
  if file is None:
138
  return "", 0, 0, "No file uploaded", "3", [], [], [], ""
 
155
  first_row = json.dumps(data[0], indent=2)
156
  return first_row, 0, len(data), f"Row: 1/{len(data)}", "3", [], [], [], ""
157
 
158
+
159
+
160
  def save_row(file_path, index, row_data):
161
  file_extension = file_path.split('.')[-1].lower()
162
 
 
171
 
172
  return f"Row {index} saved successfully"
173
 
174
+
175
+
176
  def save_jsonl_row(file_path, index, row_data):
177
  with open(file_path, 'r') as f:
178
  lines = f.readlines()
 
182
  with open(file_path, 'w') as f:
183
  f.writelines(lines)
184
 
185
+
186
+
187
  def save_csv_row(file_path, index, row_data):
188
  df = pd.read_csv(file_path)
189
  row_dict = json.loads(row_data)
 
191
  df.at[index, col] = value
192
  df.to_csv(file_path, index=False)
193
 
194
+
195
+
196
  def save_txt_row(file_path, index, row_data):
197
  with open(file_path, 'r') as f:
198
  lines = f.readlines()
 
203
  with open(file_path, 'w') as f:
204
  f.writelines(lines)
205
 
206
+
207
+
208
  def get_row(file_path, index):
209
  data = load_jsonl_dataset(file_path)
210
  if not data:
 
213
  return json.dumps(data[index], indent=2), len(data)
214
  return "", len(data)
215
 
216
+
217
+
218
  def json_to_markdown(json_str):
219
  try:
220
  data = json.loads(json_str)
 
223
  except json.JSONDecodeError:
224
  return "Error: Invalid JSON format"
225
 
226
+
227
+
228
  def markdown_to_json(markdown_str):
229
  sections = re.split(r'#\s+(System|Instruction|Response)\s*\n', markdown_str)
230
  if len(sections) != 7: # Should be: ['', 'System', content, 'Instruction', content, 'Response', content]
 
237
  }
238
  return json.dumps(json_data, indent=2)
239
 
240
+
241
+
242
  def navigate_rows(file_path: str, current_index: int, direction: Literal["prev", "next"], metadata_config):
243
  new_index = max(0, current_index + (-1 if direction == "prev" else 1))
244
  return load_and_show_row(file_path, new_index, metadata_config)
245
 
246
+
247
+
248
  def load_and_show_row(file_path, index, metadata_config):
249
  row_data, total = get_row(file_path, index)
250
  if not row_data:
 
266
  return (row_data, index, total, f"Row: {index + 1}/{total}", quality,
267
  high_quality_tags, low_quality_tags, toxic_tags, other)
268
 
269
+
270
+
271
  def save_row_with_metadata(file_path, index, row_data, config, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
272
  data = json.loads(row_data)
273
  metadata = {
 
287
  data["metadata"] = metadata
288
  return save_row(file_path, index, json.dumps(data))
289
 
290
+
291
+
292
  def update_annotation_ui(config):
293
  quality_choices = [(item["value"], item["label"]) for item in config["quality_scale"]["scale"]]
294
  quality_label = gr.Radio(
 
312
 
313
  return quality_label, *tag_components, other_description
314
 
315
+
316
+
317
  def load_config_to_ui(config):
318
  return (
319
  config["quality_scale"]["name"],
 
323
  [[field["name"], field["description"]] for field in config["free_text_fields"]]
324
  )
325
 
326
+
327
+
328
  def save_config_from_ui(name, description, scale, categories, fields, topics, all_topics_text):
329
  if all_topics_text.visible:
330
  topics_list = [topic.strip() for topic in all_topics_text.split("\n") if topic.strip()]
 
344
  save_annotation_config(new_config)
345
  return "Configuration saved successfully", new_config
346
 
347
+
348
+
349
  # Add this new function to generate the preview
350
  def generate_preview(row_data, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
351
  try:
 
368
  except json.JSONDecodeError:
369
  return "Error: Invalid JSON in the current row data"
370
 
371
+
372
+
373
  def load_dataset_config():
374
  params = load_params()
375
  with open("system_messages.py", "r") as f:
 
396
  params.get('presence_penalty', 0.0)
397
  )
398
 
399
+
400
+
401
  def edit_all_topics_func(topics):
402
  topics_list = [topic[0] for topic in topics]
403
  jsonl_rows = "\n".join([json.dumps({"topic": topic}) for topic in topics_list])
 
407
  gr.update(visible=True)
408
  )
409
 
410
+
411
+
412
  def update_topics_from_text(text):
413
  try:
414
  # Try parsing as JSONL
 
419
 
420
  return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
421
 
422
+
423
+
424
  def save_dataset_config(system_messages, prompt_1, topics, max_tokens, temperature, top_p, frequency_penalty, presence_penalty):
425
  # Save VODALUS_SYSTEM_MESSAGE to system_messages.py
426
  with open("system_messages.py", "w") as f:
 
481
  print(f"Error in chat_with_llm: {str(e)}")
482
  return history + [[message, f"Error: {str(e)}"]]
483
 
484
+
485
  def update_chat_context(row_data, index, total, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
486
  context = f"""Current app state:
487
  Row: {index + 1}/{total}
 
496
  return [[None, context]]
497
 
498
 
499
+
500
+ async def run_generate_dataset(num_workers, num_generations, output_file_path, loaded_dataset):
501
+ if loaded_dataset is None:
502
+ return "Error: No dataset loaded. Please load a dataset before generating.", ""
503
+
504
  generated_data = []
505
  for _ in range(num_generations):
506
  topic_selected = random.choice(TOPICS)
507
  system_message_selected = random.choice(SYSTEM_MESSAGES_VODALUS)
508
+ data = await generate_data(topic_selected, PROMPT_1, system_message_selected, output_file_path, loaded_dataset)
509
  if data:
510
  generated_data.append(json.dumps(data))
511
 
 
516
 
517
  return f"Generated {num_generations} entries and saved to {output_file_path}", "\n".join(generated_data[:5]) + "\n..."
518
 
519
+
520
+
521
  def add_topic_row(data):
522
  if isinstance(data, pd.DataFrame):
523
  return pd.concat([data, pd.DataFrame({"Topic": ["New Topic"]})], ignore_index=True)
524
  else:
525
  return data + [["New Topic"]]
526
 
527
+
528
+
529
  def remove_last_topic_row(data):
530
  return data[:-1] if len(data) > 1 else data
531
 
532
+
533
+
534
  def edit_all_topics_func(topics):
535
  topics_list = [topic[0] for topic in topics]
536
  jsonl_rows = "\n".join([json.dumps({"topic": topic}) for topic in topics_list])
 
540
  gr.update(visible=True)
541
  )
542
 
543
+
544
+
545
  def update_topics_from_text(text):
546
  try:
547
  # Try parsing as JSONL
 
552
 
553
  return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
554
 
555
+
556
+
557
  def update_topics_from_text(text):
558
  try:
559
  # Try parsing as JSONL
 
564
 
565
  return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
566
 
567
+
568
+
569
+ def search_huggingface_datasets(query):
570
+ try:
571
+ api = HfApi()
572
+ datasets = api.list_datasets(search=query, limit=20)
573
+ dataset_ids = [dataset.id for dataset in datasets]
574
+ return gr.update(choices=dataset_ids, visible=True), ""
575
+ except Exception as e:
576
+ print(f"Error searching datasets: {str(e)}")
577
+ return gr.update(choices=["Error: Could not search datasets"], visible=True), ""
578
+
579
+
580
+
581
+ def load_huggingface_dataset(dataset_name, split="train"):
582
+ try:
583
+ print(f"Attempting to load dataset: {dataset_name}")
584
+
585
+ # Check if dataset_name is a string
586
+ if not isinstance(dataset_name, str):
587
+ raise ValueError(f"Expected dataset_name to be a string, but got {type(dataset_name)}")
588
+
589
+ # Try loading the dataset without specifying a config
590
+ full_dataset = load_dataset(dataset_name)
591
+
592
+ print(f"Dataset loaded. Available splits: {list(full_dataset.keys())}")
593
+
594
+ # Select the appropriate split
595
+ if split in full_dataset:
596
+ dataset = full_dataset[split]
597
+ print(f"Using specified split: {split}")
598
+ else:
599
+ available_splits = list(full_dataset.keys())
600
+ if available_splits:
601
+ dataset = full_dataset[available_splits[0]]
602
+ split = available_splits[0]
603
+ print(f"Specified split not found. Using first available split: {split}")
604
+ else:
605
+ raise ValueError("No valid splits found in the dataset")
606
+
607
+ return dataset, f"Dataset '{dataset_name}' (split: {split}) loaded successfully."
608
+ except Exception as e:
609
+ error_msg = f"Error loading dataset: {str(e)}"
610
+ print(f"Error details: {error_msg}")
611
+
612
+ # If loading fails, try to get the dataset card
613
+ try:
614
+ dataset_card = hf_hub_download(repo_id=dataset_name, filename="README.md")
615
+ with open(dataset_card, 'r') as f:
616
+ card_content = f.read()
617
+ return None, f"Dataset couldn't be loaded, but here's the dataset card:\n\n{card_content[:500]}..."
618
+ except:
619
+ return None, error_msg
620
+
621
+ # Wrapper function to handle the Gradio interface
622
+ def load_dataset_wrapper(dataset_name, split):
623
+ if not dataset_name:
624
+ return None, "Please enter a dataset name."
625
+ dataset, message = load_huggingface_dataset(dataset_name, split)
626
+ return dataset, message
627
+
628
+
629
+ def get_popular_datasets():
630
+ return [
631
+ "wikipedia",
632
+ "squad",
633
+ "glue",
634
+ "imdb",
635
+ "wmt16",
636
+ "common_voice",
637
+ "cnn_dailymail",
638
+ "amazon_reviews_multi",
639
+ "yelp_review_full",
640
+ "ag_news"
641
+ ]
642
+
643
  css = """
644
  body, #root {
645
  margin: 0;
 
886
  with gr.Row():
887
  save_dataset_config_btn = gr.Button("Save Dataset Configuration", variant="primary")
888
  dataset_config_status = gr.Textbox(label="Status")
889
+
890
+ # gr.Markdown("### Hugging Face Dataset")
891
+ # with gr.Row():
892
+ # dataset_search = gr.Textbox(label="Search Datasets")
893
+ # search_button = gr.Button("Search")
894
+ # dataset_input = gr.Textbox(label="Dataset Name", info="Enter a dataset name or select from search results")
895
+ # dataset_results = gr.Radio(label="Search Results", choices=[], visible=False)
896
+ # dataset_split = gr.Textbox(label="Dataset Split (optional)", value="train")
897
+ # load_dataset_button = gr.Button("Load Selected Dataset")
898
+ # dataset_status = gr.Textbox(label="Dataset Status")
899
+
900
+ # Add a state to store the loaded dataset
901
+ # loaded_dataset = gr.State(None)
902
+
903
 
904
 
905
  with gr.Tab("Dataset Generation"):
 
1049
 
1050
  start_generation_btn.click(
1051
  run_generate_dataset,
1052
+ inputs=[num_workers, num_generations, output_file_path, loaded_dataset],
1053
  outputs=[generation_status, generation_output]
1054
  )
1055
 
 
1075
  outputs=[chatbot]
1076
  )
1077
 
1078
+ search_button.click(
1079
+ search_huggingface_datasets,
1080
+ inputs=[dataset_search],
1081
+ outputs=[dataset_results, dataset_input]
1082
+ )
1083
+
1084
+ dataset_results.change(
1085
+ lambda choice: choice,
1086
+ inputs=[dataset_results],
1087
+ outputs=[dataset_input]
1088
+ )
1089
+
1090
+ load_dataset_button.click(
1091
+ load_dataset_wrapper,
1092
+ inputs=[dataset_input, dataset_split],
1093
+ outputs=[loaded_dataset, dataset_status]
1094
+ )
1095
+
1096
+ # Modify the start_generation_btn.click to include the loaded dataset
1097
+ start_generation_btn.click(
1098
+ run_generate_dataset,
1099
+ inputs=[num_workers, num_generations, output_file_path, loaded_dataset],
1100
+ outputs=[generation_status, generation_output]
1101
+ )
1102
 
1103
  demo.load(
1104
  lambda: (