import torch from transformers import PreTrainedModel from model import Moose class MooseModel(PreTrainedModel): def __init__(self, config): super().__init__(config) # Define your model architecture here self.model = Moose() def forward(self, *inputs, **kwargs): return self.model(*inputs, **kwargs) @classmethod def from_pretrained(cls, model_path): config = LLamaConfig.from_pretrained(model_path) model = cls(config) state_dict = torch.load(f"{model_path}/model.pth") model.load_state_dict(state_dict) return model