davanstrien HF staff commited on
Commit
2b64754
1 Parent(s): f31fa4a

more prompts

Browse files
Files changed (1) hide show
  1. app.py +23 -7
app.py CHANGED
@@ -6,10 +6,12 @@ import uuid
6
  from huggingface_hub import InferenceClient, CommitScheduler, hf_hub_download
7
  from openai import OpenAI
8
  from huggingface_hub import get_token, login
9
- from prompts import detailed_genre_description_prompt, basic_prompt
10
  import random
11
  import os
12
  from pathlib import Path
 
 
13
 
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
 
@@ -94,13 +96,27 @@ def create_client(model_id):
94
  # )
95
 
96
 
97
- def generate_prompt():
98
- if random.choice([True, False]):
99
- return detailed_genre_description_prompt()
100
- else:
101
- return basic_prompt()
102
-
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def get_and_store_prompt():
105
  prompt = generate_prompt()
106
  print(prompt) # Keep this for debugging
 
6
  from huggingface_hub import InferenceClient, CommitScheduler, hf_hub_download
7
  from openai import OpenAI
8
  from huggingface_hub import get_token, login
9
+ from prompts import detailed_genre_description_prompt, basic_prompt, very_basic_prompt
10
  import random
11
  import os
12
  from pathlib import Path
13
+ import random
14
+ from typing import Callable, List, Tuple
15
 
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
 
 
96
  # )
97
 
98
 
 
 
 
 
 
 
99
 
100
+ def weighted_random_choice(choices: List[Tuple[Callable, float]]) -> Callable:
101
+ total = sum(weight for _, weight in choices)
102
+ r = random.uniform(0, total)
103
+ upto = 0
104
+ for choice, weight in choices:
105
+ if upto + weight >= r:
106
+ return choice
107
+ upto += weight
108
+ assert False, "Shouldn't get here"
109
+
110
+ def generate_prompt() -> str:
111
+ prompt_choices = [
112
+ (detailed_genre_description_prompt, 0.4),
113
+ (basic_prompt, 0.3),
114
+ (very_basic_prompt, 0.3),
115
+ ]
116
+
117
+ selected_prompt_func = weighted_random_choice(prompt_choices)
118
+ return selected_prompt_func()
119
+
120
  def get_and_store_prompt():
121
  prompt = generate_prompt()
122
  print(prompt) # Keep this for debugging