cpatonn commited on
Commit
ca9e76b
·
verified ·
1 Parent(s): 6b8275f

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. _config.json +410 -0
  2. config.json +27 -0
  3. modeling_kimi.py +1028 -0
_config.json ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "KimiLinearForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_kimi.KimiLinearConfig",
7
+ "AutoModel": "modeling_kimi.KimiLinearModel",
8
+ "AutoModelForCausalLM": "modeling_kimi.KimiLinearForCausalLM"
9
+ },
10
+ "bos_token_id": 163584,
11
+ "dtype": "bfloat16",
12
+ "eos_token_id": 163586,
13
+ "first_k_dense_replace": 1,
14
+ "head_dim": 72,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 2304,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 9216,
19
+ "kv_lora_rank": 512,
20
+ "linear_attn_config": {
21
+ "full_attn_layers": [
22
+ 4,
23
+ 8,
24
+ 12,
25
+ 16,
26
+ 20,
27
+ 24,
28
+ 27
29
+ ],
30
+ "head_dim": 128,
31
+ "kda_layers": [
32
+ 1,
33
+ 2,
34
+ 3,
35
+ 5,
36
+ 6,
37
+ 7,
38
+ 9,
39
+ 10,
40
+ 11,
41
+ 13,
42
+ 14,
43
+ 15,
44
+ 17,
45
+ 18,
46
+ 19,
47
+ 21,
48
+ 22,
49
+ 23,
50
+ 25,
51
+ 26
52
+ ],
53
+ "num_heads": 32,
54
+ "short_conv_kernel_size": 4
55
+ },
56
+ "mla_use_nope": true,
57
+ "model_max_length": 1048576,
58
+ "model_type": "kimi_linear",
59
+ "moe_intermediate_size": 1024,
60
+ "moe_layer_freq": 1,
61
+ "moe_renormalize": true,
62
+ "moe_router_activation_func": "sigmoid",
63
+ "num_attention_heads": 32,
64
+ "num_expert_group": 1,
65
+ "num_experts": 256,
66
+ "num_experts_per_token": 8,
67
+ "num_hidden_layers": 27,
68
+ "num_key_value_heads": 32,
69
+ "num_nextn_predict_layers": 0,
70
+ "num_shared_experts": 1,
71
+ "pad_token_id": 163839,
72
+ "q_lora_rank": null,
73
+ "qk_nope_head_dim": 128,
74
+ "qk_rope_head_dim": 64,
75
+ "quantization_config": {
76
+ "config_groups": {
77
+ "group_0": {
78
+ "format": "pack-quantized",
79
+ "input_activations": null,
80
+ "output_activations": null,
81
+ "targets": [
82
+ "Linear"
83
+ ],
84
+ "weights": {
85
+ "actorder": null,
86
+ "block_structure": null,
87
+ "dynamic": false,
88
+ "group_size": 32,
89
+ "num_bits": 4,
90
+ "observer": "mse",
91
+ "observer_kwargs": {},
92
+ "strategy": "group",
93
+ "symmetric": true,
94
+ "type": "int"
95
+ }
96
+ }
97
+ },
98
+ "format": "pack-quantized",
99
+ "global_compression_ratio": null,
100
+ "ignore": [
101
+ "model.layers.0.self_attn.q_proj",
102
+ "model.layers.0.self_attn.k_proj",
103
+ "model.layers.0.self_attn.v_proj",
104
+ "model.layers.0.self_attn.f_a_proj",
105
+ "model.layers.0.self_attn.f_b_proj",
106
+ "model.layers.0.self_attn.b_proj",
107
+ "model.layers.0.self_attn.g_a_proj",
108
+ "model.layers.0.self_attn.g_b_proj",
109
+ "model.layers.0.self_attn.o_proj",
110
+ "model.layers.0.mlp.gate_proj",
111
+ "model.layers.0.mlp.up_proj",
112
+ "model.layers.0.mlp.down_proj",
113
+ "model.layers.1.self_attn.q_proj",
114
+ "model.layers.1.self_attn.k_proj",
115
+ "model.layers.1.self_attn.v_proj",
116
+ "model.layers.1.self_attn.f_a_proj",
117
+ "model.layers.1.self_attn.f_b_proj",
118
+ "model.layers.1.self_attn.b_proj",
119
+ "model.layers.1.self_attn.g_a_proj",
120
+ "model.layers.1.self_attn.g_b_proj",
121
+ "model.layers.1.self_attn.o_proj",
122
+ "model.layers.1.block_sparse_moe.shared_experts.gate_proj",
123
+ "model.layers.1.block_sparse_moe.shared_experts.up_proj",
124
+ "model.layers.1.block_sparse_moe.shared_experts.down_proj",
125
+ "model.layers.2.self_attn.q_proj",
126
+ "model.layers.2.self_attn.k_proj",
127
+ "model.layers.2.self_attn.v_proj",
128
+ "model.layers.2.self_attn.f_a_proj",
129
+ "model.layers.2.self_attn.f_b_proj",
130
+ "model.layers.2.self_attn.b_proj",
131
+ "model.layers.2.self_attn.g_a_proj",
132
+ "model.layers.2.self_attn.g_b_proj",
133
+ "model.layers.2.self_attn.o_proj",
134
+ "model.layers.2.block_sparse_moe.shared_experts.gate_proj",
135
+ "model.layers.2.block_sparse_moe.shared_experts.up_proj",
136
+ "model.layers.2.block_sparse_moe.shared_experts.down_proj",
137
+ "model.layers.3.self_attn.q_proj",
138
+ "model.layers.3.self_attn.kv_a_proj_with_mqa",
139
+ "model.layers.3.self_attn.kv_b_proj",
140
+ "model.layers.3.self_attn.o_proj",
141
+ "model.layers.3.block_sparse_moe.shared_experts.gate_proj",
142
+ "model.layers.3.block_sparse_moe.shared_experts.up_proj",
143
+ "model.layers.3.block_sparse_moe.shared_experts.down_proj",
144
+ "model.layers.4.self_attn.q_proj",
145
+ "model.layers.4.self_attn.k_proj",
146
+ "model.layers.4.self_attn.v_proj",
147
+ "model.layers.4.self_attn.f_a_proj",
148
+ "model.layers.4.self_attn.f_b_proj",
149
+ "model.layers.4.self_attn.b_proj",
150
+ "model.layers.4.self_attn.g_a_proj",
151
+ "model.layers.4.self_attn.g_b_proj",
152
+ "model.layers.4.self_attn.o_proj",
153
+ "model.layers.4.block_sparse_moe.shared_experts.gate_proj",
154
+ "model.layers.4.block_sparse_moe.shared_experts.up_proj",
155
+ "model.layers.4.block_sparse_moe.shared_experts.down_proj",
156
+ "model.layers.5.self_attn.q_proj",
157
+ "model.layers.5.self_attn.k_proj",
158
+ "model.layers.5.self_attn.v_proj",
159
+ "model.layers.5.self_attn.f_a_proj",
160
+ "model.layers.5.self_attn.f_b_proj",
161
+ "model.layers.5.self_attn.b_proj",
162
+ "model.layers.5.self_attn.g_a_proj",
163
+ "model.layers.5.self_attn.g_b_proj",
164
+ "model.layers.5.self_attn.o_proj",
165
+ "model.layers.5.block_sparse_moe.shared_experts.gate_proj",
166
+ "model.layers.5.block_sparse_moe.shared_experts.up_proj",
167
+ "model.layers.5.block_sparse_moe.shared_experts.down_proj",
168
+ "model.layers.6.self_attn.q_proj",
169
+ "model.layers.6.self_attn.k_proj",
170
+ "model.layers.6.self_attn.v_proj",
171
+ "model.layers.6.self_attn.f_a_proj",
172
+ "model.layers.6.self_attn.f_b_proj",
173
+ "model.layers.6.self_attn.b_proj",
174
+ "model.layers.6.self_attn.g_a_proj",
175
+ "model.layers.6.self_attn.g_b_proj",
176
+ "model.layers.6.self_attn.o_proj",
177
+ "model.layers.6.block_sparse_moe.shared_experts.gate_proj",
178
+ "model.layers.6.block_sparse_moe.shared_experts.up_proj",
179
+ "model.layers.6.block_sparse_moe.shared_experts.down_proj",
180
+ "model.layers.7.self_attn.q_proj",
181
+ "model.layers.7.self_attn.kv_a_proj_with_mqa",
182
+ "model.layers.7.self_attn.kv_b_proj",
183
+ "model.layers.7.self_attn.o_proj",
184
+ "model.layers.7.block_sparse_moe.shared_experts.gate_proj",
185
+ "model.layers.7.block_sparse_moe.shared_experts.up_proj",
186
+ "model.layers.7.block_sparse_moe.shared_experts.down_proj",
187
+ "model.layers.8.self_attn.q_proj",
188
+ "model.layers.8.self_attn.k_proj",
189
+ "model.layers.8.self_attn.v_proj",
190
+ "model.layers.8.self_attn.f_a_proj",
191
+ "model.layers.8.self_attn.f_b_proj",
192
+ "model.layers.8.self_attn.b_proj",
193
+ "model.layers.8.self_attn.g_a_proj",
194
+ "model.layers.8.self_attn.g_b_proj",
195
+ "model.layers.8.self_attn.o_proj",
196
+ "model.layers.8.block_sparse_moe.shared_experts.gate_proj",
197
+ "model.layers.8.block_sparse_moe.shared_experts.up_proj",
198
+ "model.layers.8.block_sparse_moe.shared_experts.down_proj",
199
+ "model.layers.9.self_attn.q_proj",
200
+ "model.layers.9.self_attn.k_proj",
201
+ "model.layers.9.self_attn.v_proj",
202
+ "model.layers.9.self_attn.f_a_proj",
203
+ "model.layers.9.self_attn.f_b_proj",
204
+ "model.layers.9.self_attn.b_proj",
205
+ "model.layers.9.self_attn.g_a_proj",
206
+ "model.layers.9.self_attn.g_b_proj",
207
+ "model.layers.9.self_attn.o_proj",
208
+ "model.layers.9.block_sparse_moe.shared_experts.gate_proj",
209
+ "model.layers.9.block_sparse_moe.shared_experts.up_proj",
210
+ "model.layers.9.block_sparse_moe.shared_experts.down_proj",
211
+ "model.layers.10.self_attn.q_proj",
212
+ "model.layers.10.self_attn.k_proj",
213
+ "model.layers.10.self_attn.v_proj",
214
+ "model.layers.10.self_attn.f_a_proj",
215
+ "model.layers.10.self_attn.f_b_proj",
216
+ "model.layers.10.self_attn.b_proj",
217
+ "model.layers.10.self_attn.g_a_proj",
218
+ "model.layers.10.self_attn.g_b_proj",
219
+ "model.layers.10.self_attn.o_proj",
220
+ "model.layers.10.block_sparse_moe.shared_experts.gate_proj",
221
+ "model.layers.10.block_sparse_moe.shared_experts.up_proj",
222
+ "model.layers.10.block_sparse_moe.shared_experts.down_proj",
223
+ "model.layers.11.self_attn.q_proj",
224
+ "model.layers.11.self_attn.kv_a_proj_with_mqa",
225
+ "model.layers.11.self_attn.kv_b_proj",
226
+ "model.layers.11.self_attn.o_proj",
227
+ "model.layers.11.block_sparse_moe.shared_experts.gate_proj",
228
+ "model.layers.11.block_sparse_moe.shared_experts.up_proj",
229
+ "model.layers.11.block_sparse_moe.shared_experts.down_proj",
230
+ "model.layers.12.self_attn.q_proj",
231
+ "model.layers.12.self_attn.k_proj",
232
+ "model.layers.12.self_attn.v_proj",
233
+ "model.layers.12.self_attn.f_a_proj",
234
+ "model.layers.12.self_attn.f_b_proj",
235
+ "model.layers.12.self_attn.b_proj",
236
+ "model.layers.12.self_attn.g_a_proj",
237
+ "model.layers.12.self_attn.g_b_proj",
238
+ "model.layers.12.self_attn.o_proj",
239
+ "model.layers.12.block_sparse_moe.shared_experts.gate_proj",
240
+ "model.layers.12.block_sparse_moe.shared_experts.up_proj",
241
+ "model.layers.12.block_sparse_moe.shared_experts.down_proj",
242
+ "model.layers.13.self_attn.q_proj",
243
+ "model.layers.13.self_attn.k_proj",
244
+ "model.layers.13.self_attn.v_proj",
245
+ "model.layers.13.self_attn.f_a_proj",
246
+ "model.layers.13.self_attn.f_b_proj",
247
+ "model.layers.13.self_attn.b_proj",
248
+ "model.layers.13.self_attn.g_a_proj",
249
+ "model.layers.13.self_attn.g_b_proj",
250
+ "model.layers.13.self_attn.o_proj",
251
+ "model.layers.13.block_sparse_moe.shared_experts.gate_proj",
252
+ "model.layers.13.block_sparse_moe.shared_experts.up_proj",
253
+ "model.layers.13.block_sparse_moe.shared_experts.down_proj",
254
+ "model.layers.14.self_attn.q_proj",
255
+ "model.layers.14.self_attn.k_proj",
256
+ "model.layers.14.self_attn.v_proj",
257
+ "model.layers.14.self_attn.f_a_proj",
258
+ "model.layers.14.self_attn.f_b_proj",
259
+ "model.layers.14.self_attn.b_proj",
260
+ "model.layers.14.self_attn.g_a_proj",
261
+ "model.layers.14.self_attn.g_b_proj",
262
+ "model.layers.14.self_attn.o_proj",
263
+ "model.layers.14.block_sparse_moe.shared_experts.gate_proj",
264
+ "model.layers.14.block_sparse_moe.shared_experts.up_proj",
265
+ "model.layers.14.block_sparse_moe.shared_experts.down_proj",
266
+ "model.layers.15.self_attn.q_proj",
267
+ "model.layers.15.self_attn.kv_a_proj_with_mqa",
268
+ "model.layers.15.self_attn.kv_b_proj",
269
+ "model.layers.15.self_attn.o_proj",
270
+ "model.layers.15.block_sparse_moe.shared_experts.gate_proj",
271
+ "model.layers.15.block_sparse_moe.shared_experts.up_proj",
272
+ "model.layers.15.block_sparse_moe.shared_experts.down_proj",
273
+ "model.layers.16.self_attn.q_proj",
274
+ "model.layers.16.self_attn.k_proj",
275
+ "model.layers.16.self_attn.v_proj",
276
+ "model.layers.16.self_attn.f_a_proj",
277
+ "model.layers.16.self_attn.f_b_proj",
278
+ "model.layers.16.self_attn.b_proj",
279
+ "model.layers.16.self_attn.g_a_proj",
280
+ "model.layers.16.self_attn.g_b_proj",
281
+ "model.layers.16.self_attn.o_proj",
282
+ "model.layers.16.block_sparse_moe.shared_experts.gate_proj",
283
+ "model.layers.16.block_sparse_moe.shared_experts.up_proj",
284
+ "model.layers.16.block_sparse_moe.shared_experts.down_proj",
285
+ "model.layers.17.self_attn.q_proj",
286
+ "model.layers.17.self_attn.k_proj",
287
+ "model.layers.17.self_attn.v_proj",
288
+ "model.layers.17.self_attn.f_a_proj",
289
+ "model.layers.17.self_attn.f_b_proj",
290
+ "model.layers.17.self_attn.b_proj",
291
+ "model.layers.17.self_attn.g_a_proj",
292
+ "model.layers.17.self_attn.g_b_proj",
293
+ "model.layers.17.self_attn.o_proj",
294
+ "model.layers.17.block_sparse_moe.shared_experts.gate_proj",
295
+ "model.layers.17.block_sparse_moe.shared_experts.up_proj",
296
+ "model.layers.17.block_sparse_moe.shared_experts.down_proj",
297
+ "model.layers.18.self_attn.q_proj",
298
+ "model.layers.18.self_attn.k_proj",
299
+ "model.layers.18.self_attn.v_proj",
300
+ "model.layers.18.self_attn.f_a_proj",
301
+ "model.layers.18.self_attn.f_b_proj",
302
+ "model.layers.18.self_attn.b_proj",
303
+ "model.layers.18.self_attn.g_a_proj",
304
+ "model.layers.18.self_attn.g_b_proj",
305
+ "model.layers.18.self_attn.o_proj",
306
+ "model.layers.18.block_sparse_moe.shared_experts.gate_proj",
307
+ "model.layers.18.block_sparse_moe.shared_experts.up_proj",
308
+ "model.layers.18.block_sparse_moe.shared_experts.down_proj",
309
+ "model.layers.19.self_attn.q_proj",
310
+ "model.layers.19.self_attn.kv_a_proj_with_mqa",
311
+ "model.layers.19.self_attn.kv_b_proj",
312
+ "model.layers.19.self_attn.o_proj",
313
+ "model.layers.19.block_sparse_moe.shared_experts.gate_proj",
314
+ "model.layers.19.block_sparse_moe.shared_experts.up_proj",
315
+ "model.layers.19.block_sparse_moe.shared_experts.down_proj",
316
+ "model.layers.20.self_attn.q_proj",
317
+ "model.layers.20.self_attn.k_proj",
318
+ "model.layers.20.self_attn.v_proj",
319
+ "model.layers.20.self_attn.f_a_proj",
320
+ "model.layers.20.self_attn.f_b_proj",
321
+ "model.layers.20.self_attn.b_proj",
322
+ "model.layers.20.self_attn.g_a_proj",
323
+ "model.layers.20.self_attn.g_b_proj",
324
+ "model.layers.20.self_attn.o_proj",
325
+ "model.layers.20.block_sparse_moe.shared_experts.gate_proj",
326
+ "model.layers.20.block_sparse_moe.shared_experts.up_proj",
327
+ "model.layers.20.block_sparse_moe.shared_experts.down_proj",
328
+ "model.layers.21.self_attn.q_proj",
329
+ "model.layers.21.self_attn.k_proj",
330
+ "model.layers.21.self_attn.v_proj",
331
+ "model.layers.21.self_attn.f_a_proj",
332
+ "model.layers.21.self_attn.f_b_proj",
333
+ "model.layers.21.self_attn.b_proj",
334
+ "model.layers.21.self_attn.g_a_proj",
335
+ "model.layers.21.self_attn.g_b_proj",
336
+ "model.layers.21.self_attn.o_proj",
337
+ "model.layers.21.block_sparse_moe.shared_experts.gate_proj",
338
+ "model.layers.21.block_sparse_moe.shared_experts.up_proj",
339
+ "model.layers.21.block_sparse_moe.shared_experts.down_proj",
340
+ "model.layers.22.self_attn.q_proj",
341
+ "model.layers.22.self_attn.k_proj",
342
+ "model.layers.22.self_attn.v_proj",
343
+ "model.layers.22.self_attn.f_a_proj",
344
+ "model.layers.22.self_attn.f_b_proj",
345
+ "model.layers.22.self_attn.b_proj",
346
+ "model.layers.22.self_attn.g_a_proj",
347
+ "model.layers.22.self_attn.g_b_proj",
348
+ "model.layers.22.self_attn.o_proj",
349
+ "model.layers.22.block_sparse_moe.shared_experts.gate_proj",
350
+ "model.layers.22.block_sparse_moe.shared_experts.up_proj",
351
+ "model.layers.22.block_sparse_moe.shared_experts.down_proj",
352
+ "model.layers.23.self_attn.q_proj",
353
+ "model.layers.23.self_attn.kv_a_proj_with_mqa",
354
+ "model.layers.23.self_attn.kv_b_proj",
355
+ "model.layers.23.self_attn.o_proj",
356
+ "model.layers.23.block_sparse_moe.shared_experts.gate_proj",
357
+ "model.layers.23.block_sparse_moe.shared_experts.up_proj",
358
+ "model.layers.23.block_sparse_moe.shared_experts.down_proj",
359
+ "model.layers.24.self_attn.q_proj",
360
+ "model.layers.24.self_attn.k_proj",
361
+ "model.layers.24.self_attn.v_proj",
362
+ "model.layers.24.self_attn.f_a_proj",
363
+ "model.layers.24.self_attn.f_b_proj",
364
+ "model.layers.24.self_attn.b_proj",
365
+ "model.layers.24.self_attn.g_a_proj",
366
+ "model.layers.24.self_attn.g_b_proj",
367
+ "model.layers.24.self_attn.o_proj",
368
+ "model.layers.24.block_sparse_moe.shared_experts.gate_proj",
369
+ "model.layers.24.block_sparse_moe.shared_experts.up_proj",
370
+ "model.layers.24.block_sparse_moe.shared_experts.down_proj",
371
+ "model.layers.25.self_attn.q_proj",
372
+ "model.layers.25.self_attn.k_proj",
373
+ "model.layers.25.self_attn.v_proj",
374
+ "model.layers.25.self_attn.f_a_proj",
375
+ "model.layers.25.self_attn.f_b_proj",
376
+ "model.layers.25.self_attn.b_proj",
377
+ "model.layers.25.self_attn.g_a_proj",
378
+ "model.layers.25.self_attn.g_b_proj",
379
+ "model.layers.25.self_attn.o_proj",
380
+ "model.layers.25.block_sparse_moe.shared_experts.gate_proj",
381
+ "model.layers.25.block_sparse_moe.shared_experts.up_proj",
382
+ "model.layers.25.block_sparse_moe.shared_experts.down_proj",
383
+ "model.layers.26.self_attn.q_proj",
384
+ "model.layers.26.self_attn.kv_a_proj_with_mqa",
385
+ "model.layers.26.self_attn.kv_b_proj",
386
+ "model.layers.26.self_attn.o_proj",
387
+ "model.layers.26.block_sparse_moe.shared_experts.gate_proj",
388
+ "model.layers.26.block_sparse_moe.shared_experts.up_proj",
389
+ "model.layers.26.block_sparse_moe.shared_experts.down_proj",
390
+ "lm_head"
391
+ ],
392
+ "kv_cache_scheme": null,
393
+ "quant_method": "compressed-tensors",
394
+ "quantization_status": "compressed",
395
+ "sparsity_config": {},
396
+ "transform_config": {},
397
+ "version": "0.12.3.dev20+gd429903"
398
+ },
399
+ "rms_norm_eps": 1e-05,
400
+ "rope_scaling": null,
401
+ "rope_theta": 10000.0,
402
+ "routed_scaling_factor": 2.446,
403
+ "tie_word_embeddings": false,
404
+ "topk_group": 1,
405
+ "transformers_version": "4.57.0",
406
+ "use_cache": true,
407
+ "use_grouped_topk": true,
408
+ "v_head_dim": 128,
409
+ "vocab_size": 163840
410
+ }
config.json CHANGED
@@ -109,6 +109,7 @@
109
  "model.layers.0.self_attn.o_proj",
