ethux commited on
Commit
88d793f
1 Parent(s): 8850a0d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio.data_classes import FileData
3
+ from huggingface_hub import snapshot_download
4
+ from pathlib import Path
5
+ import base64
6
+ import spaces
7
+ import os
8
+
9
+ from mistral_inference.transformer import Transformer
10
+ from mistral_inference.generate import generate
11
+
12
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
13
+ from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk
14
+ from mistral_common.protocol.instruct.request import ChatCompletionRequest
15
+
16
+ models_path = Path.home().joinpath('pixtral', 'Pixtral')
17
+ models_path.mkdir(parents=True, exist_ok=True)
18
+
19
+ snapshot_download(repo_id="mistral-community/pixtral-12b-240910",
20
+ allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
21
+ local_dir=models_path)
22
+
23
+ tokenizer = MistralTokenizer.from_file(f"{models_path}/tekken.json")
24
+ model = Transformer.from_folder(models_path)
25
+
26
+ def image_to_base64(image_path):
27
+ with open(image_path, 'rb') as img:
28
+ encoded_string = base64.b64encode(img.read()).decode('utf-8')
29
+ return f"data:image/jpeg;base64,{encoded_string}"
30
+
31
+ @spaces.GPU
32
+ def run_inference(image_url, prompt):
33
+ base64 = image_to_base64(image_url)
34
+ completion_request = ChatCompletionRequest(messages=[UserMessage(content=[ImageURLChunk(image_url=base64), TextChunk(text=prompt)])])
35
+
36
+ encoded = tokenizer.encode_chat_completion(completion_request)
37
+
38
+ images = encoded.images
39
+ tokens = encoded.tokens
40
+
41
+ out_tokens, _ = generate([tokens], model, images=[images], max_tokens=512, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
42
+ result = tokenizer.decode(out_tokens[0])
43
+ return [[prompt, result]]
44
+
45
+ with gr.Blocks() as demo:
46
+ with gr.Row():
47
+ image_box = gr.Image(type="filepath")
48
+
49
+ chatbot = gr.Chatbot(
50
+ scale = 2,
51
+ height=750
52
+ )
53
+ text_box = gr.Textbox(
54
+ placeholder="Enter text and press enter, or upload an image",
55
+ container=False,
56
+ )
57
+
58
+
59
+ btn = gr.Button("Submit")
60
+ clicked = btn.click(run_inference,
61
+ [image_box,text_box],
62
+ chatbot
63
+ )
64
+
65
+ demo.queue().launch()