Update README.md
Browse files
README.md
CHANGED
@@ -105,10 +105,60 @@ async def get_prediction(input_request: InputRequest):
|
|
105 |
raise HTTPException(status_code=500, detail=str(e))
|
106 |
|
107 |
```
|
108 |
-
|
109 |
```
|
110 |
uvicorn server:app --host 0.0.0.0 --port $MASTER_PORT --workers 1
|
111 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
## Training procedure
|
114 |
|
|
|
105 |
raise HTTPException(status_code=500, detail=str(e))
|
106 |
|
107 |
```
|
108 |
+
run pprm_server
|
109 |
```
|
110 |
uvicorn server:app --host 0.0.0.0 --port $MASTER_PORT --workers 1
|
111 |
```
|
112 |
+
request pprm server
|
113 |
+
```
|
114 |
+
# qeustion,answer_1,answer_2 = 'What is the capital of France?', 'Berlin', 'Paris'
|
115 |
+
# {'yes_logit': -24.26136016845703, 'no_logit': 19.517587661743164, 'logit_difference': -43.778947830200195}
|
116 |
+
# Is answer_1 better than answer_2? yes or no
|
117 |
+
# 奖励模型的入口
|
118 |
+
def request_prediction(
|
119 |
+
qeustion, answer_1, answer_2, url="http://10.140.24.56:10085/predict"
|
120 |
+
):
|
121 |
+
"""
|
122 |
+
Sends a POST request to the FastAPI server to get a prediction.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
- text (str): The input text for the prediction.
|
126 |
+
- url (str): The API endpoint URL. Defaults to 'http://localhost:8000/predict'.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
- dict: The response from the API containing prediction results.
|
130 |
+
"""
|
131 |
+
headers = {"Content-Type": "application/json"}
|
132 |
+
payload = {
|
133 |
+
"text": json.dumps(
|
134 |
+
{"qeustion": qeustion, "answer_1": answer_1, "answer_2": answer_2}
|
135 |
+
)
|
136 |
+
}
|
137 |
+
|
138 |
+
response = requests.post(url, json=payload, headers=headers, timeout=TIMEOUT_PRM)
|
139 |
+
response.raise_for_status() # Raises an HTTPError if the response code was unsuccessful
|
140 |
+
return response.json() # Return the JSON response as a dictionary
|
141 |
+
|
142 |
+
def cal_reward(question, ans, ans2="I don't know"):
|
143 |
+
if ans2 in DUMMY_ANSWERS:#I don't know
|
144 |
+
return 1
|
145 |
+
if ans in DUMMY_ANSWERS:
|
146 |
+
return 0
|
147 |
+
urls = copy.deepcopy(prm_servers)
|
148 |
+
random.shuffle(urls)
|
149 |
+
for url in urls:
|
150 |
+
try:
|
151 |
+
response = request_prediction(question, ans, ans2, url)
|
152 |
+
return math.exp(response["yes_logit"]) / (
|
153 |
+
math.exp(response["yes_logit"]) + math.exp(response["no_logit"])
|
154 |
+
)
|
155 |
+
except Exception as e:
|
156 |
+
# print(e)
|
157 |
+
continue
|
158 |
+
print(Exception("All prm servers are down"))
|
159 |
+
# get_clients()
|
160 |
+
return cal_reward(question, ans, ans2)
|
161 |
+
```
|
162 |
|
163 |
## Training procedure
|
164 |
|