from collections import OrderedDict import torch import os import copy from dataclasses import dataclass import json import re from typing import Dict, Optional, Sequence import transformers from llava.constants import ( IGNORE_INDEX, ) from torch.utils.data import Dataset from llava.util.tokenization import ( preprocess_llama_2, preprocess_llama_2_obj_identifier, preprocess_multimodal, preprocess, ) from llava import conversation as conversation_lib class ObjIdentifierDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__( self, tokenizer: transformers.PreTrainedTokenizer, data_path: str | list, scene_to_obj_mapping: str, obj_context_feature_type: str = "text", mode: str = "train", **kwargs, ): super(ObjIdentifierDataset, self).__init__() self.tokenizer = tokenizer self.scene_to_obj_mapping = json.load(open(scene_to_obj_mapping, "r")) self.update_data(data_path) self.obj_context_feature_type = obj_context_feature_type self.mode = mode def __len__(self): return len(self.list_data_dict) def update_data(self, data_path: str): assert self.scene_to_obj_mapping is not None, "scene_to_obj_mapping needs to be set first." if isinstance(data_path, str): self.list_data_dict = json.load(open(data_path, "r")) elif isinstance(data_path, list): self.list_data_dict = data_path def __getitem__(self, i) -> Dict[str, torch.Tensor]: sources = copy.deepcopy(self.list_data_dict[i]) if isinstance(i, int): sources = [sources] assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME ############ scene_id = sources[0]["scene_id"] input_obj_dict = copy.deepcopy(self.scene_to_obj_mapping[scene_id]) # prepare the object-centric features # we want the LLM to see: # "%%%% Object-centric context: : , : , ..." # where will later be replaced by the actual feature in vector form, # everything else is pure text string. # 1. We need to first change the object_id to a new object_id, e.g., 'obj_0', 'obj_1', ..., # and replace the old object_id with new object_id in the text conversation # 2. Tokenize the conversation, and add object context to the tokenized conversation # 3. Gather and return the necessary information for each object, # so that it can be later embeded into vector # 1. change the object_id to a new object_id # original object_id: 'wardrobe-0', 'three-seat/multi-seat sofa-1', ... # convert to obj_id: 'obj_0', 'obj_1', ... # and remember the mapping old_id_to_new_id_mapping = {} result_obj_dict = OrderedDict() # first pass, map the old object_id to new object_id for old_id, obj_info in input_obj_dict.items(): # make sure old_id doesn't contain < or > assert ( "<" not in old_id and ">" not in old_id ), "object_id in scene graph should not contain < or >" new_id = f"obj_{len(old_id_to_new_id_mapping)}" old_id_to_new_id_mapping[old_id] = new_id # second pass, create the new object-centric context, modify the object_id in the text content for old_id, obj_info in input_obj_dict.items(): new_id = old_id_to_new_id_mapping[old_id] # TODO: Determine what information to include in the object-centric context result_obj_info_dict = {} result_obj_info_dict["category"] = obj_info["category"] # result_obj_info_dict["category_id"] = obj_info["category_id"] # for relations, we need to replace the old object_id with new object_id # result_obj_info_dict["relations"] = [] # for relation in obj_info["relations"]: # for local_old_id, local_new_id in old_id_to_new_id_mapping.items(): # if local_old_id in relation: # result_obj_info_dict["relations"].append( # re.sub(rf"<{local_old_id}>", f"<{local_new_id}>", relation) # ) if "description" in obj_info: result_obj_info_dict["description"] = obj_info["description"] else: # print(f"WARNING: Object {old_id} does not have a description.") pass # use two decimal places for the centroid and extent result_obj_info_dict["centroid"] = ( f"[{obj_info['centroid'][0]:.2f}, {obj_info['centroid'][1]:.2f}, {obj_info['centroid'][2]:.2f}]" ) result_obj_info_dict["extent"] = ( f"[{obj_info['extent'][0]:.2f}, {obj_info['extent'][1]:.2f}, {obj_info['extent'][2]:.2f}]" ) result_obj_dict[new_id] = result_obj_info_dict # replace the old object_id with new object_id in the text content # text conversation example: # { # "id": "55f2b905-d367-443d-8f88-ef71b958c81f@LivingRoom-3973@1", # "scene_id": "55f2b905-d367-443d-8f88-ef71b958c81f@LivingRoom-3973", # "conversations": [ # { # "from": "human", # "value": "Can you describe the ambiance of this room?" # }, # { # "from": "gpt", # "value": "In this Living Room, the arrangement of furniture caters to both style and function. The

warm wooden hue wardrobe

[] stands with a retro flair, while the

neutral grey, sleek rectangular form multi-seat sofa

[] and

neutral grey, sleek rectangular form three-seat

