Med Tiouti commited on
Commit
ca90067
1 Parent(s): 94bff1a

Setup for runpod

Browse files
Files changed (2) hide show
  1. app.py +7 -21
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  # retrievers
3
  from langchain.chains import RetrievalQA
 
4
 
5
  import textwrap
6
  import time
@@ -49,7 +50,7 @@ repetition_penalty = 1.15
49
 
50
  pipe = pipeline(
51
  task = "text-generation",
52
- model = model,
53
  tokenizer = tokenizer,
54
  pad_token_id = tokenizer.eos_token_id,
55
  max_length = max_len,
@@ -136,27 +137,12 @@ def process_llm_response(llm_response):
136
  ans += "\n Sand Hill Road podcast episodes based on your question : \n" + sources_used
137
  return ans,sources_used
138
 
139
- def llm_ans(query):
140
- start = time.time()
141
- llm_response = qa_chain(query)
142
  ans,sources_used = process_llm_response(llm_response)
143
- end = time.time()
 
144
 
145
- time_elapsed = int(round(end - start, 0))
146
- time_elapsed_str = f'\n\nTime elapsed: {time_elapsed} s'
147
- return ans, sources_used ,time_elapsed_str
148
 
149
 
150
- def predict(message, history):
151
- # output = message # debug mode
152
-
153
- output = str(llm_ans(message)[0]).replace("\n", "<br/>")
154
- return output
155
-
156
- demo = gr.ChatInterface(
157
- predict,
158
- title = f' Sand Hill Road Podcast Chatbot'
159
- )
160
-
161
- demo.queue()
162
- demo.launch(debug=True,share=True)
 
1
  import gradio as gr
2
  # retrievers
3
  from langchain.chains import RetrievalQA
4
+ import runpod
5
 
6
  import textwrap
7
  import time
 
50
 
51
  pipe = pipeline(
52
  task = "text-generation",
53
+ model = "daryl149/llama-2-13b-chat-hf",
54
  tokenizer = tokenizer,
55
  pad_token_id = tokenizer.eos_token_id,
56
  max_length = max_len,
 
137
  ans += "\n Sand Hill Road podcast episodes based on your question : \n" + sources_used
138
  return ans,sources_used
139
 
140
+ def text_generation(job):
141
+ llm_response = qa_chain(job_input = job["prompt"])
 
142
  ans,sources_used = process_llm_response(llm_response)
143
+
144
+ return str(ans).replace("\n", "<br/>")
145
 
 
 
 
146
 
147
 
148
+ runpod.serverless.start({"handler": text_generation})
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -6,4 +6,5 @@ sentence_transformers
6
  accelerate
7
  bitsandbytes
8
  xformers
 
9
  einops
 
6
  accelerate
7
  bitsandbytes
8
  xformers
9
+ runpod
10
  einops