110
  "model.layers.0.mlp.gate_proj",
111
  "model.layers.0.mlp.up_proj",
 
112
  "model.layers.0.mlp.down_proj",
113
  "model.layers.1.self_attn.q_proj",
114
  "model.layers.1.self_attn.k_proj",
@@ -121,6 +122,7 @@
121
  "model.layers.1.self_attn.o_proj",
122
  "model.layers.1.block_sparse_moe.shared_experts.gate_proj",
123
  "model.layers.1.block_sparse_moe.shared_experts.up_proj",
 
124
  "model.layers.1.block_sparse_moe.shared_experts.down_proj",
125
  "model.layers.2.self_attn.q_proj",
126
  "model.layers.2.self_attn.k_proj",
@@ -133,6 +135,7 @@
133
  "model.layers.2.self_attn.o_proj",
134
  "model.layers.2.block_sparse_moe.shared_experts.gate_proj",
135
  "model.layers.2.block_sparse_moe.shared_experts.up_proj",
 
136
  "model.layers.2.block_sparse_moe.shared_experts.down_proj",
137
  "model.layers.3.self_attn.q_proj",
138
  "model.layers.3.self_attn.kv_a_proj_with_mqa",
@@ -140,6 +143,7 @@
140
  "model.layers.3.self_attn.o_proj",
