"Client-server interface custom implementation for seizure detection models." from common import SEIZURE_DETECTION_MODEL_PATH from concrete import fhe from seizure_detection import SeizureDetector class FHEServer: """Server interface to run a FHE circuit for seizure detection.""" def __init__(self, model_path): """Initialize the FHE interface. Args: model_path (Path): The path to the directory where the circuit is saved. """ self.model_path = model_path # Load the FHE circuit self.server = fhe.Server.load(self.model_path / "server.zip") def run(self, serialized_encrypted_image, serialized_evaluation_keys): """Run seizure detection on the server over an encrypted image. Args: serialized_encrypted_image (bytes): The encrypted and serialized image. serialized_evaluation_keys (bytes): The serialized evaluation keys. Returns: bytes: The encrypted boolean output indicating seizure detection. """ # Deserialize the encrypted input image and the evaluation keys encrypted_image = fhe.Value.deserialize(serialized_encrypted_image) evaluation_keys = fhe.EvaluationKeys.deserialize(serialized_evaluation_keys) # Execute the seizure detection in FHE encrypted_output = self.server.run(encrypted_image, evaluation_keys=evaluation_keys) # Serialize the encrypted output serialized_encrypted_output = encrypted_output.serialize() return serialized_encrypted_output class FHEDev: """Development interface to save and load the seizure detection model.""" def __init__(self, seizure_detector, model_path): """Initialize the FHE interface. Args: seizure_detector (SeizureDetector): The seizure detection model to use in the FHE interface. model_path (str): The path to the directory where the circuit is saved. """ self.seizure_detector = seizure_detector self.model_path = model_path self.model_path.mkdir(parents=True, exist_ok=True) def save(self): """Export all needed artifacts for the client and server interfaces.""" assert self.seizure_detector.fhe_circuit is not None, ( "The model must be compiled before saving it." ) # Save the circuit for the server, using the via_mlir in order to handle cross-platform # execution path_circuit_server = self.model_path / "server.zip" self.seizure_detector.fhe_circuit.server.save(path_circuit_server, via_mlir=True) # Save the circuit for the client path_circuit_client = self.model_path / "client.zip" self.seizure_detector.fhe_circuit.client.save(path_circuit_client) class FHEClient: """Client interface to encrypt and decrypt FHE data associated to a SeizureDetector.""" def __init__(self, key_dir=None): """Initialize the FHE interface. Args: model_path (Path): The path to the directory where the circuit is saved. key_dir (Path): The path to the directory where the keys are stored. Default to None. """ self.model_path = SEIZURE_DETECTION_MODEL_PATH self.key_dir = key_dir print(self.model_path) # If model_path does not exist raise assert self.model_path.exists(), f"{self.model_path} does not exist. Please specify a valid path." # Load the client self.client = fhe.Client.load(self.model_path / "client.zip", self.key_dir) # Instantiate the seizure detector self.seizure_detector = SeizureDetector() def generate_private_and_evaluation_keys(self, force=False): """Generate the private and evaluation keys. Args: force (bool): If True, regenerate the keys even if they already exist. """ self.client.keygen(force) def get_serialized_evaluation_keys(self): """Get the serialized evaluation keys. Returns: bytes: The evaluation keys. """ return self.client.evaluation_keys.serialize() def encrypt_serialize(self, input_image): """Encrypt and serialize the input image in the clear. Args: input_image (numpy.ndarray): The image to encrypt and serialize. Returns: bytes: The pre-processed, encrypted and serialized image. """ # Encrypt the image encrypted_image = self.client.encrypt(input_image) # Serialize the encrypted image to be sent to the server serialized_encrypted_image = encrypted_image.serialize() return serialized_encrypted_image def deserialize_decrypt_post_process(self, serialized_encrypted_output): """Deserialize, decrypt and post-process the output in the clear. Args: serialized_encrypted_output (bytes): The serialized and encrypted output. Returns: bool: The decrypted and deserialized boolean indicating seizure detection. """ # Deserialize the encrypted output encrypted_output = fhe.Value.deserialize(serialized_encrypted_output) # Decrypt the output output = self.client.decrypt(encrypted_output) # Post-process the output (if needed) seizure_detected = self.seizure_detector.post_processing(output) return seizure_detected