Edit model card

GPT-2 Medium Fine-Tuned on Anthropic-hh Dataset

This repository houses a GPT-2 Medium model fine-tuned on the Anthropic-hh dataset. The fine-tuning process involved masking Human's utterances, with the loss computed exclusively on the Assistant's responses.

Model Information

  • Base Model: GPT-2 Medium
  • Training Data: Anthropic-hh dataset
  • Fine-Tuning Approach: Supervised fine-tuning with a focus on Assistant's responses.

How to Use

from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained("RaushanTurganbay/GPT2_instruct_tuned")
model = GPT2LMHeadModel.from_pretrained("RaushanTurganbay/GPT2_instruct_tuned")

# Generate responses
class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = [stop.to("cuda") for stop in stops]
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False


def stopping_criteria(tokenizer, stop_words):
    stop_words_ids = [tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
    return stopping_criteria

# Generate responses
stopping = stopping_criteria(tokenizer, ["\n\nHuman:"])
prompt = "\n\nHuman: {your_instruction}\n\nAssistant:"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs,  stopping_criteria=stopping, max_length=150)

print("Model Response:", tokenizer.batch_decode(outputs))
Downloads last month
2
Safetensors
Model size
355M params
Tensor type
F32
·
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train RaushanTurganbay/GPT2_instruct_tuned