|
--- |
|
base_model: google/gemma-2-2b-it |
|
library_name: peft |
|
license: other |
|
tags: |
|
- llama-factory |
|
- lora |
|
- generated_from_trainer |
|
model-index: |
|
- name: PPRM-gemma-2-2b-it |
|
results: [] |
|
--- |
|
|
|
<!-- This model card has been generated automatically according to the information the Trainer had access to. You |
|
should probably proofread and complete it, then remove this comment. --> |
|
|
|
# PRM_DPO_GEMMA_ZD_8_18_1 |
|
|
|
This model is a fine-tuned version of [google/gemma-2-2b-it](https://huggingface.co/google/gemma-2-2b-it) on the prm_dpo dataset. |
|
|
|
# Citation |
|
``` |
|
@article{zhang2024llama, |
|
title={LLaMA-Berry: Pairwise Optimization for O1-like Olympiad-Level Mathematical Reasoning}, |
|
author={Zhang, Di and Wu, Jianbo and Lei, Jingdi and Che, Tong and Li, Jiatong and Xie, Tong and Huang, Xiaoshui and Zhang, Shufei and Pavone, Marco and Li, Yuqiang and others}, |
|
journal={arXiv preprint arXiv:2410.02884}, |
|
year={2024} |
|
} |
|
|
|
@article{zhang2024accessing, |
|
title={Accessing GPT-4 level Mathematical Olympiad Solutions via Monte Carlo Tree Self-refine with LLaMa-3 8B}, |
|
author={Zhang, Di and Li, Jiatong and Huang, Xiaoshui and Zhou, Dongzhan and Li, Yuqiang and Ouyang, Wanli}, |
|
journal={arXiv preprint arXiv:2406.07394}, |
|
year={2024} |
|
} |
|
|
|
|
|
``` |
|
|
|
## Model usage |
|
|
|
`server.py` |
|
``` |
|
import json |
|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from peft import PeftModel |
|
import torch |
|
|
|
# Initialize FastAPI |
|
app = FastAPI() |
|
|
|
# Device configuration |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
# Model and tokenizer loading (as you provided) |
|
model_name = "google/gemma-2-2b-it" |
|
|
|
lora_checkpoint_path = "qq8933/PPRM-gemma-2-2b-it" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
base_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map='cuda') |
|
model = PeftModel.from_pretrained(base_model, lora_checkpoint_path, device_map='cuda') |
|
|
|
yes_token_id = tokenizer.convert_tokens_to_ids("yes") |
|
no_token_id = tokenizer.convert_tokens_to_ids("no") |
|
|
|
# Request model |
|
class InputRequest(BaseModel): |
|
text: str |
|
|
|
# Predict function |
|
def predict(qeustion,answer_1,answer_2): |
|
prompt_template = """Problem:\n\n{}\n\nFirst Answer:\n\n{}\n\nSecond Answer:\n\n{}\n\nIs First Answer better than Second Answer?\n\n""" |
|
input_text = prompt_template.format(qeustion,answer_1,answer_2) |
|
input_text = tokenizer.apply_chat_template( |
|
[{'role': 'user', 'content': input_text}], tokenize=False, add_generation_prompt=True |
|
) |
|
inputs = tokenizer(input_text, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
generated_outputs = model.generate( |
|
**inputs, max_new_tokens=2, output_scores=True, return_dict_in_generate=True |
|
) |
|
scores = generated_outputs.scores |
|
first_token_logits = scores[0] |
|
yes_logit = first_token_logits[0, yes_token_id].item() |
|
no_logit = first_token_logits[0, no_token_id].item() |
|
|
|
return { |
|
"yes_logit": yes_logit, |
|
"no_logit": no_logit, |
|
"logit_difference": yes_logit - no_logit |
|
} |
|
|
|
# Define API endpoint |
|
@app.post("/predict") |
|
async def get_prediction(input_request: InputRequest): |
|
payload = json.loads(input_request.text) |
|
qeustion,answer_1,answer_2 = payload['qeustion'],payload['answer_1'],payload['answer_2'] |
|
try: |
|
result = predict(qeustion,answer_1,answer_2) |
|
return result |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
``` |
|
run pprm_server |
|
``` |
|
uvicorn server:app --host 0.0.0.0 --port $MASTER_PORT --workers 1 |
|
``` |
|
request pprm server |
|
``` |
|
# qeustion,answer_1,answer_2 = 'What is the capital of France?', 'Berlin', 'Paris' |
|
# {'yes_logit': -24.26136016845703, 'no_logit': 19.517587661743164, 'logit_difference': -43.778947830200195} |
|
# Is answer_1 better than answer_2? yes or no |
|
# 奖励模型的入口 |
|
def request_prediction( |
|
qeustion, answer_1, answer_2, url="http://10.140.24.56:10085/predict" |
|
): |
|
""" |
|
Sends a POST request to the FastAPI server to get a prediction. |
|
|
|
Args: |
|
- text (str): The input text for the prediction. |
|
- url (str): The API endpoint URL. Defaults to 'http://localhost:8000/predict'. |
|
|
|
Returns: |
|
- dict: The response from the API containing prediction results. |
|
""" |
|
headers = {"Content-Type": "application/json"} |
|
payload = { |
|
"text": json.dumps( |
|
{"qeustion": qeustion, "answer_1": answer_1, "answer_2": answer_2} |
|
) |
|
} |
|
|
|
response = requests.post(url, json=payload, headers=headers, timeout=TIMEOUT_PRM) |
|
response.raise_for_status() # Raises an HTTPError if the response code was unsuccessful |
|
return response.json() # Return the JSON response as a dictionary |
|
|
|
def cal_reward(question, ans, ans2="I don't know"): |
|
if ans2 in DUMMY_ANSWERS:#I don't know |
|
return 1 |
|
if ans in DUMMY_ANSWERS: |
|
return 0 |
|
urls = copy.deepcopy(prm_servers) |
|
random.shuffle(urls) |
|
for url in urls: |
|
try: |
|
response = request_prediction(question, ans, ans2, url) |
|
return math.exp(response["yes_logit"]) / ( |
|
math.exp(response["yes_logit"]) + math.exp(response["no_logit"]) |
|
) |
|
except Exception as e: |
|
# print(e) |
|
continue |
|
print(Exception("All prm servers are down")) |
|
# get_clients() |
|
return cal_reward(question, ans, ans2) |
|
``` |
|
|
|
## Training procedure |
|
|
|
### Training hyperparameters |
|
|
|
The following hyperparameters were used during training: |
|
- learning_rate: 5e-05 |
|
- train_batch_size: 4 |
|
- eval_batch_size: 8 |
|
- seed: 42 |
|
- distributed_type: multi-GPU |
|
- num_devices: 16 |
|
- gradient_accumulation_steps: 2 |
|
- total_train_batch_size: 128 |
|
- total_eval_batch_size: 128 |
|
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08 |
|
- lr_scheduler_type: linear |
|
- num_epochs: 1.0 |
|
|
|
|
|
### Framework versions |
|
|
|
- PEFT 0.11.1 |
|
- Transformers 4.44.0 |
|
- Pytorch 2.3.1 |
|
- Datasets 2.20.0 |
|
- Tokenizers 0.19.1 |