Upload model
Browse files- config.json +6 -1
- model.safetensors +2 -2
- modeling_mamba.py +75 -18
config.json
CHANGED
|
@@ -1,6 +1,10 @@
|
|
| 1 |
{
|
|
|
|
|
|
|
|
|
|
| 2 |
"auto_map": {
|
| 3 |
-
"AutoConfig": "configuration_mamba.MambaConfig"
|
|
|
|
| 4 |
},
|
| 5 |
"bias": false,
|
| 6 |
"conv_bias": true,
|
|
@@ -15,6 +19,7 @@
|
|
| 15 |
"model_type": "mamba",
|
| 16 |
"n_layer": 24,
|
| 17 |
"pad_vocab_size_multiple": 8,
|
|
|
|
| 18 |
"transformers_version": "4.37.2",
|
| 19 |
"vocab_size": 50280
|
| 20 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"MambaModelForCausalLM"
|
| 4 |
+
],
|
| 5 |
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_mamba.MambaConfig",
|
| 7 |
+
"AutoModelForCausalLM": "modeling_mamba.MambaModelForCausalLM"
|
| 8 |
},
|
| 9 |
"bias": false,
|
| 10 |
"conv_bias": true,
|
|
|
|
| 19 |
"model_type": "mamba",
|
| 20 |
"n_layer": 24,
|
| 21 |
"pad_vocab_size_multiple": 8,
|
| 22 |
+
"torch_dtype": "float32",
|
| 23 |
"transformers_version": "4.37.2",
|
| 24 |
"vocab_size": 50280
|
| 25 |
}
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:699ed6f59fb948186f449c5031e0dc659d504c90d7e018302aa1e190cdb40220
|
| 3 |
+
size 516567560
|
modeling_mamba.py
CHANGED
|
@@ -8,8 +8,7 @@ from torch.nn import CrossEntropyLoss
|
|
| 8 |
from transformers.modeling_outputs import (
|
| 9 |
BaseModelOutputWithPast,
|
| 10 |
CausalLMOutputWithPast,
|
| 11 |
-
|
| 12 |
-
SequenceClassifierOutput,
|
| 13 |
)
|
| 14 |
from transformers.modeling_utils import PreTrainedModel
|
| 15 |
|
|
@@ -320,9 +319,9 @@ class MambaModelForCausalLM(MambaPreTrainedModel):
|
|
| 320 |
**kwargs,
|
| 321 |
) -> CausalLMOutputWithPast:
|
| 322 |
batch_size = input_ids.shape[0]
|
|
|
|
| 323 |
sequence_length = input_ids.shape[1]
|
| 324 |
vocab_size = self.config.vocab_size
|
| 325 |
-
output_hidden_states = output_hidden_states or self.config.output_hidden_states
|
| 326 |
|
| 327 |
outputs = self.backbone(
|
| 328 |
input_ids=input_ids,
|
|
@@ -337,7 +336,7 @@ class MambaModelForCausalLM(MambaPreTrainedModel):
|
|
| 337 |
)
|
| 338 |
)
|
| 339 |
|
| 340 |
-
if labels:
|
| 341 |
shift_logits = logits[..., :-1, :].contiguous()
|
| 342 |
shift_labels = labels[..., 1:].contiguous()
|
| 343 |
loss_fct = CrossEntropyLoss()
|
|
@@ -364,17 +363,75 @@ class MambaModelForCausalLM(MambaPreTrainedModel):
|
|
| 364 |
}
|
| 365 |
|
| 366 |
|
| 367 |
-
class MambaModelForSequenceClassification(
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from transformers.modeling_outputs import (
|
| 9 |
BaseModelOutputWithPast,
|
| 10 |
CausalLMOutputWithPast,
|
| 11 |
+
SequenceClassifierOutputWithPast,
|
|
|
|
| 12 |
)
|
| 13 |
from transformers.modeling_utils import PreTrainedModel
|
| 14 |
|
|
|
|
| 319 |
**kwargs,
|
| 320 |
) -> CausalLMOutputWithPast:
|
| 321 |
batch_size = input_ids.shape[0]
|
| 322 |
+
output_hidden_states = output_hidden_states or self.config.output_hidden_states
|
| 323 |
sequence_length = input_ids.shape[1]
|
| 324 |
vocab_size = self.config.vocab_size
|
|
|
|
| 325 |
|
| 326 |
outputs = self.backbone(
|
| 327 |
input_ids=input_ids,
|
|
|
|
| 336 |
)
|
| 337 |
)
|
| 338 |
|
| 339 |
+
if labels is not None:
|
| 340 |
shift_logits = logits[..., :-1, :].contiguous()
|
| 341 |
shift_labels = labels[..., 1:].contiguous()
|
| 342 |
loss_fct = CrossEntropyLoss()
|
|
|
|
| 363 |
}
|
| 364 |
|
| 365 |
|
| 366 |
+
# class MambaModelForSequenceClassification(MambaModelForCausalLM):
|
| 367 |
+
# def __init__(
|
| 368 |
+
# self,
|
| 369 |
+
# config,
|
| 370 |
+
# id2label={0: "NEGATIVE", 1: "POSITIVE"},
|
| 371 |
+
# label2id={"NEGATIVE": 0, "POSITIVE": 1},
|
| 372 |
+
# num_labels=2,
|
| 373 |
+
# **kwargs,
|
| 374 |
+
# ):
|
| 375 |
+
# super().__init__(
|
| 376 |
+
# config,
|
| 377 |
+
# **kwargs,
|
| 378 |
+
# )
|
| 379 |
+
|
| 380 |
+
# self.id2label = id2label
|
| 381 |
+
# self.label2id = label2id
|
| 382 |
+
# self.num_labels = num_labels # TODO: config.num_labels
|
| 383 |
+
|
| 384 |
+
# self.score = nn.Linear(
|
| 385 |
+
# in_features=self.config.vocab_size,
|
| 386 |
+
# out_features=self.num_labels,
|
| 387 |
+
# bias=False,
|
| 388 |
+
# )
|
| 389 |
+
|
| 390 |
+
# def forward(
|
| 391 |
+
# self,
|
| 392 |
+
# input_ids: Optional[torch.Tensor] = None,
|
| 393 |
+
# labels: Optional[torch.Tensor] = None,
|
| 394 |
+
# output_hidden_states=False,
|
| 395 |
+
# **kwargs,
|
| 396 |
+
# ) -> SequenceClassifierOutputWithPast:
|
| 397 |
+
# batch_size = input_ids.shape[0]
|
| 398 |
+
# hidden_size = self.config.vocab_size
|
| 399 |
+
# hidden_states: Tuple[
|
| 400 |
+
# torch.Tensor[(batch_size, sequence_length, hidden_size)]
|
| 401 |
+
# ] = ()
|
| 402 |
+
# num_labels = self.num_labels # TODO: config.num_labels
|
| 403 |
+
# sequence_length = input_ids.shape[1]
|
| 404 |
+
# vocab_size = self.config.vocab_size
|
| 405 |
+
# output_hidden_states = output_hidden_states or self.config.output_hidden_states
|
| 406 |
+
|
| 407 |
+
# outputs = super().forward(
|
| 408 |
+
# input_ids=input_ids,
|
| 409 |
+
# labels=None,
|
| 410 |
+
# output_hidden_states=output_hidden_states,
|
| 411 |
+
# **kwargs,
|
| 412 |
+
# )
|
| 413 |
+
|
| 414 |
+
# last_hidden_state = outputs.logits
|
| 415 |
+
# assert last_hidden_state.shape == (
|
| 416 |
+
# batch_size,
|
| 417 |
+
# sequence_length,
|
| 418 |
+
# hidden_size,
|
| 419 |
+
# ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
|
| 420 |
+
# hidden_states += (last_hidden_state,)
|
| 421 |
+
|
| 422 |
+
# logits: torch.FloatTensor[batch_size, num_labels] = self.score(
|
| 423 |
+
# last_hidden_state[:, -1, :] # TODO: Check if this makes sense
|
| 424 |
+
# )
|
| 425 |
+
|
| 426 |
+
# if labels is not None:
|
| 427 |
+
# loss_fct = CrossEntropyLoss()
|
| 428 |
+
# loss = loss_fct(logits, labels)
|
| 429 |
+
|
| 430 |
+
# else:
|
| 431 |
+
# loss = None
|
| 432 |
+
|
| 433 |
+
# return SequenceClassifierOutputWithPast(
|
| 434 |
+
# loss=loss,
|
| 435 |
+
# logits=logits,
|
| 436 |
+
# hidden_states=hidden_states,
|
| 437 |
+
# )
|