nawhgnuj commited on
Commit
65e6974
1 Parent(s): 09408e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -49,6 +49,10 @@ quantization_config = BitsAndBytesConfig(
49
  bnb_4bit_quant_type="nf4")
50
 
51
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
 
 
 
 
52
  model = AutoModelForCausalLM.from_pretrained(
53
  MODEL,
54
  torch_dtype=torch.bfloat16,
@@ -89,19 +93,22 @@ Importantly, always respond to and rebut the previous speaker's points in Trump'
89
  conversation.append({"role": "user", "content": message})
90
 
91
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
 
92
 
93
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
94
 
95
  generate_kwargs = dict(
96
  input_ids=input_ids,
 
97
  max_new_tokens=max_new_tokens,
98
  do_sample=True,
99
  top_p=top_p,
100
  top_k=top_k,
101
  temperature=temperature,
102
- repetition_penalty=repetition_penalty,
103
- eos_token_id=[128001,128008,128009],
104
  streamer=streamer,
 
105
  )
106
 
107
  with torch.no_grad():
 
49
  bnb_4bit_quant_type="nf4")
50
 
51
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
52
+ if tokenizer.pad_token is None:
53
+ tokenizer.pad_token = tokenizer.eos_token
54
+ tokenizer.pad_token_id = tokenizer.eos_token_id
55
+
56
  model = AutoModelForCausalLM.from_pretrained(
57
  MODEL,
58
  torch_dtype=torch.bfloat16,
 
93
  conversation.append({"role": "user", "content": message})
94
 
95
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
96
+ attention_mask = torch.ones_like(input_ids)
97
 
98
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
99
 
100
  generate_kwargs = dict(
101
  input_ids=input_ids,
102
+ attention_mask=attention_mask,
103
  max_new_tokens=max_new_tokens,
104
  do_sample=True,
105
  top_p=top_p,
106
  top_k=top_k,
107
  temperature=temperature,
108
+ pad_token_id=tokenizer.pad_token_id,
109
+ eos_token_id=tokenizer.eos_token_id,
110
  streamer=streamer,
111
+ repetition_penalty=repetition_penalty,
112
  )
113
 
114
  with torch.no_grad():