File size: 3,285 Bytes
2df6b4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import spaces


# Example images and texts
EXAMPLES = [
    ["images/ingredients_1.jpg", "24.36% chocolat noir 63% origine non UE (cacao, sucre, beurre de cacao, émulsifiant léci - thine de colza, vanille bourbon gousse), œuf, farine de blé, beurre, sucre, miel, sucre perlé, levure chimique, zeste de citron."],
    ["images/ingredients_2.jpg", "farine de froment, œufs, lait entier pasteurisé Aprigine: France), sucre, sel, extrait de vanille naturelle Conditi( 35."],
    ["images/ingredients_3.jpg", "tural basmati rice - cooked (98%), rice bran oil, salt"],
    ["images/ingredients_4.jpg", "Eau de noix de coco 93.9%, Arôme natutel de fruit"],
    ["images/ingredients_5.jpg", "Sucre, pâte de cacao, beurre de cacao, émulsifiant: léci - thines (soja). Peut contenir des traces de lait. Chocolat noir: cacao: 50% minimum. À conserver à l'abri de la chaleur et de l'humidité. Élaboré en France."],
]
MODEL_ID = "openfoodfacts/spellcheck-mistral-7b"


# CPU/GPU device
zero = torch.Tensor([0]).cuda()

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

# Model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
)

@spaces.GPU
def process(text: str) -> str:
    """Take the text, the tokenizer and the causal model and generate the correction."""
    prompt = prepare_instruction(text)
    input_ids = tokenizer(
        prompt, 
        add_special_tokens=True, 
        return_tensors="pt"
    ).input_ids
    output = model.generate(
        input_ids.to(zero.device), # GPU
        do_sample=False,
        max_new_tokens=512,
    )
    return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt):].strip()


def prepare_instruction(text: str) -> str:
    """Prepare instruction prompt for fine-tuning and inference.

    Args:
        text (str): List of ingredients

    Returns:
        str: Instruction.
    """
    instruction = (
        "###Correct the list of ingredients:\n"
        + text
        + "\n\n###Correction:\n"
    )
    return instruction


##########################
# GRADIO SETUP
##########################

# Creating the Gradio interface
with gr.Blocks() as demo:

    gr.Markdown("# Ingredients Spellcheck")
    gr.Markdown("")
    
    with gr.Row():
        with gr.Column():
            image = gr.Image(type="pil", label="image_input")
            ingredients = gr.Textbox(label="List of ingredients")
            spellcheck_button = gr.Button(value='Spellcheck')

        with gr.Column():
            correction = gr.Textbox(label="Correction", interactive=False)

    with gr.Row():
        gr.Examples(
            fn=process,
            examples=EXAMPLES,
            inputs=[
                image,
                ingredients,
            ],
            outputs=[correction],
            run_on_click=False,
        )

    spellcheck_button.click(
        fn=process,
        inputs=[ingredients],
        outputs=[correction]
    )


if __name__ == "__main__":
    # Launch the demo
    demo.launch()