import gradio as gr # !python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'" import json import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer, StoppingCriteria, StoppingCriteriaList, GenerationConfig # from google.colab import userdata import os #sft_model = "somosnlp/gemma-FULL-RAC-Colombia_v2" sft_model = "somosnlp/RecetasDeLaAbuela_mistral-7b-instruct-v0.2-bnb-4bit" base_model_name = "mistralai/Mistral-7B-Instruct-v0.2" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) max_seq_length=400 # if torch.cuda.get_device_capability()[0] >= 8: # # print("Flash Attention") # attn_implementation="flash_attention_2" # else: # attn_implementation=None attn_implementation=None #base_model = AutoModelForCausalLM.from_pretrained(model_name,return_dict=True,torch_dtype=torch.float16,) #base_model = AutoModelForCausalLM.from_pretrained(model_name,return_dict=True,device_map="auto", torch_dtype=torch.float16,) base_model = AutoModelForCausalLM.from_pretrained(base_model_name, return_dict=True, device_map = {"":0}, attn_implementation = attn_implementation, # A100 o H100).eval() tokenizer = AutoTokenizer.from_pretrained(sft_model, max_length = max_seq_length) ft_model = PeftModel.from_pretrained(base_model, sft_model) model = ft_model.merge_and_unload() model.save_pretrained(".") model.to('cuda') tokenizer.save_pretrained(".") class ListOfTokensStoppingCriteria(StoppingCriteria): """ Clase para definir un criterio de parada basado en una lista de tokens específicos. """ def __init__(self, tokenizer, stop_tokens): self.tokenizer = tokenizer # Codifica cada token de parada y guarda sus IDs en una lista self.stop_token_ids_list = [tokenizer.encode(stop_token, add_special_tokens=False) for stop_token in stop_tokens] def __call__(self, input_ids, scores, **kwargs): # Verifica si los últimos tokens generados coinciden con alguno de los conjuntos de tokens de parada for stop_token_ids in self.stop_token_ids_list: len_stop_tokens = len(stop_token_ids) if len(input_ids[0]) >= len_stop_tokens: if input_ids[0, -len_stop_tokens:].tolist() == stop_token_ids: return True return False # Uso del criterio de parada personalizado stop_tokens = [""] # Lista de tokens de parada # Inicializa tu criterio de parada con el tokenizer y la lista de tokens de parada stopping_criteria = ListOfTokensStoppingCriteria(tokenizer, stop_tokens) # Añade tu criterio de parada a una StoppingCriteriaList stopping_criteria_list = StoppingCriteriaList([stopping_criteria]) def generate_text(prompt, max_length=2100): # prompt="""What were the main contributions of Eratosthenes to the development of mathematics in ancient Greece?""" prompt=prompt.replace("\n", "").replace("¿","").replace("?","") #EXAMPLE input_text = f'''system You are a helpful AI assistant. Responde en formato json. Eres un experto cocinero de la cocina hispanoamericana. user ¿{prompt}? model ''' inputs = tokenizer.encode(input_text, return_tensors="pt", add_special_tokens=False).to("cuda:0") max_new_tokens=max_length generation_config = GenerationConfig( max_new_tokens=max_new_tokens, temperature=0.32, #top_p=0.9, top_k=50, # 45 repetition_penalty=1.04, #1.1 do_sample=True, ) outputs = model.generate(generation_config=generation_config, input_ids=inputs, stopping_criteria=stopping_criteria_list,) return tokenizer.decode(outputs[0], skip_special_tokens=False) #True def mostrar_respuesta(pregunta): try: res= generate_text(pregunta, max_length=500) inicio_json = res.find('{') fin_json = res.rfind('}') + 1 json_str = res[inicio_json:fin_json] json_obj = json.loads(json_str) # print(json_obj) return json_obj["Respuesta"] except: json_obj={} json_obj['Respuesta']='Error' return json_obj # Ejemplos de preguntas ejemplos = [ ["¿Dime la receta de la tortilla de patatatas?"], ["¿Dime la receta del ceviche?"], ["¿Como se cocinan unos autenticos frijoles?"], ] iface = gr.Interface( fn=mostrar_respuesta, inputs=gr.Textbox(label="Pregunta"), outputs=[ gr.Textbox(label="Respuesta", lines=2), ], title="Recetas de la Abuel@", description="Introduce tu pregunta sobre recetas de cocina.", examples=ejemplos, ) iface.queue(max_size=14).launch() # share=True,debug=True