141
  "model.layers.3.block_sparse_moe.shared_experts.gate_proj",
142
  "model.layers.3.block_sparse_moe.shared_experts.up_proj",
 
143
  "model.layers.3.block_sparse_moe.shared_experts.down_proj",
144
  "model.layers.4.self_attn.q_proj",
145
  "model.layers.4.self_attn.k_proj",
@@ -152,6 +156,7 @@
152
  "model.layers.4.self_attn.o_proj",
153
  "model.layers.4.block_sparse_moe.shared_experts.gate_proj",
154
  "model.layers.4.block_sparse_moe.shared_experts.up_proj",
 
155
  "model.layers.4.block_sparse_moe.shared_experts.down_proj",
156
  "model.layers.5.self_attn.q_proj",
157
  "model.layers.5.self_attn.k_proj",
@@ -164,6 +169,7 @@
164
  "model.layers.5.self_attn.o_proj",
165
  "model.layers.5.block_sparse_moe.shared_experts.gate_proj",
166
  "model.layers.5.block_sparse_moe.shared_experts.up_proj",
 
167
  "model.layers.5.block_sparse_moe.shared_experts.down_proj",
168
  "model.layers.6.self_attn.q_proj",
169
  "model.layers.6.self_attn.k_proj",
@@ -176,6 +182,7 @@
176
  "model.layers.6.self_attn.o_proj",
177
  "model.layers.6.block_sparse_moe.shared_experts.gate_proj",
178
  "model.layers.6.block_sparse_moe.shared_experts.up_proj",
 
179
  "model.layers.6.block_sparse_moe.shared_experts.down_proj",
180
  "model.layers.7.self_attn.q_proj",
181
  "model.layers.7.self_attn.kv_a_proj_with_mqa",
@@ -183,6 +190,7 @@
183
  "model.layers.7.self_attn.o_proj",
184
  "model.layers.7.block_sparse_moe.shared_experts.gate_proj",
185
  "model.layers.7.block_sparse_moe.shared_experts.up_proj",
 
186
  "model.layers.7.block_sparse_moe.shared_experts.down_proj",
187
  "model.layers.8.self_attn.q_proj",
188
  "model.layers.8.self_attn.k_proj",
@@ -195,6 +203,7 @@
195
  "model.layers.8.self_attn.o_proj",
196
  "model.layers.8.block_sparse_moe.shared_experts.gate_proj",
197
  "model.layers.8.block_sparse_moe.shared_experts.up_proj",
 
198
  "model.layers.8.block_sparse_moe.shared_experts.down_proj",
199
  "model.layers.9.self_attn.q_proj",
200
  "model.layers.9.self_attn.k_proj",
@@ -207,6 +216,7 @@
207
  "model.layers.9.self_attn.o_proj",
208
  "model.layers.9.block_sparse_moe.shared_experts.gate_proj",
209
  "model.layers.9.block_sparse_moe.shared_experts.up_proj",
 
210
  "model.layers.9.block_sparse_moe.shared_experts.down_proj",
211
  "model.layers.10.self_attn.q_proj",
212
  "model.layers.10.self_attn.k_proj",
@@ -219,6 +229,7 @@
219
  "model.layers.10.self_attn.o_proj",
220
  "model.layers.10.block_sparse_moe.shared_experts.gate_proj",
221
  "model.layers.10.block_sparse_moe.shared_experts.up_proj",
 
222
  "model.layers.10.block_sparse_moe.shared_experts.down_proj",
223
  "model.layers.11.self_attn.q_proj",
224
  "model.layers.11.self_attn.kv_a_proj_with_mqa",
@@ -226,6 +237,7 @@
226
  "model.layers.11.self_attn.o_proj",
227
  "model.layers.11.block_sparse_moe.shared_experts.gate_proj",
228
  "model.layers.11.block_sparse_moe.shared_experts.up_proj",
 
229
  "model.layers.11.block_sparse_moe.shared_experts.down_proj",
230
  "model.layers.12.self_attn.q_proj",
231
  "model.layers.12.self_attn.k_proj",
@@ -238,6 +250,7 @@
238
  "model.layers.12.self_attn.o_proj",
239
  "model.layers.12.block_sparse_moe.shared_experts.gate_proj",
240
  "model.layers.12.block_sparse_moe.shared_experts.up_proj",
 
241
  "model.layers.12.block_sparse_moe.shared_experts.down_proj",
242
  "model.layers.13.self_attn.q_proj",
243
  "model.layers.13.self_attn.k_proj",
@@ -250,6 +263,7 @@
250
  "model.layers.13.self_attn.o_proj",
251
  "model.layers.13.block_sparse_moe.shared_experts.gate_proj",
252
  "model.layers.13.block_sparse_moe.shared_experts.up_proj",
 
253
  "model.layers.13.block_sparse_moe.shared_experts.down_proj",
254
  "model.layers.14.self_attn.q_proj",
255
  "model.layers.14.self_attn.k_proj",
@@ -262,6 +276,7 @@
262
  "model.layers.14.self_attn.o_proj",
263
  "model.layers.14.block_sparse_moe.shared_experts.gate_proj",
264
  "model.layers.14.block_sparse_moe.shared_experts.up_proj",
 
265
  "model.layers.14.block_sparse_moe.shared_experts.down_proj",
266
  "model.layers.15.self_attn.q_proj",
267
  "model.layers.15.self_attn.kv_a_proj_with_mqa",
@@ -269,6 +284,7 @@
269
  "model.layers.15.self_attn.o_proj",
270
  "model.layers.15.block_sparse_moe.shared_experts.gate_proj",
271
  "model.layers.15.block_sparse_moe.shared_experts.up_proj",
 
272
  "model.layers.15.block_sparse_moe.shared_experts.down_proj",
273
  "model.layers.16.self_attn.q_proj",
274
  "model.layers.16.self_attn.k_proj",
@@ -281,6 +297,7 @@
281
  "model.layers.16.self_attn.o_proj",
282
  "model.layers.16.block_sparse_moe.shared_experts.gate_proj",
283
  "model.layers.16.block_sparse_moe.shared_experts.up_proj",
 
284
  "model.layers.16.block_sparse_moe.shared_experts.down_proj",
285
  "model.layers.17.self_attn.q_proj",
286
  "model.layers.17.self_attn.k_proj",
@@ -293,6 +310,7 @@
293
  "model.layers.17.self_attn.o_proj",
294
  "model.layers.17.block_sparse_moe.shared_experts.gate_proj",
295
  "model.layers.17.block_sparse_moe.shared_experts.up_proj",
 
296
  "model.layers.17.block_sparse_moe.shared_experts.down_proj",
297
  "model.layers.18.self_attn.q_proj",
298
  "model.layers.18.self_attn.k_proj",
@@ -305,6 +323,7 @@
305
  "model.layers.18.self_attn.o_proj",
306
  "model.layers.18.block_sparse_moe.shared_experts.gate_proj",
307
  "model.layers.18.block_sparse_moe.shared_experts.up_proj",
 
308
  "model.layers.18.block_sparse_moe.shared_experts.down_proj",
309
  "model.layers.19.self_attn.q_proj",
310
  "model.layers.19.self_attn.kv_a_proj_with_mqa",
@@ -312,6 +331,7 @@
312
  "model.layers.19.self_attn.o_proj",
313
  "model.layers.19.block_sparse_moe.shared_experts.gate_proj",
314
  "model.layers.19.block_sparse_moe.shared_experts.up_proj",
 
315
  "model.layers.19.block_sparse_moe.shared_experts.down_proj",
316
  "model.layers.20.self_attn.q_proj",
317
  "model.layers.20.self_attn.k_proj",
@@ -324,6 +344,7 @@
324
  "model.layers.20.self_attn.o_proj",
325
  "model.layers.20.block_sparse_moe.shared_experts.gate_proj",
326
  "model.layers.20.block_sparse_moe.shared_experts.up_proj",
 
327
  "model.layers.20.block_sparse_moe.shared_experts.down_proj",
328
  "model.layers.21.self_attn.q_proj",
329
  "model.layers.21.self_attn.k_proj",
@@ -336,6 +357,7 @@
336
  "model.layers.21.self_attn.o_proj",
337
  "model.layers.21.block_sparse_moe.shared_experts.gate_proj",
338
  "model.layers.21.block_sparse_moe.shared_experts.up_proj",
 
339
  "model.layers.21.block_sparse_moe.shared_experts.down_proj",
340
  "model.layers.22.self_attn.q_proj",
341
  "model.layers.22.self_attn.k_proj",