[] provide modern and comfortable seating options. The

sleek black, dark grey, brown, rectangular coffee table

[] and

sleek black, dark grey, brown coffee table

[] in minimalist style serve as focal points and functional pieces for gatherings. The

light grey and dark grey (two-tone), shell-like backrest, smooth armchair

[] and

light grey and dark grey (two-tone), smooth armchair

[] add additional seating, complemented by the

rich walnut brown top and contrasting light grey base side table

[] for convenience. Suspended above, the

gradient of grey to bronze hues, cylindrical with abstract cityscape cutouts pendant lamp

[] offers a decorative element with its unique Chinoiserie design. The room's setup is ideal for hosting guests or enjoying quiet evenings, with thoughtful placement of each piece to enhance the living experience." # } # ] # }, for conv in sources[0]["conversations"]: for old_id, new_id in old_id_to_new_id_mapping.items(): conv["value"] = re.sub(rf"<{old_id}>", f"<{new_id}>", conv["value"]) # if in generate mode, shave off the last conversation if it is from the assistant if self.mode == "generate" and sources[0]["conversations"][-1]["from"] == "gpt": sources[0]["conversations"] = sources[0]["conversations"][:-1] # 2. Tokenize the conversation, and add object context to the tokenized conversation sources = preprocess_multimodal( copy.deepcopy([e["conversations"] for e in sources]), is_multimodal=True, mm_use_im_start_end=False, ) data_dict = preprocess_llama_2_obj_identifier( sources=sources, tokenizer=self.tokenizer, obj_dict=result_obj_dict, obj_context_feature_type=self.obj_context_feature_type, mode=self.mode, ) # if in generate mode, add the obj context and bbox label to the data_dict # so that we can use them later to compute the metrics if self.mode == "generate": data_dict["obj_context"] = result_obj_dict if "bbox" in self.list_data_dict[i]: data_dict["bbox_label"] = self.list_data_dict[i]["bbox"] # full_info_dict is the full information of this data sample # {'id': 'scene0643_00$desk-0@0', 'scene_id': 'scene0643_00', 'conversations': [{...}], # 'referred_object_id': '0', 'referred_object_text': 'desk', # 'grounded_object_reference': 'a brown wooden office desk on the left to the gray shelf.', # 'bbox': [0.3769365990161897, -0.06906220592784873, -0.020513275327205656, 1.1370925301275254, 1.5355764355778696, 0.8130822017173767] # } data_dict["full_info_dict"] = self.list_data_dict[i] return data_dict @dataclass class DataCollatorForObjIdentifierDataset(object): """Collate examples for supervised fine-tuning.""" def __init__(self, tokenizer: transformers.PreTrainedTokenizer, **kwargs): self.tokenizer = tokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple( [instance[key] for instance in instances] for key in ("input_ids", "labels") ) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id ) labels = torch.nn.utils.rnn.pad_sequence( labels, batch_first=True, padding_value=IGNORE_INDEX ) input_ids = input_ids[:, : self.tokenizer.model_max_length] labels = labels[:, : self.tokenizer.model_max_length] batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) return batch @dataclass class DataCollatorForBatchDecodingObjIdentifierDataset(object): """Collate examples for batch decoding.""" def __init__(self, tokenizer: transformers.PreTrainedTokenizer, **kwargs): self.tokenizer = tokenizer def pad_sequence(self, input_ids, batch_first, padding_value): if self.tokenizer.padding_side == "right": input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids] input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=batch_first, padding_value=padding_value ) if self.tokenizer.padding_side == "right": input_ids = torch.flip(input_ids, [1]) return input_ids def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids = [instance["input_ids"] for instance in instances] input_ids = self.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id ) batch = dict(input_ids=input_ids) if "bbox_label" in instances[0].keys(): batch["bbox_label"] = [instance["bbox_label"] for instance in instances] if "obj_context" in instances[0].keys(): batch["obj_context"] = [instance["obj_context"] for instance in instances] if "full_info_dict" in instances[0].keys(): batch["full_info_dict"] = [instance["full_info_dict"] for instance in instances] return batch # test the dataset if __name__ == "__main__": from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( "/data/jianingy/3d-llama/checkpoints/llava-llama-2-7b-chat-lightning-preview", use_fast=False, ) tokenizer.pad_token = tokenizer.unk_token conversation_lib.default_conversation = conversation_lib.conv_templates["llava_llama_2"] dataset = ObjIdentifierDataset( tokenizer, data_path="/home/jianingy/research/LLaVA-original/llava/dataset/3dfront/grounded_scene_description_gpt_format.json", scene_to_obj_mapping="/home/jianingy/research/LLaVA-original/llava/dataset/3dfront/compressed_organized_data.json", ) print(len(dataset)) print(dataset[0])