gokaygokay commited on
Commit
9f18ec6
1 Parent(s): eede49b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -19
app.py CHANGED
@@ -18,22 +18,7 @@ hf_hub_download(
18
  local_dir = "./models"
19
  )
20
 
21
- def initialize_model(model):
22
- global llm, llm_model
23
- if llm is None or llm_model != model:
24
- llm = Llama(
25
- model_path=f"models/{model}",
26
- flash_attn=True,
27
- n_gpu_layers=81,
28
- n_batch=1024,
29
- n_ctx=8192,
30
- )
31
- llm_model = model
32
- return llm
33
 
34
- # Initialize the model with the default model
35
- default_model = "Reflection-Llama-3.1-70B-Q3_K_M.gguf"
36
- initialize_model(default_model)
37
 
38
  def get_messages_formatter_type(model_name):
39
  if "Llama" in model_name:
@@ -42,7 +27,7 @@ def get_messages_formatter_type(model_name):
42
  raise ValueError(f"Unsupported model: {model_name}")
43
 
44
 
45
- @spaces.GPU(duration=60)
46
  def respond(
47
  message,
48
  history: list[tuple[str, str]],
@@ -54,12 +39,20 @@ def respond(
54
  top_k,
55
  repeat_penalty,
56
  ):
57
- global llm, llm_model
 
58
 
59
  chat_template = get_messages_formatter_type(model)
60
 
61
- if llm_model != model:
62
- llm = initialize_model(model)
 
 
 
 
 
 
 
63
 
64
  provider = LlamaCppPythonProvider(llm)
65
 
 
18
  local_dir = "./models"
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
 
 
 
22
 
23
  def get_messages_formatter_type(model_name):
24
  if "Llama" in model_name:
 
27
  raise ValueError(f"Unsupported model: {model_name}")
28
 
29
 
30
+ @spaces.GPU(duration=120)
31
  def respond(
32
  message,
33
  history: list[tuple[str, str]],
 
39
  top_k,
40
  repeat_penalty,
41
  ):
42
+ global llm
43
+ global llm_model
44
 
45
  chat_template = get_messages_formatter_type(model)
46
 
47
+ if llm is None or llm_model != model:
48
+ llm = Llama(
49
+ model_path=f"models/{model}",
50
+ flash_attn=True,
51
+ n_gpu_layers=81,
52
+ n_batch=1024,
53
+ n_ctx=8192,
54
+ )
55
+ llm_model = model
56
 
57
  provider = LlamaCppPythonProvider(llm)
58