Remove deprecated import, type hint
Browse files- modeling_qwen3_shared_moe.py +2 -10
modeling_qwen3_shared_moe.py
CHANGED
|
@@ -20,17 +20,12 @@ import torch
|
|
| 20 |
from torch import nn
|
| 21 |
import torch.nn.functional as F
|
| 22 |
|
| 23 |
-
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 24 |
from transformers.modeling_outputs import (
|
| 25 |
MoeCausalLMOutputWithPast,
|
| 26 |
MoeModelOutputWithPast,
|
| 27 |
)
|
| 28 |
from transformers.activations import ACT2FN
|
| 29 |
-
from transformers.
|
| 30 |
-
from transformers.utils import (
|
| 31 |
-
LossKwargs,
|
| 32 |
-
logging,
|
| 33 |
-
)
|
| 34 |
from transformers.models.mixtral.modeling_mixtral import (
|
| 35 |
load_balancing_loss_func,
|
| 36 |
)
|
|
@@ -132,9 +127,6 @@ class Qwen3SharedMoeModel(Qwen3MoeModel):
|
|
| 132 |
)
|
| 133 |
|
| 134 |
|
| 135 |
-
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
| 136 |
-
|
| 137 |
-
|
| 138 |
class Qwen3SharedMoeForCausalLM(Qwen3MoeForCausalLM):
|
| 139 |
config_class = Qwen3SharedMoeConfig
|
| 140 |
|
|
@@ -157,7 +149,7 @@ class Qwen3SharedMoeForCausalLM(Qwen3MoeForCausalLM):
|
|
| 157 |
output_router_logits: Optional[bool] = None,
|
| 158 |
cache_position: Optional[torch.LongTensor] = None,
|
| 159 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 160 |
-
**kwargs
|
| 161 |
) -> MoeCausalLMOutputWithPast:
|
| 162 |
r"""
|
| 163 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
|
| 20 |
from torch import nn
|
| 21 |
import torch.nn.functional as F
|
| 22 |
|
|
|
|
| 23 |
from transformers.modeling_outputs import (
|
| 24 |
MoeCausalLMOutputWithPast,
|
| 25 |
MoeModelOutputWithPast,
|
| 26 |
)
|
| 27 |
from transformers.activations import ACT2FN
|
| 28 |
+
from transformers.utils import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
from transformers.models.mixtral.modeling_mixtral import (
|
| 30 |
load_balancing_loss_func,
|
| 31 |
)
|
|
|
|
| 127 |
)
|
| 128 |
|
| 129 |
|
|
|
|
|
|
|
|
|
|
| 130 |
class Qwen3SharedMoeForCausalLM(Qwen3MoeForCausalLM):
|
| 131 |
config_class = Qwen3SharedMoeConfig
|
| 132 |
|
|
|
|
| 149 |
output_router_logits: Optional[bool] = None,
|
| 150 |
cache_position: Optional[torch.LongTensor] = None,
|
| 151 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 152 |
+
**kwargs,
|
| 153 |
) -> MoeCausalLMOutputWithPast:
|
| 154 |
r"""
|
| 155 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|