import json from typing import List, Dict, Any import gradio as gr import numpy as np import pandas as pd import torch from huggingface_hub import hf_hub_download from transformers import AutoModel, AutoTokenizer import onnxruntime as rt # Define the Lionguard model repository REPO_PATH = "govtech/lionguard-v1" def load_config() -> Dict[str, Any]: """ Load the configuration for the Lionguard model. Returns: Dict[str, Any]: The configuration dictionary. """ config_path = hf_hub_download(repo_id=REPO_PATH, filename="config.json") with open(config_path, 'r') as f: return json.load(f) def get_embeddings(device: str, data: List[str], config: Dict[str, Any]) -> np.ndarray: """ Generate embeddings for the input data using the specified model configuration. Args: device (str): The device to use for computations. data (List[str]): The input text data. config (Dict[str, Any]): The model configuration. Returns: np.ndarray: The generated embeddings. """ tokenizer = AutoTokenizer.from_pretrained(config['embedding']['tokenizer']) model = AutoModel.from_pretrained(config['embedding']['model']) model.eval() model.to(device) batch_size = config['embedding']['batch_size'] num_batches = int(np.ceil(len(data)/batch_size)) output = [] for i in range(num_batches): sentences = data[i*batch_size:(i+1)*batch_size] encoded_input = tokenizer(sentences, max_length=config['embedding']['max_length'], padding=True, truncation=True, return_tensors='pt') encoded_input.to(device) with torch.no_grad(): model_output = model(**encoded_input) sentence_embeddings = model_output[0][:, 0] sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) output.extend(sentence_embeddings.cpu().numpy()) return np.array(output) def predict(text: str, config: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: """ Predict probabilities for all Lionguard categories given an input text. Args: text (str): The input text to predict on. config (Dict[str, Any]): The model configuration. Returns: Dict[str, Dict[str, Any]]: A dictionary containing prediction results for each category. """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") embeddings = get_embeddings(device, [text], config) X_input = np.array(embeddings, dtype=np.float32) results = {} for category, details in config['classifier'].items(): local_model_fp = hf_hub_download(repo_id=REPO_PATH, filename=details['model_fp']) session = rt.InferenceSession(local_model_fp) input_name = session.get_inputs()[0].name outputs = session.run(None, {input_name: X_input}) if details['calibrated']: scores = [output[1] for output in outputs[1]] else: scores = outputs[1].flatten() results[category] = { 'score': float(scores[0]), 'predictions': { 'high_recall': 1 if scores[0] >= details['threshold']['high_recall'] else 0, 'balanced': 1 if scores[0] >= details['threshold']['balanced'] else 0, 'high_precision': 1 if scores[0] >= details['threshold']['high_precision'] else 0 } } return results def predict_and_format(text: str, config: Dict[str, Any]) -> pd.DataFrame: """ Predict and format the results for display in the Gradio interface. Args: text (str): The input text to predict on. config (Dict[str, Any]): The model configuration. Returns: pd.DataFrame: A DataFrame containing prediction results for each category. """ if not text.strip(): return None results = predict(text, config) formatted_results = [] for category, result in results.items(): formatted_results.append({ "Category": category, "Score": f"{result['score']:.3f}", "High Recall": result['predictions']['high_recall'], "Balanced": result['predictions']['balanced'], "High Precision": result['predictions']['high_precision'] }) return pd.DataFrame(formatted_results).sort_values("Score", ascending=False) def create_interface(config: Dict[str, Any]) -> gr.Interface: """ Create the Gradio interface for the Lionguard demo. Args: config (Dict[str, Any]): The model configuration. Returns: gr.Interface: The Gradio interface object. """ return gr.Interface( fn=lambda text: predict_and_format(text, config), inputs=gr.Textbox(lines=3, placeholder="Enter text here..."), outputs=gr.DataFrame(label="Prediction Results"), title="🦁 Lionguard Demo", description="Lionguard is a Singapore-contextualized moderation classifier that can serve against unsafe LLM outputs.", allow_flagging="never" ) if __name__ == "__main__": config = load_config() iface = create_interface(config) iface.launch()