File size: 5,122 Bytes
17306ce
 
2df6b4b
17306ce
2df6b4b
 
 
 
17306ce
 
 
 
 
 
 
 
2df6b4b
 
 
 
b24df6e
2df6b4b
 
 
17306ce
2df6b4b
 
17306ce
 
 
 
 
 
 
 
 
 
 
 
b24df6e
 
17306ce
 
 
 
 
 
 
2df6b4b
 
 
 
17306ce
 
 
 
 
 
 
2df6b4b
17306ce
2df6b4b
 
 
 
 
17306ce
2df6b4b
 
 
6874f2a
 
2df6b4b
 
17306ce
 
 
 
2df6b4b
 
 
 
 
 
 
 
 
b24df6e
 
 
 
 
 
2df6b4b
 
 
 
 
17306ce
2df6b4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17306ce
2df6b4b
 
 
17306ce
2df6b4b
 
32de7fa
17306ce
2df6b4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import logging

import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
import torch
import spaces


##########################
# CONFIGURATION
##########################
logging.basicConfig(
    level=logging.getLevelName("INFO"),
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)

# Example images and texts
EXAMPLES = [
    ["images/ingredients_1.jpg", "24.36% chocolat noir 63% origine non UE (cacao, sucre, beurre de cacao, émulsifiant léci - thine de colza, vanille bourbon gousse), œuf, farine de blé, beurre, sucre, miel, sucre perlé, levure chimique, zeste de citron."],
    ["images/ingredients_2.jpg", "farine de froment, œufs, lait entier pasteurisé Aprigine: France), sucre, sel, extrait de vanille naturelle Conditi( 35."],
    # ["images/ingredients_3.jpg", "tural basmati rice - cooked (98%), rice bran oil, salt"],
    ["images/ingredients_4.jpg", "Eau de noix de coco 93.9%, Arôme natutel de fruit"],
    ["images/ingredients_5.jpg", "Sucre, pâte de cacao, beurre de cacao, émulsifiant: léci - thines (soja). Peut contenir des traces de lait. Chocolat noir: cacao: 50% minimum. À conserver à l'abri de la chaleur et de l'humidité. Élaboré en France."],
]

MODEL_ID = "openfoodfacts/spellcheck-mistral-7b"

PRESENTATION = """# 🍊 Ingredients Spellcheck - Open Food Facts

Open Food Facts is a non-profit organization building the largest open food database in the world. 🌎

When a product is added to the database, all its details, such as allergens, additives, or nutritional values, are either wrote down by the contributor, 
or automatically extracted from the product pictures using OCR.

However, it often happens the information extracted by OCR contains typos and errors due to bad quality pictures: low-definition, curved product, light reflection, etc...

To solve this problem, we developed an 🍊 **Ingredient Spellcheck** 🍊, a model capable of correcting typos in a list of ingredients following a defined guideline. 
The model, based on Mistral-7B-v0.3, was fine-tuned on thousand of corrected lists of ingredients extracted from the database. More information in the model card.

## Project in progress

## 👇 Links

* Open Food Facts website: https://world.openfoodfacts.org/discover
* Open Food Facts Github: https://github.com/openfoodfacts
* Spellcheck project: https://github.com/openfoodfacts/openfoodfacts-ai/tree/develop/spellcheck
* Model card: https://huggingface.co/openfoodfacts/spellcheck-mistral-7b
"""

# CPU/GPU device
zero = torch.Tensor([0]).cuda()

# Transformers seed to orient generation to be reproducible (as possible since it doesn't ensure 100% reproducibility)
set_seed(42)


##########################
# LOADING
##########################
# Tokenizer
logging.info(f"Load tokenizer from {MODEL_ID}.")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

# Model
logging.info(f"Load model from {MODEL_ID}.")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    # attn_implementation="flash_attention_2", # Not supported by ZERO GPU
    # torch_dtype=torch.bfloat16,
)


##########################
# FUNCTIONS
##########################
@spaces.GPU
def process(text: str) -> str:
    """Take the text, the tokenizer and the causal model and generate the correction."""
    prompt = prepare_instruction(text)
    input_ids = tokenizer(
        prompt, 
        add_special_tokens=True, 
        return_tensors="pt"
    ).input_ids
    with torch.no_grad():
        output = model.generate(
            input_ids.to(zero.device), # GPU
            do_sample=False,
            max_new_tokens=512,
        )
    return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt):].strip()


def prepare_instruction(text: str) -> str:
    """Prepare instruction prompt for fine-tuning and inference.
    Identical to instruction during training.

    Args:
        text (str): List of ingredients

    Returns:
        str: Instruction.
    """
    instruction = (
        "###Correct the list of ingredients:\n"
        + text
        + "\n\n###Correction:\n"
    )
    return instruction


##########################
# GRADIO SETUP
##########################
with gr.Blocks() as demo:

    gr.Markdown(PRESENTATION)
    
    with gr.Row():
        with gr.Column():
            image = gr.Image(type="pil", label="image_input", interactive=False)

        with gr.Column():
            ingredients = gr.Textbox(label="List of ingredients")
            spellcheck_button = gr.Button(value='Run spellcheck')
            correction = gr.Textbox(label="Correction", interactive=False)

    with gr.Row():
        gr.Examples(
            fn=process,
            examples=EXAMPLES,
            inputs=[
                image,
                ingredients,
            ],
        )
    spellcheck_button.click(
        fn=process,
        inputs=[ingredients],
        outputs=[correction]
    )


if __name__ == "__main__":
    # Launch the demo
    demo.launch()