tangzhy commited on
Commit
b7bc525
1 Parent(s): b57680f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -8,7 +8,7 @@ import torch
8
  from transformers import (
9
  AutoModelForCausalLM,
10
  BitsAndBytesConfig,
11
- GemmaTokenizerFast,
12
  TextIteratorStreamer,
13
  )
14
 
@@ -29,7 +29,7 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
29
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
30
 
31
  model_id = "google/gemma-2-9b-it"
32
- tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
33
  model = AutoModelForCausalLM.from_pretrained(
34
  model_id,
35
  device_map="auto",
 
8
  from transformers import (
9
  AutoModelForCausalLM,
10
  BitsAndBytesConfig,
11
+ AutoTokenizer,
12
  TextIteratorStreamer,
13
  )
14
 
 
29
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
30
 
31
  model_id = "google/gemma-2-9b-it"
32
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
33
  model = AutoModelForCausalLM.from_pretrained(
34
  model_id,
35
  device_map="auto",