Muennighoff commited on
Commit
91f94e3
1 Parent(s): 18652d8
Files changed (1) hide show
  1. data_utils.py +9 -9
data_utils.py CHANGED
@@ -21,7 +21,7 @@ from tensorflow.python.framework import ops
21
  from tensorflow.python.ops import array_ops
22
  from transformers import PreTrainedTokenizerFast
23
 
24
- from . import seqio_tokenizer as vocab
25
  from .constants import *
26
  from .utils import pop_metadata
27
  from .util import is_url
@@ -43,21 +43,21 @@ def build_tokenizer(
43
  return cache[cache_key]
44
 
45
  if tokenizer_type == 'llama':
46
- tok = vocab.SentencePieceVocabulary(
47
  os.path.join(tokenizer_dir, "llama_tokenizer.model"),
48
  extra_ids=DEFAULT_EXTRA_IDS,
49
  reverse_extra_ids=True,
50
  extra_tokens=EXTRA_TOKENS if has_extra_token else None,
51
  )
52
  elif tokenizer_type == 'yi':
53
- tok = vocab.SentencePieceVocabulary(
54
  os.path.join(tokenizer_dir, "yi_tokenizer.model"),
55
  extra_ids=DEFAULT_EXTRA_IDS,
56
  reverse_extra_ids=True,
57
  extra_tokens=EXTRA_TOKENS if has_extra_token else None,
58
  )
59
  elif tokenizer_type == 'mistral':
60
- tok = vocab.SentencePieceVocabulary(
61
  os.path.join(tokenizer_dir, "mistral_tokenizer.model"),
62
  extra_ids=DEFAULT_EXTRA_IDS,
63
  reverse_extra_ids=True,
@@ -65,14 +65,14 @@ def build_tokenizer(
65
  )
66
 
67
  elif tokenizer_type == "mistral0.3":
68
- tok = vocab.SentencePieceVocabulary(
69
  os.path.join(tokenizer_dir, "mistral0.3_tokenizer.model.v3"),
70
  extra_ids=DEFAULT_EXTRA_IDS,
71
  reverse_extra_ids=True,
72
  extra_tokens=EXTRA_TOKENS if has_extra_token else None,
73
  )
74
  elif tokenizer_type == 'gemma':
75
- tok = vocab.SentencePieceVocabulary(
76
  os.path.join(tokenizer_dir, "gemma_tokenizer.model"),
77
  extra_ids=DEFAULT_EXTRA_IDS,
78
  reverse_extra_ids=True,
@@ -114,7 +114,7 @@ def build_tokenizer(
114
  ids = tokenizer.encode(tok, add_special_tokens=False)
115
  assert ids == [pad_tokenizer_to + ix]
116
 
117
- tok = vocab.HfTokenizerWrapper(tokenizer, bos_token_id=bos_token_id, adds_space=adds_space)
118
  elif tokenizer_type.startswith("olmo-"):
119
  from olmo.tokenizer import Tokenizer
120
  assert Path(tokenizer_type[5:]).is_file()
@@ -123,7 +123,7 @@ def build_tokenizer(
123
  eos_token_id=olmo_eos_token_id,
124
  pad_token_id=-1,
125
  )
126
- tok = vocab.OLMoTokenizerWrapper(tokenizer, bos_token_id=olmo_bos_token_id, adds_space=adds_space)
127
  else:
128
  raise NotImplementedError(tokenizer_type)
129
  cache[cache_key] = tok
@@ -131,7 +131,7 @@ def build_tokenizer(
131
 
132
 
133
  def get_special_token_ids(tokenizer):
134
- if isinstance(tokenizer, (vocab.HfTokenizerWrapper, vocab.OLMoTokenizerWrapper)):
135
  ids = tokenizer.encode("".join(EXTRA_TOKENS))
136
  if len(ids) == len(EXTRA_TOKENS) + 1:
137
  ids = ids[1:]
 
21
  from tensorflow.python.ops import array_ops
22
  from transformers import PreTrainedTokenizerFast
23
 
24
+ from .seqio_tokenizer import SentencePieceVocabulary, HfTokenizerWrapper, OLMoTokenizerWrapper
25
  from .constants import *
26
  from .utils import pop_metadata
27
  from .util import is_url
 
43
  return cache[cache_key]
44
 
45
  if tokenizer_type == 'llama':
46
+ tok = SentencePieceVocabulary(
47
  os.path.join(tokenizer_dir, "llama_tokenizer.model"),
48
  extra_ids=DEFAULT_EXTRA_IDS,
49
  reverse_extra_ids=True,
50
  extra_tokens=EXTRA_TOKENS if has_extra_token else None,
51
  )
52
  elif tokenizer_type == 'yi':
53
+ tok = SentencePieceVocabulary(
54
  os.path.join(tokenizer_dir, "yi_tokenizer.model"),
55
  extra_ids=DEFAULT_EXTRA_IDS,
56
  reverse_extra_ids=True,
57
  extra_tokens=EXTRA_TOKENS if has_extra_token else None,
58
  )
59
  elif tokenizer_type == 'mistral':
60
+ tok = SentencePieceVocabulary(
61
  os.path.join(tokenizer_dir, "mistral_tokenizer.model"),
62
  extra_ids=DEFAULT_EXTRA_IDS,
63
  reverse_extra_ids=True,
 
65
  )
66
 
67
  elif tokenizer_type == "mistral0.3":
68
+ tok = SentencePieceVocabulary(
69
  os.path.join(tokenizer_dir, "mistral0.3_tokenizer.model.v3"),
70
  extra_ids=DEFAULT_EXTRA_IDS,
71
  reverse_extra_ids=True,
72
  extra_tokens=EXTRA_TOKENS if has_extra_token else None,
73
  )
74
  elif tokenizer_type == 'gemma':
75
+ tok = SentencePieceVocabulary(
76
  os.path.join(tokenizer_dir, "gemma_tokenizer.model"),
77
  extra_ids=DEFAULT_EXTRA_IDS,
78
  reverse_extra_ids=True,
 
114
  ids = tokenizer.encode(tok, add_special_tokens=False)
115
  assert ids == [pad_tokenizer_to + ix]
116
 
117
+ tok = HfTokenizerWrapper(tokenizer, bos_token_id=bos_token_id, adds_space=adds_space)
118
  elif tokenizer_type.startswith("olmo-"):
119
  from olmo.tokenizer import Tokenizer
120
  assert Path(tokenizer_type[5:]).is_file()
 
123
  eos_token_id=olmo_eos_token_id,
124
  pad_token_id=-1,
125
  )
126
+ tok = OLMoTokenizerWrapper(tokenizer, bos_token_id=olmo_bos_token_id, adds_space=adds_space)
127
  else:
128
  raise NotImplementedError(tokenizer_type)
129
  cache[cache_key] = tok
 
131
 
132
 
133
  def get_special_token_ids(tokenizer):
134
+ if isinstance(tokenizer, (HfTokenizerWrapper, OLMoTokenizerWrapper)):
135
  ids = tokenizer.encode("".join(EXTRA_TOKENS))
136
  if len(ids) == len(EXTRA_TOKENS) + 1:
137
  ids = ids[1:]