BeardedMonster commited on
Commit
1005bd7
1 Parent(s): 0b93e56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -3
app.py CHANGED
@@ -7,6 +7,10 @@ import aiohttp
7
  import json
8
  import torch
9
  import re
 
 
 
 
10
 
11
  repo_name = "BeardedMonster/SabiYarn-125M"
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -147,6 +151,14 @@ async def generate_from_api(user_input, generation_config):
147
 
148
  return "FAILED"
149
 
 
 
 
 
 
 
 
 
150
 
151
  # Sample texts
152
  sample_texts = {
@@ -245,10 +257,12 @@ if st.button("Generate"):
245
  # Attempt the asynchronous API call
246
  generation_config["max_new_tokens"] = min(max_new_tokens, 1024 - len(tokenizer.tokenize(wrapped_input)))
247
  # generated_text = asyncio.run(generate_from_api(wrapped_input, generation_config))
 
 
248
 
249
- loop = asyncio.new_event_loop()
250
- asyncio.set_event_loop(loop)
251
- generated_text = loop.run_until_complete(generate_from_api(wrapped_input, generation_config))
252
  # except Exception as e:
253
  # print(f"API call failed: {e}. Using local model for text generation.")
254
  # Use the locally loaded model for text generation
 
7
  import json
8
  import torch
9
  import re
10
+ import nest_asyncio
11
+ from hashlib import md5
12
+
13
+ nest_asyncio.apply()
14
 
15
  repo_name = "BeardedMonster/SabiYarn-125M"
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
151
 
152
  return "FAILED"
153
 
154
+ def generate_cache_key(user_input, generation_config):
155
+ key_data = f"{user_input}_{json.dumps(generation_config, sort_keys=True)}"
156
+ return md5(key_data.encode()).hexdigest()
157
+
158
+ @st.cache_data(show_spinner=False)
159
+ def get_cached_response(user_input, generation_config):
160
+ return asyncio.run(generate_from_api(user_input, generation_config))
161
+
162
 
163
  # Sample texts
164
  sample_texts = {
 
257
  # Attempt the asynchronous API call
258
  generation_config["max_new_tokens"] = min(max_new_tokens, 1024 - len(tokenizer.tokenize(wrapped_input)))
259
  # generated_text = asyncio.run(generate_from_api(wrapped_input, generation_config))
260
+ cache_key = generate_cache_key(user_input, generation_config)
261
+ generated_text = get_cached_response(user_input, generation_config)
262
 
263
+ # loop = asyncio.new_event_loop()
264
+ # asyncio.set_event_loop(loop)
265
+ # generated_text = loop.run_until_complete(generate_from_api(wrapped_input, generation_config))
266
  # except Exception as e:
267
  # print(f"API call failed: {e}. Using local model for text generation.")
268
  # Use the locally loaded model for text generation