qq8933 commited on
Commit
06d5b88
1 Parent(s): 2056078

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +92 -15
README.md CHANGED
@@ -7,7 +7,7 @@ tags:
7
  - lora
8
  - generated_from_trainer
9
  model-index:
10
- - name: PRM_DPO_GEMMA_ZD_8_18_1
11
  results: []
12
  ---
13
 
@@ -18,17 +18,97 @@ should probably proofread and complete it, then remove this comment. -->
18
 
19
  This model is a fine-tuned version of [google/gemma-2-2b-it](https://huggingface.co/google/gemma-2-2b-it) on the prm_dpo dataset.
20
 
21
- ## Model description
22
-
23
- More information needed
24
-
25
- ## Intended uses & limitations
26
-
27
- More information needed
28
-
29
- ## Training and evaluation data
30
-
31
- More information needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  ## Training procedure
34
 
@@ -48,9 +128,6 @@ The following hyperparameters were used during training:
48
  - lr_scheduler_type: linear
49
  - num_epochs: 1.0
50
 
51
- ### Training results
52
-
53
-
54
 
55
  ### Framework versions
56
 
 
7
  - lora
8
  - generated_from_trainer
9
  model-index:
10
+ - name: PPRM-gemma-2-2b-it
11
  results: []
12
  ---
13
 
 
18
 
19
  This model is a fine-tuned version of [google/gemma-2-2b-it](https://huggingface.co/google/gemma-2-2b-it) on the prm_dpo dataset.
20
 
21
+ # Citation
22
+ ```
23
+ @article{zhang2024llama,
24
+ title={LLaMA-Berry: Pairwise Optimization for O1-like Olympiad-Level Mathematical Reasoning},
25
+ author={Zhang, Di and Wu, Jianbo and Lei, Jingdi and Che, Tong and Li, Jiatong and Xie, Tong and Huang, Xiaoshui and Zhang, Shufei and Pavone, Marco and Li, Yuqiang and others},
26
+ journal={arXiv preprint arXiv:2410.02884},
27
+ year={2024}
28
+ }
29
+
30
+ @article{zhang2024accessing,
31
+ title={Accessing GPT-4 level Mathematical Olympiad Solutions via Monte Carlo Tree Self-refine with LLaMa-3 8B},
32
+ author={Zhang, Di and Li, Jiatong and Huang, Xiaoshui and Zhou, Dongzhan and Li, Yuqiang and Ouyang, Wanli},
33
+ journal={arXiv preprint arXiv:2406.07394},
34
+ year={2024}
35
+ }
36
+
37
+
38
+ ```
39
+
40
+ ## Model usage
41
+
42
+ `server.py`
43
+ ```
44
+ import json
45
+ from fastapi import FastAPI, HTTPException
46
+ from pydantic import BaseModel
47
+ from transformers import AutoModelForCausalLM, AutoTokenizer
48
+ from peft import PeftModel
49
+ import torch
50
+
51
+ # Initialize FastAPI
52
+ app = FastAPI()
53
+
54
+ # Device configuration
55
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+
57
+ # Model and tokenizer loading (as you provided)
58
+ model_name = "google/gemma-2-2b-it"
59
+
60
+ lora_checkpoint_path = "qq8933/PPRM-gemma-2-2b-it"
61
+
62
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
63
+ base_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map='cuda')
64
+ model = PeftModel.from_pretrained(base_model, lora_checkpoint_path, device_map='cuda')
65
+
66
+ yes_token_id = tokenizer.convert_tokens_to_ids("yes")
67
+ no_token_id = tokenizer.convert_tokens_to_ids("no")
68
+
69
+ # Request model
70
+ class InputRequest(BaseModel):
71
+ text: str
72
+
73
+ # Predict function
74
+ def predict(qeustion,answer_1,answer_2):
75
+ prompt_template = """Problem:\n\n{}\n\nFirst Answer:\n\n{}\n\nSecond Answer:\n\n{}\n\nIs First Answer better than Second Answer?\n\n"""
76
+ input_text = prompt_template.format(qeustion,answer_1,answer_2)
77
+ input_text = tokenizer.apply_chat_template(
78
+ [{'role': 'user', 'content': input_text}], tokenize=False, add_generation_prompt=True
79
+ )
80
+ inputs = tokenizer(input_text, return_tensors="pt").to(device)
81
+ with torch.no_grad():
82
+ generated_outputs = model.generate(
83
+ **inputs, max_new_tokens=2, output_scores=True, return_dict_in_generate=True
84
+ )
85
+ scores = generated_outputs.scores
86
+ first_token_logits = scores[0]
87
+ yes_logit = first_token_logits[0, yes_token_id].item()
88
+ no_logit = first_token_logits[0, no_token_id].item()
89
+
90
+ return {
91
+ "yes_logit": yes_logit,
92
+ "no_logit": no_logit,
93
+ "logit_difference": yes_logit - no_logit
94
+ }
95
+
96
+ # Define API endpoint
97
+ @app.post("/predict")
98
+ async def get_prediction(input_request: InputRequest):
99
+ payload = json.loads(input_request.text)
100
+ qeustion,answer_1,answer_2 = payload['qeustion'],payload['answer_1'],payload['answer_2']
101
+ try:
102
+ result = predict(qeustion,answer_1,answer_2)
103
+ return result
104
+ except Exception as e:
105
+ raise HTTPException(status_code=500, detail=str(e))
106
+
107
+ ```
108
+
109
+ ```
110
+ uvicorn server:app --host 0.0.0.0 --port $MASTER_PORT --workers 1
111
+ ```
112
 
113
  ## Training procedure
114
 
 
128
  - lr_scheduler_type: linear
129
  - num_epochs: 1.0
130
 
 
 
 
131
 
132
  ### Framework versions
133