Feature Extraction
Safetensors
English
minicpmv
VisRAG
custom_code
File size: 4,606 Bytes
c7a0be3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
from dataclasses import dataclass
from transformers.utils import ModelOutput
from typing import Optional
from .modeling_minicpmv import MiniCPMV
from .modeling_minicpm import MiniCPMForCausalLM
from .resampler import Resampler
from concurrent.futures import ThreadPoolExecutor


def transform_image_mp(img_list, transform, device, max_workers=None):
    pixel_values = []
    

    
    # 使用ThreadPoolExecutor
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        for img_batch in img_list:
            img_inps = list(executor.map(transform, img_batch))
            for i in range(len(img_inps)):
                img_inps[i] = img_inps[i].to(device)
            pixel_values.append(img_inps if img_inps else [])

    return pixel_values


@dataclass
class BaseModelOutputWithAttentionMask(ModelOutput):
    last_hidden_state: torch.FloatTensor = None
    attention_mask: Optional[torch.Tensor] = None

class VisRAG_Ret(MiniCPMV): # -> MiniCPMV ->  Ultimately a CausalLM
    def fused_tokenize(
        self,
        data_list=None, # List[str] 
        img_list=None, # List[List[PIL.Image]]
        tokenizer=None,
        max_inp_length: Optional[int] = None,
        vision_hidden_states=None, # default None
        return_vision_hidden_states=False,
        **kwargs):
        
        assert data_list is not None
        bs = len(data_list)
        if img_list == None:
            img_list = [[] for i in range(bs)]
        assert bs == len(img_list)

        model_inputs = self._process_list(tokenizer, data_list, max_inp_length, padding_side="right")
        
        if vision_hidden_states is None:
            pixel_values = transform_image_mp(img_list, self.transform, self.device, max_workers=8)
            model_inputs["pixel_values"] = pixel_values
        else:
            model_inputs["vision_hidden_states"] = vision_hidden_states

        return model_inputs
    
    def prepare_context(self, inputs, tokenizer):
        text_, image_ = inputs
        if not isinstance(text_, str):
            raise NotImplementedError(f"chatml format expected, expect outmost type to be str but got {type(text_)}")
        
        # 1.add text
        content = text_ 
        
        # 2. add image
        if image_:
            if self.config.slice_mode:
                images, final_placeholder = self.get_slice_image_placeholder(
                    image_, tokenizer
                ) # crop one image into multiple sub images -> List[Image]
                content = final_placeholder + "\n" + content
            else:
                images = [image_] # only keep one image without cropping -> List[Image]
                content = (
                    tokenizer.im_start
                    + tokenizer.unk_token * self.config.query_num
                    + tokenizer.im_end
                    + "\n"
                    + content
                )
        else:
            images = []
        
        return content, images
    
    def forward(
        self,
        text, # List[str] B*str
        image, # List[ PIL.Image ] B*PIL.Image, one image for each data
        tokenizer,
        vision_hidden_states=None,
        max_inp_length=2048,
        **kwargs):
        
        processed_image = []
        processed_text = []
        
        with ThreadPoolExecutor(max_workers=8) as executor:
            contexts = list(executor.map(lambda inputs: self.prepare_context(inputs, tokenizer), zip(text, image)))
        
        for context in contexts:
            content_, image_ = context
            processed_text.append(content_)
            processed_image.append(image_)
        
        model_inputs = self.fused_tokenize(
            data_list=processed_text, # List[str]
            img_list=processed_image, # List[List[PIL.Image]]
            tokenizer=tokenizer,
            max_inp_length=max_inp_length
        )
        
        # this is vision encoder forward.
        model_inputs["inputs_embeds"], vision_hidden_states = self.get_vllm_embedding(model_inputs)
        vlm_outputs = self.llm.model(
            input_ids=None, # because image and text have been merged into model_inputs["inputs_embeds"] here, we don't give input_ids
            position_ids=None,
            inputs_embeds=model_inputs["inputs_embeds"],
            attention_mask=model_inputs["attention_mask"],
            return_dict=True
        )
        
        return BaseModelOutputWithAttentionMask(
            last_hidden_state=vlm_outputs.last_hidden_state,
            attention_mask=model_inputs.attention_mask
        )