import logging import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed import torch import spaces ########################## # CONFIGURATION ########################## logging.basicConfig( level=logging.getLevelName("INFO"), format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) # Example images and texts EXAMPLES = [ ["images/ingredients_1.jpg", "24.36% chocolat noir 63% origine non UE (cacao, sucre, beurre de cacao, émulsifiant léci - thine de colza, vanille bourbon gousse), œuf, farine de blé, beurre, sucre, miel, sucre perlé, levure chimique, zeste de citron."], ["images/ingredients_2.jpg", "farine de froment, œufs, lait entier pasteurisé Aprigine: France), sucre, sel, extrait de vanille naturelle Conditi( 35."], # ["images/ingredients_3.jpg", "tural basmati rice - cooked (98%), rice bran oil, salt"], ["images/ingredients_4.jpg", "Eau de noix de coco 93.9%, Arôme natutel de fruit"], ["images/ingredients_5.jpg", "Sucre, pâte de cacao, beurre de cacao, émulsifiant: léci - thines (soja). Peut contenir des traces de lait. Chocolat noir: cacao: 50% minimum. À conserver à l'abri de la chaleur et de l'humidité. Élaboré en France."], ] MODEL_ID = "openfoodfacts/spellcheck-mistral-7b" PRESENTATION = """# 🍊 Ingredients Spellcheck - Open Food Facts Open Food Facts is a non-profit organization building the largest open food database in the world. 🌎 When a product is added to the database, all its details, such as allergens, additives, or nutritional values, are either wrote down by the contributor, or automatically extracted from the product pictures using OCR. However, it often happens the information extracted by OCR contains typos and errors due to bad quality pictures: low-definition, curved product, light reflection, etc... To solve this problem, we developed an 🍊 **Ingredient Spellcheck** 🍊, a model capable of correcting typos in a list of ingredients following a defined guideline. The model, based on Mistral-7B-v0.3, was fine-tuned on thousand of corrected lists of ingredients extracted from the database. More information in the model card. ### *Project in progress* 🏗️ ## 👇 Links * Open Food Facts website: https://world.openfoodfacts.org/discover * Open Food Facts Github: https://github.com/openfoodfacts * Spellcheck project: https://github.com/openfoodfacts/openfoodfacts-ai/tree/develop/spellcheck * Model card: https://huggingface.co/openfoodfacts/spellcheck-mistral-7b """ # CPU/GPU device zero = torch.Tensor([0]).cuda() # Transformers seed to orient generation to be reproducible (as possible since it doesn't ensure 100% reproducibility) set_seed(42) ########################## # LOADING ########################## # Tokenizer logging.info(f"Load tokenizer from {MODEL_ID}.") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id # Model logging.info(f"Load model from {MODEL_ID}.") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", # attn_implementation="flash_attention_2", # Not supported by ZERO GPU # torch_dtype=torch.bfloat16, ) ########################## # FUNCTIONS ########################## @spaces.GPU def process(text: str) -> str: """Take the text, the tokenizer and the causal model and generate the correction.""" prompt = prepare_instruction(text) input_ids = tokenizer( prompt, add_special_tokens=True, return_tensors="pt" ).input_ids with torch.no_grad(): output = model.generate( input_ids.to(zero.device), # GPU do_sample=False, max_new_tokens=512, ) return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt):].strip() def prepare_instruction(text: str) -> str: """Prepare instruction prompt for fine-tuning and inference. Identical to instruction during training. Args: text (str): List of ingredients Returns: str: Instruction. """ instruction = ( "###Correct the list of ingredients:\n" + text + "\n\n###Correction:\n" ) return instruction ########################## # GRADIO SETUP ########################## with gr.Blocks() as demo: gr.Markdown(PRESENTATION) with gr.Row(): with gr.Column(): image = gr.Image(type="pil", label="image_input", interactive=False) with gr.Column(): ingredients = gr.Textbox(label="List of ingredients") spellcheck_button = gr.Button(value='Run spellcheck') correction = gr.Textbox(label="Correction", interactive=False) with gr.Row(): gr.Examples( fn=process, examples=EXAMPLES, inputs=[ image, ingredients, ], ) spellcheck_button.click( fn=process, inputs=[ingredients], outputs=[correction] ) if __name__ == "__main__": # Launch the demo demo.launch()