File size: 6,265 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
from typing import List
import torch
from mmpretrain.structures import DataSample
class OTTERMMBenchPromptConstructor:
"""Prompt constructor for OTTER on MMBench.
Args:
image_prompt (str): Image prompt. Defaults to `''`.
reply_prompt (str): Reply prompt. Defaults to `''`.
"""
def __init__(self, user_label: str = '', model_label: str = '') -> None:
self.image_token = '<image>'
self.reply_token = '<answer>'
self.user_label = user_label
self.model_label = model_label
def __call__(self, inputs: dict) -> dict:
"""Construct prompt.
Args:
inputs (dict): Input data containing image and data_samples.
Returns:
dict: A dict containing prompt, images and data_samples.
"""
images = [image.unsqueeze(0) for image in inputs['inputs']]
data_samples = [data_sample for data_sample in inputs['data_samples']]
images = torch.cat(images, dim=0)
inputs = {'image': images, 'data_samples': data_samples}
data_samples = inputs['data_samples']
prompt = self._process(data_samples)
inputs.update({'prompt': prompt})
return inputs
def _process(self, data_samples: List[DataSample]) -> str:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
assert len(data_samples) == 1, 'Only support batch size 1.'
data_sample = data_samples[0]
question = data_sample.get('question')
options = data_sample.get('options')
context = data_sample.get('context')
# e.g. <image>User: What is the color of the sky? A: Blue B: Red C: Green D: Yellow GPT:<answer> # noqa
if context is not None:
prompt = f'{self.image_token}{self.user_label} {context} {question} {options} {self.model_label}:{self.reply_token}' # noqa
else:
prompt = f'{self.image_token}{self.user_label} {question} {options} {self.model_label}:{self.reply_token}' # noqa
return prompt
class OTTERCOCOCaotionPromptConstructor(OTTERMMBenchPromptConstructor):
"""Prompt constructor for OTTER on COCO Caption."""
def _process(self, data_samples: List[DataSample]) -> str:
# e.g. <image>User: a photo of GPT:<answer> # noqa
prompt = f'{self.image_token}{self.user_label} a photo of {self.model_label}:{self.reply_token}' # noqa
return prompt
class OTTERScienceQAPromptConstructor(OTTERMMBenchPromptConstructor):
"""Prompt constructor for OTTER on ScienceQA."""
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
def _process(self, data_samples: List[DataSample]) -> str:
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
'Question: ' + data_sample.get('question') + '\n'
for data_sample in data_samples
] # noqa
choices = [data_sample.get('choices') for data_sample in data_samples]
choices = [[
f'({self.choice_mapping[i]}) ' + item
for i, item in enumerate(choice)
] for choice in choices]
choices = [
'Choices: ' + ' '.join(choice) + '\n' for choice in choices
] # noqa
contexts = [
'Context: ' + data_sample.get('hint') + '\n'
for data_sample in data_samples
] # noqa
question = questions[0]
choice = choices[0]
context = contexts[0]
prompt = f'{self.image_token}{self.user_label} {context} {question} {choice} The answer is {self.model_label}:{self.reply_token}' # noqa
return prompt
class OTTERVQAPromptConstructor(OTTERMMBenchPromptConstructor):
"""Prompt constructor for OTTER on VQA."""
def _process(self, data_samples: List[DataSample]) -> str:
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
data_sample.get('question') for data_sample in data_samples
]
question = questions[0]
prompt = f'{self.image_token}{self.user_label} {question}. Answer it with with few words. {self.model_label}:{self.reply_token}' # noqa
return prompt
class OTTERVSRPromptConstructor(OTTERMMBenchPromptConstructor):
"""Prompt constructor for OTTER on VSR."""
def _process(self, data_samples: List[DataSample]) -> str:
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
data_sample.get('question') for data_sample in data_samples
]
question = questions[0]
prompt = f'{self.image_token}{self.user_label} {question}. Is the above description correct? Answer yes or no. {self.model_label}:{self.reply_token}' # noqa
return prompt
class OTTERSEEDBenchPromptConstructor(OTTERMMBenchPromptConstructor):
def _process(self, data_samples: List[DataSample]) -> str:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
data_sample.get('question') for data_sample in data_samples
]
question = questions[0]
prompt = f'{self.image_token}{self.user_label} {question} {self.model_label}:{self.reply_token}' # noqa
return prompt
class OTTERMMEPromptConstructor(OTTERMMBenchPromptConstructor):
"""Prompt constructor for OTTER on MME.
Args:
image_prompt (str): Image prompt. Defaults to `''`.
reply_prompt (str): Reply prompt. Defaults to `''`.
"""
def _process(self, data_samples: List[DataSample]) -> str:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
assert len(data_samples) == 1, 'Only support batch size 1.'
question = data_samples[0].get('question')
prompt = f'{self.image_token}{self.user_label} {question} {self.model_label}:{self.reply_token}' # noqa
return prompt
|