jeremyarancio
commited on
Commit
•
2487e56
1
Parent(s):
5a04ada
Update handler
Browse files- handler.py +4 -4
handler.py
CHANGED
@@ -23,15 +23,15 @@ class EndpointHandler():
|
|
23 |
"""
|
24 |
LOGGER.info(f"Received data: {data}")
|
25 |
# Get inputs
|
26 |
-
|
27 |
parameters = data.pop("parameters", None)
|
28 |
# Preprocess
|
29 |
-
|
30 |
# Forward
|
31 |
if parameters is not None:
|
32 |
-
outputs = self.model.generate(
|
33 |
else:
|
34 |
-
outputs = self.model.generate(
|
35 |
# Postprocess
|
36 |
prediction = self.tokenizer.decode(outputs[0])
|
37 |
LOGGER.info(f"Generated text: {prediction}")
|
|
|
23 |
"""
|
24 |
LOGGER.info(f"Received data: {data}")
|
25 |
# Get inputs
|
26 |
+
inputs = data.pop("inputs", data)
|
27 |
parameters = data.pop("parameters", None)
|
28 |
# Preprocess
|
29 |
+
inputs_ids = self.tokenizer(inputs, return_tensors="pt").inputs_ids
|
30 |
# Forward
|
31 |
if parameters is not None:
|
32 |
+
outputs = self.model.generate(inputs_ids, **parameters)
|
33 |
else:
|
34 |
+
outputs = self.model.generate(inputs_ids)
|
35 |
# Postprocess
|
36 |
prediction = self.tokenizer.decode(outputs[0])
|
37 |
LOGGER.info(f"Generated text: {prediction}")
|