@@ -348,6 +370,7 @@
348
  "model.layers.22.self_attn.o_proj",
349
  "model.layers.22.block_sparse_moe.shared_experts.gate_proj",
350
  "model.layers.22.block_sparse_moe.shared_experts.up_proj",
 
351
  "model.layers.22.block_sparse_moe.shared_experts.down_proj",
352
  "model.layers.23.self_attn.q_proj",
353
  "model.layers.23.self_attn.kv_a_proj_with_mqa",
@@ -355,6 +378,7 @@
355
  "model.layers.23.self_attn.o_proj",
356
  "model.layers.23.block_sparse_moe.shared_experts.gate_proj",
357
  "model.layers.23.block_sparse_moe.shared_experts.up_proj",
 
358
  "model.layers.23.block_sparse_moe.shared_experts.down_proj",
359
  "model.layers.24.self_attn.q_proj",
360
  "model.layers.24.self_attn.k_proj",
@@ -367,6 +391,7 @@
367
  "model.layers.24.self_attn.o_proj",
368
  "model.layers.24.block_sparse_moe.shared_experts.gate_proj",
369
  "model.layers.24.block_sparse_moe.shared_experts.up_proj",
 
370
  "model.layers.24.block_sparse_moe.shared_experts.down_proj",
371
  "model.layers.25.self_attn.q_proj",
372
  "model.layers.25.self_attn.k_proj",
@@ -379,6 +404,7 @@
379
  "model.layers.25.self_attn.o_proj",
380
  "model.layers.25.block_sparse_moe.shared_experts.gate_proj",
381
  "model.layers.25.block_sparse_moe.shared_experts.up_proj",
 
382
  "model.layers.25.block_sparse_moe.shared_experts.down_proj",
383
  "model.layers.26.self_attn.q_proj",
384
  "model.layers.26.self_attn.kv_a_proj_with_mqa",
@@ -386,6 +412,7 @@
386
  "model.layers.26.self_attn.o_proj",
387
  "model.layers.26.block_sparse_moe.shared_experts.gate_proj",
388
  "model.layers.26.block_sparse_moe.shared_experts.up_proj",
 
389
  "model.layers.26.block_sparse_moe.shared_experts.down_proj",
390
  "lm_head"
391
  ],
 
109
  "model.layers.0.self_attn.o_proj",
110
  "model.layers.0.mlp.gate_proj",
111
  "model.layers.0.mlp.up_proj",
112
+ "model.layers.0.mlp.gate_up_proj",
113
  "model.layers.0.mlp.down_proj",
114
  "model.layers.1.self_attn.q_proj",
115
  "model.layers.1.self_attn.k_proj",
 
122
  "model.layers.1.self_attn.o_proj",
123
  "model.layers.1.block_sparse_moe.shared_experts.gate_proj",
124
  "model.layers.1.block_sparse_moe.shared_experts.up_proj",
125
+ "model.layers.1.block_sparse_moe.shared_experts.gate_up_proj",
126
  "model.layers.1.block_sparse_moe.shared_experts.down_proj",
127
  "model.layers.2.self_attn.q_proj",
128
  "model.layers.2.self_attn.k_proj",
 
135
  "model.layers.2.self_attn.o_proj",
136
  "model.layers.2.block_sparse_moe.shared_experts.gate_proj",
137
  "model.layers.2.block_sparse_moe.shared_experts.up_proj",
138
+ "model.layers.2.block_sparse_moe.shared_experts.gate_up_proj",
139
  "model.layers.2.block_sparse_moe.shared_experts.down_proj",
140
  "model.layers.3.self_attn.q_proj",
141
  "model.layers.3.self_attn.kv_a_proj_with_mqa",
 
143
  "model.layers.3.self_attn.o_proj",
144
  "model.layers.3.block_sparse_moe.shared_experts.gate_proj",
145
  "model.layers.3.block_sparse_moe.shared_experts.up_proj",
146
+ "model.layers.3.block_sparse_moe.shared_experts.gate_up_proj",
147
  "model.layers.3.block_sparse_moe.shared_experts.down_proj",
148
  "model.layers.4.self_attn.q_proj",
149
  "model.layers.4.self_attn.k_proj",
 
156
  "model.layers.4.self_attn.o_proj",
157
  "model.layers.4.block_sparse_moe.shared_experts.gate_proj",
158
  "model.layers.4.block_sparse_moe.shared_experts.up_proj",
159
+ "model.layers.4.block_sparse_moe.shared_experts.gate_up_proj",
160
  "model.layers.4.block_sparse_moe.shared_experts.down_proj",
161
  "model.layers.5.self_attn.q_proj",
162
  "model.layers.5.self_attn.k_proj",
 
169
  "model.layers.5.self_attn.o_proj",
170
  "model.layers.5.block_sparse_moe.shared_experts.gate_proj",
171
  "model.layers.5.block_sparse_moe.shared_experts.up_proj",
172
+ "model.layers.5.block_sparse_moe.shared_experts.gate_up_proj",
173
  "model.layers.5.block_sparse_moe.shared_experts.down_proj",
174
  "model.layers.6.self_attn.q_proj",
175
  "model.layers.6.self_attn.k_proj",
 
182
  "model.layers.6.self_attn.o_proj",
183
  "model.layers.6.block_sparse_moe.shared_experts.gate_proj",
184
  "model.layers.6.block_sparse_moe.shared_experts.up_proj",
185
+ "model.layers.6.block_sparse_moe.shared_experts.gate_up_proj",
186
  "model.layers.6.block_sparse_moe.shared_experts.down_proj",
187
  "model.layers.7.self_attn.q_proj",
188
  "model.layers.7.self_attn.kv_a_proj_with_mqa",
 
190
  "model.layers.7.self_attn.o_proj",
191
  "model.layers.7.block_sparse_moe.shared_experts.gate_proj",
192
  "model.layers.7.block_sparse_moe.shared_experts.up_proj",
193
+ "model.layers.7.block_sparse_moe.shared_experts.gate_up_proj",
194
  "model.layers.7.block_sparse_moe.shared_experts.down_proj",
195
  "model.layers.8.self_attn.q_proj",
196
  "model.layers.8.self_attn.k_proj",
 
203
  "model.layers.8.self_attn.o_proj",
204
  "model.layers.8.block_sparse_moe.shared_experts.gate_proj",
205
  "model.layers.8.block_sparse_moe.shared_experts.up_proj",
206
+ "model.layers.8.block_sparse_moe.shared_experts.gate_up_proj",
207
  "model.layers.8.block_sparse_moe.shared_experts.down_proj",
208
  "model.layers.9.self_attn.q_proj",
209
  "model.layers.9.self_attn.k_proj",
 
216
  "model.layers.9.self_attn.o_proj",
217
  "model.layers.9.block_sparse_moe.shared_experts.gate_proj",
218
  "model.layers.9.block_sparse_moe.shared_experts.up_proj",
219
+ "model.layers.9.block_sparse_moe.shared_experts.gate_up_proj",
220
  "model.layers.9.block_sparse_moe.shared_experts.down_proj",
221
  "model.layers.10.self_attn.q_proj",
222
  "model.layers.10.self_attn.k_proj",
 
229
  "model.layers.10.self_attn.o_proj",
230
  "model.layers.10.block_sparse_moe.shared_experts.gate_proj",
231
  "model.layers.10.block_sparse_moe.shared_experts.up_proj",
232
+ "model.layers.10.block_sparse_moe.shared_experts.gate_up_proj",
233
  "model.layers.10.block_sparse_moe.shared_experts.down_proj",
234
  "model.layers.11.self_attn.q_proj",
235
  "model.layers.11.self_attn.kv_a_proj_with_mqa",
 
237
  "model.layers.11.self_attn.o_proj",
238
  "model.layers.11.block_sparse_moe.shared_experts.gate_proj",
239
  "model.layers.11.block_sparse_moe.shared_experts.up_proj",
240
+ "model.layers.11.block_sparse_moe.shared_experts.gate_up_proj",
241
  "model.layers.11.block_sparse_moe.shared_experts.down_proj",
242
  "model.layers.12.self_attn.q_proj",
243
  "model.layers.12.self_attn.k_proj",
 
250
  "model.layers.12.self_attn.o_proj",
251
  "model.layers.12.block_sparse_moe.shared_experts.gate_proj",
252
  "model.layers.12.block_sparse_moe.shared_experts.up_proj",
253
+ "model.layers.12.block_sparse_moe.shared_experts.gate_up_proj",
254
  "model.layers.12.block_sparse_moe.shared_experts.down_proj",
255
  "model.layers.13.self_attn.q_proj",
256
  "model.layers.13.self_attn.k_proj",
 
263
  "model.layers.13.self_attn.o_proj",
264
  "model.layers.13.block_sparse_moe.shared_experts.gate_proj",
265
  "model.layers.13.block_sparse_moe.shared_experts.up_proj",
266
+ "model.layers.13.block_sparse_moe.shared_experts.gate_up_proj",
267
  "model.layers.13.block_sparse_moe.shared_experts.down_proj",
268
  "model.layers.14.self_attn.q_proj",
269
  "model.layers.14.self_attn.k_proj",
 
276
  "model.layers.14.self_attn.o_proj",
277
  "model.layers.14.block_sparse_moe.shared_experts.gate_proj",
278
  "model.layers.14.block_sparse_moe.shared_experts.up_proj",
279
+ "model.layers.14.block_sparse_moe.shared_experts.gate_up_proj",
280
  "model.layers.14.block_sparse_moe.shared_experts.down_proj",
281
  "model.layers.15.self_attn.q_proj",
282
  "model.layers.15.self_attn.kv_a_proj_with_mqa",
 
284
  "model.layers.15.self_attn.o_proj",
285
  "model.layers.15.block_sparse_moe.shared_experts.gate_proj",
286
  "model.layers.15.block_sparse_moe.shared_experts.up_proj",
287
+ "model.layers.15.block_sparse_moe.shared_experts.gate_up_proj",
288
  "model.layers.15.block_sparse_moe.shared_experts.down_proj",
289
  "model.layers.16.self_attn.q_proj",
290
  "model.layers.16.self_attn.k_proj",
 
297
  "model.layers.16.self_attn.o_proj",
298
  "model.layers.16.block_sparse_moe.shared_experts.gate_proj",
299
  "model.layers.16.block_sparse_moe.shared_experts.up_proj",
300
+ "model.layers.16.block_sparse_moe.shared_experts.gate_up_proj",
301
  "model.layers.16.block_sparse_moe.shared_experts.down_proj",
302
  "model.layers.17.self_attn.q_proj",
303
  "model.layers.17.self_attn.k_proj",
 
310
  "model.layers.17.self_attn.o_proj",
311
  "model.layers.17.block_sparse_moe.shared_experts.gate_proj",
312
  "model.layers.17.block_sparse_moe.shared_experts.up_proj",
313
+ "model.layers.17.block_sparse_moe.shared_experts.gate_up_proj",
314
  "model.layers.17.block_sparse_moe.shared_experts.down_proj",
