jeremyarancio's picture
Change front org
32de7fa
raw
history blame
3.32 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import spaces
# Example images and texts
EXAMPLES = [
["images/ingredients_1.jpg", "24.36% chocolat noir 63% origine non UE (cacao, sucre, beurre de cacao, émulsifiant léci - thine de colza, vanille bourbon gousse), œuf, farine de blé, beurre, sucre, miel, sucre perlé, levure chimique, zeste de citron."],
["images/ingredients_2.jpg", "farine de froment, œufs, lait entier pasteurisé Aprigine: France), sucre, sel, extrait de vanille naturelle Conditi( 35."],
["images/ingredients_3.jpg", "tural basmati rice - cooked (98%), rice bran oil, salt"],
["images/ingredients_4.jpg", "Eau de noix de coco 93.9%, Arôme natutel de fruit"],
["images/ingredients_5.jpg", "Sucre, pâte de cacao, beurre de cacao, émulsifiant: léci - thines (soja). Peut contenir des traces de lait. Chocolat noir: cacao: 50% minimum. À conserver à l'abri de la chaleur et de l'humidité. Élaboré en France."],
]
MODEL_ID = "openfoodfacts/spellcheck-mistral-7b"
# CPU/GPU device
zero = torch.Tensor([0]).cuda()
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
# Model
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
# attn_implementation="flash_attention_2", # Not supported by ZERO GPU
# torch_dtype=torch.bfloat16,
)
@spaces.GPU
def process(text: str) -> str:
"""Take the text, the tokenizer and the causal model and generate the correction."""
prompt = prepare_instruction(text)
input_ids = tokenizer(
prompt,
add_special_tokens=True,
return_tensors="pt"
).input_ids
output = model.generate(
input_ids.to(zero.device), # GPU
do_sample=False,
max_new_tokens=512,
)
return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt):].strip()
def prepare_instruction(text: str) -> str:
"""Prepare instruction prompt for fine-tuning and inference.
Args:
text (str): List of ingredients
Returns:
str: Instruction.
"""
instruction = (
"###Correct the list of ingredients:\n"
+ text
+ "\n\n###Correction:\n"
)
return instruction
##########################
# GRADIO SETUP
##########################
# Creating the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Ingredients Spellcheck")
gr.Markdown("")
with gr.Row():
with gr.Column():
image = gr.Image(type="pil", label="image_input")
spellcheck_button = gr.Button(value='Spellcheck')
with gr.Column():
ingredients = gr.Textbox(label="List of ingredients")
correction = gr.Textbox(label="Correction", interactive=False)
with gr.Row():
gr.Examples(
fn=process,
examples=EXAMPLES,
inputs=[
image,
ingredients,
],
outputs=[correction],
run_on_click=False,
)
spellcheck_button.click(
fn=process,
inputs=[ingredients],
outputs=[correction]
)
if __name__ == "__main__":
# Launch the demo
demo.launch()