jeremyarancio's picture
Update
8af8877
raw
history blame
5.13 kB
import logging
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
import torch
import spaces
##########################
# CONFIGURATION
##########################
logging.basicConfig(
level=logging.getLevelName("INFO"),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
# Example images and texts
EXAMPLES = [
["images/ingredients_1.jpg", "24.36% chocolat noir 63% origine non UE (cacao, sucre, beurre de cacao, émulsifiant léci - thine de colza, vanille bourbon gousse), œuf, farine de blé, beurre, sucre, miel, sucre perlé, levure chimique, zeste de citron."],
["images/ingredients_2.jpg", "farine de froment, œufs, lait entier pasteurisé Aprigine: France), sucre, sel, extrait de vanille naturelle Conditi( 35."],
# ["images/ingredients_3.jpg", "tural basmati rice - cooked (98%), rice bran oil, salt"],
["images/ingredients_4.jpg", "Eau de noix de coco 93.9%, Arôme natutel de fruit"],
["images/ingredients_5.jpg", "Sucre, pâte de cacao, beurre de cacao, émulsifiant: léci - thines (soja). Peut contenir des traces de lait. Chocolat noir: cacao: 50% minimum. À conserver à l'abri de la chaleur et de l'humidité. Élaboré en France."],
]
MODEL_ID = "openfoodfacts/spellcheck-mistral-7b"
PRESENTATION = """# 🍊 Ingredients Spellcheck - Open Food Facts
Open Food Facts is a non-profit organization building the largest open food database in the world. 🌎
When a product is added to the database, all its details, such as allergens, additives, or nutritional values, are either wrote down by the contributor,
or automatically extracted from the product pictures using OCR.
However, it often happens the information extracted by OCR contains typos and errors due to bad quality pictures: low-definition, curved product, light reflection, etc...
To solve this problem, we developed an 🍊 **Ingredient Spellcheck** 🍊, a model capable of correcting typos in a list of ingredients following a defined guideline.
The model, based on Mistral-7B-v0.3, was fine-tuned on thousand of corrected lists of ingredients extracted from the database. More information in the model card.
### *Project in progress* 🏗️
## 👇 Links
* Open Food Facts website: https://world.openfoodfacts.org/discover
* Open Food Facts Github: https://github.com/openfoodfacts
* Spellcheck project: https://github.com/openfoodfacts/openfoodfacts-ai/tree/develop/spellcheck
* Model card: https://huggingface.co/openfoodfacts/spellcheck-mistral-7b
"""
# CPU/GPU device
zero = torch.Tensor([0]).cuda()
# Transformers seed to orient generation to be reproducible (as possible since it doesn't ensure 100% reproducibility)
set_seed(42)
##########################
# LOADING
##########################
# Tokenizer
logging.info(f"Load tokenizer from {MODEL_ID}.")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
# Model
logging.info(f"Load model from {MODEL_ID}.")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
# attn_implementation="flash_attention_2", # Not supported by ZERO GPU
# torch_dtype=torch.bfloat16,
)
##########################
# FUNCTIONS
##########################
@spaces.GPU
def process(text: str) -> str:
"""Take the text, the tokenizer and the causal model and generate the correction."""
prompt = prepare_instruction(text)
input_ids = tokenizer(
prompt,
add_special_tokens=True,
return_tensors="pt"
).input_ids
with torch.no_grad():
output = model.generate(
input_ids.to(zero.device), # GPU
do_sample=False,
max_new_tokens=512,
)
return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt):].strip()
def prepare_instruction(text: str) -> str:
"""Prepare instruction prompt for fine-tuning and inference.
Identical to instruction during training.
Args:
text (str): List of ingredients
Returns:
str: Instruction.
"""
instruction = (
"###Correct the list of ingredients:\n"
+ text
+ "\n\n###Correction:\n"
)
return instruction
##########################
# GRADIO SETUP
##########################
with gr.Blocks() as demo:
gr.Markdown(PRESENTATION)
with gr.Row():
with gr.Column():
image = gr.Image(type="pil", label="image_input", interactive=False)
with gr.Column():
ingredients = gr.Textbox(label="List of ingredients")
spellcheck_button = gr.Button(value='Run spellcheck')
correction = gr.Textbox(label="Correction", interactive=False)
with gr.Row():
gr.Examples(
fn=process,
examples=EXAMPLES,
inputs=[
image,
ingredients,
],
)
spellcheck_button.click(
fn=process,
inputs=[ingredients],
outputs=[correction]
)
if __name__ == "__main__":
# Launch the demo
demo.launch()