davanstrien HF staff commited on
Commit
4e1ec1c
β€’
1 Parent(s): 21187af
Files changed (2) hide show
  1. app.py +116 -3
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,7 +1,120 @@
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  demo.launch()
 
1
+ import spaces
2
  import gradio as gr
3
 
4
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
5
+ from qwen_vl_utils import process_vision_info
6
+ import torch
7
+ import os
8
+ import json
9
 
10
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
11
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
12
+ "Qwen/Qwen2-VL-7B-Instruct",
13
+ torch_dtype=torch.bfloat16,
14
+ # attn_implementation="flash_attention_2",
15
+ device_map="auto",
16
+ )
17
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
18
+ from pydantic import BaseModel
19
+ from typing import Tuple
20
+
21
+ class GeneralRetrievalQuery(BaseModel):
22
+ broad_topical_query: str
23
+ broad_topical_explanation: str
24
+ specific_detail_query: str
25
+ specific_detail_explanation: str
26
+ visual_element_query: str
27
+ visual_element_explanation: str
28
+
29
+ def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
30
+ if prompt_name != "general":
31
+ raise ValueError("Only 'general' prompt is available in this version")
32
+
33
+ prompt = """You are an AI assistant specialized in document retrieval tasks. Given an image of a document page, your task is to generate retrieval queries that someone might use to find this document in a large corpus.
34
+
35
+ Please generate 3 different types of retrieval queries:
36
+
37
+ 1. A broad topical query: This should cover the main subject of the document.
38
+ 2. A specific detail query: This should focus on a particular fact, figure, or point made in the document.
39
+ 3. A visual element query: This should reference a chart, graph, image, or other visual component in the document, if present.
40
+
41
+ Important guidelines:
42
+ - Ensure the queries are relevant for retrieval tasks, not just describing the page content.
43
+ - Frame the queries as if someone is searching for this document, not asking questions about its content.
44
+ - Make the queries diverse and representative of different search strategies.
45
+
46
+ For each query, also provide a brief explanation of why this query would be effective in retrieving this document.
47
+
48
+ Format your response as a JSON object with the following structure:
49
+
50
+ {
51
+ "broad_topical_query": "Your query here",
52
+ "broad_topical_explanation": "Brief explanation",
53
+ "specific_detail_query": "Your query here",
54
+ "specific_detail_explanation": "Brief explanation",
55
+ "visual_element_query": "Your query here",
56
+ "visual_element_explanation": "Brief explanation"
57
+ }
58
+
59
+ If there are no relevant visual elements, replace the third query with another specific detail query.
60
+
61
+ Here is the document image to analyze:
62
+ <image>
63
+
64
+ Generate the queries based on this image and provide the response in the specified JSON format."""
65
+
66
+ return prompt, GeneralRetrievalQuery
67
+
68
+
69
+
70
+ prompt, pydantic_model = get_retrieval_prompt("general")
71
+
72
+ @spaces.GPU
73
+ def generate_response(image):
74
+ messages = [
75
+ {
76
+ "role": "user",
77
+ "content": [
78
+ {
79
+ "type": "image",
80
+ "image": image,
81
+ },
82
+ {"type": "text", "text": prompt},
83
+ ],
84
+ }
85
+ ]
86
+
87
+ text = processor.apply_chat_template(
88
+ messages, tokenize=False, add_generation_prompt=True
89
+ )
90
+
91
+ image_inputs, video_inputs = process_vision_info(messages)
92
+
93
+ inputs = processor(
94
+ text=[text],
95
+ images=image_inputs,
96
+ videos=video_inputs,
97
+ padding=True,
98
+ return_tensors="pt",
99
+ )
100
+ inputs = inputs.to("cuda")
101
+
102
+ generated_ids = model.generate(**inputs, max_new_tokens=200)
103
+ generated_ids_trimmed = [
104
+ out_ids[len(in_ids) :]
105
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
106
+ ]
107
+
108
+ output_text = processor.batch_decode(
109
+ generated_ids_trimmed,
110
+ skip_special_tokens=True,
111
+ clean_up_tokenization_spaces=False,
112
+ )
113
+ try:
114
+ data = json.loads(output_text[0])
115
+ return data
116
+ except Exception:
117
+ return {}
118
+
119
+ demo = gr.Interface(fn=generate_response, inputs=gr.Image(type='pil'), outputs="json")
120
  demo.launch()
requirements.txt CHANGED
@@ -5,4 +5,4 @@ torch
5
  datasets
6
  huggingface_hub[hf_transfer]
7
  polars
8
- transformers@git+https://github.com/huggingface/transformers.git
 
5
  datasets
6
  huggingface_hub[hf_transfer]
7
  polars
8
+ transformers @git+https://github.com/huggingface/transformers.git