jeremyarancio
commited on
Commit
•
adf79f2
1
Parent(s):
88e1248
Update handler
Browse files- handler.py +6 -4
handler.py
CHANGED
@@ -24,16 +24,18 @@ class EndpointHandler():
|
|
24 |
"""
|
25 |
LOGGER.info(f"Received data: {data}")
|
26 |
# Get inputs
|
27 |
-
prompt = data.pop("prompt",
|
28 |
parameters = data.pop("parameters", None)
|
|
|
|
|
29 |
# Preprocess
|
30 |
-
|
31 |
# Forward
|
32 |
LOGGER.info(f"Start generation.")
|
33 |
if parameters is not None:
|
34 |
-
output = self.model.generate(**
|
35 |
else:
|
36 |
-
output = self.model.generate(**
|
37 |
# Postprocess
|
38 |
prediction = self.tokenizer.decode(output[0])
|
39 |
LOGGER.info(f"Generated text: {prediction}")
|
|
|
24 |
"""
|
25 |
LOGGER.info(f"Received data: {data}")
|
26 |
# Get inputs
|
27 |
+
prompt = data.pop("prompt", None)
|
28 |
parameters = data.pop("parameters", None)
|
29 |
+
if prompt is None:
|
30 |
+
raise ValueError("Missing prompt.")
|
31 |
# Preprocess
|
32 |
+
inputs = self.tokenizer(prompt, return_tensors="pt")
|
33 |
# Forward
|
34 |
LOGGER.info(f"Start generation.")
|
35 |
if parameters is not None:
|
36 |
+
output = self.model.generate(**inputs, **parameters)
|
37 |
else:
|
38 |
+
output = self.model.generate(**inputs)
|
39 |
# Postprocess
|
40 |
prediction = self.tokenizer.decode(output[0])
|
41 |
LOGGER.info(f"Generated text: {prediction}")
|