315
  "model.layers.18.self_attn.q_proj",
316
  "model.layers.18.self_attn.k_proj",
 
323
  "model.layers.18.self_attn.o_proj",
324
  "model.layers.18.block_sparse_moe.shared_experts.gate_proj",
325
  "model.layers.18.block_sparse_moe.shared_experts.up_proj",
326
+ "model.layers.18.block_sparse_moe.shared_experts.gate_up_proj",
327
  "model.layers.18.block_sparse_moe.shared_experts.down_proj",
328
  "model.layers.19.self_attn.q_proj",
329
  "model.layers.19.self_attn.kv_a_proj_with_mqa",
 
331
  "model.layers.19.self_attn.o_proj",
332
  "model.layers.19.block_sparse_moe.shared_experts.gate_proj",
333
  "model.layers.19.block_sparse_moe.shared_experts.up_proj",
334
+ "model.layers.19.block_sparse_moe.shared_experts.gate_up_proj",
335
  "model.layers.19.block_sparse_moe.shared_experts.down_proj",
336
  "model.layers.20.self_attn.q_proj",
337
  "model.layers.20.self_attn.k_proj",
 
344
  "model.layers.20.self_attn.o_proj",
345
  "model.layers.20.block_sparse_moe.shared_experts.gate_proj",
346
  "model.layers.20.block_sparse_moe.shared_experts.up_proj",
347
+ "model.layers.20.block_sparse_moe.shared_experts.gate_up_proj",
348
  "model.layers.20.block_sparse_moe.shared_experts.down_proj",
349
  "model.layers.21.self_attn.q_proj",
350
  "model.layers.21.self_attn.k_proj",
 
357
  "model.layers.21.self_attn.o_proj",
358
  "model.layers.21.block_sparse_moe.shared_experts.gate_proj",
359
  "model.layers.21.block_sparse_moe.shared_experts.up_proj",
360
+ "model.layers.21.block_sparse_moe.shared_experts.gate_up_proj",
361
  "model.layers.21.block_sparse_moe.shared_experts.down_proj",
362
  "model.layers.22.self_attn.q_proj",
363
  "model.layers.22.self_attn.k_proj",
 
370
  "model.layers.22.self_attn.o_proj",
371
  "model.layers.22.block_sparse_moe.shared_experts.gate_proj",
372
  "model.layers.22.block_sparse_moe.shared_experts.up_proj",
373
+ "model.layers.22.block_sparse_moe.shared_experts.gate_up_proj",
374
  "model.layers.22.block_sparse_moe.shared_experts.down_proj",
375
  "model.layers.23.self_attn.q_proj",
376
  "model.layers.23.self_attn.kv_a_proj_with_mqa",
 
378
  "model.layers.23.self_attn.o_proj",
379
  "model.layers.23.block_sparse_moe.shared_experts.gate_proj",
380
  "model.layers.23.block_sparse_moe.shared_experts.up_proj",
381
+ "model.layers.23.block_sparse_moe.shared_experts.gate_up_proj",
382
  "model.layers.23.block_sparse_moe.shared_experts.down_proj",
383
  "model.layers.24.self_attn.q_proj",
384
  "model.layers.24.self_attn.k_proj",
 
391
  "model.layers.24.self_attn.o_proj",
392
  "model.layers.24.block_sparse_moe.shared_experts.gate_proj",
393
  "model.layers.24.block_sparse_moe.shared_experts.up_proj",
394
+ "model.layers.24.block_sparse_moe.shared_experts.gate_up_proj",
395
  "model.layers.24.block_sparse_moe.shared_experts.down_proj",
396
  "model.layers.25.self_attn.q_proj",
397
  "model.layers.25.self_attn.k_proj",
 
404
  "model.layers.25.self_attn.o_proj",
405
  "model.layers.25.block_sparse_moe.shared_experts.gate_proj",
406
  "model.layers.25.block_sparse_moe.shared_experts.up_proj",
407
+ "model.layers.25.block_sparse_moe.shared_experts.gate_up_proj",
408
  "model.layers.25.block_sparse_moe.shared_experts.down_proj",
409
  "model.layers.26.self_attn.q_proj",
410
  "model.layers.26.self_attn.kv_a_proj_with_mqa",
 
412
  "model.layers.26.self_attn.o_proj",
413
  "model.layers.26.block_sparse_moe.shared_experts.gate_proj",
414
  "model.layers.26.block_sparse_moe.shared_experts.up_proj",
415
+ "model.layers.26.block_sparse_moe.shared_experts.gate_up_proj",
416
  "model.layers.26.block_sparse_moe.shared_experts.down_proj",
417
  "lm_head"
418
  ],
