irotem98's picture
moondream_model_state_dict.pt
495fe55
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
from typing import Dict
import open_clip
from torch import Tensor, nn
class ClipTokenizer(nn.Module):
def __init__(self, cfg, *args, **kwargs):
super().__init__()
self.context_length = cfg["text_cfg"]["context_length"]
model_name = getattr(cfg["text_cfg"], "open_clip_tokenizer", "ViT-B-16")
self.tokenizer = open_clip.get_tokenizer(model_name)
def get_vocab_size(self) -> int:
return len(self.tokenizer.encoder)
def get_encodings(self) -> Dict[str, int]:
return self.tokenizer.encoder
def get_eot_token(self) -> int:
# Tokenizing an empty string returns a list [sot_id, eot_id]
return self.tokenizer("")[1]
def get_sot_token(self) -> int:
# Tokenizing an empty string returns a list [sot_id, eot_id]
return self.tokenizer("")[0]
def forward(self, input_sentence: str, *args, **kwargs) -> Tensor:
# tokenizer returns indices as a string
tokenized_sentence = self.tokenizer(input_sentence, self.context_length)
assert (
tokenized_sentence.shape[-1] == self.context_length
), "Tokenized tensor should be exactly `context_length` long."
return tokenized_sentence