chantera commited on
Commit
e5c54e7
·
1 Parent(s): 7665e4b
Files changed (1) hide show
  1. modeling_vila.py +3 -1
modeling_vila.py CHANGED
@@ -428,6 +428,8 @@ class VILAPretrainedModel(PreTrainedModel):
428
  # print("DEBUG", len(self.tokenizer.added_tokens_encoder.keys()), self.tokenizer.added_tokens_encoder.keys())
429
  NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
430
 
 
 
431
  # TODO: SENTINEL_TOKEN is not added, need to check with Zhijian
432
  self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
433
  # XGrammar tokenizer and grammar compiler
@@ -651,7 +653,7 @@ class VILAForCasualLM(VILAPretrainedModel):
651
  input = media_embeds[name].popleft()
652
  label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype)
653
  # print(f"{self.tokenizer.padding_side} [media] {k=} {pos=}, {self.tokenizer.batch_decode(input_ids[k][pos:pos+1])}"); python_input()
654
- elif input_ids[k][pos].item() in (self.tokenizer.pad_token_id, self.tokenizer.eos_token_id):
655
  end = pos + 1
656
  pos = end
657
  # print(f"[skip PAD/EOS] {k=} {pos=}, {self.tokenizer.batch_decode(input_ids[k][pos:end])}"); python_input()
 
428
  # print("DEBUG", len(self.tokenizer.added_tokens_encoder.keys()), self.tokenizer.added_tokens_encoder.keys())
429
  NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
430
 
431
+ self.skip_pad_tokens = True
432
+
433
  # TODO: SENTINEL_TOKEN is not added, need to check with Zhijian
434
  self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
435
  # XGrammar tokenizer and grammar compiler
 
653
  input = media_embeds[name].popleft()
654
  label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype)
655
  # print(f"{self.tokenizer.padding_side} [media] {k=} {pos=}, {self.tokenizer.batch_decode(input_ids[k][pos:pos+1])}"); python_input()
656
+ elif self.skip_pad_tokens and input_ids[k][pos].item() in (self.tokenizer.pad_token_id, self.tokenizer.eos_token_id):
657
  end = pos + 1
658
  pos = end
659
  # print(f"[skip PAD/EOS] {k=} {pos=}, {self.tokenizer.batch_decode(input_ids[k][pos:end])}"); python_input()