modeling_kimi.py ADDED
@@ -0,0 +1,1028 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections.abc import Callable
3
+ from typing import Any, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import transformers
8
+ from einops import rearrange
9
+ from packaging import version
10
+ from torch import nn
11
+ from transformers.activations import ACT2FN
12
+ from transformers.cache_utils import Cache
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.masking_utils import create_causal_mask
15
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
17
+ CausalLMOutputWithPast)
18
+ from transformers.modeling_utils import (ALL_ATTENTION_FUNCTIONS,
19
+ PreTrainedModel)
20
+ from transformers.processing_utils import Unpack
21
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
22
+ from transformers.utils import (TransformersKwargs, auto_docstring,
23
+ can_return_tuple, logging)
24
+ from transformers.utils.generic import OutputRecorder, check_model_inputs
25
+
26
+ try:
27
+ from fla.layers.utils import get_unpad_data, index_first_axis, pad_input
28
+ from fla.modules import FusedRMSNormGated, ShortConvolution
29
+ from fla.ops.kda import chunk_kda, fused_recurrent_kda
30
+ from fla.ops.kda.gate import fused_kda_gate
31
+ except ImportError:
32
+ raise ImportError("Plese run `pip install -U fla-core`")
33
+
34
+ from .configuration_kimi import KimiLinearConfig
35
+
36
+ assert version.parse(transformers.__version__) >= version.parse("4.56.0"), \
37
+ "Please upgrade transformers to >= 4.56.0"
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ class KimiDynamicCache:
43
+ """
44
+ Dynamic cache for Kimi model.
45
+ Inspired by Qwen3-Next
46
+ """
47
+ is_compileable = False
48
+
49
+ def __init__(self, config: KimiLinearConfig):
50
+ super().__init__()
51
+ self.config = config
52
+
53
+ if config.linear_attn_config is not None:
54
+ self.layer_types = []
55
+ for i in range(config.num_hidden_layers):
56
+ if config.is_kda_layer(i):
57
+ self.layer_types.append("linear_attention")
58
+ else:
59
+ self.layer_types.append("full_attention")
60
+ else:
61
+ self.layer_types = ["full_attention"] * config.num_hidden_layers
62
+
63
+ self.transformer_layers = [
64
+ i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention"
65
+ ]
66
+
67
+ linear_layers = [i for i in range(
68
+ config.num_hidden_layers) if self.layer_types[i] == "linear_attention"]
69
+ self.last_linear_layer = linear_layers[-1] if linear_layers else -1
70
+
71
+ self.conv_states = [None for _ in range(config.num_hidden_layers)]
72
+ self.recurrent_states = [None for _ in range(config.num_hidden_layers)]
73
+ self.key_cache = [None for _ in range(config.num_hidden_layers)]
74
+ self.value_cache = [None for _ in range(config.num_hidden_layers)]
75
+
76
+ def __len__(self):
77
+ return len(self.layer_types)
78
+
79
+ def update(
80
+ self,
81
+ key_states: torch.Tensor,
82
+ value_states: torch.Tensor,
83
+ layer_idx: int,
84
+ cache_kwargs: Optional[dict[str, Any]] = None,
85
+ ) -> tuple[torch.Tensor, torch.Tensor]:
86
+ if self.key_cache[layer_idx] is None:
87
+ self.key_cache[layer_idx] = key_states
88
+ self.value_cache[layer_idx] = value_states
89
+ else:
90
+ self.key_cache[layer_idx] = torch.cat(
91
+ [self.key_cache[layer_idx], key_states], dim=2)
92
+ self.value_cache[layer_idx] = torch.cat(
93
+ [self.value_cache[layer_idx], value_states], dim=2)
94
+
95
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
96
+
97
+ def reorder_cache(self, beam_idx: torch.LongTensor):
98
+ """Reorders the cache for beam search, given the selected beam indices."""
99
+ for layer_idx in range(len(self.key_cache)):
100
+ if self.key_cache[layer_idx] is not None:
101
+ device = self.key_cache[layer_idx].device
102
+ beam_idx = beam_idx.to(device)
103
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
104
+ 0, beam_idx)
105
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
106
+ 0, beam_idx)
107
+
108
+ if self.conv_states[layer_idx] is not None:
109
+ device = self.conv_states[layer_idx][0].device
110
+ beam_idx = beam_idx.to(device)
111
+ q_conv, k_conv, v_conv = self.conv_states[layer_idx]
112
+ self.conv_states[layer_idx] = (
113
+ q_conv.index_select(0, beam_idx),
114
+ k_conv.index_select(0, beam_idx),
115
+ v_conv.index_select(0, beam_idx)
116
+ )
117
+ self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(
118
+ 0, beam_idx)
119
+
120
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
121
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
122
+ # take any layer that contains cache and not empty tensor
123
+ layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
124
+ if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None:
125
+ return 0
126
+ return self.key_cache[layer_idx].shape[-2]
127
+
128
+ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
129
+ """
130
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
131
+ the given layer at `layer_idx`.
132
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
133
+ """
134
+ kv_offset = 0
135
+ query_length = cache_position.shape[0]
136
+ past_seen_tokens = self.get_seq_length(layer_idx)
137
+ kv_length = query_length + past_seen_tokens
138
+ return kv_length, kv_offset
139
+
140
+ @property
141
+ def has_previous_state(self):
142
+ """We have a previous state if the last linear (conv) layer was already updated."""
143
+ if self.last_linear_layer == -1:
144
+ return False
145
+ return self.conv_states[self.last_linear_layer] is not None
146
+
147
+
148
+ class KimiRMSNorm(nn.Module):
149
+ def __init__(self, hidden_size, eps=1e-6):
150
+ """
151
+ KimiRMSNorm is equivalent to T5LayerNorm
152
+ """
153
+ super().__init__()
154
+ self.weight = nn.Parameter(torch.ones(hidden_size))
155
+ self.variance_epsilon = eps
156
+
157
+ def forward(self, hidden_states):
158
+ input_dtype = hidden_states.dtype
159
+ hidden_states = hidden_states.to(torch.float32)
160
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
161
+ hidden_states = hidden_states * \
162
+ torch.rsqrt(variance + self.variance_epsilon)
163
+ return self.weight * hidden_states.to(input_dtype)
164
+
165
+
166
+ ALL_LAYERNORM_LAYERS.append(KimiRMSNorm)
167
+
168
+
169
+ class KimiBlockSparseMLP(nn.Module):
170
+ def __init__(self, config: KimiLinearConfig, hidden_size=None, intermediate_size=None):
171
+ super().__init__()
172
+ self.config = config
173
+ self.ffn_dim = config.intermediate_size if intermediate_size is None else intermediate_size
174
+ self.hidden_dim = config.hidden_size if hidden_size is None else hidden_size
175
+
176
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # gate
177
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) # down
178
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # up
179
+
180
+ self.act_fn = ACT2FN[config.hidden_act]
181
+
182
+ def forward(self, hidden_states):
183
+ current_hidden_states = self.act_fn(
184
+ self.w1(hidden_states)) * self.w3(hidden_states)
185
+ current_hidden_states = self.w2(current_hidden_states)
186
+ return current_hidden_states
187
+
188
+
189
+ class KimiMLP(nn.Module):
190
+ def __init__(self, config: KimiLinearConfig, hidden_size=None, intermediate_size=None):
191
+ super().__init__()
192
+ self.config = config
193
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
194
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
195
+ self.gate_proj = nn.Linear(
196
+ self.hidden_size, self.intermediate_size, bias=False)
197
+ self.up_proj = nn.Linear(
198
+ self.hidden_size, self.intermediate_size, bias=False)
199
+ self.down_proj = nn.Linear(
200
+ self.intermediate_size, self.hidden_size, bias=False)
201
+ self.act_fn = ACT2FN[config.hidden_act]
202
+
203
+ def forward(self, x):
204
+ down_proj = self.down_proj(self.act_fn(
205
+ self.gate_proj(x)) * self.up_proj(x))
206
+ return down_proj
207
+
208
+
209
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
210
+ """
211
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
212
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
213
+ """
214
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
215
+ if n_rep == 1:
216
+ return hidden_states
217
+ hidden_states = hidden_states[:, :, None, :, :].expand(
218
+ batch, num_key_value_heads, n_rep, slen, head_dim)
219
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
220
+
221
+
222
+ def eager_attention_forward(
223
+ module: nn.Module,
224
+ query: torch.Tensor,
225
+ key: torch.Tensor,
226
+ value: torch.Tensor,
227
+ attention_mask: Optional[torch.Tensor],
228
+ scaling: float,
229
+ dropout: float = 0.0,
230
+ **kwargs: Unpack[TransformersKwargs],
231
+ ):
232
+ key_states = repeat_kv(key, module.num_key_value_groups)
233
+ value_states = repeat_kv(value, module.num_key_value_groups)
234
+
235
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
236
+ if attention_mask is not None:
237
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
238
+ attn_weights = attn_weights + causal_mask
239
+
240
+ attn_weights = nn.functional.softmax(
241
+ attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
242
+ attn_weights = nn.functional.dropout(
243
+ attn_weights, p=dropout, training=module.training)
244
+ attn_output = torch.matmul(attn_weights, value_states)
245
+ attn_output = attn_output.transpose(1, 2).contiguous()
246
+
247
+ return attn_output, attn_weights
248
+
249
+
250
+ class KimiMLAAttention(nn.Module):
251
+ """
252
+ Multi-Latent Attention adapted from deepseek-v3
253
+ """
254
+
255
+ def __init__(self, config: KimiLinearConfig, layer_idx: int):
256
+ nn.Module.__init__(self)
257
+ self.config = config
258
+ self.layer_idx = layer_idx
259
+ self.hidden_size = config.hidden_size
260
+ self.num_heads = config.num_attention_heads
261
+ self.num_key_value_heads = config.num_key_value_heads
262
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
263
+
264
+ self.rope_theta = config.rope_theta
265
+ self.attention_dropout = getattr(config, "attention_dropout", 0.0)
266
+
267
+ try:
268
+ self.q_lora_rank = config.q_lora_rank
269
+ self.qk_rope_head_dim = config.qk_rope_head_dim
270
+ self.kv_lora_rank = config.kv_lora_rank
271
+ self.v_head_dim = config.v_head_dim
272
+ self.qk_nope_head_dim = config.qk_nope_head_dim
273
+ self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
274
+ self.use_nope = config.mla_use_nope
275
+ self.scaling = self.q_head_dim ** (-0.5)
276
+ except Exception as e:
277
+ raise ValueError(
278
+ f"Kimi MLA config is not found or not properly formatted: {e}")
279
+
280
+ assert self.q_lora_rank is None
281
+ self.q_proj = nn.Linear(
282
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False,
283
+ )
284
+ self.kv_a_proj_with_mqa = nn.Linear(
285
+ self.hidden_size,
286
+ self.kv_lora_rank + self.qk_rope_head_dim,
287
+ bias=False,
288
+ )
289
+ self.kv_a_layernorm = KimiRMSNorm(self.kv_lora_rank)
290
+ self.kv_b_proj = nn.Linear(
291
+ self.kv_lora_rank,
292
+ self.num_heads
293
+ * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
294
+ bias=False,
295
+ )
296
+ self.o_proj = nn.Linear(
297
+ self.num_heads * self.v_head_dim,
298
+ self.hidden_size,
299
+ bias=False,
300
+ )
301
+ self.is_causal = True
302
+ assert self.use_nope
303
+
304
+ def forward(
305
+ self,
306
+ hidden_states: torch.Tensor,
307
+ attention_mask: Optional[torch.Tensor] = None,
308
+ past_key_values: Optional[Cache] = None,
309
+ **kwargs,
310
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
311
+ batch_size, seq_length = hidden_states.shape[:-1]
312
+ query_shape = (batch_size, seq_length, -1, self.q_head_dim)
313
+ key_shape = (batch_size, seq_length, -1,
314
+ self.qk_nope_head_dim + self.v_head_dim)
315
+
316
+ q_states = self.q_proj(hidden_states)
317
+ q_states = q_states.view(query_shape).transpose(1, 2)
318
+ q_pass, q_rot = torch.split(
319
+ q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
320
+
321
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
322
+ k_pass, k_rot = torch.split(
323
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
324
+
325
+ k_pass = self.kv_b_proj(self.kv_a_layernorm(
326
+ k_pass)).view(key_shape).transpose(1, 2)
327
+ k_pass, value_states = torch.split(
328
+ k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
329
+
330
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
331
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
332
+
333
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
334
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
335
+
336
+ if past_key_values is not None:
337
+ key_states, value_states = past_key_values.update(
338
+ key_states, value_states, self.layer_idx)
339
+
340
+ if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
341
+ value_states = F.pad(
342
+ value_states, [0, self.q_head_dim - self.v_head_dim])
343
+
344
+ attention_interface: Callable = eager_attention_forward
345
+ if self.config._attn_implementation != "eager":
346
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
347
+
348
+ attn_output, _ = attention_interface(
349
+ self,
350
+ query_states,
351
+ key_states,
352
+ value_states,
353
+ attention_mask,
354
+ dropout=0.0 if not self.training else self.attention_dropout,
355
+ scaling=self.scaling,
356
+ **kwargs,
357
+ )
358
+
359
+ if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
360
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
361
+
362
+ attn_output = attn_output.reshape(
363
+ batch_size, seq_length, -1).contiguous()
364
+ attn_output = self.o_proj(attn_output)
365
+ return attn_output
366
+
367
+
368
+ class KimiDeltaAttention(nn.Module):
369
+ def __init__(self, config: KimiLinearConfig, layer_idx: int):
370
+ super().__init__()
371
+ self.config = config
372
+ self.mode = "chunk"
373
+
374
+ self.hidden_size = config.hidden_size
375
+ self.conv_size = config.linear_attn_config["short_conv_kernel_size"]
376
+ self.head_dim = config.linear_attn_config["head_dim"]
377
+ self.num_heads = config.linear_attn_config["num_heads"]
378
+ self.head_k_dim = self.head_dim
379
+ self.num_k_heads = self.num_heads
380
+
381
+ self.layer_idx = layer_idx
382
+
383
+ assert self.mode in [
384
+ 'chunk', 'fused_recurrent'], f"Not suppoerted mode `{self.mode}`."
385
+
386
+ projection_k_size = self.head_k_dim * self.num_k_heads
387
+ projection_size = self.head_dim * self.num_heads
388
+
389
+ self.q_proj = nn.Linear(
390
+ self.hidden_size, projection_k_size, bias=False)
391
+ self.k_proj = nn.Linear(
392
+ self.hidden_size, projection_k_size, bias=False)
393
+ self.v_proj = nn.Linear(self.hidden_size, projection_size, bias=False)
394
+
395
+ self.q_conv1d = ShortConvolution(
396
+ hidden_size=projection_k_size,
397
+ kernel_size=self.conv_size,
398
+ activation='silu',
399
+ )
400
+ self.k_conv1d = ShortConvolution(
401
+ hidden_size=projection_k_size,
402
+ kernel_size=self.conv_size,
403
+ activation='silu'
404
+ )
405
+ self.v_conv1d = ShortConvolution(
406
+ hidden_size=projection_size,
407
+ kernel_size=self.conv_size,
408
+ activation='silu'
409
+ )
410
+
411
+ self.A_log = torch.nn.Parameter(torch.log(torch.empty(
412
+ self.num_heads, dtype=torch.float32).uniform_(1, 16)).view(1, 1, -1, 1))
413
+
414
+ self.f_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False)
415
+ self.f_b_proj = nn.Linear(self.head_dim, projection_size, bias=False)
416
+
417
+ self.dt_bias = nn.Parameter(
418
+ torch.empty(projection_size, dtype=torch.float32))
419
+
420
+ self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False)
421
+
422
+ self.g_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False)
423
+ self.g_b_proj = nn.Linear(self.head_dim, projection_size, bias=False)
424
+
425
+ self.o_norm = FusedRMSNormGated(
426
+ self.head_dim, eps=config.rms_norm_eps, activation='sigmoid')
427
+ self.o_proj = nn.Linear(projection_size, self.hidden_size, bias=False)
428
+
429
+ def forward(
430
+ self,
431
+ hidden_states: torch.Tensor,
432
+ attention_mask: Optional[torch.Tensor] = None,
433
+ cache_params: Optional[KimiDynamicCache] = None,
434
+ **kwargs: Unpack[dict]
435
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
436
+ if attention_mask is not None:
437
+ if attention_mask.dim() != 2:
438
+ attention_mask = kwargs.get("padding_mask", None)
439
+
440
+ if attention_mask is not None and attention_mask.dim() != 2:
441
+ raise ValueError(
442
+ "attention_mask must be a 0-1 matrix of shape [batch_size, seq_len] "
443
+ "(0 = padding). 3D masks are not supported here."
444
+ )
445
+ use_cache = cache_params is not None
446
+ batch_size, q_len, _ = hidden_states.shape
447
+ mode = 'fused_recurrent' if q_len <= 64 else self.mode
448
+ if self.training:
449
+ assert mode == 'chunk', "Only chunk mode is supported in training."
450
+
451
+ cu_seqlens = kwargs.get('cu_seqlens', None)
452
+ indices = None
453
+ if attention_mask is not None:
454
+ indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
455
+ hidden_states = index_first_axis(
456
+ rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0)
457
+
458
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
459
+ recurrent_state = None
460
+ if cache_params is not None:
461
+ if cache_params.conv_states[self.layer_idx] is not None:
462
+ conv_state_q, conv_state_k, conv_state_v = cache_params.conv_states[
463
+ self.layer_idx]
464
+ recurrent_state = cache_params.recurrent_states[self.layer_idx]
465
+ q, conv_state_q = self.q_conv1d(
466
+ x=self.q_proj(hidden_states),
467
+ cache=conv_state_q,
468
+ output_final_state=use_cache,
469
+ cu_seqlens=cu_seqlens
470
+ )
471
+ k, conv_state_k = self.k_conv1d(
472
+ x=self.k_proj(hidden_states),
473
+ cache=conv_state_k,
474
+ output_final_state=use_cache,
475
+ cu_seqlens=cu_seqlens
476
+ )
477
+ v, conv_state_v = self.v_conv1d(
478
+ x=self.v_proj(hidden_states),
479
+ cache=conv_state_v,
480
+ output_final_state=use_cache,
481
+ cu_seqlens=cu_seqlens
482
+ )
483
+ g = self.f_b_proj(self.f_a_proj(hidden_states))
484
+ g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
485
+ beta = self.b_proj(hidden_states).float().sigmoid()
486
+
487
+ q, k = map(lambda x: rearrange(
488
+ x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
489
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
490
+
491
+ if mode == 'chunk':
492
+ o, recurrent_state = chunk_kda(
493
+ q=q,
494
+ k=k,
495
+ v=v,
496
+ g=g,
497
+ beta=beta,
498
+ initial_state=recurrent_state,
499
+ output_final_state=True,
500
+ use_qk_l2norm_in_kernel=True,
501
+ cu_seqlens=cu_seqlens,
502
+ )
503
+ else:
504
+ o, recurrent_state = fused_recurrent_kda(
505
+ q=q,
506
+ k=k,
507
+ v=v,
508
+ g=g,
509
+ beta=beta,
510
+ initial_state=recurrent_state,
511
+ output_final_state=True,
512
+ use_qk_l2norm_in_kernel=True,
513
+ cu_seqlens=cu_seqlens,
514
+ )
515
+ if cache_params is not None:
516
+ cache_params.recurrent_states[self.layer_idx] = recurrent_state
517
+ cache_params.conv_states[self.layer_idx] = (
518
+ conv_state_q, conv_state_k, conv_state_v)
519
+
520
+ g = self.g_b_proj(self.g_a_proj(hidden_states))
521
+ g = rearrange(g, '... (h d) -> ... h d', d=self.head_dim)
522
+ o = self.o_norm(o, g)
523
+
524
+ o = rearrange(o, 'b t h d -> b t (h d)')
525
+ o = self.o_proj(o)
526
+ if attention_mask is not None:
527
+ o = pad_input(o.squeeze(0), indices, batch_size, q_len)
528
+
529
+ return o
530
+
531
+
532
+ class KimiMoEGate(nn.Module):
533
+ """
534
+ MoEGate adapted from Deepseek-V3.
535
+ Parameter correspondences:
536
+ num_experts -> n_routed_experts
537
+ num_experts_per_token -> num_experts_per_tok
538
+ num_expert_group -> n_group
539
+ moe_router_activation_func -> scoring_func
540
+ """
541
+
542
+ def __init__(self, config: KimiLinearConfig):
543
+ super().__init__()
544
+ self.config = config
545
+ self.top_k = config.num_experts_per_token
546
+ self.num_experts = config.num_experts
547
+ self.routed_scaling_factor = config.routed_scaling_factor
548
+ self.moe_router_activation_func = config.moe_router_activation_func
549
+ self.num_expert_group = getattr(config, "num_expert_group", 1)
550
+ self.topk_group = getattr(config, "topk_group", 1)
551
+
552
+ # topk selection algorithm
553
+ self.moe_renormalize = config.moe_renormalize
554
+ self.gating_dim = config.hidden_size
555
+ self.weight = nn.Parameter(
556
+ torch.empty((self.num_experts, self.gating_dim))
557
+ )
558
+
559
+ self.e_score_correction_bias = nn.Parameter(
560
+ torch.empty((self.num_experts))
561
+ )
562
+ self.reset_parameters()
563
+
564
+ def reset_parameters(self) -> None:
565
+ import torch.nn.init as init
566
+
567
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
568
+
569
+ def forward(self, hidden_states):
570
+ bsz, seq_len, h = hidden_states.shape
571
+ # compute gating score
572
+ hidden_states = hidden_states.view(-1, h)
573
+ logits = F.linear(
574
+ hidden_states.type(torch.float32), self.weight.type(
575
+ torch.float32), None
576
+ )
577
+ if self.moe_router_activation_func == "sigmoid":
578
+ scores = logits.sigmoid()
579
+ elif self.moe_router_activation_func == "softmax":
580
+ scores = logits.softmax(dim=1)
581
+ else:
582
+ raise NotImplementedError(
583
+ f"insupportable scoring function for MoE gating: {self.moe_router_activation_func}"
584
+ )
585
+
586
+ # select top-k experts
587
+ assert not self.training
588
+ scores_for_choice = scores.view(bsz * seq_len, -1)
589
+ scores_for_choice += self.e_score_correction_bias.unsqueeze(0)
590
+ group_scores = (
591
+ scores_for_choice.view(
592
+ bsz * seq_len, self.num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
593
+ ) # [n, num_expert_group]
594
+ group_idx = torch.topk(
595
+ group_scores, k=self.topk_group, dim=-1, sorted=False
596
+ )[
597
+ 1
598
+ ] # [n, top_k_group]
599
+ group_mask = torch.zeros_like(group_scores) # [n, num_expert_group]
600
+ group_mask.scatter_(1, group_idx, 1) # [n, num_expert_group]
601
+ score_mask = (
602
+ group_mask.unsqueeze(-1)
603
+ .expand(
604
+ bsz * seq_len, self.num_expert_group, self.num_experts // self.num_expert_group
605
+ )
606
+ .reshape(bsz * seq_len, -1)
607
+ ) # [n, e]
608
+ tmp_scores = scores_for_choice.masked_fill(
609
+ ~score_mask.bool(), 0.0) # [n, e]
610
+ _, topk_idx = torch.topk(
611
+ tmp_scores, k=self.top_k, dim=-1, sorted=False
612
+ )
613
+ topk_weight = scores.gather(1, topk_idx)
614
+
615
+ # norm gate to sum 1
616
+ if self.top_k > 1 and self.moe_renormalize:
617
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
618
+ topk_weight = topk_weight / denominator
619
+ # must multiply the scaling factor
620
+ topk_weight = topk_weight * self.routed_scaling_factor
621
+
622
+ return topk_idx, topk_weight
623
+
624
+
625
+ class KimiSparseMoeBlock(nn.Module):
626
+ """
627
+ Adapted from Deepseek-V3's MOE implementation
628
+ The namings are consistent with Kimi's version.
629
+ """
630
+
631
+ def __init__(self, config: KimiLinearConfig):
632
+ super().__init__()
633
+ self.config = config
634
+ self.hidden_dim = config.hidden_size
635
+ self.num_experts = config.num_experts
636
+ self.top_k = config.num_experts_per_token
637
+ self.moe_renormalize = config.moe_renormalize
638
+
639
+ self.ep_size = 1
640
+ self.experts_per_rank = config.num_experts
641
+ self.ep_rank = 0
642
+ self.experts = nn.ModuleList(
643
+ [
644
+ KimiBlockSparseMLP(
645
+ config, intermediate_size=config.moe_intermediate_size
646
+ )
647
+ for _ in range(config.num_experts)
648
+ ]
649
+ )
650
+ self.gate = KimiMoEGate(config)
651
+ if config.num_shared_experts is not None:
652
+ intermediate_size = config.moe_intermediate_size * config.num_shared_experts
653
+ self.shared_experts = KimiMLP(
654
+ config=config, intermediate_size=intermediate_size
655
+ )
656
+
657
+ def forward(self, hidden_states):
658
+ identity = hidden_states
659
+ orig_shape = hidden_states.shape
660
+ topk_idx, topk_weight = self.gate(hidden_states)
661
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
662
+ flat_topk_idx = topk_idx.view(-1)
663
+ if not self.training:
664
+ y = self.moe_infer(hidden_states, topk_idx,
665
+ topk_weight).view(*orig_shape)
666
+ else:
667
+ raise NotImplementedError(
668
+ "Training mode is not supported in KimiSparseMoeBlock")
669
+ if self.config.num_shared_experts is not None:
670
+ y = y + self.shared_experts(identity)
671
+ return y
672
+
673
+ @torch.no_grad()
674
+ def moe_infer(self, x, topk_ids, topk_weight):
675
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
676
+ cnts.scatter_(1, topk_ids, 1)
677
+ tokens_per_expert = cnts.sum(dim=0)
678
+ idxs = topk_ids.view(-1).argsort()
679
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
680
+
681
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
682
+
683
+ outputs = []
684
+ start_idx = 0
685
+ for i, num_tokens in enumerate(tokens_per_expert):
686
+ end_idx = start_idx + num_tokens
687
+ if num_tokens == 0:
688
+ continue
689
+ expert = self.experts[i + self.ep_rank * self.experts_per_rank]
690
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
691
+ expert_out = expert(tokens_for_this_expert)
692
+ outputs.append(expert_out)
693
+ start_idx = end_idx
694
+
695
+ outs = torch.cat(outputs, dim=0) if len(
696
+ outputs) else sorted_tokens.new_empty(0)
697
+
698
+ new_x = torch.empty_like(outs)
699
+ new_x[idxs] = outs
700
+ final_out = (
701
+ new_x.view(*topk_ids.shape, -1)
702
+ .type(topk_weight.dtype)
703
+ .mul_(topk_weight.unsqueeze(dim=-1))
704
+ .sum(dim=1)
705
+ .type(new_x.dtype)
706
+ )
707
+ return final_out
708
+
709
+
710
+ class KimiDecoderLayer(nn.Module):
711
+ def __init__(self, config: KimiLinearConfig, layer_idx: int):
712
+ super().__init__()
713
+ self.hidden_size = config.hidden_size
714
+ self.config = config
715
+ if config.is_kda_layer(layer_idx):
716
+ self.is_linear_attn = True
717
+ self.self_attn = KimiDeltaAttention(
718
+ config=config, layer_idx=layer_idx)
719
+ elif config.is_mla:
720
+ self.is_linear_attn = False
721
+ self.self_attn = KimiMLAAttention(
722
+ config=config, layer_idx=layer_idx)
723
+ else:
724
+ raise NotImplementedError
725
+ if (
726
+ config.num_experts is not None
727
+ and layer_idx >= config.first_k_dense_replace
728
+ and layer_idx % getattr(config, "moe_layer_freq", 1) == 0
729
+ ):
730
+ self.block_sparse_moe = KimiSparseMoeBlock(config)
731
+ else:
732
+ self.mlp = KimiMLP(config)
733
+ self.input_layernorm = KimiRMSNorm(
734
+ config.hidden_size, eps=config.rms_norm_eps)
735
+ self.post_attention_layernorm = KimiRMSNorm(
736
+ config.hidden_size, eps=config.rms_norm_eps)
737
+
738
+ def forward(
739
+ self,
740
+ hidden_states: torch.Tensor,
741
+ attention_mask: Optional[torch.Tensor] = None,
742
+ position_ids: Optional[torch.LongTensor] = None,
743
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
744
+ output_attentions: Optional[bool] = False,
745
+ use_cache: Optional[bool] = False,
746
+ **kwargs: Unpack[FlashAttentionKwargs],
747
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
748
+ """
749
+ Args:
750
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
751
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
752
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
753
+ output_attentions (`bool`, *optional*):
754
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
755
+ returned tensors for more detail.
756
+ use_cache (`bool`, *optional*):
757
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
758
+ (see `past_key_values`).
759
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
760
+ """
761
+
762
+ residual = hidden_states
763
+
764
+ hidden_states = self.input_layernorm(hidden_states)
765
+
766
+ # Self Attention
767
+ if self.is_linear_attn is False:
768
+ hidden_states = self.self_attn(
769
+ hidden_states=hidden_states,
770
+ attention_mask=attention_mask,
771
+ position_ids=position_ids,
772
+ past_key_values=past_key_values,
773
+ output_attentions=output_attentions,
774
+ use_cache=use_cache,
775
+ **kwargs,
776
+ )
777
+ else:
778
+ hidden_states = self.self_attn(
779
+ hidden_states=hidden_states,
780
+ attention_mask=attention_mask,
781
+ cache_params=past_key_values,
782
+ output_attentions=output_attentions,
783
+ use_cache=use_cache,
784
+ **kwargs,
785
+ )
786
+ hidden_states = residual + hidden_states
787
+
788
+ # Fully Connected
789
+ residual = hidden_states
790
+ hidden_states = self.post_attention_layernorm(hidden_states)
791
+ if hasattr(self, "block_sparse_moe"):
792
+ hidden_states = self.block_sparse_moe(hidden_states)
793
+ else:
794
+ hidden_states = self.mlp(hidden_states)
795
+ hidden_states = residual + hidden_states
796
+
797
+ return hidden_states
798
+
799
+
800
+ class KimiPreTrainedModel(PreTrainedModel):
801
+ config_class = KimiLinearConfig
802
+ base_model_prefix = "model"
803
+ supports_gradient_checkpointing = True
804
+ _no_split_modules = ["KimiDecoderLayer"]
805
+ _skip_keys_device_placement = "past_key_values"
806
+ _supports_flash_attn_2 = True
807
+ _can_record_outputs = {
808
+ "router_logits": OutputRecorder(KimiBlockSparseMLP, index=1),
809
+ "hidden_states": KimiDecoderLayer,
810
+ "attentions": KimiMLAAttention,
811
+ }
812
+ _is_stateful = True
813
+
814
+ def _init_weights(self, module):
815
+ std = self.config.initializer_range
816
+ if isinstance(module, nn.Linear):
817
+ module.weight.data.normal_(mean=0.0, std=std)
818
+ if module.bias is not None:
819
+ module.bias.data.zero_()
820
+ elif isinstance(module, nn.Embedding):
821
+ module.weight.data.normal_(mean=0.0, std=std)
822
+ if module.padding_idx is not None:
823
+ module.weight.data[module.padding_idx].zero_()
824
+
825
+
826
+ class KimiLinearModel(KimiPreTrainedModel):
827
+ def __init__(self, config: KimiLinearConfig):
828
+ super().__init__(config)
829
+ self.padding_idx = config.pad_token_id
830
+ self.vocab_size = config.vocab_size
831
+
832
+ self.embed_tokens = nn.Embedding(
833
+ config.vocab_size, config.hidden_size, self.padding_idx)
834
+ self.layers = nn.ModuleList([KimiDecoderLayer(
835
+ config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
836
+ self.norm = KimiRMSNorm(
837
+ config.hidden_size, eps=config.rms_norm_eps)
838
+
839
+ if getattr(config, "_attn_implementation", None) is not None:
840
+ if config._attn_implementation != "flash_attention_2":
841
+ logger.warning_once(
842
+ f"Ignoring the provided attention implementation {config._attn_implementation}")
843
+ logger.warning_once("Using flash_attention_2 backend instead.")
844
+ config._attn_implementation = "flash_attention_2"
845
+ else:
846
+ config._attn_implementation = "flash_attention_2"
847
+
848
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
849
+ self.gradient_checkpointing = False
850
+ # Initialize weights and apply final processing
851
+ self.post_init()
852
+
853
+ def _update_linear_attn_mask(self, attention_mask, cache_position):
854
+ """
855
+ NOTE: Left-padding is used for linear attention mask.
856
+ No need for zeroing states when
857
+ 1. Cached forward
858
+ 2. Attending to all inputs
859
+ """
860
+ linear_attn_mask = attention_mask
861
+ if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
862
+ linear_attn_mask = None
863
+ return linear_attn_mask
864
+
865
+ @check_model_inputs
866
+ @auto_docstring
867
+ def forward(
868
+ self,
869
+ input_ids: torch.LongTensor = None,
870
+ attention_mask: Optional[torch.Tensor] = None,
871
+ position_ids: Optional[torch.LongTensor] = None,
872
+ past_key_values: Optional[Cache] = None,
873
+ inputs_embeds: Optional[torch.FloatTensor] = None,
874
+ cache_position: Optional[torch.LongTensor] = None,
875
+ use_cache: Optional[bool] = None,
876
+ **kwargs: Unpack[TransformersKwargs],
877
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
878
+
879
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
880
+
881
+ if (input_ids is None) and (inputs_embeds is None):
882
+ raise ValueError(
883
+ "You must specify exactly one of input_ids or inputs_embeds")
884
+
885
+ # Get inputs_embeds
886
+ if inputs_embeds is None:
887
+ inputs_embeds = self.embed_tokens(input_ids)
888
+
889
+ if use_cache and past_key_values is None:
890
+ past_key_values = KimiDynamicCache(config=self.config)
891
+
892
+ if cache_position is None:
893
+ past_seen_tokens = past_key_values.get_seq_length(
894
+ ) if past_key_values is not None else 0
895
+ cache_position: torch.Tensor = torch.arange(
896
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
897
+ )
898
+
899
+ if position_ids is None:
900
+ position_ids = cache_position.unsqueeze(0)
901
+
902
+ causal_mask = create_causal_mask(
903
+ config=self.config,
904
+ input_embeds=inputs_embeds,
905
+ attention_mask=attention_mask,
906
+ cache_position=cache_position,
907
+ past_key_values=past_key_values,
908
+ position_ids=position_ids,
909
+ )
910
+ linear_attn_mask = self._update_linear_attn_mask(
911
+ attention_mask, cache_position)
912
+
913
+ hidden_states = inputs_embeds
914
+ if past_key_values is not None:
915
+ assert isinstance(past_key_values, KimiDynamicCache)
916
+
917
+ for decoder_layer in self.layers:
918
+ layer_mask = linear_attn_mask if decoder_layer.is_linear_attn else causal_mask
919
+
920
+ hidden_states = decoder_layer(
921
+ hidden_states,
922
+ attention_mask=layer_mask,
923
+ past_key_values=past_key_values,
924
+ cache_position=cache_position,
925
+ **kwargs,
926
+ )
927
+
928
+ hidden_states = self.norm(hidden_states)
929
+
930
+ return BaseModelOutputWithPast(
931
+ last_hidden_state=hidden_states,
932
+ past_key_values=past_key_values,
933
+ )
934
+
935
+
936
+ class KimiLinearForCausalLM(KimiPreTrainedModel, GenerationMixin):
937
+ _tied_weights_keys = ["lm_head.weight"]
938
+
939
+ def __init__(self, config):
940
+ super().__init__(config)
941
+ self.model = KimiLinearModel(config)
942
+ self.vocab_size = config.vocab_size
943
+ self.lm_head = nn.Linear(
944
+ config.hidden_size, config.vocab_size, bias=False)
945
+
946
+ # Initialize weights and apply final processing
947
+ self.post_init()
948
+
949
+ @can_return_tuple
950
+ @auto_docstring
951
+ def forward(
952
+ self,
953
+ input_ids: torch.LongTensor = None,
954
+ attention_mask: Optional[torch.Tensor] = None,
955
+ position_ids: Optional[torch.LongTensor] = None,
956
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
957
+ inputs_embeds: Optional[torch.FloatTensor] = None,
958
+ labels: Optional[torch.LongTensor] = None,
959
+ use_cache: Optional[bool] = None,
960
+ output_attentions: Optional[bool] = None,
961
+ output_hidden_states: Optional[bool] = None,
962
+ generation_mode: Optional[bool] = None,
963
+ return_dict: Optional[bool] = None,
964
+ cache_position: Optional[torch.LongTensor] = None,
965
+ **kwargs: Unpack[TransformersKwargs],
966
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
967
+ r"""
968
+ Args:
969
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
970
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
971
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
972
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
973
+
974
+ Returns:
975
+
976
+ Example:
977
+
978
+ ```python
979
+ >>> from transformers import AutoTokenizer, KimiLinearForCausalLM
980
+
981
+ >>> model = KimiLinearForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
982
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
983
+
984
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
985
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
986
+
987
+ >>> # Generate
988
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
989
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
990
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
991
+ ```"""
992
+
993
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
994
+ output_hidden_states = (
995
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
996
+ )
997
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
998
+
999
+ outputs = self.model(
1000
+ input_ids=input_ids,
1001
+ attention_mask=attention_mask,
1002
+ position_ids=position_ids,
1003
+ past_key_values=past_key_values,
1004
+ inputs_embeds=inputs_embeds,
1005
+ use_cache=use_cache,
1006
+ output_attentions=output_attentions,
1007
+ output_hidden_states=output_hidden_states,
1008
+ return_dict=return_dict,
1009
+ cache_position=cache_position,
1010
+ )
1011
+
1012
+ logits = outputs[0]
1013
+ if generation_mode:
1014
+ logits = logits[:, -1:]
1015
+ logits = self.lm_head(logits)
1016
+
1017
+ loss = None
1018
+ if labels is not None:
1019
+ loss = self.loss_function(
1020
+ logits, labels, self.vocab_size, **kwargs)
1021
+
1022
+ return CausalLMOutputWithPast(
1023
+ loss=loss,
1024
+ logits=logits,
1025
+ past_key_values=outputs.past_key_values,
1026
+ hidden_states=outputs.hidden_states,
1027
+ attentions=outputs.attentions,
1028
+ )