FSA-test / pipeline.py
TharinduCD's picture
Update pipeline.py
346767e
raw
history blame
No virus
781 Bytes
import numpy as np
import typing
class PreTrainedPipeline():
def __init__(self, path=""):
# Load the FastText model
self.model = fasttext.load_model(path)
def __call__(self, inputs: str) -> List[List[Dict[str, float]]]:
# Get the predictions from the model
predictions = self.model.predict(inputs)
# Get the top 5 predictions
top_5_predictions = predictions[:4]
# Convert the predictions to a list of dictionaries
top_5_predictions_dict = []
for prediction in top_5_predictions:
prediction_dict = {"label": prediction[0], "score": prediction[1]}
top_5_predictions_dict.append(prediction_dict)
# Return the top 5 predictions
return [top_5_predictions_dict]