Spaces:
Running
Running
| import logging | |
| from typing import List, Tuple | |
| from dataclasses import dataclass | |
| from transformers import ProcessorMixin, AutoProcessor, AutoTokenizer | |
| from src.arguments import DataArguments, ModelArguments | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| class TrainCollator: | |
| data_args: DataArguments | |
| model_args: ModelArguments | |
| processor: ProcessorMixin | |
| def __call__(self, examples): | |
| """ | |
| :param examples: [{qry:..., qry_image:..., pos_text:..., pos_image:...}] * batch_size | |
| """ | |
| qry_inputs = self._get_batch_inputs(examples, 0, 1) | |
| pos_inputs = self._get_batch_inputs(examples, 2, 3) | |
| if "hard_neg" in self.data_args.dataset_name: | |
| hard_neg_inputs = self._get_batch_inputs(examples, 4, 5) | |
| return qry_inputs, pos_inputs, hard_neg_inputs | |
| return qry_inputs, pos_inputs | |
| def _get_batch_inputs(self, examples, text_idx, image_idx): | |
| input_ids, pixel_values = [], [] | |
| image_mask, image_sizes, image_grid_thw = [], [], [] | |
| for example in examples: | |
| text, image = example[text_idx], example[image_idx] | |
| has_image = image is not None | |
| image_mask.append(1 if has_image else 0) | |
| # 统一processor调用逻辑 | |
| if self.model_args.model_backbone == "llava_next": | |
| inputs = self.processor( | |
| text=text, | |
| images=image if has_image else None, | |
| return_tensors="pt", | |
| max_length=self.data_args.max_len, | |
| truncation=True | |
| ) | |
| elif self.model_args.model_backbone in ["qwen", "qwen2_vl"]: | |
| inputs = self.processor( | |
| text=[text], | |
| images=[image] if has_image else None, | |
| return_tensors="pt", | |
| max_length=self.data_args.max_len, | |
| truncation=True | |
| ) | |
| else: | |
| inputs = self.processor( | |
| text=text, | |
| images=[image] if has_image else None, | |
| return_tensors="pt", | |
| max_length=self.data_args.max_len, | |
| truncation=True | |
| ) | |
| if has_image: | |
| if self.model_args.model_backbone == "qwen": | |
| pixel_values.append(inputs['pixel_values'].unsqueeze(0)) | |
| else: | |
| pixel_values.append(inputs['pixel_values']) | |
| input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) | |
| if "image_sizes" in inputs: | |
| image_sizes.append(inputs['image_sizes']) | |
| if "image_grid_thw" in inputs: | |
| image_grid_thw.append(inputs['image_grid_thw']) | |
| input_ids = torch._C._nn.pad_sequence( | |
| input_ids, | |
| batch_first=True, | |
| padding_value=self.processor.tokenizer.pad_token_id | |
| ).squeeze(2) | |
| attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id) | |
| inputs = { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask, | |
| 'image_mask': torch.tensor(image_mask, dtype=torch.float) | |
| } | |
| if any(image_mask): | |
| if pixel_values: | |
| inputs['pixel_values'] = torch.cat(pixel_values, dim=0) | |
| if image_sizes: | |
| inputs['image_sizes'] = torch.cat(image_sizes, dim=0) | |
| if image_grid_thw: | |
| inputs['image_grid_thw'] = torch.cat(image_grid_thw, dim=0) | |
| if self.model_args.model_backbone == "internvl_2_5": | |
| inputs['image_flags'] = inputs['image_mask'].to(torch.long) | |
| return inputs | |
| class EvalCollator: | |
| data_args: DataArguments | |
| model_args: ModelArguments | |
| processor: ProcessorMixin | |
| def __call__(self, examples): | |
| """ | |
| :param examples: qry, qry_image, pos_text, pos_image | |
| """ | |
| inputs = self._get_batch_inputs(examples) | |
| return inputs | |
| def _get_batch_inputs(self, examples): | |
| input_ids, pixel_values, image_sizes = [], [], [] | |
| image_mask = [] | |
| image_exist = False | |
| for example in examples: | |
| text, image = example | |
| has_image = image is not None | |
| image_mask.append(1 if has_image else 0) | |
| if self.model_args.model_backbone == "internvl_2_5": | |
| inputs = self.processor( | |
| text=text, | |
| images=[image] if has_image else None, | |
| return_tensors="pt", | |
| max_length=self.data_args.max_len, | |
| truncation=True | |
| ) | |
| input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) | |
| if has_image: | |
| pixel_values.append(inputs['pixel_values']) | |
| if 'image_sizes' in inputs: | |
| image_sizes.append(inputs['image_sizes']) | |
| continue | |
| if image is None: | |
| if self.model_args.model_backbone == "llava_next": | |
| inputs = self.processor(images=None, text=text, return_tensors="pt") | |
| else: | |
| inputs = self.processor(text, None, return_tensors="pt", max_length=self.data_args.max_len, | |
| truncation=True) | |
| input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) | |
| pixel_values.append(None) | |
| image_sizes.append(None) | |
| else: | |
| image_exist = True | |
| if self.model_args.model_backbone == "llava_next": | |
| inputs = self.processor(images=image, text=text, return_tensors="pt") | |
| else: | |
| inputs = self.processor(text, [image], return_tensors="pt", max_length=self.data_args.max_len, truncation=True) | |
| input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) | |
| pixel_values.append(inputs['pixel_values']) | |
| image_sizes.append(inputs['image_sizes']) | |
| input_ids = torch._C._nn.pad_sequence( | |
| input_ids, batch_first=True, padding_value=self.processor.tokenizer.pad_token_id | |
| ).squeeze(2) | |
| attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id) | |
| if self.model_args.model_backbone == "internvl_2_5": | |
| inputs = { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask, | |
| 'image_mask': torch.tensor(image_mask, dtype=torch.float) | |
| } | |
| if any(image_mask): | |
| if pixel_values: | |
| inputs['pixel_values'] = torch.cat(pixel_values, dim=0) | |
| if image_sizes: | |
| inputs['image_sizes'] = torch.cat(image_sizes, dim=0) | |
| inputs['image_flags'] = inputs['image_mask'].to(torch.long) | |
| del inputs['image_mask'] | |
| else: | |
| if not image_exist: | |
| dummy_pixel_values = torch.zeros(input_ids.shape[0], 1) | |
| dummy_image_sizes = torch.ones(input_ids.shape[0], 1) | |
| inputs = { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask, | |
| 'pixel_values': dummy_pixel_values, | |
| 'image_sizes': dummy_image_sizes, | |
| } | |
| else: | |
| pixel_values_shape = list(set(v.shape for v in pixel_values if v is not None))[0] | |
| pixel_values = [v if v is not None else torch.zeros(pixel_values_shape) for v in pixel_values] | |
| pixel_values = torch.cat(pixel_values, dim=0) | |
| image_sizes_shape = list(set(v.shape for v in image_sizes if v is not None))[0] | |
| image_sizes = [v if v is not None else torch.ones(image_sizes_shape) for v in image_sizes] | |
| image_sizes = torch.cat(image_sizes, dim=0) | |
| inputs = { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask, | |
| 'pixel_values': pixel_values, | |
| 'image_sizes': image_sizes, | |
| } | |
| return inputs | |
| class CLIPCollator: | |
| data_args: DataArguments | |
| vis_processors: AutoProcessor | |
| txt_processors: AutoTokenizer | |
| def __call__(self, examples): | |
| """ | |
| :param examples: qry, qry_image, pos_text, pos_image | |
| """ | |
| inputs = self._get_batch_inputs(examples) | |
| return inputs | |
| def _get_batch_inputs(self, examples): | |
| input_ids, pixel_values, attention_mask = [], [], [] | |
| image_exist, text_exist = False, False | |
| for example in examples: | |
| text, image = example | |
| if image is not None: | |
| if image.mode == 'L': | |
| image = image.convert('RGB') | |
| image_inputs = self.vis_processors(images=image, return_tensors="pt") | |
| image_exist = True | |
| pixel_values.append(image_inputs['pixel_values']) | |
| if text: | |
| text_exist = True | |
| text_inputs = self.txt_processors(text, padding=getattr(self.data_args, "padding", True), max_length=self.data_args.max_len, truncation=True, return_tensors="pt") | |
| input_ids.append(text_inputs["input_ids"].squeeze(0)) | |
| if text_exist: | |
| input_ids = torch.nn.utils.rnn.pad_sequence( | |
| input_ids, batch_first=True, padding_value=self.txt_processors.pad_token_id | |
| ) | |
| attention_mask = input_ids.ne(self.txt_processors.pad_token_id) | |
| if image_exist: | |
| pixel_values = torch.cat(pixel_values, dim=0) | |
| if text_exist and image_exist: | |
| assert input_ids.size()[0]==pixel_values.size()[0] | |
| inputs = { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask, | |
| 'pixel_values': pixel_values, | |
| } | |
| return inputs | |
| class OpenCLIPCollator: | |
| data_args: DataArguments | |
| vis_processors: AutoProcessor | |
| txt_processors: AutoTokenizer | |
| def __call__(self, examples): | |
| """ | |
| :param examples: qry, qry_image, pos_text, pos_image | |
| """ | |
| inputs = self._get_batch_inputs(examples) | |
| return inputs | |
| def _get_batch_inputs(self, examples): | |
| input_ids, pixel_values, attention_mask = [], [], [] | |
| image_exist, text_exist = False, False | |
| for example in examples: | |
| text, image = example | |
| if image is not None: | |
| if image.mode == 'L': | |
| image = image.convert('RGB') | |
| image_inputs = self.vis_processors(image).unsqueeze(0) | |
| image_exist = True | |
| pixel_values.append(image_inputs) | |
| if text: | |
| text_exist = True | |
| text_inputs = self.txt_processors(text) | |
| input_ids.append(text_inputs) | |
| if text_exist: | |
| input_ids = torch.cat(input_ids, dim=0) | |
| attention_mask = input_ids.ne(0) | |
| if image_exist: | |
| pixel_values = torch.cat(pixel_values, dim=0) | |
| if text_exist and image_exist: | |
| assert input_ids.size()[0]==pixel_values.size()[0] | |
| inputs = { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask, | |
| 'pixel_values': pixel_values, | |
| } | |
| return inputs | |