akhaliq HF staff commited on
Commit
259d504
1 Parent(s): 4c6948c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
3
+ from PIL import Image
4
+ import torch
5
+
6
+ # Load the processor and model
7
+ processor = AutoProcessor.from_pretrained(
8
+ 'allenai/Molmo-7B-D-0924',
9
+ trust_remote_code=True,
10
+ torch_dtype='auto',
11
+ device_map='auto'
12
+ )
13
+
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ 'allenai/Molmo-7B-D-0924',
16
+ trust_remote_code=True,
17
+ torch_dtype='auto',
18
+ device_map='auto'
19
+ )
20
+
21
+ def process_image_and_text(image, text):
22
+ # Process the image and text
23
+ inputs = processor.process(
24
+ images=[Image.fromarray(image)],
25
+ text=text
26
+ )
27
+
28
+ # Move inputs to the correct device and make a batch of size 1
29
+ inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
30
+
31
+ # Generate output
32
+ output = model.generate_from_batch(
33
+ inputs,
34
+ GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
35
+ tokenizer=processor.tokenizer
36
+ )
37
+
38
+ # Only get generated tokens; decode them to text
39
+ generated_tokens = output[0, inputs['input_ids'].size(1):]
40
+ generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
41
+
42
+ return generated_text
43
+
44
+ def chatbot(image, text, history):
45
+ if image is None:
46
+ return "Please upload an image first.", history
47
+
48
+ response = process_image_and_text(image, text)
49
+ history.append((text, response))
50
+ return response, history
51
+
52
+ # Define the Gradio interface
53
+ with gr.Blocks() as demo:
54
+ gr.Markdown("# Image Chatbot with Molmo-7B-D-0924")
55
+
56
+ with gr.Row():
57
+ image_input = gr.Image(type="numpy")
58
+ chatbot_output = gr.Chatbot()
59
+
60
+ text_input = gr.Textbox(placeholder="Ask a question about the image...")
61
+ submit_button = gr.Button("Submit")
62
+
63
+ state = gr.State([])
64
+
65
+ submit_button.click(
66
+ chatbot,
67
+ inputs=[image_input, text_input, state],
68
+ outputs=[chatbot_output, state]
69
+ )
70
+
71
+ text_input.submit(
72
+ chatbot,
73
+ inputs=[image_input, text_input, state],
74
+ outputs=[chatbot_output, state]
75
+ )
76
+
77
+ demo.launch()