from typing import Sequence, Union import numpy as np import scipy.stats from .registry import registry @registry.register_metric('mse') def mean_squared_error(target: Sequence[float], prediction: Sequence[float]) -> float: target_array = np.asarray(target) prediction_array = np.asarray(prediction) return np.mean(np.square(target_array - prediction_array)) @registry.register_metric('mae') def mean_absolute_error(target: Sequence[float], prediction: Sequence[float]) -> float: target_array = np.asarray(target) prediction_array = np.asarray(prediction) return np.mean(np.abs(target_array - prediction_array)) @registry.register_metric('spearmanr') def spearmanr(target: Sequence[float], prediction: Sequence[float]) -> float: target_array = np.asarray(target) prediction_array = np.asarray(prediction) return scipy.stats.spearmanr(target_array, prediction_array).correlation @registry.register_metric('accuracy') def accuracy(target: Union[Sequence[int], Sequence[Sequence[int]]], prediction: Union[Sequence[float], Sequence[Sequence[float]]]) -> float: if isinstance(target[0], int): # non-sequence case return np.mean(np.asarray(target) == np.asarray(prediction).argmax(-1)) else: correct = 0 total = 0 for label, score in zip(target, prediction): label_array = np.asarray(label) pred_array = np.asarray(score).argmax(-1) mask = label_array != -1 is_correct = label_array[mask] == pred_array[mask] correct += is_correct.sum() total += is_correct.size return correct / total