Fix.
Browse files- modeling_vila.py +3 -1
modeling_vila.py
CHANGED
|
@@ -428,6 +428,8 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
| 428 |
# print("DEBUG", len(self.tokenizer.added_tokens_encoder.keys()), self.tokenizer.added_tokens_encoder.keys())
|
| 429 |
NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
|
| 430 |
|
|
|
|
|
|
|
| 431 |
# TODO: SENTINEL_TOKEN is not added, need to check with Zhijian
|
| 432 |
self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
|
| 433 |
# XGrammar tokenizer and grammar compiler
|
|
@@ -651,7 +653,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
| 651 |
input = media_embeds[name].popleft()
|
| 652 |
label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype)
|
| 653 |
# print(f"{self.tokenizer.padding_side} [media] {k=} {pos=}, {self.tokenizer.batch_decode(input_ids[k][pos:pos+1])}"); python_input()
|
| 654 |
-
elif input_ids[k][pos].item() in (self.tokenizer.pad_token_id, self.tokenizer.eos_token_id):
|
| 655 |
end = pos + 1
|
| 656 |
pos = end
|
| 657 |
# print(f"[skip PAD/EOS] {k=} {pos=}, {self.tokenizer.batch_decode(input_ids[k][pos:end])}"); python_input()
|
|
|
|
| 428 |
# print("DEBUG", len(self.tokenizer.added_tokens_encoder.keys()), self.tokenizer.added_tokens_encoder.keys())
|
| 429 |
NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
|
| 430 |
|
| 431 |
+
self.skip_pad_tokens = True
|
| 432 |
+
|
| 433 |
# TODO: SENTINEL_TOKEN is not added, need to check with Zhijian
|
| 434 |
self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
|
| 435 |
# XGrammar tokenizer and grammar compiler
|
|
|
|
| 653 |
input = media_embeds[name].popleft()
|
| 654 |
label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype)
|
| 655 |
# print(f"{self.tokenizer.padding_side} [media] {k=} {pos=}, {self.tokenizer.batch_decode(input_ids[k][pos:pos+1])}"); python_input()
|
| 656 |
+
elif self.skip_pad_tokens and input_ids[k][pos].item() in (self.tokenizer.pad_token_id, self.tokenizer.eos_token_id):
|
| 657 |
end = pos + 1
|
| 658 |
pos = end
|
| 659 |
# print(f"[skip PAD/EOS] {k=} {pos=}, {self.tokenizer.batch_decode(input_ids[k][pos:end])}"); python_input()
|