File size: 6,265 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from typing import List

import torch
from mmpretrain.structures import DataSample


class OTTERMMBenchPromptConstructor:
    """Prompt constructor for OTTER on MMBench.

    Args:
        image_prompt (str): Image prompt. Defaults to `''`.
        reply_prompt (str): Reply prompt. Defaults to `''`.
    """

    def __init__(self, user_label: str = '', model_label: str = '') -> None:
        self.image_token = '<image>'
        self.reply_token = '<answer>'
        self.user_label = user_label
        self.model_label = model_label

    def __call__(self, inputs: dict) -> dict:
        """Construct prompt.

        Args:
            inputs (dict): Input data containing image and data_samples.

        Returns:
            dict: A dict containing prompt, images and data_samples.
        """
        images = [image.unsqueeze(0) for image in inputs['inputs']]
        data_samples = [data_sample for data_sample in inputs['data_samples']]
        images = torch.cat(images, dim=0)
        inputs = {'image': images, 'data_samples': data_samples}
        data_samples = inputs['data_samples']
        prompt = self._process(data_samples)
        inputs.update({'prompt': prompt})

        return inputs

    def _process(self, data_samples: List[DataSample]) -> str:
        """Process data sample to prompt.

        Args:
            data_samples (List[DataSample]): A list of data_samples.

        Returns:
            str: Prompt.
        """
        assert len(data_samples) == 1, 'Only support batch size 1.'
        data_sample = data_samples[0]
        question = data_sample.get('question')
        options = data_sample.get('options')
        context = data_sample.get('context')
        # e.g. <image>User: What is the color of the sky? A: Blue B: Red C: Green D: Yellow GPT:<answer>  # noqa
        if context is not None:
            prompt = f'{self.image_token}{self.user_label} {context} {question} {options} {self.model_label}:{self.reply_token}'  # noqa
        else:
            prompt = f'{self.image_token}{self.user_label} {question} {options} {self.model_label}:{self.reply_token}'  # noqa

        return prompt


class OTTERCOCOCaotionPromptConstructor(OTTERMMBenchPromptConstructor):
    """Prompt constructor for OTTER on COCO Caption."""

    def _process(self, data_samples: List[DataSample]) -> str:
        # e.g. <image>User: a photo of GPT:<answer>  # noqa
        prompt = f'{self.image_token}{self.user_label} a photo of {self.model_label}:{self.reply_token}'  # noqa
        return prompt


class OTTERScienceQAPromptConstructor(OTTERMMBenchPromptConstructor):
    """Prompt constructor for OTTER on ScienceQA."""

    choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}

    def _process(self, data_samples: List[DataSample]) -> str:
        assert len(data_samples) == 1, 'Only support batch size 1.'
        questions = [
            'Question: ' + data_sample.get('question') + '\n'
            for data_sample in data_samples
        ]  # noqa
        choices = [data_sample.get('choices') for data_sample in data_samples]
        choices = [[
            f'({self.choice_mapping[i]}) ' + item
            for i, item in enumerate(choice)
        ] for choice in choices]
        choices = [
            'Choices: ' + ' '.join(choice) + '\n' for choice in choices
        ]  # noqa
        contexts = [
            'Context: ' + data_sample.get('hint') + '\n'
            for data_sample in data_samples
        ]  # noqa
        question = questions[0]
        choice = choices[0]
        context = contexts[0]
        prompt = f'{self.image_token}{self.user_label} {context} {question} {choice} The answer is {self.model_label}:{self.reply_token}'  # noqa
        return prompt


class OTTERVQAPromptConstructor(OTTERMMBenchPromptConstructor):
    """Prompt constructor for OTTER on VQA."""

    def _process(self, data_samples: List[DataSample]) -> str:
        assert len(data_samples) == 1, 'Only support batch size 1.'
        questions = [
            data_sample.get('question') for data_sample in data_samples
        ]
        question = questions[0]
        prompt = f'{self.image_token}{self.user_label} {question}. Answer it with with few words. {self.model_label}:{self.reply_token}'  # noqa
        return prompt


class OTTERVSRPromptConstructor(OTTERMMBenchPromptConstructor):
    """Prompt constructor for OTTER on VSR."""

    def _process(self, data_samples: List[DataSample]) -> str:
        assert len(data_samples) == 1, 'Only support batch size 1.'
        questions = [
            data_sample.get('question') for data_sample in data_samples
        ]
        question = questions[0]
        prompt = f'{self.image_token}{self.user_label} {question}. Is the above description correct? Answer yes or no. {self.model_label}:{self.reply_token}'  # noqa
        return prompt


class OTTERSEEDBenchPromptConstructor(OTTERMMBenchPromptConstructor):

    def _process(self, data_samples: List[DataSample]) -> str:
        """Process data sample to prompt.

        Args:
            data_samples (List[DataSample]): A list of data_samples.

        Returns:
            str: Prompt.
        """
        assert len(data_samples) == 1, 'Only support batch size 1.'
        questions = [
            data_sample.get('question') for data_sample in data_samples
        ]
        question = questions[0]
        prompt = f'{self.image_token}{self.user_label} {question} {self.model_label}:{self.reply_token}'  # noqa
        return prompt


class OTTERMMEPromptConstructor(OTTERMMBenchPromptConstructor):
    """Prompt constructor for OTTER on MME.

    Args:
        image_prompt (str): Image prompt. Defaults to `''`.
        reply_prompt (str): Reply prompt. Defaults to `''`.
    """

    def _process(self, data_samples: List[DataSample]) -> str:
        """Process data sample to prompt.

        Args:
            data_samples (List[DataSample]): A list of data_samples.

        Returns:
            str: Prompt.
        """
        assert len(data_samples) == 1, 'Only support batch size 1.'
        question = data_samples[0].get('question')
        prompt = f'{self.image_token}{self.user_label} {question} {self.model_label}:{self.reply_token}'  # noqa
        return prompt