import logging from typing import List, Dict import torch from gliner import GLiNER # Configure logging logging.basicConfig(level=logging.INFO) TAU = 0.3 class EntityExtractor: def __init__(self, extractor_model: str): """ Initializes the EntityExtractor class with an extractor model. Args: extractor_model (str): The model name for the entity extractor. """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.extractor = self.load_extractor(extractor_model).to(self.device) @staticmethod def load_extractor(model_name: str) -> GLiNER: """Loads the entity extractor model.""" return GLiNER.from_pretrained(model_name, load_tokenizer=True) def extract_entities(self, text: str, entity_types: List[str] = None) -> List[Dict[str, str]]: if entity_types is None: entity_types = ["brand", "color_finish", "style", "collection", "dimension", "feature", "product_type", "part_number"] output = self.extractor.predict_entities( text, entity_types, threshold=TAU, flat_ner=True, multi_label=False ) extracted_entities = [] for entity in output: extracted_entities.append({ "entity": entity["text"], "entity_type": entity["label"] }) return extracted_entities