x54-729
commited on
Commit
·
38f564d
1
Parent(s):
eb85986
support flash attn 2
Browse files- configuration_internlm.py +4 -7
- modeling_internlm.py +186 -18
configuration_internlm.py
CHANGED
|
@@ -32,11 +32,8 @@ class InternLMConfig(PretrainedConfig):
|
|
| 32 |
This is the configuration class to store the configuration of a [`InternLMModel`]. It is used to instantiate
|
| 33 |
an InternLM model according to the specified arguments, defining the model architecture. Instantiating a
|
| 34 |
configuration with the defaults will yield a similar configuration to that of the InternLM-7B.
|
| 35 |
-
|
| 36 |
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 37 |
documentation from [`PretrainedConfig`] for more information.
|
| 38 |
-
|
| 39 |
-
|
| 40 |
Args:
|
| 41 |
vocab_size (`int`, *optional*, defaults to 32000):
|
| 42 |
Vocabulary size of the InternLM model. Defines the number of different tokens that can be represented by the
|
|
@@ -64,16 +61,12 @@ class InternLMConfig(PretrainedConfig):
|
|
| 64 |
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
| 65 |
Whether to tie weight embeddings
|
| 66 |
Example:
|
| 67 |
-
|
| 68 |
```python
|
| 69 |
>>> from transformers import InternLMModel, InternLMConfig
|
| 70 |
-
|
| 71 |
>>> # Initializing a InternLM internlm-7b style configuration
|
| 72 |
>>> configuration = InternLMConfig()
|
| 73 |
-
|
| 74 |
>>> # Initializing a model from the internlm-7b style configuration
|
| 75 |
>>> model = InternLMModel(configuration)
|
| 76 |
-
|
| 77 |
>>> # Accessing the model configuration
|
| 78 |
>>> configuration = model.config
|
| 79 |
```"""
|
|
@@ -98,6 +91,7 @@ class InternLMConfig(PretrainedConfig):
|
|
| 98 |
tie_word_embeddings=False,
|
| 99 |
bias=True,
|
| 100 |
rotary={"base": 10000, "type": "dynamic"}, # pylint: disable=W0102
|
|
|
|
| 101 |
**kwargs,
|
| 102 |
):
|
| 103 |
self.vocab_size = vocab_size
|
|
@@ -112,6 +106,9 @@ class InternLMConfig(PretrainedConfig):
|
|
| 112 |
self.use_cache = use_cache
|
| 113 |
self.bias = bias
|
| 114 |
self.rotary = rotary
|
|
|
|
|
|
|
|
|
|
| 115 |
super().__init__(
|
| 116 |
pad_token_id=pad_token_id,
|
| 117 |
bos_token_id=bos_token_id,
|
|
|
|
| 32 |
This is the configuration class to store the configuration of a [`InternLMModel`]. It is used to instantiate
|
| 33 |
an InternLM model according to the specified arguments, defining the model architecture. Instantiating a
|
| 34 |
configuration with the defaults will yield a similar configuration to that of the InternLM-7B.
|
|
|
|
| 35 |
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 36 |
documentation from [`PretrainedConfig`] for more information.
|
|
|
|
|
|
|
| 37 |
Args:
|
| 38 |
vocab_size (`int`, *optional*, defaults to 32000):
|
| 39 |
Vocabulary size of the InternLM model. Defines the number of different tokens that can be represented by the
|
|
|
|
| 61 |
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
| 62 |
Whether to tie weight embeddings
|
| 63 |
Example:
|
|
|
|
| 64 |
```python
|
| 65 |
>>> from transformers import InternLMModel, InternLMConfig
|
|
|
|
| 66 |
>>> # Initializing a InternLM internlm-7b style configuration
|
| 67 |
>>> configuration = InternLMConfig()
|
|
|
|
| 68 |
>>> # Initializing a model from the internlm-7b style configuration
|
| 69 |
>>> model = InternLMModel(configuration)
|
|
|
|
| 70 |
>>> # Accessing the model configuration
|
| 71 |
>>> configuration = model.config
|
| 72 |
```"""
|
|
|
|
| 91 |
tie_word_embeddings=False,
|
| 92 |
bias=True,
|
| 93 |
rotary={"base": 10000, "type": "dynamic"}, # pylint: disable=W0102
|
| 94 |
+
attn_implementation="eager",
|
| 95 |
**kwargs,
|
| 96 |
):
|
| 97 |
self.vocab_size = vocab_size
|
|
|
|
| 106 |
self.use_cache = use_cache
|
| 107 |
self.bias = bias
|
| 108 |
self.rotary = rotary
|
| 109 |
+
self.attn_implementation = attn_implementation
|
| 110 |
+
if self.attn_implementation is None:
|
| 111 |
+
self.attn_implementation = "eager"
|
| 112 |
super().__init__(
|
| 113 |
pad_token_id=pad_token_id,
|
| 114 |
bos_token_id=bos_token_id,
|
modeling_internlm.py
CHANGED
|
@@ -1,10 +1,6 @@
|
|
| 1 |
-
#
|
| 2 |
-
# Copyright (c) InternLM. All rights reserved.
|
| 3 |
#
|
| 4 |
-
# This code is based on
|
| 5 |
-
# and OPT implementations in this library. It has been modified from its
|
| 6 |
-
# original forms to accommodate minor architectural differences compared
|
| 7 |
-
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
#
|
| 9 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
# you may not use this file except in compliance with the License.
|
|
@@ -52,6 +48,17 @@ logger = logging.get_logger(__name__)
|
|
| 52 |
|
| 53 |
_CONFIG_FOR_DOC = "InternLMConfig"
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
| 57 |
def _make_causal_mask(
|
|
@@ -85,7 +92,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|
| 85 |
|
| 86 |
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
| 87 |
|
| 88 |
-
|
| 89 |
class InternLMRMSNorm(nn.Module):
|
| 90 |
"""RMSNorm implemention."""
|
| 91 |
|
|
@@ -228,8 +234,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
|
| 228 |
k_sin = sin[position_ids].unsqueeze(1).expand(k.shape)
|
| 229 |
k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
|
| 230 |
else:
|
| 231 |
-
cos = cos[position_ids].unsqueeze(1)
|
| 232 |
-
sin = sin[position_ids].unsqueeze(1)
|
| 233 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 234 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 235 |
return q_embed, k_embed
|
|
@@ -273,6 +279,7 @@ class InternLMAttention(nn.Module):
|
|
| 273 |
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
| 274 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
|
| 275 |
self.rotary_emb = self._init_rope()
|
|
|
|
| 276 |
|
| 277 |
def _init_rope(self):
|
| 278 |
if self.config.rotary["type"] == "origin":
|
|
@@ -356,13 +363,167 @@ class InternLMAttention(nn.Module):
|
|
| 356 |
attn_weights = None
|
| 357 |
|
| 358 |
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
|
| 361 |
class InternLMDecoderLayer(nn.Module):
|
| 362 |
def __init__(self, config: InternLMConfig):
|
| 363 |
super().__init__()
|
| 364 |
self.hidden_size = config.hidden_size
|
| 365 |
-
|
|
|
|
|
|
|
| 366 |
self.mlp = InternLMMLP(
|
| 367 |
hidden_size=self.hidden_size,
|
| 368 |
intermediate_size=config.intermediate_size,
|
|
@@ -539,8 +700,10 @@ class InternLMModel(InternLMPreTrainedModel):
|
|
| 539 |
super().__init__(config)
|
| 540 |
self.padding_idx = config.pad_token_id
|
| 541 |
self.vocab_size = config.vocab_size
|
|
|
|
| 542 |
|
| 543 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
|
|
|
| 544 |
self.layers = nn.ModuleList([InternLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 545 |
self.norm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 546 |
|
|
@@ -627,14 +790,16 @@ class InternLMModel(InternLMPreTrainedModel):
|
|
| 627 |
|
| 628 |
if inputs_embeds is None:
|
| 629 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 634 |
)
|
| 635 |
-
attention_mask = self._prepare_decoder_attention_mask(
|
| 636 |
-
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
| 637 |
-
)
|
| 638 |
|
| 639 |
hidden_states = inputs_embeds
|
| 640 |
|
|
@@ -759,6 +924,7 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
| 759 |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 760 |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 761 |
Returns:
|
|
|
|
| 762 |
Example:
|
| 763 |
```python
|
| 764 |
>>> from transformers import AutoTokenizer, InternLMForCausalLM
|
|
@@ -770,7 +936,9 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
| 770 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 771 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 772 |
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
| 773 |
-
```
|
|
|
|
|
|
|
| 774 |
|
| 775 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 776 |
output_hidden_states = (
|
|
|
|
| 1 |
+
# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
|
|
|
|
| 2 |
#
|
| 3 |
+
# This code is based on transformers/src/transformers/models/llama/modeling_llama.py
|
|
|
|
|
|
|
|
|
|
| 4 |
#
|
| 5 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
# you may not use this file except in compliance with the License.
|
|
|
|
| 48 |
|
| 49 |
_CONFIG_FOR_DOC = "InternLMConfig"
|
| 50 |
|
| 51 |
+
def _get_unpad_data(attention_mask):
|
| 52 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 53 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 54 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 55 |
+
cu_seqlens = nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
| 56 |
+
return (
|
| 57 |
+
indices,
|
| 58 |
+
cu_seqlens,
|
| 59 |
+
max_seqlen_in_batch,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
|
| 63 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
| 64 |
def _make_causal_mask(
|
|
|
|
| 92 |
|
| 93 |
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
| 94 |
|
|
|
|
| 95 |
class InternLMRMSNorm(nn.Module):
|
| 96 |
"""RMSNorm implemention."""
|
| 97 |
|
|
|
|
| 234 |
k_sin = sin[position_ids].unsqueeze(1).expand(k.shape)
|
| 235 |
k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
|
| 236 |
else:
|
| 237 |
+
cos = cos[position_ids].unsqueeze(1)
|
| 238 |
+
sin = sin[position_ids].unsqueeze(1)
|
| 239 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 240 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 241 |
return q_embed, k_embed
|
|
|
|
| 279 |
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
| 280 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
|
| 281 |
self.rotary_emb = self._init_rope()
|
| 282 |
+
self.is_causal = True
|
| 283 |
|
| 284 |
def _init_rope(self):
|
| 285 |
if self.config.rotary["type"] == "origin":
|
|
|
|
| 363 |
attn_weights = None
|
| 364 |
|
| 365 |
return attn_output, attn_weights, past_key_value
|
| 366 |
+
|
| 367 |
+
class InternLMFlashAttention2(InternLMAttention):
|
| 368 |
+
"""
|
| 369 |
+
InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
|
| 370 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 371 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
| 372 |
+
"""
|
| 373 |
|
| 374 |
+
def forward(
|
| 375 |
+
self,
|
| 376 |
+
hidden_states: torch.Tensor,
|
| 377 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 378 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 379 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 380 |
+
output_attentions: bool = False,
|
| 381 |
+
use_cache: bool = False,
|
| 382 |
+
**kwargs,
|
| 383 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 384 |
+
# InternLM2FlashAttention2 attention does not support output_attentions
|
| 385 |
+
bsz, q_len, _ = hidden_states.size()
|
| 386 |
+
|
| 387 |
+
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 388 |
+
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 389 |
+
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 390 |
+
|
| 391 |
+
if past_key_value is not None:
|
| 392 |
+
# reuse k, v, self_attention
|
| 393 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 394 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 395 |
+
|
| 396 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
| 397 |
+
|
| 398 |
+
kv_seq_len = key_states.shape[-2]
|
| 399 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 400 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 401 |
+
|
| 402 |
+
query_states = query_states.transpose(1, 2)
|
| 403 |
+
key_states = key_states.transpose(1, 2)
|
| 404 |
+
value_states = value_states.transpose(1, 2)
|
| 405 |
+
|
| 406 |
+
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
| 407 |
+
|
| 408 |
+
attn_output = self._flash_attention_forward(
|
| 409 |
+
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
| 410 |
+
)
|
| 411 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
| 412 |
+
attn_output = self.o_proj(attn_output)
|
| 413 |
+
|
| 414 |
+
if not output_attentions:
|
| 415 |
+
attn_weights = None
|
| 416 |
+
|
| 417 |
+
return attn_output, attn_weights, past_key_value
|
| 418 |
+
|
| 419 |
+
def _flash_attention_forward(
|
| 420 |
+
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
| 421 |
+
):
|
| 422 |
+
"""
|
| 423 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
| 424 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
| 425 |
+
|
| 426 |
+
Args:
|
| 427 |
+
query_states (`torch.Tensor`):
|
| 428 |
+
Input query states to be passed to Flash Attention API
|
| 429 |
+
key_states (`torch.Tensor`):
|
| 430 |
+
Input key states to be passed to Flash Attention API
|
| 431 |
+
value_states (`torch.Tensor`):
|
| 432 |
+
Input value states to be passed to Flash Attention API
|
| 433 |
+
attention_mask (`torch.Tensor`):
|
| 434 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
| 435 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
| 436 |
+
dropout (`int`, *optional*):
|
| 437 |
+
Attention dropout
|
| 438 |
+
softmax_scale (`float`, *optional*):
|
| 439 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
| 440 |
+
"""
|
| 441 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 442 |
+
from flash_attn.bert_padding import pad_input
|
| 443 |
+
# Contains at least one padding token in the sequence
|
| 444 |
+
causal = self.is_causal and query_length != 1
|
| 445 |
+
if attention_mask is not None:
|
| 446 |
+
batch_size = query_states.shape[0]
|
| 447 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
| 448 |
+
query_states, key_states, value_states, attention_mask, query_length
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 452 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 453 |
+
|
| 454 |
+
attn_output_unpad = flash_attn_varlen_func(
|
| 455 |
+
query_states,
|
| 456 |
+
key_states,
|
| 457 |
+
value_states,
|
| 458 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 459 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 460 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
| 461 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
| 462 |
+
dropout_p=dropout,
|
| 463 |
+
softmax_scale=softmax_scale,
|
| 464 |
+
causal=causal,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
| 468 |
+
else:
|
| 469 |
+
attn_output = flash_attn_func(
|
| 470 |
+
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
return attn_output
|
| 474 |
+
|
| 475 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
| 476 |
+
from flash_attn.bert_padding import index_first_axis, unpad_input
|
| 477 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
| 478 |
+
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
| 479 |
+
|
| 480 |
+
key_layer = index_first_axis(
|
| 481 |
+
key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
| 482 |
+
)
|
| 483 |
+
value_layer = index_first_axis(
|
| 484 |
+
value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
if query_length == kv_seq_len:
|
| 488 |
+
query_layer = index_first_axis(
|
| 489 |
+
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
| 490 |
+
)
|
| 491 |
+
cu_seqlens_q = cu_seqlens_k
|
| 492 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
| 493 |
+
indices_q = indices_k
|
| 494 |
+
elif query_length == 1:
|
| 495 |
+
max_seqlen_in_batch_q = 1
|
| 496 |
+
cu_seqlens_q = torch.arange(
|
| 497 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
| 498 |
+
) # There is a memcpy here, that is very bad.
|
| 499 |
+
indices_q = cu_seqlens_q[:-1]
|
| 500 |
+
query_layer = query_layer.squeeze(1)
|
| 501 |
+
else:
|
| 502 |
+
# The -q_len: slice assumes left padding.
|
| 503 |
+
attention_mask = attention_mask[:, -query_length:]
|
| 504 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
| 505 |
+
|
| 506 |
+
return (
|
| 507 |
+
query_layer,
|
| 508 |
+
key_layer,
|
| 509 |
+
value_layer,
|
| 510 |
+
indices_q.to(torch.int64),
|
| 511 |
+
(cu_seqlens_q, cu_seqlens_k),
|
| 512 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
INTERNLM_ATTENTION_CLASSES = {
|
| 516 |
+
"eager": InternLMAttention,
|
| 517 |
+
"flash_attention_2": InternLMFlashAttention2,
|
| 518 |
+
}
|
| 519 |
|
| 520 |
class InternLMDecoderLayer(nn.Module):
|
| 521 |
def __init__(self, config: InternLMConfig):
|
| 522 |
super().__init__()
|
| 523 |
self.hidden_size = config.hidden_size
|
| 524 |
+
|
| 525 |
+
self.self_attn = INTERNLM_ATTENTION_CLASSES[config.attn_implementation](config=config)
|
| 526 |
+
|
| 527 |
self.mlp = InternLMMLP(
|
| 528 |
hidden_size=self.hidden_size,
|
| 529 |
intermediate_size=config.intermediate_size,
|
|
|
|
| 700 |
super().__init__(config)
|
| 701 |
self.padding_idx = config.pad_token_id
|
| 702 |
self.vocab_size = config.vocab_size
|
| 703 |
+
self.config = config
|
| 704 |
|
| 705 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 706 |
+
|
| 707 |
self.layers = nn.ModuleList([InternLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 708 |
self.norm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 709 |
|
|
|
|
| 790 |
|
| 791 |
if inputs_embeds is None:
|
| 792 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 793 |
+
if self.config.attn_implementation == "flash_attention_2":
|
| 794 |
+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 795 |
+
else:
|
| 796 |
+
if attention_mask is None:
|
| 797 |
+
attention_mask = torch.ones(
|
| 798 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
| 799 |
+
)
|
| 800 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
| 801 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
| 802 |
)
|
|
|
|
|
|
|
|
|
|
| 803 |
|
| 804 |
hidden_states = inputs_embeds
|
| 805 |
|
|
|
|
| 924 |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 925 |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 926 |
Returns:
|
| 927 |
+
|
| 928 |
Example:
|
| 929 |
```python
|
| 930 |
>>> from transformers import AutoTokenizer, InternLMForCausalLM
|
|
|
|
| 936 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 937 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 938 |
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
| 939 |
+
```
|
| 940 |
+
|
| 941 |
+
"""
|
| 942 |
|
| 943 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 944 |
output_hidden_states = (
|