xueyunlong commited on
Commit
0bfee9f
·
verified ·
1 Parent(s): 6ab9edc

Upload 11 files

Browse files
LMConfig.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ class LMConfig(PretrainedConfig):
6
+ model_type = "nanochat"
7
+
8
+ def __init__(
9
+ self,
10
+ dim: int = 512,
11
+ n_layers: int = 8,
12
+ tie_word_embeddings: bool = True,
13
+ ###########################################
14
+ attention:str='GQA',
15
+ #GQA
16
+ n_heads: int = 14,
17
+ n_kv_heads: int = 2,
18
+ #MLA
19
+ q_lora_rank: int=0,
20
+ kv_lora_rank: int=512,
21
+ qk_nope_head_dim: int=64,
22
+ qk_rope_head_dim:int=64,
23
+ v_head_dim:int=64,
24
+ #############################################
25
+ vocab_size: int = 151650,
26
+ # vocab_size: int = 6400,
27
+ hidden_dim: int = None,
28
+ multiple_of: int = 64,
29
+ norm_eps: float = 1e-5,
30
+ max_seq_len: int = 8192,
31
+ rope_theta: int = 1e6,
32
+ dropout: float = 0.0,
33
+ flash_attn: bool = True,
34
+ ####################################################
35
+ # Here are the specific configurations of MOE
36
+ # When use_moe is false, the following is invalid
37
+ ####################################################
38
+ use_moe: bool = False,
39
+ ####################################################
40
+ num_experts_per_tok: int = 2,
41
+ n_routed_experts: int = 4,
42
+ n_shared_experts: bool = True,
43
+ scoring_func: str = 'softmax',
44
+ aux_loss_alpha: float = 0.1,
45
+ seq_aux: bool = True,
46
+ norm_topk_prob: bool = True,
47
+ **kwargs,
48
+ ):
49
+ self.dim = dim
50
+ self.n_layers = n_layers
51
+ self.tie_word_embeddings = tie_word_embeddings
52
+ self.vocab_size = vocab_size
53
+ self.hidden_dim = hidden_dim
54
+ self.multiple_of = multiple_of
55
+ self.norm_eps = norm_eps
56
+ self.max_seq_len = max_seq_len
57
+ self.rope_theta = rope_theta
58
+ self.dropout = dropout
59
+ self.flash_attn = flash_attn
60
+ #####################################################
61
+ self.attention=attention
62
+ #GQA
63
+ self.n_heads = n_heads
64
+ self.n_kv_heads = n_kv_heads
65
+ #MLA
66
+ self.q_lora_rank = q_lora_rank
67
+ self.kv_lora_rank = kv_lora_rank
68
+ self.qk_nope_head_dim = qk_nope_head_dim
69
+ self.qk_rope_head_dim = qk_rope_head_dim
70
+ self.v_head_dim = v_head_dim
71
+ ####################################################
72
+ # Here are the specific configurations of MOE
73
+ # When use_moe is false, the following is invalid
74
+ ####################################################
75
+ self.use_moe = use_moe
76
+ self.num_experts_per_tok = num_experts_per_tok # 每个token选择的专家数量
77
+ self.n_routed_experts = n_routed_experts # 总的专家数量
78
+ self.n_shared_experts = n_shared_experts # 共享专家
79
+ self.scoring_func = scoring_func # 评分函数,默认为'softmax'
80
+ self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
81
+ self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
82
+ self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
83
+ super().__init__(**kwargs)
added_tokens.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</tool_call>": 151658,
3
+ "<tool_call>": 151657,
4
+ "<|box_end|>": 151649,
5
+ "<|box_start|>": 151648,
6
+ "<|endoftext|>": 151643,
7
+ "<|file_sep|>": 151664,
8
+ "<|fim_middle|>": 151660,
9
+ "<|fim_pad|>": 151662,
10
+ "<|fim_prefix|>": 151659,
11
+ "<|fim_suffix|>": 151661,
12
+ "<|im_end|>": 151645,
13
+ "<|im_start|>": 151644,
14
+ "<|image_pad|>": 151655,
15
+ "<|object_ref_end|>": 151647,
16
+ "<|object_ref_start|>": 151646,
17
+ "<|quad_end|>": 151651,
18
+ "<|quad_start|>": 151650,
19
+ "<|repo_name|>": 151663,
20
+ "<|video_pad|>": 151656,
21
+ "<|vision_end|>": 151653,
22
+ "<|vision_pad|>": 151654,
23
+ "<|vision_start|>": 151652
24
+ }
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "out_sft",
3
+ "architectures": [
4
+ "NanoChatLM"
5
+ ],
6
+ "attention": "GQA",
7
+ "auto_map": {
8
+ "AutoConfig": "LMConfig.LMConfig",
9
+ "AutoModel": "model.NanoChatLM",
10
+ "AutoModelForCausalLM": "model.NanoChatLM"
11
+ },
12
+ "aux_loss_alpha": 0.1,
13
+ "dim": 896,
14
+ "dropout": 0.0,
15
+ "flash_attn": true,
16
+ "hidden_dim": 2432,
17
+ "kv_lora_rank": 512,
18
+ "max_seq_len": 512,
19
+ "model_type": "nanochat",
20
+ "multiple_of": 64,
21
+ "n_heads": 14,
22
+ "n_kv_heads": 2,
23
+ "n_layers": 24,
24
+ "n_routed_experts": 4,
25
+ "n_shared_experts": true,
26
+ "norm_eps": 1e-05,
27
+ "norm_topk_prob": true,
28
+ "num_experts_per_tok": 2,
29
+ "q_lora_rank": 0,
30
+ "qk_nope_head_dim": 64,
31
+ "qk_rope_head_dim": 64,
32
+ "rope_theta": 1000000.0,
33
+ "scoring_func": "softmax",
34
+ "seq_aux": true,
35
+ "torch_dtype": "float32",
36
+ "transformers_version": "4.42.4",
37
+ "use_moe": false,
38
+ "v_head_dim": 64,
39
+ "vocab_size": 151650
40
+ }
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.42.4"
4
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import struct
3
+ import inspect
4
+ import time
5
+
6
+ from .LMConfig import LMConfig
7
+ from typing import Any, Optional, Tuple, List
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from transformers import PreTrainedModel
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast
14
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
15
+
16
+
17
+ class RMSNorm(torch.nn.Module):
18
+ def __init__(self, dim: int, eps: float):
19
+ super().__init__()
20
+ self.eps = eps
21
+ self.weight = nn.Parameter(torch.ones(dim))
22
+
23
+ def forward(self, x):
24
+ return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) ).type_as(x)
25
+
26
+
27
+ def precompute_pos_cis(dim: int, end: int, theta: float = 1e4):
28
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
29
+ t = torch.arange(end, device=freqs.device) # type: ignore
30
+ freqs = torch.outer(t, freqs).float() # type: ignore
31
+ pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
32
+ return pos_cis
33
+
34
+
35
+ def apply_rotary_emb(xq, xk, pos_cis):
36
+ def unite_shape(pos_cis, x):
37
+ ndim = x.ndim
38
+ assert 0 <= 1 < ndim
39
+ assert pos_cis.shape == (x.shape[1], x.shape[-1])
40
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
41
+ return pos_cis.view(*shape)
42
+
43
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
44
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
45
+ pos_cis = unite_shape(pos_cis, xq_)
46
+ xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
47
+ xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
48
+ return xq_out.type_as(xq), xk_out.type_as(xk)
49
+
50
+
51
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
52
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
53
+ bs, slen, n_kv_heads, head_dim = x.shape
54
+ if n_rep == 1:
55
+ return x
56
+ return (
57
+ x[:, :, :, None, :]
58
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
59
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
60
+ )
61
+
62
+
63
+ class GQA(nn.Module):
64
+ def __init__(self, args: LMConfig):
65
+ super().__init__()
66
+ self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
67
+ assert args.n_heads % self.n_kv_heads == 0
68
+ self.n_local_heads = args.n_heads
69
+ self.n_local_kv_heads = self.n_kv_heads
70
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
71
+ self.head_dim = args.dim // args.n_heads
72
+ self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=True)
73
+ self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=True)
74
+ self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=True)
75
+ self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
76
+ self.attn_dropout = nn.Dropout(args.dropout)
77
+ self.resid_dropout = nn.Dropout(args.dropout)
78
+ self.dropout = args.dropout
79
+ self.softmax_scale=self.head_dim ** -0.5
80
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
81
+ # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
82
+ mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
83
+ mask = torch.triu(mask, diagonal=1)
84
+ self.register_buffer("mask", mask, persistent=False)
85
+
86
+ def forward(self,
87
+ x: torch.Tensor,
88
+ pos_cis: torch.Tensor,
89
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
90
+ use_cache=False):
91
+ bsz, seq_len, _ = x.shape
92
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
93
+ xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
94
+ xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
95
+ xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
96
+
97
+ xq, xk = apply_rotary_emb(xq, xk, pos_cis)
98
+ # kv_cache实现
99
+ if past_key_value is not None:
100
+ xk = torch.cat([past_key_value[0], xk], dim=1)
101
+ xv = torch.cat([past_key_value[1], xv], dim=1)
102
+ past_kv = (xk, xv) if use_cache else None
103
+
104
+ xq, xk, xv = (
105
+ xq.transpose(1, 2),
106
+ repeat_kv(xk, self.n_rep).transpose(1, 2),
107
+ repeat_kv(xv, self.n_rep).transpose(1, 2)
108
+ )
109
+ if self.flash and seq_len != 1:
110
+ dropout_p = self.dropout if self.training else 0.0
111
+ output = F.scaled_dot_product_attention(
112
+ xq, xk, xv,
113
+ attn_mask=None,
114
+ dropout_p=dropout_p,
115
+ is_causal=True
116
+ )
117
+ # output =flash_attn_func(
118
+ # xq, xk, xv,
119
+ # dropout_p=dropout_p,
120
+ # softmax_scale=self.softmax_scale,
121
+ # causal=True
122
+ # )
123
+ else:
124
+ scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
125
+ scores += self.mask[:, :, :seq_len, :seq_len]
126
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
127
+ scores = self.attn_dropout(scores)
128
+ output = scores @ xv
129
+
130
+ output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
131
+ output = self.resid_dropout(self.wo(output))
132
+ return output, past_kv
133
+
134
+ class MLA(nn.Module):
135
+ """
136
+ Multi-Headed Attention Layer (MLA).
137
+
138
+ Attributes:
139
+ dim (int): Dimensionality of the input features.
140
+ n_heads (int): Number of attention heads.
141
+ n_local_heads (int): Number of local attention heads for distributed systems.
142
+ q_lora_rank (int): Rank for low-rank query projection.
143
+ kv_lora_rank (int): Rank for low-rank key/value projection.
144
+ qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
145
+ qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
146
+ qk_head_dim (int): Total dimensionality of query/key projections.
147
+ v_head_dim (int): Dimensionality of value projections.
148
+ softmax_scale (float): Scaling factor for softmax in attention computation.
149
+ """
150
+ def __init__(self, args: LMConfig):
151
+ super().__init__()
152
+ self.dim = args.dim
153
+ self.n_heads = args.n_heads
154
+ self.n_local_heads = args.n_heads
155
+ self.q_lora_rank = args.q_lora_rank
156
+ self.kv_lora_rank = args.kv_lora_rank
157
+ self.qk_nope_head_dim = args.qk_nope_head_dim
158
+ self.qk_rope_head_dim = args.qk_rope_head_dim
159
+ self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
160
+ self.v_head_dim = args.v_head_dim
161
+
162
+ if self.q_lora_rank == 0:
163
+ self.wq = nn.Linear(self.dim, self.n_heads * self.qk_head_dim,bias=False)
164
+ else:
165
+ self.wq_a = nn.Linear(self.dim, self.q_lora_rank)
166
+ self.q_norm = RMSNorm(self.q_lora_rank,eps=args.norm_eps)
167
+ self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
168
+ self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
169
+ self.kv_norm = RMSNorm(self.kv_lora_rank,eps=args.norm_eps)
170
+ self.wkv_b = nn.Linear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
171
+ self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim)
172
+ self.softmax_scale = self.qk_head_dim ** -0.5
173
+ mask = torch.full((args.max_seq_len, args.max_seq_len), float("-inf")).triu_(1)
174
+ self.register_buffer("mask", mask, persistent=False)
175
+
176
+ def forward(self,
177
+ x: torch.Tensor,
178
+ freqs_cis: torch.Tensor,
179
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
180
+ use_cache=False):
181
+ """
182
+ Forward pass for the Multi-Headed Attention Layer (MLA).
183
+
184
+ Args:
185
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
186
+ start_pos (int): Starting position in the sequence for caching.
187
+ freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
188
+ mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
189
+
190
+ Returns:
191
+ torch.Tensor: Output tensor with the same shape as the input.
192
+ """
193
+ bsz, seqlen, _ = x.size()
194
+
195
+ if self.q_lora_rank == 0:
196
+ q = self.wq(x)#x*Wqc
197
+ else:
198
+ q = self.wq_b(self.q_norm(self.wq_a(x)))#x*Wc'*Wqc
199
+ q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
200
+ q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
201
+
202
+ kv = self.wkv_a(x)#x*Wc
203
+ kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
204
+ q_pe,k_pe=apply_rotary_emb(q_pe,k_pe.unsqueeze(2), freqs_cis)
205
+
206
+ wkv_b = self.wkv_b.weight
207
+ wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
208
+ q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
209
+ kv=self.kv_norm(kv)
210
+ k_pe=k_pe.squeeze(2)
211
+ if past_key_value is not None:
212
+ kv=torch.cat([past_key_value[0], kv], dim=1)
213
+ k_pe=torch.cat([past_key_value[1], k_pe], dim=1)
214
+ past_kv = (kv, k_pe) if use_cache else None
215
+ scores = (torch.einsum("bshc,btc->bsht", q_nope, kv) +
216
+ torch.einsum("bshr,btr->bsht", q_pe, k_pe)) * self.softmax_scale
217
+ mask=self.mask.unsqueeze(1)
218
+ scores += mask[:seqlen,:,:seqlen]
219
+ scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
220
+
221
+ x = torch.einsum("bsht,btc->bshc", scores, kv)
222
+ x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
223
+ x = self.wo(x.flatten(2))
224
+
225
+ return x,past_kv
226
+
227
+ class FeedForward(nn.Module):
228
+ def __init__(self, config: LMConfig):
229
+ super().__init__()
230
+ if config.hidden_dim is None:
231
+ hidden_dim = 4 * config.dim
232
+ hidden_dim = int(2 * hidden_dim / 3)
233
+ config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
234
+ self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
235
+ self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
236
+ self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
237
+ self.dropout = nn.Dropout(config.dropout)
238
+
239
+ def forward(self, x):
240
+ return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
241
+
242
+
243
+ class MoEGate(nn.Module):
244
+ def __init__(self, config: LMConfig):
245
+ super().__init__()
246
+ self.config = config
247
+ self.top_k = config.num_experts_per_tok
248
+ self.n_routed_experts = config.n_routed_experts
249
+
250
+ self.scoring_func = config.scoring_func
251
+ self.alpha = config.aux_loss_alpha
252
+ self.seq_aux = config.seq_aux
253
+
254
+ self.norm_topk_prob = config.norm_topk_prob
255
+ self.gating_dim = config.dim
256
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
257
+ self.reset_parameters()
258
+
259
+ def reset_parameters(self) -> None:
260
+ import torch.nn.init as init
261
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
262
+
263
+ def forward(self, hidden_states):
264
+ bsz, seq_len, h = hidden_states.shape
265
+ hidden_states = hidden_states.view(-1, h)
266
+ logits = F.linear(hidden_states, self.weight, None)
267
+ if self.scoring_func == 'softmax':
268
+ scores = logits.softmax(dim=-1)
269
+ else:
270
+ raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
271
+
272
+ topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
273
+
274
+ if self.top_k > 1 and self.norm_topk_prob:
275
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
276
+ topk_weight = topk_weight / denominator
277
+
278
+ if self.training and self.alpha > 0.0:
279
+ scores_for_aux = scores
280
+ aux_topk = self.top_k
281
+ topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
282
+ if self.seq_aux:
283
+ scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
284
+ ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
285
+ ce.scatter_add_(1, topk_idx_for_aux_loss,
286
+ torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
287
+ seq_len * aux_topk / self.n_routed_experts)
288
+ aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
289
+ else:
290
+ mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
291
+ ce = mask_ce.float().mean(0)
292
+ Pi = scores_for_aux.mean(0)
293
+ fi = ce * self.n_routed_experts
294
+ aux_loss = (Pi * fi).sum() * self.alpha
295
+ else:
296
+ aux_loss = 0
297
+ return topk_idx, topk_weight, aux_loss
298
+
299
+
300
+ class MOEFeedForward(nn.Module):
301
+ def __init__(self, config: LMConfig):
302
+ super().__init__()
303
+ self.config = config
304
+ self.experts = nn.ModuleList([
305
+ FeedForward(config)
306
+ for _ in range(config.n_routed_experts)
307
+ ])
308
+ self.gate = MoEGate(config)
309
+ if config.n_shared_experts is not None:
310
+ self.shared_experts = FeedForward(config)
311
+
312
+ def forward(self, x):
313
+ identity = x
314
+ orig_shape = x.shape
315
+ bsz, seq_len, _ = x.shape
316
+ # 使用门控机制选择专家
317
+ topk_idx, topk_weight, aux_loss = self.gate(x)
318
+ x = x.view(-1, x.shape[-1])
319
+ flat_topk_idx = topk_idx.view(-1)
320
+ if self.training:
321
+ # 训练模式下,重复输入数据
322
+ x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
323
+ y = torch.empty_like(x, dtype=torch.float16)
324
+ for i, expert in enumerate(self.experts):
325
+ y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
326
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
327
+ y = y.view(*orig_shape)
328
+ else:
329
+ # 推理模式下,只选择最优专家
330
+ y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
331
+ if self.config.n_shared_experts is not None:
332
+ y = y + self.shared_experts(identity)
333
+ self.aux_loss = aux_loss
334
+ return y
335
+
336
+ @torch.no_grad()
337
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
338
+ expert_cache = torch.zeros_like(x)
339
+ idxs = flat_expert_indices.argsort()
340
+ tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
341
+ token_idxs = idxs // self.config.num_experts_per_tok
342
+ # 例如当tokens_per_expert=[6, 15, 20, 26, 33, 38, 46, 52]
343
+ # 当token_idxs=[3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...]
344
+ # 意味着当token_idxs[:6] -> [3, 7, 19, 21, 24, 25, 4]位���的token都由专家0处理,token_idxs[6:15]位置的token都由专家1处理......
345
+ for i, end_idx in enumerate(tokens_per_expert):
346
+ start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
347
+ if start_idx == end_idx:
348
+ continue
349
+ expert = self.experts[i]
350
+ exp_token_idx = token_idxs[start_idx:end_idx]
351
+ expert_tokens = x[exp_token_idx]
352
+ expert_out = expert(expert_tokens).to(expert_cache.dtype)
353
+ expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
354
+ # 使用 scatter_add_ 进行 sum 操作
355
+ expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
356
+
357
+ return expert_cache
358
+
359
+
360
+ class NanoChatBlock(nn.Module):
361
+ def __init__(self, layer_id: int, config: LMConfig):
362
+ super().__init__()
363
+ self.n_heads = config.n_heads
364
+ self.dim = config.dim
365
+ self.head_dim = config.dim // config.n_heads
366
+ if config.attention=='MLA':
367
+ self.attention = MLA(config)
368
+ elif config.attention=='GQA':
369
+ self.attention = GQA(config)
370
+
371
+ self.layer_id = layer_id
372
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
373
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
374
+ self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
375
+
376
+ def forward(self, x, pos_cis, past_key_value=None, use_cache=False):
377
+ h_attn, past_kv = self.attention(
378
+ self.attention_norm(x),
379
+ pos_cis,
380
+ past_key_value=past_key_value,
381
+ use_cache=use_cache
382
+ )
383
+ h = x + h_attn
384
+ out = h + self.feed_forward(self.ffn_norm(h))
385
+ return out, past_kv
386
+
387
+
388
+ class NanoChatLM(PreTrainedModel):
389
+ config_class = LMConfig
390
+
391
+ def __init__(self, params: LMConfig = None):
392
+ self.params = params or LMConfig()
393
+ super().__init__(self.params)
394
+ self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
395
+ self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
396
+ self.dropout = nn.Dropout(params.dropout)
397
+ self.layers = nn.ModuleList([NanoChatBlock(l, params) for l in range(self.n_layers)])
398
+ self.norm = RMSNorm(params.dim, eps=params.norm_eps)
399
+ self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
400
+ if params.tie_word_embeddings:
401
+ self.output.weight = self.tok_embeddings.weight
402
+ self.register_buffer("pos_cis", precompute_pos_cis(params.dim // params.n_heads, params.max_seq_len,
403
+ theta=params.rope_theta), persistent=False)
404
+ self.OUT = CausalLMOutputWithPast()
405
+
406
+ def forward(self,
407
+ input_ids: Optional[torch.Tensor] = None,
408
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
409
+ use_cache: bool = False,
410
+ **args):
411
+ past_key_values = past_key_values or [None] * len(self.layers)
412
+ start_pos = args.get('start_pos', 0)
413
+ h = self.dropout(self.tok_embeddings(input_ids))
414
+ pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
415
+ past_kvs = []
416
+ for l, layer in enumerate(self.layers):
417
+ h, past_kv = layer(
418
+ h, pos_cis,
419
+ past_key_value=past_key_values[l],
420
+ use_cache=use_cache
421
+ )
422
+ past_kvs.append(past_kv)
423
+ logits = self.output(self.norm(h))
424
+ aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
425
+ self.OUT.__setitem__('logits', logits)
426
+ self.OUT.__setitem__('aux_loss', aux_loss)
427
+ self.OUT.__setitem__('past_key_values', past_kvs)
428
+ return self.OUT
429
+
430
+ @torch.inference_mode()
431
+ def generate(self, input_ids, eos_token_id=151643, max_new_tokens=1024, temperature=0.75, top_p=0.90,
432
+ stream=False, rp=1., use_cache=True, pad_token_id=0, **args):
433
+ # 流式生成
434
+ if stream:
435
+ return self._generate_stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache)
436
+
437
+ # 直接生成
438
+ generated = []
439
+ for i in range(input_ids.size(0)):
440
+ non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
441
+ out = self._generate_stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache)
442
+ tokens_list = [tokens[:, -1:] for tokens in out]
443
+ gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
444
+ full_sequence = torch.cat([non_pad, gen], dim=-1)
445
+ generated.append(full_sequence)
446
+ max_length = max(seq.size(1) for seq in generated)
447
+ generated = [
448
+ torch.cat(
449
+ [seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
450
+ dim=-1)
451
+ for seq in generated
452
+ ]
453
+ return torch.cat(generated, dim=0)
454
+
455
+ def _generate_stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
456
+ start, first_seq, past_kvs = input_ids.shape[1], True, None
457
+ while input_ids.shape[1] < max_new_tokens - 1:
458
+ if first_seq or not use_cache:
459
+ out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache), False
460
+ else:
461
+ out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
462
+ start_pos=input_ids.shape[1] - 1)
463
+ logits, past_kvs = out.logits[:, -1, :], out.past_key_values
464
+ logits[:, list(set(input_ids.tolist()[0]))] /= rp
465
+ logits /= (temperature + 1e-9)
466
+ if top_p is not None and top_p < 1.0:
467
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
468
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
469
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
470
+ sorted_indices_to_remove = cumulative_probs > top_p
471
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
472
+ sorted_indices_to_remove[:, 0] = False
473
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
474
+ logits[indices_to_remove] = -float('Inf')
475
+ input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
476
+ input_ids = torch.cat((input_ids, input_ids_next), dim=1)
477
+ yield input_ids[:, start:]
478
+ if input_ids_next.item() == eos_token_id:
479
+ break
480
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dee4b365e9d90b07cc82f4ebfd15f7b2a3bc01a3614780b0421d521827ace4a4
3
+ size 1347624082
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ }
181
+ },
182
+ "additional_special_tokens": [
183
+ "<|im_start|>",
184
+ "<|im_end|>",
185
+ "<|object_ref_start|>",
186
+ "<|object_ref_end|>",
187
+ "<|box_start|>",
188
+ "<|box_end|>",
189
+ "<|quad_start|>",
190
+ "<|quad_end|>",
191
+ "<|vision_start|>",
192
+ "<|vision_end|>",
193
+ "<|vision_pad|>",
194
+ "<|image_pad|>",
195
+ "<|video_pad|>"
196
+ ],
197
+ "bos_token": null,
198
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
199
+ "clean_up_tokenization_spaces": false,
200
+ "eos_token": "<|im_end|>",
201
+ "errors": "replace",
202
+ "model_max_length": 131072,
203
+ "pad_token": "<|endoftext|>",
204
+ "split_special_tokens": false,
205
+ "tokenizer_class": "Qwen2Tokenizer",
206
+ "unk_token": null
207
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff