gokaygokay commited on
Commit
41c56d2
1 Parent(s): 0c9ea73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -11
app.py CHANGED
@@ -18,7 +18,22 @@ hf_hub_download(
18
  local_dir = "./models"
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
 
 
 
22
 
23
  def get_messages_formatter_type(model_name):
24
  if "Llama" in model_name:
@@ -39,20 +54,12 @@ def respond(
39
  top_k,
40
  repeat_penalty,
41
  ):
42
- global llm
43
- global llm_model
44
 
45
  chat_template = get_messages_formatter_type(model)
46
 
47
- if llm is None or llm_model != model:
48
- llm = Llama(
49
- model_path=f"models/{model}",
50
- flash_attn=True,
51
- n_gpu_layers=81,
52
- n_batch=1024,
53
- n_ctx=8192,
54
- )
55
- llm_model = model
56
 
57
  provider = LlamaCppPythonProvider(llm)
58
 
 
18
  local_dir = "./models"
19
  )
20
 
21
+ def initialize_model(model):
22
+ global llm, llm_model
23
+ if llm is None or llm_model != model:
24
+ llm = Llama(
25
+ model_path=f"models/{model}",
26
+ flash_attn=True,
27
+ n_gpu_layers=81,
28
+ n_batch=1024,
29
+ n_ctx=8192,
30
+ )
31
+ llm_model = model
32
+ return llm
33
 
34
+ # Initialize the model with the default model
35
+ default_model = "Reflection-Llama-3.1-70B-Q3_K_M.gguf"
36
+ initialize_model(default_model)
37
 
38
  def get_messages_formatter_type(model_name):
39
  if "Llama" in model_name:
 
54
  top_k,
55
  repeat_penalty,
56
  ):
57
+ global llm, llm_model
 
58
 
59
  chat_template = get_messages_formatter_type(model)
60
 
61
+ if llm_model != model:
62
+ llm = initialize_model(model)
 
 
 
 
 
 
 
63
 
64
  provider = LlamaCppPythonProvider(llm)
65