Nekochu commited on
Commit
b0302a5
1 Parent(s): bee5b00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -25,6 +25,12 @@ if not torch.cuda.is_available():
25
 
26
  models_cache = {}
27
 
 
 
 
 
 
 
28
  @spaces.GPU(duration=120)
29
  def generate(
30
  model_id: str,
@@ -37,14 +43,12 @@ def generate(
37
  top_k: int = 50,
38
  repetition_penalty: float = 1.2,
39
  ) -> Iterator[str]:
 
40
  if model_id not in models_cache:
41
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
42
- tokenizer = AutoTokenizer.from_pretrained(model_id)
43
- tokenizer.use_default_system_prompt = False
44
  models_cache[model_id] = (model, tokenizer)
45
  else:
46
  model, tokenizer = models_cache[model_id]
47
-
48
  conversation = []
49
  if system_prompt:
50
  conversation.append({"role": "system", "content": system_prompt})
 
25
 
26
  models_cache = {}
27
 
28
+ def load_model(model_id: str):
29
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
30
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
31
+ tokenizer.use_default_system_prompt = False
32
+ return model, tokenizer
33
+
34
  @spaces.GPU(duration=120)
35
  def generate(
36
  model_id: str,
 
43
  top_k: int = 50,
44
  repetition_penalty: float = 1.2,
45
  ) -> Iterator[str]:
46
+ # Load the model if it's not already loaded
47
  if model_id not in models_cache:
48
+ model, tokenizer = load_model(model_id)
 
 
49
  models_cache[model_id] = (model, tokenizer)
50
  else:
51
  model, tokenizer = models_cache[model_id]
 
52
  conversation = []
53
  if system_prompt:
54
  conversation.append({"role": "system", "content": system_prompt})