laolida-w commited on
Commit
5bb2d2a
·
verified ·
1 Parent(s): 697a4fe

Update eval_dataset.py

Browse files
Files changed (1) hide show
  1. eval_dataset.py +10 -10
eval_dataset.py CHANGED
@@ -134,12 +134,12 @@ class SingleRegionCaptionDataset(Dataset):
134
  aspect_ratio = model_inputs["aspect_ratio"]
135
 
136
  ret = dict(
137
- input_ids=input_ids.cuda().unsqueeze(0),
138
- attention_mask=attention_mask.cuda().to(self.data_dtype).unsqueeze(0),
139
- pixel_values=pixel_values.cuda().to(self.data_dtype).flatten(0, 1),
140
- global_mask_values=mask_values.cuda().to(self.data_dtype).squeeze(),
141
  bboxes=[data_dict["bboxes"]],
142
- aspect_ratios=aspect_ratio.unsqueeze(0).cuda(),
143
  )
144
  return ret
145
 
@@ -273,11 +273,11 @@ class MultiRegionDataset(Dataset):
273
  aspect_ratio = model_inputs["aspect_ratio"]
274
 
275
  ret = dict(
276
- input_ids=input_ids.cuda().unsqueeze(0),
277
- attention_mask=attention_mask.cuda().to(self.data_dtype).unsqueeze(0),
278
- pixel_values=pixel_values.cuda().to(self.data_dtype).flatten(0, 1),
279
- global_mask_values=mask_values.cuda().to(self.data_dtype).squeeze(),
280
  bboxes=[data_dict["bboxes"]],
281
- aspect_ratios=aspect_ratio.unsqueeze(0).cuda(),
282
  )
283
  return ret
 
134
  aspect_ratio = model_inputs["aspect_ratio"]
135
 
136
  ret = dict(
137
+ input_ids=input_ids.unsqueeze(0),
138
+ attention_mask=attention_mask.to(self.data_dtype).unsqueeze(0),
139
+ pixel_values=pixel_values.to(self.data_dtype).flatten(0, 1),
140
+ global_mask_values=mask_values.to(self.data_dtype).squeeze(),
141
  bboxes=[data_dict["bboxes"]],
142
+ aspect_ratios=aspect_ratio.unsqueeze(0),
143
  )
144
  return ret
145
 
 
273
  aspect_ratio = model_inputs["aspect_ratio"]
274
 
275
  ret = dict(
276
+ input_ids=input_ids.unsqueeze(0),
277
+ attention_mask=attention_mask.to(self.data_dtype).unsqueeze(0),
278
+ pixel_values=pixel_values.to(self.data_dtype).flatten(0, 1),
279
+ global_mask_values=mask_values.to(self.data_dtype).squeeze(),
280
  bboxes=[data_dict["bboxes"]],
281
+ aspect_ratios=aspect_ratio.unsqueeze(0),
282
  )
283
  return ret