Update modeling_rotary_indictrans.py
Browse files
modeling_rotary_indictrans.py
CHANGED
|
@@ -27,6 +27,7 @@ from einops import rearrange, repeat
|
|
| 27 |
from torch.amp import autocast
|
| 28 |
from torch import einsum
|
| 29 |
|
|
|
|
| 30 |
from transformers.modeling_utils import PreTrainedModel
|
| 31 |
from .configuration_rotary_indictrans import RotaryIndicTransConfig
|
| 32 |
|
|
@@ -1496,7 +1497,7 @@ class RotaryIndicTransModel(RotaryIndicTransPreTrainedModel):
|
|
| 1496 |
|
| 1497 |
|
| 1498 |
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->RotaryIndicTrans
|
| 1499 |
-
class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel):
|
| 1500 |
base_model_prefix = "model"
|
| 1501 |
_tied_weights_keys = None
|
| 1502 |
_label_smoothing = 0.0
|
|
|
|
| 27 |
from torch.amp import autocast
|
| 28 |
from torch import einsum
|
| 29 |
|
| 30 |
+
from transformers.generation import GenerationMixin
|
| 31 |
from transformers.modeling_utils import PreTrainedModel
|
| 32 |
from .configuration_rotary_indictrans import RotaryIndicTransConfig
|
| 33 |
|
|
|
|
| 1497 |
|
| 1498 |
|
| 1499 |
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->RotaryIndicTrans
|
| 1500 |
+
class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel, GenerationMixin):
|
| 1501 |
base_model_prefix = "model"
|
| 1502 |
_tied_weights_keys = None
|
| 1503 |
_label_smoothing = 0.0
|