chore: adding lolcats configs scrc and src
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- configs/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml +52 -0
- configs/experiment/distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml +52 -0
- configs/experiment/eval_alpaca_clean.yaml +56 -0
- configs/experiment/finetune_lora_fqkvo_alpaca_clean.yaml +58 -0
- configs/experiment/finetune_lora_qkvo_alpaca_clean.yaml +56 -0
- configs/experiment/no_distill_alpaca_clean.yaml +29 -0
- configs/model/base_llama3_1_8b.yaml +15 -0
- configs/model/base_llama3_8b.yaml +15 -0
- configs/model/base_mistral_7b.yaml +15 -0
- configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml +40 -0
- configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml +40 -0
- configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wsw64_fd64_w01.yaml +34 -0
- configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01.yaml +34 -0
- configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wsw64_fd64_w01.yaml +36 -0
- configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wtk64_fd64_w01.yaml +35 -0
- configs/model/distill_llama3_1_8b_lk_smd_fd64.yaml +35 -0
- configs/model/distill_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml +39 -0
- configs/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml +39 -0
- configs/model/distill_llama3_1_8b_lk_t2r.yaml +35 -0
- configs/model/distill_llama3_8b_lk_smd_fd64.yaml +29 -0
- configs/model/distill_llama3_8b_lk_smd_wsw64_fd64_w01.yaml +33 -0
- configs/model/distill_llama3_8b_lk_smd_wtk64_fd64_w01.yaml +33 -0
- configs/model/distill_llama3_8b_lk_t2r.yaml +29 -0
- configs/model/distill_mistral_7b_lk_smd_fd64.yaml +29 -0
- configs/model/distill_mistral_7b_lk_smd_wsw64_fd64_w01.yaml +35 -0
- configs/model/distill_mistral_7b_lk_smd_wtk64_fd64_w01.yaml +35 -0
- configs/model/distill_mistral_7b_lk_t2r.yaml +29 -0
- csrc/__init__.py +6 -0
- csrc/causal_attention.cpp +225 -0
- csrc/causal_attention.py +77 -0
- csrc/causal_attention_cuda.cu +1483 -0
- csrc/causal_attention_kv_cuda.cu +1483 -0
- csrc/setup.py +53 -0
- src/__init__.py +0 -0
- src/dataloaders/__init__.py +22 -0
- src/dataloaders/alpaca_clean.py +149 -0
- src/dataloaders/alpaca_clean_instruct.py +148 -0
- src/dataloaders/utils/__init__.py +4 -0
- src/dataloaders/utils/llama3.py +62 -0
- src/dataloaders/utils/packing.py +80 -0
- src/dataloaders/utils/setup.py +123 -0
- src/finetune.py +68 -0
- src/model/__init__.py +0 -0
- src/model/convert_model.py +173 -0
- src/model/feature_map.py +306 -0
- src/model/linear_attention/__init__.py +23 -0
- src/model/linear_attention/linear_attention.py +459 -0
- src/model/linear_attention/linear_window_attention_sw.py +339 -0
- src/model/linear_attention/linear_window_attention_sw_linear.py +522 -0
- src/model/linear_attention/linear_window_attention_sw_long.py +23 -0
configs/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
name: alpaca_clean
|
| 3 |
+
dataset_config:
|
| 4 |
+
name: default
|
| 5 |
+
path: yahma/alpaca-cleaned
|
| 6 |
+
chunk_size: 1024 # sequence length for distilling
|
| 7 |
+
concat_data: true
|
| 8 |
+
cache_dir: 'data/alpaca' # Change this to where you want to save
|
| 9 |
+
pretrained_model_config: # will be updated based on model_config
|
| 10 |
+
pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3-8B'
|
| 11 |
+
cache_dir: '/scratch/'
|
| 12 |
+
preprocess_config: null
|
| 13 |
+
|
| 14 |
+
dataloader:
|
| 15 |
+
batch_size: 1
|
| 16 |
+
num_workers: 2
|
| 17 |
+
drop_last: false
|
| 18 |
+
pin_memory: true
|
| 19 |
+
|
| 20 |
+
optimizer:
|
| 21 |
+
optim: adamw_torch_fused
|
| 22 |
+
lr: 0.01
|
| 23 |
+
weight_decay: 0.0
|
| 24 |
+
|
| 25 |
+
lr_scheduler:
|
| 26 |
+
lr_scheduler_type: reduce_lr_on_plateau
|
| 27 |
+
mode: min
|
| 28 |
+
factor: 0.1
|
| 29 |
+
patience: 10
|
| 30 |
+
min_lr: 0.00001
|
| 31 |
+
|
| 32 |
+
trainer: # HuggingFace Trainer-like arguments
|
| 33 |
+
name: distill_attention_xent_mse
|
| 34 |
+
reverse_kl: false
|
| 35 |
+
mse_factor: 1000
|
| 36 |
+
xent_factor: 0
|
| 37 |
+
|
| 38 |
+
bf16: true
|
| 39 |
+
train_split: train
|
| 40 |
+
val_split: validation
|
| 41 |
+
num_train_epochs: 2
|
| 42 |
+
gradient_accumulation_steps: 8
|
| 43 |
+
seed: 42
|
| 44 |
+
batch_size: 1
|
| 45 |
+
load_best_model_at_end: true
|
| 46 |
+
greater_is_better: false
|
| 47 |
+
metric_for_best_model: distill/eval/loss
|
| 48 |
+
logging_steps: 100
|
| 49 |
+
evaluation_strategy: steps
|
| 50 |
+
max_steps: -1
|
| 51 |
+
eval_steps: 100
|
| 52 |
+
max_eval_batches: null
|
configs/experiment/distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
name: alpaca_clean
|
| 3 |
+
dataset_config:
|
| 4 |
+
name: default
|
| 5 |
+
path: yahma/alpaca-cleaned
|
| 6 |
+
chunk_size: 1024 # sequence length for distilling
|
| 7 |
+
concat_data: true
|
| 8 |
+
cache_dir: 'data/alpaca' # Change this to where you want to save
|
| 9 |
+
pretrained_model_config: # will be updated based on model_config
|
| 10 |
+
pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3.1-8B'
|
| 11 |
+
cache_dir: '/data_persistent2/sim_data/llama-3_1-8b/'
|
| 12 |
+
preprocess_config: null
|
| 13 |
+
|
| 14 |
+
dataloader:
|
| 15 |
+
batch_size: 1
|
| 16 |
+
num_workers: 2
|
| 17 |
+
drop_last: false
|
| 18 |
+
pin_memory: true
|
| 19 |
+
|
| 20 |
+
optimizer:
|
| 21 |
+
optim: adamw_torch_fused
|
| 22 |
+
lr: 0.01
|
| 23 |
+
weight_decay: 0.0
|
| 24 |
+
|
| 25 |
+
lr_scheduler:
|
| 26 |
+
lr_scheduler_type: reduce_lr_on_plateau
|
| 27 |
+
mode: min
|
| 28 |
+
factor: 0.1
|
| 29 |
+
patience: 10
|
| 30 |
+
min_lr: 0.00001
|
| 31 |
+
|
| 32 |
+
trainer: # HuggingFace Trainer-like arguments
|
| 33 |
+
name: distill_attention_xent_mse
|
| 34 |
+
reverse_kl: false
|
| 35 |
+
mse_factor: 1000
|
| 36 |
+
xent_factor: 1
|
| 37 |
+
|
| 38 |
+
bf16: true
|
| 39 |
+
train_split: train
|
| 40 |
+
val_split: validation
|
| 41 |
+
num_train_epochs: 2
|
| 42 |
+
gradient_accumulation_steps: 8
|
| 43 |
+
seed: 42
|
| 44 |
+
batch_size: 1
|
| 45 |
+
load_best_model_at_end: true
|
| 46 |
+
greater_is_better: false
|
| 47 |
+
metric_for_best_model: distill/eval/loss
|
| 48 |
+
logging_steps: 100
|
| 49 |
+
evaluation_strategy: steps
|
| 50 |
+
max_steps: -1
|
| 51 |
+
eval_steps: 100
|
| 52 |
+
max_eval_batches: null
|
configs/experiment/eval_alpaca_clean.yaml
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
name: alpaca_clean
|
| 3 |
+
dataset_config:
|
| 4 |
+
name: alpaca
|
| 5 |
+
path: yahma/alpaca-cleaned
|
| 6 |
+
chunk_size: 1024 # sequence length for distilling
|
| 7 |
+
concat_data: true
|
| 8 |
+
cache_dir: 'data/alpaca' # Change this to where you want to save
|
| 9 |
+
pretrained_model_config:
|
| 10 |
+
pretrained_model_name_or_path: 'mistralai/Mistral-7B-v0.1' # will be updated based on model_config
|
| 11 |
+
cache_dir: '/scratch/'
|
| 12 |
+
preprocess_config: null
|
| 13 |
+
|
| 14 |
+
dataloader:
|
| 15 |
+
batch_size: 1
|
| 16 |
+
num_workers: 2
|
| 17 |
+
drop_last: false
|
| 18 |
+
pin_memory: true
|
| 19 |
+
|
| 20 |
+
optimizer:
|
| 21 |
+
optim: adamw_torch_fused
|
| 22 |
+
lr: 1e-4
|
| 23 |
+
weight_decay: 0.0
|
| 24 |
+
|
| 25 |
+
lr_scheduler:
|
| 26 |
+
lr_scheduler_type: reduce_lr_on_plateau
|
| 27 |
+
mode: min
|
| 28 |
+
factor: 0.1
|
| 29 |
+
patience: 10
|
| 30 |
+
min_lr: 0.00001
|
| 31 |
+
|
| 32 |
+
trainer: # HuggingFace Trainer-like arguments
|
| 33 |
+
name: finetune_seq2seq
|
| 34 |
+
bf16: true
|
| 35 |
+
train_split: train
|
| 36 |
+
val_split: test
|
| 37 |
+
num_train_epochs: 2
|
| 38 |
+
gradient_accumulation_steps: 8
|
| 39 |
+
seed: 42
|
| 40 |
+
batch_size: 1
|
| 41 |
+
load_best_model_at_end: true
|
| 42 |
+
greater_is_better: true
|
| 43 |
+
metric_for_best_model: eval/rouge/geometric_mean
|
| 44 |
+
logging_steps: 100
|
| 45 |
+
evaluation_strategy: steps
|
| 46 |
+
max_steps: -1
|
| 47 |
+
eval_steps: 100
|
| 48 |
+
max_eval_batches: null
|
| 49 |
+
|
| 50 |
+
finetune:
|
| 51 |
+
method: lora
|
| 52 |
+
kwargs:
|
| 53 |
+
r: 8
|
| 54 |
+
lora_alpha: 16
|
| 55 |
+
lora_dropout: 0 # 0.05
|
| 56 |
+
target_modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj']
|
configs/experiment/finetune_lora_fqkvo_alpaca_clean.yaml
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
name: alpaca_clean
|
| 3 |
+
dataset_config:
|
| 4 |
+
name: default
|
| 5 |
+
path: yahma/alpaca-cleaned
|
| 6 |
+
chunk_size: 1024
|
| 7 |
+
concat_data: true
|
| 8 |
+
cache_dir: "data/alpaca"
|
| 9 |
+
pretrained_model_config:
|
| 10 |
+
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" # will be updated based on model_config
|
| 11 |
+
cache_dir: "/data_persistent2/sim_data/"
|
| 12 |
+
preprocess_config: null
|
| 13 |
+
|
| 14 |
+
dataloader:
|
| 15 |
+
batch_size: 1
|
| 16 |
+
num_workers: 2
|
| 17 |
+
drop_last: false
|
| 18 |
+
pin_memory: true
|
| 19 |
+
|
| 20 |
+
optimizer:
|
| 21 |
+
optim: adamw_torch_fused
|
| 22 |
+
lr: 1e-4
|
| 23 |
+
weight_decay: 0.0
|
| 24 |
+
|
| 25 |
+
lr_scheduler:
|
| 26 |
+
lr_scheduler_type: reduce_lr_on_plateau
|
| 27 |
+
mode: min
|
| 28 |
+
factor: 0.1
|
| 29 |
+
patience: 10
|
| 30 |
+
min_lr: 0.00001
|
| 31 |
+
|
| 32 |
+
trainer: # HuggingFace Trainer-like arguments
|
| 33 |
+
name: default_lm
|
| 34 |
+
bf16: true
|
| 35 |
+
train_split: train
|
| 36 |
+
val_split: validation
|
| 37 |
+
num_train_epochs: 2
|
| 38 |
+
gradient_accumulation_steps: 8
|
| 39 |
+
seed: 42
|
| 40 |
+
batch_size: 1
|
| 41 |
+
load_best_model_at_end: true
|
| 42 |
+
greater_is_better: false
|
| 43 |
+
metric_for_best_model: eval/loss # eval/rouge/geometric_mean
|
| 44 |
+
logging_steps: 100
|
| 45 |
+
evaluation_strategy: steps
|
| 46 |
+
max_steps: -1
|
| 47 |
+
eval_steps: 100
|
| 48 |
+
max_eval_batches: null
|
| 49 |
+
num_save_ckpt_steps: 200
|
| 50 |
+
|
| 51 |
+
finetune:
|
| 52 |
+
method: lora
|
| 53 |
+
kwargs:
|
| 54 |
+
r: 8
|
| 55 |
+
lora_alpha: 16
|
| 56 |
+
lora_dropout: 0 # 0.05
|
| 57 |
+
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
| 58 |
+
trainable_weights: ['feature_map_q.mlp.layer', 'feature_map_k.mlp.layer', 'window_factors']
|
configs/experiment/finetune_lora_qkvo_alpaca_clean.yaml
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
name: alpaca_clean
|
| 3 |
+
dataset_config:
|
| 4 |
+
name: default
|
| 5 |
+
path: yahma/alpaca-cleaned
|
| 6 |
+
chunk_size: 1024
|
| 7 |
+
concat_data: true
|
| 8 |
+
cache_dir: "data/alpaca"
|
| 9 |
+
pretrained_model_config:
|
| 10 |
+
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" # will be updated based on model_config
|
| 11 |
+
cache_dir: "/scratch/"
|
| 12 |
+
preprocess_config: null
|
| 13 |
+
|
| 14 |
+
dataloader:
|
| 15 |
+
batch_size: 1
|
| 16 |
+
num_workers: 2
|
| 17 |
+
drop_last: false
|
| 18 |
+
pin_memory: true
|
| 19 |
+
|
| 20 |
+
optimizer:
|
| 21 |
+
optim: adamw_torch_fused
|
| 22 |
+
lr: 1e-4
|
| 23 |
+
weight_decay: 0.0
|
| 24 |
+
|
| 25 |
+
lr_scheduler:
|
| 26 |
+
lr_scheduler_type: reduce_lr_on_plateau
|
| 27 |
+
mode: min
|
| 28 |
+
factor: 0.1
|
| 29 |
+
patience: 10
|
| 30 |
+
min_lr: 0.00001
|
| 31 |
+
|
| 32 |
+
trainer: # HuggingFace Trainer-like arguments
|
| 33 |
+
name: default_lm
|
| 34 |
+
bf16: true
|
| 35 |
+
train_split: train
|
| 36 |
+
val_split: validation
|
| 37 |
+
num_train_epochs: 2
|
| 38 |
+
gradient_accumulation_steps: 8
|
| 39 |
+
seed: 42
|
| 40 |
+
batch_size: 1
|
| 41 |
+
load_best_model_at_end: true
|
| 42 |
+
greater_is_better: false
|
| 43 |
+
metric_for_best_model: eval/loss # eval/rouge/geometric_mean
|
| 44 |
+
logging_steps: 100
|
| 45 |
+
evaluation_strategy: steps
|
| 46 |
+
max_steps: -1
|
| 47 |
+
eval_steps: 100
|
| 48 |
+
max_eval_batches: null
|
| 49 |
+
|
| 50 |
+
finetune:
|
| 51 |
+
method: lora
|
| 52 |
+
kwargs:
|
| 53 |
+
r: 8
|
| 54 |
+
lora_alpha: 16
|
| 55 |
+
lora_dropout: 0 # 0.05
|
| 56 |
+
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
configs/experiment/no_distill_alpaca_clean.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
name: alpaca_clean
|
| 3 |
+
dataset_config:
|
| 4 |
+
name: alpaca
|
| 5 |
+
path: yahma/alpaca-cleaned
|
| 6 |
+
chunk_size: 1024 # sequence length for distilling
|
| 7 |
+
concat_data: true
|
| 8 |
+
cache_dir: 'data/alpaca' # Change this to where you want to save
|
| 9 |
+
pretrained_model_config:
|
| 10 |
+
pretrained_model_name_or_path: 'mistralai/Mistral-7B-v0.1' # will be updated based on model_config
|
| 11 |
+
cache_dir: '/scr-ssd/mzhang/models/mistral-v0.1'
|
| 12 |
+
preprocess_config: null
|
| 13 |
+
|
| 14 |
+
dataloader:
|
| 15 |
+
batch_size: 1
|
| 16 |
+
num_workers: 2
|
| 17 |
+
drop_last: false
|
| 18 |
+
pin_memory: true
|
| 19 |
+
|
| 20 |
+
optimizer:
|
| 21 |
+
optim: adamw_torch_fused
|
| 22 |
+
lr: 0.01
|
| 23 |
+
weight_decay: 0.0
|
| 24 |
+
|
| 25 |
+
lr_scheduler:
|
| 26 |
+
lr_scheduler_type: none
|
| 27 |
+
|
| 28 |
+
trainer: # HuggingFace Trainer-like arguments
|
| 29 |
+
name: null
|
configs/model/base_llama3_1_8b.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: llama
|
| 2 |
+
model:
|
| 3 |
+
pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3.1-8B'
|
| 4 |
+
cache_dir: '/scr-ssd/mzhang/models/llama3' # Set this to where you want to save checkpoint weights
|
| 5 |
+
return_dict: true
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: false
|
| 8 |
+
device_map: auto
|
| 9 |
+
low_cpu_mem_usage: true
|
| 10 |
+
torch_dtype: bfloat16
|
| 11 |
+
attn_implementation: flash_attention_2
|
| 12 |
+
rope_theta: 500000.0
|
| 13 |
+
|
| 14 |
+
attention:
|
| 15 |
+
attention_type: softmax
|
configs/model/base_llama3_8b.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: llama
|
| 2 |
+
model:
|
| 3 |
+
pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3-8B'
|
| 4 |
+
cache_dir: '/scr-ssd/mzhang/models/llama3' # Set this to where you want to save checkpoint weights
|
| 5 |
+
return_dict: true
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: false
|
| 8 |
+
device_map: auto
|
| 9 |
+
low_cpu_mem_usage: true
|
| 10 |
+
torch_dtype: bfloat16
|
| 11 |
+
attn_implementation: flash_attention_2
|
| 12 |
+
rope_theta: 500000.0
|
| 13 |
+
|
| 14 |
+
attention:
|
| 15 |
+
attention_type: softmax
|
configs/model/base_mistral_7b.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: llama
|
| 2 |
+
model:
|
| 3 |
+
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
|
| 4 |
+
cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
|
| 5 |
+
return_dict: true
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: false
|
| 8 |
+
device_map: auto
|
| 9 |
+
low_cpu_mem_usage: true
|
| 10 |
+
torch_dtype: bfloat16
|
| 11 |
+
attn_implementation: flash_attention_2
|
| 12 |
+
rope_theta: 10000.0
|
| 13 |
+
|
| 14 |
+
attention:
|
| 15 |
+
attention_type: softmax
|
configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Experimental config for chunked linear attention
|
| 2 |
+
name: llama
|
| 3 |
+
model:
|
| 4 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
|
| 5 |
+
cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights
|
| 6 |
+
return_dict: true
|
| 7 |
+
load_in_8bit: false
|
| 8 |
+
load_in_4bit: false
|
| 9 |
+
device_map: auto
|
| 10 |
+
low_cpu_mem_usage: true
|
| 11 |
+
torch_dtype: bfloat16
|
| 12 |
+
attn_implementation: flash_attention_2
|
| 13 |
+
rope_theta: 500000.0
|
| 14 |
+
rope_scaling:
|
| 15 |
+
factor: 8.0
|
| 16 |
+
low_freq_factor: 1.0
|
| 17 |
+
high_freq_factor: 4.0
|
| 18 |
+
original_max_position_embeddings: 8192
|
| 19 |
+
rope_type: llama3
|
| 20 |
+
|
| 21 |
+
attention:
|
| 22 |
+
attention_type: lolcats_long_llama_window_sw
|
| 23 |
+
state_chunk_len: 1024
|
| 24 |
+
window_size: 64
|
| 25 |
+
affine_attention_factors: false
|
| 26 |
+
init_window_factor: -2.1972245773362196
|
| 27 |
+
feature_map: softmax_dim
|
| 28 |
+
feature_map_kwargs:
|
| 29 |
+
eps: 1e-12
|
| 30 |
+
# mlp: null # to set
|
| 31 |
+
fullspace: true
|
| 32 |
+
layer_idx: null # to set
|
| 33 |
+
learned_kernel: untied_head_einsum
|
| 34 |
+
learned_kernel_kwargs:
|
| 35 |
+
feature_dim: 64
|
| 36 |
+
skip_connection: false
|
| 37 |
+
bias: false
|
| 38 |
+
zero_init: false
|
| 39 |
+
tie_qk_kernels: false
|
| 40 |
+
train_qk: false
|
configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Experimental config for chunked linear attention
|
| 2 |
+
name: llama
|
| 3 |
+
model:
|
| 4 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
|
| 5 |
+
cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights
|
| 6 |
+
return_dict: true
|
| 7 |
+
load_in_8bit: false
|
| 8 |
+
load_in_4bit: false
|
| 9 |
+
device_map: auto
|
| 10 |
+
low_cpu_mem_usage: true
|
| 11 |
+
torch_dtype: bfloat16
|
| 12 |
+
attn_implementation: flash_attention_2
|
| 13 |
+
rope_theta: 500000.0
|
| 14 |
+
rope_scaling:
|
| 15 |
+
factor: 8.0
|
| 16 |
+
low_freq_factor: 1.0
|
| 17 |
+
high_freq_factor: 4.0
|
| 18 |
+
original_max_position_embeddings: 8192
|
| 19 |
+
rope_type: llama3
|
| 20 |
+
|
| 21 |
+
attention:
|
| 22 |
+
attention_type: lolcats_long_llama_window_tk
|
| 23 |
+
state_chunk_len: 1024
|
| 24 |
+
window_size: 64
|
| 25 |
+
affine_attention_factors: false
|
| 26 |
+
init_window_factor: -2.1972245773362196
|
| 27 |
+
feature_map: softmax_dim
|
| 28 |
+
feature_map_kwargs:
|
| 29 |
+
eps: 1e-12
|
| 30 |
+
# mlp: null # to set
|
| 31 |
+
fullspace: true
|
| 32 |
+
layer_idx: null # to set
|
| 33 |
+
learned_kernel: untied_head_einsum
|
| 34 |
+
learned_kernel_kwargs:
|
| 35 |
+
feature_dim: 64
|
| 36 |
+
skip_connection: false
|
| 37 |
+
bias: false
|
| 38 |
+
zero_init: false
|
| 39 |
+
tie_qk_kernels: false
|
| 40 |
+
train_qk: false
|
configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wsw64_fd64_w01.yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Experimental config for chunked linear attention
|
| 2 |
+
name: llama
|
| 3 |
+
model:
|
| 4 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
|
| 5 |
+
cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights
|
| 6 |
+
return_dict: true
|
| 7 |
+
load_in_8bit: false
|
| 8 |
+
load_in_4bit: false
|
| 9 |
+
device_map: auto
|
| 10 |
+
low_cpu_mem_usage: true
|
| 11 |
+
torch_dtype: bfloat16
|
| 12 |
+
attn_implementation: flash_attention_2
|
| 13 |
+
rope_theta: 500000.0
|
| 14 |
+
|
| 15 |
+
attention:
|
| 16 |
+
attention_type: lolcats_long_llama_window_sw
|
| 17 |
+
state_chunk_len: 1024
|
| 18 |
+
window_size: 64
|
| 19 |
+
affine_attention_factors: false
|
| 20 |
+
init_window_factor: -2.1972245773362196
|
| 21 |
+
feature_map: softmax_dim
|
| 22 |
+
feature_map_kwargs:
|
| 23 |
+
eps: 1e-12
|
| 24 |
+
# mlp: null # to set
|
| 25 |
+
fullspace: true
|
| 26 |
+
layer_idx: null # to set
|
| 27 |
+
learned_kernel: untied_head_einsum
|
| 28 |
+
learned_kernel_kwargs:
|
| 29 |
+
feature_dim: 64
|
| 30 |
+
skip_connection: false
|
| 31 |
+
bias: false
|
| 32 |
+
zero_init: false
|
| 33 |
+
tie_qk_kernels: false
|
| 34 |
+
train_qk: false
|
configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01.yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Experimental config for chunked linear attention
|
| 2 |
+
name: llama
|
| 3 |
+
model:
|
| 4 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
|
| 5 |
+
cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights
|
| 6 |
+
return_dict: true
|
| 7 |
+
load_in_8bit: false
|
| 8 |
+
load_in_4bit: false
|
| 9 |
+
device_map: auto
|
| 10 |
+
low_cpu_mem_usage: true
|
| 11 |
+
torch_dtype: bfloat16
|
| 12 |
+
attn_implementation: flash_attention_2
|
| 13 |
+
rope_theta: 500000.0
|
| 14 |
+
|
| 15 |
+
attention:
|
| 16 |
+
attention_type: lolcats_long_llama_window_tk
|
| 17 |
+
state_chunk_len: 1024
|
| 18 |
+
window_size: 64
|
| 19 |
+
affine_attention_factors: false
|
| 20 |
+
init_window_factor: -2.1972245773362196
|
| 21 |
+
feature_map: softmax_dim
|
| 22 |
+
feature_map_kwargs:
|
| 23 |
+
eps: 1e-12
|
| 24 |
+
# mlp: null # to set
|
| 25 |
+
fullspace: true
|
| 26 |
+
layer_idx: null # to set
|
| 27 |
+
learned_kernel: untied_head_einsum
|
| 28 |
+
learned_kernel_kwargs:
|
| 29 |
+
feature_dim: 64
|
| 30 |
+
skip_connection: false
|
| 31 |
+
bias: false
|
| 32 |
+
zero_init: false
|
| 33 |
+
tie_qk_kernels: false
|
| 34 |
+
train_qk: false
|
configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wsw64_fd64_w01.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Experimental config for chunked linear attention
|
| 2 |
+
name: llama
|
| 3 |
+
model:
|
| 4 |
+
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
|
| 5 |
+
cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
|
| 6 |
+
return_dict: true
|
| 7 |
+
load_in_8bit: false
|
| 8 |
+
load_in_4bit: false
|
| 9 |
+
device_map: auto
|
| 10 |
+
low_cpu_mem_usage: true
|
| 11 |
+
torch_dtype: bfloat16
|
| 12 |
+
attn_implementation: flash_attention_2 # eager # so we can load attention weights
|
| 13 |
+
rope_theta: 10000.0
|
| 14 |
+
|
| 15 |
+
attention:
|
| 16 |
+
attention_type: lolcats_long_llama_window_sw
|
| 17 |
+
state_chunk_len: 512 # 1024
|
| 18 |
+
window_size: 64
|
| 19 |
+
affine_attention_factors: false
|
| 20 |
+
init_window_factor: -2.1972245773362196
|
| 21 |
+
train_window_factor: true
|
| 22 |
+
train_attention_weights: false
|
| 23 |
+
feature_map: softmax_dim
|
| 24 |
+
feature_map_kwargs:
|
| 25 |
+
eps: 1e-12
|
| 26 |
+
# mlp: null # to set
|
| 27 |
+
fullspace: true
|
| 28 |
+
layer_idx: null # to set
|
| 29 |
+
learned_kernel: untied_head_einsum
|
| 30 |
+
learned_kernel_kwargs:
|
| 31 |
+
feature_dim: 64
|
| 32 |
+
skip_connection: false
|
| 33 |
+
bias: false
|
| 34 |
+
zero_init: false
|
| 35 |
+
tie_qk_kernels: false
|
| 36 |
+
train_qk: false
|
configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wtk64_fd64_w01.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: llama
|
| 2 |
+
model:
|
| 3 |
+
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
|
| 4 |
+
cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
|
| 5 |
+
return_dict: true
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: false
|
| 8 |
+
device_map: auto
|
| 9 |
+
low_cpu_mem_usage: true
|
| 10 |
+
torch_dtype: bfloat16
|
| 11 |
+
attn_implementation: flash_attention_2 # eager # so we can load attention weights
|
| 12 |
+
rope_theta: 10000.0
|
| 13 |
+
|
| 14 |
+
attention:
|
| 15 |
+
attention_type: lolcats_long_llama_window_tk
|
| 16 |
+
state_chunk_len: 512 # 1024
|
| 17 |
+
window_size: 64
|
| 18 |
+
affine_attention_factors: false
|
| 19 |
+
init_window_factor: -2.1972245773362196
|
| 20 |
+
train_window_factor: true
|
| 21 |
+
train_attention_weights: false
|
| 22 |
+
feature_map: softmax_dim
|
| 23 |
+
feature_map_kwargs:
|
| 24 |
+
eps: 1e-12
|
| 25 |
+
# mlp: null # to set
|
| 26 |
+
fullspace: true
|
| 27 |
+
layer_idx: null # to set
|
| 28 |
+
learned_kernel: untied_head_einsum
|
| 29 |
+
learned_kernel_kwargs:
|
| 30 |
+
feature_dim: 64
|
| 31 |
+
skip_connection: false
|
| 32 |
+
bias: false
|
| 33 |
+
zero_init: false
|
| 34 |
+
tie_qk_kernels: false
|
| 35 |
+
train_qk: false
|
configs/model/distill_llama3_1_8b_lk_smd_fd64.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: llama
|
| 2 |
+
model:
|
| 3 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
|
| 4 |
+
cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights
|
| 5 |
+
return_dict: true
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: false
|
| 8 |
+
device_map: auto
|
| 9 |
+
low_cpu_mem_usage: true
|
| 10 |
+
torch_dtype: bfloat16
|
| 11 |
+
attn_implementation: eager
|
| 12 |
+
rope_theta: 500000.0
|
| 13 |
+
rope_scaling:
|
| 14 |
+
factor: 8.0
|
| 15 |
+
low_freq_factor: 1.0
|
| 16 |
+
high_freq_factor: 4.0
|
| 17 |
+
original_max_position_embeddings: 8192
|
| 18 |
+
rope_type: llama3
|
| 19 |
+
|
| 20 |
+
attention:
|
| 21 |
+
attention_type: lolcats_llama
|
| 22 |
+
feature_map: softmax_dim
|
| 23 |
+
feature_map_kwargs:
|
| 24 |
+
eps: 1e-12
|
| 25 |
+
# mlp: null # to set
|
| 26 |
+
fullspace: true
|
| 27 |
+
layer_idx: null # to set
|
| 28 |
+
learned_kernel: untied_head_einsum
|
| 29 |
+
learned_kernel_kwargs:
|
| 30 |
+
feature_dim: 64
|
| 31 |
+
skip_connection: false
|
| 32 |
+
bias: false
|
| 33 |
+
zero_init: false
|
| 34 |
+
tie_qk_kernels: false
|
| 35 |
+
train_qk: false
|
configs/model/distill_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: llama
|
| 2 |
+
model:
|
| 3 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
|
| 4 |
+
cache_dir: "/scratch/" # Set this to where you want to save checkpoint weights
|
| 5 |
+
return_dict: true
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: false
|
| 8 |
+
device_map: auto
|
| 9 |
+
low_cpu_mem_usage: true
|
| 10 |
+
torch_dtype: bfloat16
|
| 11 |
+
attn_implementation: eager
|
| 12 |
+
rope_theta: 500000.0
|
| 13 |
+
rope_scaling:
|
| 14 |
+
factor: 8.0
|
| 15 |
+
low_freq_factor: 1.0
|
| 16 |
+
high_freq_factor: 4.0
|
| 17 |
+
original_max_position_embeddings: 8192
|
| 18 |
+
rope_type: llama3
|
| 19 |
+
|
| 20 |
+
attention:
|
| 21 |
+
attention_type: lolcats_llama_window_sw
|
| 22 |
+
state_chunk_len: 1024
|
| 23 |
+
window_size: 64
|
| 24 |
+
affine_attention_factors: false
|
| 25 |
+
init_window_factor: -2.1972245773362196
|
| 26 |
+
feature_map: softmax_dim
|
| 27 |
+
feature_map_kwargs:
|
| 28 |
+
eps: 1e-12
|
| 29 |
+
# mlp: null # to set
|
| 30 |
+
fullspace: true
|
| 31 |
+
layer_idx: null # to set
|
| 32 |
+
learned_kernel: untied_head_einsum
|
| 33 |
+
learned_kernel_kwargs:
|
| 34 |
+
feature_dim: 64
|
| 35 |
+
skip_connection: false
|
| 36 |
+
bias: false
|
| 37 |
+
zero_init: false
|
| 38 |
+
tie_qk_kernels: false
|
| 39 |
+
train_qk: false
|
configs/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: llama
|
| 2 |
+
model:
|
| 3 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
|
| 4 |
+
cache_dir: "/scratch/" # Set this to where you want to save checkpoint weights
|
| 5 |
+
return_dict: true
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: false
|
| 8 |
+
device_map: auto
|
| 9 |
+
low_cpu_mem_usage: true
|
| 10 |
+
torch_dtype: bfloat16
|
| 11 |
+
attn_implementation: eager
|
| 12 |
+
rope_theta: 500000.0
|
| 13 |
+
rope_scaling:
|
| 14 |
+
factor: 8.0
|
| 15 |
+
low_freq_factor: 1.0
|
| 16 |
+
high_freq_factor: 4.0
|
| 17 |
+
original_max_position_embeddings: 8192
|
| 18 |
+
rope_type: llama3
|
| 19 |
+
|
| 20 |
+
attention:
|
| 21 |
+
attention_type: lolcats_llama_window_tk
|
| 22 |
+
state_chunk_len: 1024
|
| 23 |
+
window_size: 64
|
| 24 |
+
affine_attention_factors: false
|
| 25 |
+
init_window_factor: -2.1972245773362196
|
| 26 |
+
feature_map: softmax_dim
|
| 27 |
+
feature_map_kwargs:
|
| 28 |
+
eps: 1e-12
|
| 29 |
+
# mlp: null # to set
|
| 30 |
+
fullspace: true
|
| 31 |
+
layer_idx: null # to set
|
| 32 |
+
learned_kernel: untied_head_einsum
|
| 33 |
+
learned_kernel_kwargs:
|
| 34 |
+
feature_dim: 64
|
| 35 |
+
skip_connection: false
|
| 36 |
+
bias: false
|
| 37 |
+
zero_init: false
|
| 38 |
+
tie_qk_kernels: false
|
| 39 |
+
train_qk: false
|
configs/model/distill_llama3_1_8b_lk_t2r.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: llama
|
| 2 |
+
model:
|
| 3 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
|
| 4 |
+
cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights
|
| 5 |
+
return_dict: true
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: false
|
| 8 |
+
device_map: auto
|
| 9 |
+
low_cpu_mem_usage: true
|
| 10 |
+
torch_dtype: bfloat16
|
| 11 |
+
attn_implementation: eager
|
| 12 |
+
rope_theta: 500000.0
|
| 13 |
+
rope_scaling:
|
| 14 |
+
factor: 8.0
|
| 15 |
+
low_freq_factor: 1.0
|
| 16 |
+
high_freq_factor: 4.0
|
| 17 |
+
original_max_position_embeddings: 8192
|
| 18 |
+
rope_type: llama3
|
| 19 |
+
|
| 20 |
+
attention:
|
| 21 |
+
attention_type: lolcats_llama
|
| 22 |
+
feature_map: relu
|
| 23 |
+
feature_map_kwargs:
|
| 24 |
+
eps: 1e-12
|
| 25 |
+
# mlp: null # to set
|
| 26 |
+
fullspace: true
|
| 27 |
+
layer_idx: null # to set
|
| 28 |
+
learned_kernel: untied_head_einsum
|
| 29 |
+
learned_kernel_kwargs:
|
| 30 |
+
feature_dim: 128
|
| 31 |
+
skip_connection: false
|
| 32 |
+
bias: true
|
| 33 |
+
zero_init: false
|
| 34 |
+
tie_qk_kernels: false
|
| 35 |
+
train_qk: false
|
configs/model/distill_llama3_8b_lk_smd_fd64.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: llama
|
| 2 |
+
model:
|
| 3 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
|
| 4 |
+
cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights
|
| 5 |
+
return_dict: true
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: false
|
| 8 |
+
device_map: auto
|
| 9 |
+
low_cpu_mem_usage: true
|
| 10 |
+
torch_dtype: bfloat16
|
| 11 |
+
attn_implementation: flash_attention_2
|
| 12 |
+
rope_theta: 500000.0
|
| 13 |
+
|
| 14 |
+
attention:
|
| 15 |
+
attention_type: lolcats_llama
|
| 16 |
+
feature_map: softmax_dim
|
| 17 |
+
feature_map_kwargs:
|
| 18 |
+
eps: 1e-12
|
| 19 |
+
# mlp: null # to set
|
| 20 |
+
fullspace: true
|
| 21 |
+
layer_idx: null # to set
|
| 22 |
+
learned_kernel: untied_head_einsum
|
| 23 |
+
learned_kernel_kwargs:
|
| 24 |
+
feature_dim: 64
|
| 25 |
+
skip_connection: false
|
| 26 |
+
bias: false
|
| 27 |
+
zero_init: false
|
| 28 |
+
tie_qk_kernels: false
|
| 29 |
+
train_qk: false
|
configs/model/distill_llama3_8b_lk_smd_wsw64_fd64_w01.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: llama
|
| 2 |
+
model:
|
| 3 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
|
| 4 |
+
cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights
|
| 5 |
+
return_dict: true
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: false
|
| 8 |
+
device_map: auto
|
| 9 |
+
low_cpu_mem_usage: true
|
| 10 |
+
torch_dtype: bfloat16
|
| 11 |
+
attn_implementation: flash_attention_2
|
| 12 |
+
rope_theta: 500000.0
|
| 13 |
+
|
| 14 |
+
attention:
|
| 15 |
+
attention_type: lolcats_llama_window_sw
|
| 16 |
+
state_chunk_len: 1024
|
| 17 |
+
window_size: 64
|
| 18 |
+
affine_attention_factors: false
|
| 19 |
+
init_window_factor: -2.1972245773362196
|
| 20 |
+
feature_map: softmax_dim
|
| 21 |
+
feature_map_kwargs:
|
| 22 |
+
eps: 1e-12
|
| 23 |
+
# mlp: null # to set
|
| 24 |
+
fullspace: true
|
| 25 |
+
layer_idx: null # to set
|
| 26 |
+
learned_kernel: untied_head_einsum
|
| 27 |
+
learned_kernel_kwargs:
|
| 28 |
+
feature_dim: 64
|
| 29 |
+
skip_connection: false
|
| 30 |
+
bias: false
|
| 31 |
+
zero_init: false
|
| 32 |
+
tie_qk_kernels: false
|
| 33 |
+
train_qk: false
|
configs/model/distill_llama3_8b_lk_smd_wtk64_fd64_w01.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: llama
|
| 2 |
+
model:
|
| 3 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
|
| 4 |
+
cache_dir: '/scr-ssd/mzhang/models/llama3' # Set this to where you want to save checkpoint weights
|
| 5 |
+
return_dict: true
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: false
|
| 8 |
+
device_map: auto
|
| 9 |
+
low_cpu_mem_usage: true
|
| 10 |
+
torch_dtype: bfloat16
|
| 11 |
+
attn_implementation: flash_attention_2
|
| 12 |
+
rope_theta: 500000.0
|
| 13 |
+
|
| 14 |
+
attention:
|
| 15 |
+
attention_type: lolcats_llama_window_tk
|
| 16 |
+
state_chunk_len: 1024
|
| 17 |
+
window_size: 64
|
| 18 |
+
affine_attention_factors: false
|
| 19 |
+
init_window_factor: -2.1972245773362196
|
| 20 |
+
feature_map: softmax_dim
|
| 21 |
+
feature_map_kwargs:
|
| 22 |
+
eps: 1e-12
|
| 23 |
+
# mlp: null # to set
|
| 24 |
+
fullspace: true
|
| 25 |
+
layer_idx: null # to set
|
| 26 |
+
learned_kernel: untied_head_einsum
|
| 27 |
+
learned_kernel_kwargs:
|
| 28 |
+
feature_dim: 64
|
| 29 |
+
skip_connection: false
|
| 30 |
+
bias: false
|
| 31 |
+
zero_init: false
|
| 32 |
+
tie_qk_kernels: false
|
| 33 |
+
train_qk: false
|
configs/model/distill_llama3_8b_lk_t2r.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: llama
|
| 2 |
+
model:
|
| 3 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
|
| 4 |
+
cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights
|
| 5 |
+
return_dict: true
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: false
|
| 8 |
+
device_map: auto
|
| 9 |
+
low_cpu_mem_usage: true
|
| 10 |
+
torch_dtype: bfloat16
|
| 11 |
+
attn_implementation: flash_attention_2
|
| 12 |
+
rope_theta: 500000.0
|
| 13 |
+
|
| 14 |
+
attention:
|
| 15 |
+
attention_type: lolcats_llama
|
| 16 |
+
feature_map: relu
|
| 17 |
+
feature_map_kwargs:
|
| 18 |
+
eps: 1e-12
|
| 19 |
+
# mlp: null # to set
|
| 20 |
+
fullspace: true
|
| 21 |
+
layer_idx: null # to set
|
| 22 |
+
learned_kernel: untied_head_einsum
|
| 23 |
+
learned_kernel_kwargs:
|
| 24 |
+
feature_dim: 128
|
| 25 |
+
skip_connection: false
|
| 26 |
+
bias: true
|
| 27 |
+
zero_init: false
|
| 28 |
+
tie_qk_kernels: false
|
| 29 |
+
train_qk: false
|
configs/model/distill_mistral_7b_lk_smd_fd64.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: llama
|
| 2 |
+
model:
|
| 3 |
+
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
|
| 4 |
+
cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
|
| 5 |
+
return_dict: true
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: false
|
| 8 |
+
device_map: auto
|
| 9 |
+
low_cpu_mem_usage: true
|
| 10 |
+
torch_dtype: bfloat16
|
| 11 |
+
attn_implementation: flash_attention_2 # eager # so we can load attention weights
|
| 12 |
+
rope_theta: 10000.0
|
| 13 |
+
|
| 14 |
+
attention:
|
| 15 |
+
attention_type: lolcats_llama
|
| 16 |
+
feature_map: softmax_dim
|
| 17 |
+
feature_map_kwargs:
|
| 18 |
+
eps: 1e-12
|
| 19 |
+
# mlp: null # to set
|
| 20 |
+
fullspace: true
|
| 21 |
+
layer_idx: null # to set
|
| 22 |
+
learned_kernel: untied_head_einsum
|
| 23 |
+
learned_kernel_kwargs:
|
| 24 |
+
feature_dim: 64
|
| 25 |
+
skip_connection: false
|
| 26 |
+
bias: false
|
| 27 |
+
zero_init: false
|
| 28 |
+
tie_qk_kernels: false
|
| 29 |
+
train_qk: false
|
configs/model/distill_mistral_7b_lk_smd_wsw64_fd64_w01.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: llama
|
| 2 |
+
model:
|
| 3 |
+
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
|
| 4 |
+
cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
|
| 5 |
+
return_dict: true
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: false
|
| 8 |
+
device_map: auto
|
| 9 |
+
low_cpu_mem_usage: true
|
| 10 |
+
torch_dtype: bfloat16
|
| 11 |
+
attn_implementation: flash_attention_2 # eager # so we can load attention weights
|
| 12 |
+
rope_theta: 10000.0
|
| 13 |
+
|
| 14 |
+
attention:
|
| 15 |
+
attention_type: lolcats_llama_window_sw
|
| 16 |
+
state_chunk_len: 512 # 1024
|
| 17 |
+
window_size: 64
|
| 18 |
+
affine_attention_factors: false
|
| 19 |
+
init_window_factor: -2.1972245773362196
|
| 20 |
+
train_window_factor: true
|
| 21 |
+
train_attention_weights: false
|
| 22 |
+
feature_map: softmax_dim
|
| 23 |
+
feature_map_kwargs:
|
| 24 |
+
eps: 1e-12
|
| 25 |
+
# mlp: null # to set
|
| 26 |
+
fullspace: true
|
| 27 |
+
layer_idx: null # to set
|
| 28 |
+
learned_kernel: untied_head_einsum
|
| 29 |
+
learned_kernel_kwargs:
|
| 30 |
+
feature_dim: 64
|
| 31 |
+
skip_connection: false
|
| 32 |
+
bias: false
|
| 33 |
+
zero_init: false
|
| 34 |
+
tie_qk_kernels: false
|
| 35 |
+
train_qk: false
|
configs/model/distill_mistral_7b_lk_smd_wtk64_fd64_w01.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: llama
|
| 2 |
+
model:
|
| 3 |
+
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
|
| 4 |
+
cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
|
| 5 |
+
return_dict: true
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: false
|
| 8 |
+
device_map: auto
|
| 9 |
+
low_cpu_mem_usage: true
|
| 10 |
+
torch_dtype: bfloat16
|
| 11 |
+
attn_implementation: flash_attention_2 # eager # so we can load attention weights
|
| 12 |
+
rope_theta: 10000.0
|
| 13 |
+
|
| 14 |
+
attention:
|
| 15 |
+
attention_type: lolcats_llama_window_tk
|
| 16 |
+
state_chunk_len: 512 # 1024
|
| 17 |
+
window_size: 64
|
| 18 |
+
affine_attention_factors: false
|
| 19 |
+
init_window_factor: -2.1972245773362196
|
| 20 |
+
train_window_factor: true
|
| 21 |
+
train_attention_weights: false
|
| 22 |
+
feature_map: softmax_dim
|
| 23 |
+
feature_map_kwargs:
|
| 24 |
+
eps: 1e-12
|
| 25 |
+
# mlp: null # to set
|
| 26 |
+
fullspace: true
|
| 27 |
+
layer_idx: null # to set
|
| 28 |
+
learned_kernel: untied_head_einsum
|
| 29 |
+
learned_kernel_kwargs:
|
| 30 |
+
feature_dim: 64
|
| 31 |
+
skip_connection: false
|
| 32 |
+
bias: false
|
| 33 |
+
zero_init: false
|
| 34 |
+
tie_qk_kernels: false
|
| 35 |
+
train_qk: false
|
configs/model/distill_mistral_7b_lk_t2r.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: llama
|
| 2 |
+
model:
|
| 3 |
+
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
|
| 4 |
+
cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
|
| 5 |
+
return_dict: true
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: false
|
| 8 |
+
device_map: auto
|
| 9 |
+
low_cpu_mem_usage: true
|
| 10 |
+
torch_dtype: bfloat16
|
| 11 |
+
attn_implementation: flash_attention_2 # eager # so we can load attention weights
|
| 12 |
+
rope_theta: 10000.0
|
| 13 |
+
|
| 14 |
+
attention:
|
| 15 |
+
attention_type: lolcats_llama
|
| 16 |
+
feature_map: relu
|
| 17 |
+
feature_map_kwargs:
|
| 18 |
+
eps: 1e-12
|
| 19 |
+
# mlp: null # to set
|
| 20 |
+
fullspace: true
|
| 21 |
+
layer_idx: null # to set
|
| 22 |
+
learned_kernel: untied_head_einsum
|
| 23 |
+
learned_kernel_kwargs:
|
| 24 |
+
feature_dim: 128
|
| 25 |
+
skip_connection: false
|
| 26 |
+
bias: true
|
| 27 |
+
zero_init: false
|
| 28 |
+
tie_qk_kernels: false
|
| 29 |
+
train_qk: false
|
csrc/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
| 4 |
+
# Apoorv Vyas <[email protected]>
|
| 5 |
+
#
|
| 6 |
+
from .causal_attention import causal_dot_product
|
csrc/causal_attention.cpp
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//
|
| 2 |
+
// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
// Written by Angelos Katharopoulos <[email protected]>,
|
| 4 |
+
// Apoorv Vyas <[email protected]>
|
| 5 |
+
//
|
| 6 |
+
|
| 7 |
+
#include <torch/extension.h>
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
/**
|
| 11 |
+
* Compute a*b^T and save it into out.
|
| 12 |
+
*
|
| 13 |
+
* a \in R^A
|
| 14 |
+
* b \in R^B
|
| 15 |
+
*/
|
| 16 |
+
inline void vvt_dot(float *a, float *b, float *out, int A, int B) {
|
| 17 |
+
for (int i=0; i<A; i++) {
|
| 18 |
+
float * bi = b;
|
| 19 |
+
for (int j=0; j<B; j++) {
|
| 20 |
+
*out += (*a) * (*bi);
|
| 21 |
+
out++;
|
| 22 |
+
bi++;
|
| 23 |
+
}
|
| 24 |
+
a++;
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
/**
|
| 30 |
+
* Implement a vector matrix product v*m and save it into out.
|
| 31 |
+
*
|
| 32 |
+
* v \in R^A
|
| 33 |
+
* m \in R^{AxB}
|
| 34 |
+
*/
|
| 35 |
+
inline void vm_dot(float *v, float *m, float *out, int A, int B) {
|
| 36 |
+
// TODO: Consider removing the zeroing part and assuming out already
|
| 37 |
+
// contains 0s
|
| 38 |
+
for (int i=0; i<B; i++) {
|
| 39 |
+
out[i] = 0;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
for (int i=0; i<A; i++) {
|
| 43 |
+
float *oi = out;
|
| 44 |
+
for (int j=0; j<B; j++) {
|
| 45 |
+
*oi += (*v) * (*m);
|
| 46 |
+
oi++;
|
| 47 |
+
m++;
|
| 48 |
+
}
|
| 49 |
+
v++;
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
/**
|
| 55 |
+
* Implement a vector transposed-matrix product and save it into out.
|
| 56 |
+
*
|
| 57 |
+
* v \in R^B
|
| 58 |
+
* m \in R^{AxB}
|
| 59 |
+
*/
|
| 60 |
+
inline void vmt_dot(float *v, float *m, float *out, int A, int B) {
|
| 61 |
+
for (int i=0; i<A; i++) {
|
| 62 |
+
float *vi = v;
|
| 63 |
+
float s = 0;
|
| 64 |
+
for (int j=0; j<B; j++) {
|
| 65 |
+
s += (*vi) * (*m);
|
| 66 |
+
vi++;
|
| 67 |
+
m++;
|
| 68 |
+
}
|
| 69 |
+
// TODO: Should we be aggregating? See the comment on vm_dot.
|
| 70 |
+
*out = s;
|
| 71 |
+
out++;
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
/**
|
| 77 |
+
* Compute the causally masked dot products of queries, keys and values.
|
| 78 |
+
*
|
| 79 |
+
* Basically compute V_j' = (Q_{0:j} * K_{0:j}^T) * V_{0:j} for all j. The
|
| 80 |
+
* computation is done efficiently by changing the order of the dot products.
|
| 81 |
+
*/
|
| 82 |
+
void causal_dot_product(
|
| 83 |
+
const torch::Tensor queries,
|
| 84 |
+
const torch::Tensor keys,
|
| 85 |
+
const torch::Tensor values,
|
| 86 |
+
torch::Tensor product
|
| 87 |
+
) {
|
| 88 |
+
// Extract some shapes
|
| 89 |
+
int N = queries.size(0);
|
| 90 |
+
int H = queries.size(1);
|
| 91 |
+
int L = queries.size(2);
|
| 92 |
+
int E = queries.size(3);
|
| 93 |
+
int M = values.size(3);
|
| 94 |
+
|
| 95 |
+
// Create accessors for all the arguments
|
| 96 |
+
auto qa = queries.accessor<float, 4>();
|
| 97 |
+
auto ka = keys.accessor<float, 4>();
|
| 98 |
+
auto va = values.accessor<float, 4>();
|
| 99 |
+
auto pa = product.accessor<float, 4>();
|
| 100 |
+
|
| 101 |
+
#pragma omp parallel for collapse(2)
|
| 102 |
+
for (int n=0; n<N; n++) {
|
| 103 |
+
for (int h=0; h<H; h++) {
|
| 104 |
+
auto kv = torch::zeros({E, M}, queries.options());
|
| 105 |
+
float *kvp = kv.data_ptr<float>();
|
| 106 |
+
for (int l=0; l<L; l++) {
|
| 107 |
+
vvt_dot(
|
| 108 |
+
&ka[n][h][l][0],
|
| 109 |
+
&va[n][h][l][0],
|
| 110 |
+
kvp,
|
| 111 |
+
E,
|
| 112 |
+
M
|
| 113 |
+
);
|
| 114 |
+
vm_dot(
|
| 115 |
+
&qa[n][h][l][0],
|
| 116 |
+
kvp,
|
| 117 |
+
&pa[n][h][l][0],
|
| 118 |
+
E,
|
| 119 |
+
M
|
| 120 |
+
);
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
/**
|
| 128 |
+
* Compute the gradients of queries, keys and values given the gradient of the
|
| 129 |
+
* causal_dot_product output.
|
| 130 |
+
*
|
| 131 |
+
* Make sure that everything is computed in O(N D^2) complexity.
|
| 132 |
+
*/
|
| 133 |
+
void causal_dot_backward(
|
| 134 |
+
const torch::Tensor queries,
|
| 135 |
+
const torch::Tensor keys,
|
| 136 |
+
const torch::Tensor values,
|
| 137 |
+
const torch::Tensor grad_out,
|
| 138 |
+
torch::Tensor grad_queries,
|
| 139 |
+
torch::Tensor grad_keys,
|
| 140 |
+
torch::Tensor grad_values
|
| 141 |
+
) {
|
| 142 |
+
// Extract some shapes
|
| 143 |
+
int N = queries.size(0);
|
| 144 |
+
int H = queries.size(1);
|
| 145 |
+
int L = queries.size(2);
|
| 146 |
+
int E = queries.size(3);
|
| 147 |
+
int M = values.size(3);
|
| 148 |
+
|
| 149 |
+
// Create accessors for all the arguments
|
| 150 |
+
auto qa = queries.accessor<float, 4>();
|
| 151 |
+
auto ka = keys.accessor<float, 4>();
|
| 152 |
+
auto va = values.accessor<float, 4>();
|
| 153 |
+
auto ga = grad_out.accessor<float, 4>();
|
| 154 |
+
auto gqa = grad_queries.accessor<float, 4>();
|
| 155 |
+
auto gka = grad_keys.accessor<float, 4>();
|
| 156 |
+
auto gva = grad_values.accessor<float, 4>();
|
| 157 |
+
|
| 158 |
+
#pragma omp parallel for collapse(2)
|
| 159 |
+
for (int n=0; n<N; n++) {
|
| 160 |
+
for (int h=0; h<H; h++) {
|
| 161 |
+
auto kv = torch::zeros({E, M}, queries.options());
|
| 162 |
+
float *kvp = kv.data_ptr<float>();
|
| 163 |
+
|
| 164 |
+
// Compute the gradient wrt the queries
|
| 165 |
+
for (int l=0; l<L; l++) {
|
| 166 |
+
vvt_dot(
|
| 167 |
+
&ka[n][h][l][0],
|
| 168 |
+
&va[n][h][l][0],
|
| 169 |
+
kvp,
|
| 170 |
+
E,
|
| 171 |
+
M
|
| 172 |
+
);
|
| 173 |
+
vmt_dot(
|
| 174 |
+
&ga[n][h][l][0],
|
| 175 |
+
kvp,
|
| 176 |
+
&gqa[n][h][l][0],
|
| 177 |
+
E,
|
| 178 |
+
M
|
| 179 |
+
);
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
// Compute the gradient wrt the keys and values
|
| 183 |
+
kv.zero_();
|
| 184 |
+
for (int l=L-1; l>=0; l--) {
|
| 185 |
+
vvt_dot(
|
| 186 |
+
&qa[n][h][l][0],
|
| 187 |
+
&ga[n][h][l][0],
|
| 188 |
+
kvp,
|
| 189 |
+
E,
|
| 190 |
+
M
|
| 191 |
+
);
|
| 192 |
+
vmt_dot(
|
| 193 |
+
&va[n][h][l][0],
|
| 194 |
+
kvp,
|
| 195 |
+
&gka[n][h][l][0],
|
| 196 |
+
E,
|
| 197 |
+
M
|
| 198 |
+
);
|
| 199 |
+
vm_dot(
|
| 200 |
+
&ka[n][h][l][0],
|
| 201 |
+
kvp,
|
| 202 |
+
&gva[n][h][l][0],
|
| 203 |
+
E,
|
| 204 |
+
M
|
| 205 |
+
);
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 213 |
+
m.def(
|
| 214 |
+
"causal_dot_product",
|
| 215 |
+
&causal_dot_product,
|
| 216 |
+
"Compute the weighted sum of values but attending only to previous "
|
| 217 |
+
"values."
|
| 218 |
+
);
|
| 219 |
+
m.def(
|
| 220 |
+
"causal_dot_backward",
|
| 221 |
+
&causal_dot_backward,
|
| 222 |
+
"Compute the gradient of queries, keys and values given the gradient "
|
| 223 |
+
"of causal_dot_product."
|
| 224 |
+
);
|
| 225 |
+
}
|
csrc/causal_attention.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
| 4 |
+
# Apoorv Vyas <[email protected]>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from causal_attention_cuda import causal_dot_product as causal_dot_product_cuda
|
| 11 |
+
from causal_attention_cuda import causal_dot_backward as causal_dot_backward_cuda
|
| 12 |
+
except ImportError as e:
|
| 13 |
+
print(e)
|
| 14 |
+
causal_dot_product_cuda = causal_dot_backward_cuda = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CausalDotProduct(torch.autograd.Function):
|
| 18 |
+
"""Compute the weighted sum of values but attending only to previous
|
| 19 |
+
values."""
|
| 20 |
+
dot = {
|
| 21 |
+
# "cpu": causal_dot_product_cpu,
|
| 22 |
+
"cuda": causal_dot_product_cuda
|
| 23 |
+
}
|
| 24 |
+
dot_backward = {
|
| 25 |
+
# "cpu": causal_dot_backward_cpu,
|
| 26 |
+
"cuda": causal_dot_backward_cuda
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
@staticmethod
|
| 30 |
+
def forward(ctx, Q, K, V):
|
| 31 |
+
# Save the inputs for the gradient computation
|
| 32 |
+
ctx.save_for_backward(Q, K, V)
|
| 33 |
+
|
| 34 |
+
# Create the output tensor
|
| 35 |
+
device = Q.device
|
| 36 |
+
N, H, L, _ = Q.shape
|
| 37 |
+
_, _, _, M = V.shape
|
| 38 |
+
product = torch.zeros((N, H, L, M), dtype=Q.dtype, device=device)
|
| 39 |
+
|
| 40 |
+
# Actually perform the dot product
|
| 41 |
+
CausalDotProduct.dot[device.type](
|
| 42 |
+
Q.data,
|
| 43 |
+
K.data,
|
| 44 |
+
V.data,
|
| 45 |
+
product
|
| 46 |
+
)
|
| 47 |
+
# breakpoint()
|
| 48 |
+
# CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
|
| 49 |
+
|
| 50 |
+
return product
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def backward(ctx, grad_out):
|
| 54 |
+
# Extract the saved tensors
|
| 55 |
+
Q, K, V = ctx.saved_tensors
|
| 56 |
+
|
| 57 |
+
# Allocate memory for the gradients
|
| 58 |
+
grad_Q = torch.zeros_like(Q)
|
| 59 |
+
grad_K = torch.zeros_like(K)
|
| 60 |
+
grad_V = torch.zeros_like(V)
|
| 61 |
+
|
| 62 |
+
# Actually compute the gradients
|
| 63 |
+
CausalDotProduct.dot_backward[Q.device.type](
|
| 64 |
+
Q.data,
|
| 65 |
+
K.data,
|
| 66 |
+
V.data,
|
| 67 |
+
grad_out,
|
| 68 |
+
grad_Q,
|
| 69 |
+
grad_K,
|
| 70 |
+
grad_V
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
return grad_Q, grad_K, grad_V
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Alias the autograd functions to python style snake case naming
|
| 77 |
+
causal_dot_product = CausalDotProduct.apply
|
csrc/causal_attention_cuda.cu
ADDED
|
@@ -0,0 +1,1483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//
|
| 2 |
+
// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
// Written by Angelos Katharopoulos <[email protected]>,
|
| 4 |
+
// Apoorv Vyas <[email protected]>
|
| 5 |
+
//
|
| 6 |
+
|
| 7 |
+
//
|
| 8 |
+
// For modifications made inside namespace nvidia (authored by jdemouth):
|
| 9 |
+
//
|
| 10 |
+
// Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved.
|
| 11 |
+
//
|
| 12 |
+
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
| 13 |
+
// this software and associated documentation files (the "Software"), to deal in
|
| 14 |
+
// the Software without restriction, including without limitation the rights to
|
| 15 |
+
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
| 16 |
+
// the Software, and to permit persons to whom the Software is furnished to do so,
|
| 17 |
+
// subject to the following conditions:
|
| 18 |
+
//
|
| 19 |
+
// The above copyright notice and this permission notice shall be included in all
|
| 20 |
+
// copies or substantial portions of the Software.
|
| 21 |
+
//
|
| 22 |
+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 23 |
+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
| 24 |
+
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
| 25 |
+
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
| 26 |
+
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
| 27 |
+
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 28 |
+
//
|
| 29 |
+
|
| 30 |
+
#include <torch/extension.h>
|
| 31 |
+
#include <assert.h>
|
| 32 |
+
#include <stdio.h>
|
| 33 |
+
|
| 34 |
+
#define ENABLE_NVIDIA_OPTIMIZATIONS
|
| 35 |
+
|
| 36 |
+
#ifdef ENABLE_NVIDIA_OPTIMIZATIONS
|
| 37 |
+
namespace nvidia {
|
| 38 |
+
|
| 39 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 40 |
+
|
| 41 |
+
constexpr int THREADS_PER_WARP = 32;
|
| 42 |
+
|
| 43 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
constexpr int LOW_OCCUPANCY_THRESHOLD = 40; // TODO: Make it HW specific (like 1/2 SMs).
|
| 46 |
+
|
| 47 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
static inline __device__ __host__ int div_up(int m, int n) {
|
| 50 |
+
return (m + n-1) / n;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 54 |
+
|
| 55 |
+
static inline __device__ __host__ int round_up(int m, int n) {
|
| 56 |
+
return div_up(m, n) * n;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 60 |
+
|
| 61 |
+
template< typename T >
|
| 62 |
+
struct Lmha_params {
|
| 63 |
+
|
| 64 |
+
// The output buffer. Dimensions [B, H, L, M].
|
| 65 |
+
T *out;
|
| 66 |
+
|
| 67 |
+
// The input Qs. Dimensions [B, H, L, E].
|
| 68 |
+
const T *q;
|
| 69 |
+
// The input Ks. Dimensions [B, H, L, E].
|
| 70 |
+
const T *k;
|
| 71 |
+
// The input Vs. Dimensions [B, H, L, M].
|
| 72 |
+
const T *v;
|
| 73 |
+
|
| 74 |
+
// The different dimensions.
|
| 75 |
+
int B, L, H, E, M;
|
| 76 |
+
|
| 77 |
+
// The strides for the different tensors.
|
| 78 |
+
int q_stride_B, q_stride_H, q_stride_L;
|
| 79 |
+
int k_stride_B, k_stride_H, k_stride_L;
|
| 80 |
+
int v_stride_B, v_stride_H, v_stride_L;
|
| 81 |
+
int o_stride_B, o_stride_H, o_stride_L;
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 85 |
+
|
| 86 |
+
template< int E, bool GO_BACKWARD, int WARPS, int COLS_PER_THREAD = 4 >
|
| 87 |
+
__global__ __launch_bounds__(WARPS * THREADS_PER_WARP)
|
| 88 |
+
void lmha_low_occupancy_kernel(Lmha_params<float> params) {
|
| 89 |
+
|
| 90 |
+
// The number of threads per block.
|
| 91 |
+
constexpr int THREADS_PER_BLOCK = WARPS * THREADS_PER_WARP;
|
| 92 |
+
// The number of rows per thread.
|
| 93 |
+
constexpr int ROWS_PER_THREAD = E / THREADS_PER_WARP;
|
| 94 |
+
// The number of steps per iteration.
|
| 95 |
+
constexpr int COLS_PER_ITER = WARPS * COLS_PER_THREAD;
|
| 96 |
+
|
| 97 |
+
// Make sure E is a multiple of the warp size.
|
| 98 |
+
static_assert(E % THREADS_PER_WARP == 0, "");
|
| 99 |
+
|
| 100 |
+
// Shared memory to store V/O.
|
| 101 |
+
__shared__ float smem_v[COLS_PER_ITER], smem_o[COLS_PER_ITER];
|
| 102 |
+
// Shared memory buffer to performance the reductions.
|
| 103 |
+
__shared__ float smem_reds[E * WARPS];
|
| 104 |
+
|
| 105 |
+
// The sequence processed by that block.
|
| 106 |
+
const int bi = blockIdx.z;
|
| 107 |
+
// The head processed by that block.
|
| 108 |
+
const int hi = blockIdx.y;
|
| 109 |
+
// The hidden cell in the V/output buffers.
|
| 110 |
+
const int vi = blockIdx.x;
|
| 111 |
+
|
| 112 |
+
// The linear index of the thread.
|
| 113 |
+
const int tidx = threadIdx.x;
|
| 114 |
+
|
| 115 |
+
// Decompose the block in warp/lane.
|
| 116 |
+
const int warp = tidx / THREADS_PER_WARP;
|
| 117 |
+
const int lane = tidx % THREADS_PER_WARP;
|
| 118 |
+
|
| 119 |
+
// The base offset loaded by the thread in Q and K.
|
| 120 |
+
int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + lane;
|
| 121 |
+
int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + lane;
|
| 122 |
+
|
| 123 |
+
// If we walk backward, account for the extra offset.
|
| 124 |
+
if( GO_BACKWARD ) {
|
| 125 |
+
offset_q += (params.L-1)*params.q_stride_L;
|
| 126 |
+
offset_k += (params.L-1)*params.k_stride_L;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
// Position the warp at the beginning of the proper timestep.
|
| 130 |
+
if( GO_BACKWARD ) {
|
| 131 |
+
offset_q -= warp*COLS_PER_THREAD*params.q_stride_L;
|
| 132 |
+
offset_k -= warp*COLS_PER_THREAD*params.k_stride_L;
|
| 133 |
+
} else {
|
| 134 |
+
offset_q += warp*COLS_PER_THREAD*params.q_stride_L;
|
| 135 |
+
offset_k += warp*COLS_PER_THREAD*params.k_stride_L;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
// Determine the base pointers for Q and K.
|
| 139 |
+
const float *ptr_q = ¶ms.q[offset_q];
|
| 140 |
+
const float *ptr_k = ¶ms.k[offset_k];
|
| 141 |
+
|
| 142 |
+
// Is a given row valid?
|
| 143 |
+
int valid_qk[ROWS_PER_THREAD];
|
| 144 |
+
#pragma unroll
|
| 145 |
+
for( int ii = 0; ii < ROWS_PER_THREAD; ++ii ) {
|
| 146 |
+
valid_qk[ii] = lane + ii*THREADS_PER_WARP < params.E;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
// The offset to the position loaded by the thread in V.
|
| 150 |
+
int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + vi;
|
| 151 |
+
int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + vi;
|
| 152 |
+
|
| 153 |
+
// If we walk backward, account for the extra offset.
|
| 154 |
+
if( GO_BACKWARD ) {
|
| 155 |
+
offset_v += (params.L-1)*params.v_stride_L;
|
| 156 |
+
offset_o += (params.L-1)*params.o_stride_L;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
// We load/store a strided matrix of COLS_PER_ITER x OUTPUTS_PER_BLOCK.
|
| 160 |
+
if( GO_BACKWARD ) {
|
| 161 |
+
offset_v -= tidx*params.v_stride_L;
|
| 162 |
+
offset_o -= tidx*params.o_stride_L;
|
| 163 |
+
} else {
|
| 164 |
+
offset_v += tidx*params.v_stride_L;
|
| 165 |
+
offset_o += tidx*params.o_stride_L;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
// Determine the base pointer for V.
|
| 169 |
+
const float *ptr_v = ¶ms.v[offset_v];
|
| 170 |
+
// The output pointer.
|
| 171 |
+
float *ptr_o = ¶ms.out[offset_o];
|
| 172 |
+
|
| 173 |
+
// The running KVs.
|
| 174 |
+
float running_kv[ROWS_PER_THREAD];
|
| 175 |
+
#pragma unroll
|
| 176 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
| 177 |
+
running_kv[ri] = 0.f;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
// Iterate over the timesteps. TODO: Use params.loop_count!!!
|
| 181 |
+
for( int iter = 0; iter < params.L; iter += COLS_PER_ITER ) {
|
| 182 |
+
|
| 183 |
+
// Each thread loads a matrix of elements.
|
| 184 |
+
float q[ROWS_PER_THREAD][COLS_PER_THREAD], k[ROWS_PER_THREAD][COLS_PER_THREAD];
|
| 185 |
+
|
| 186 |
+
// Trigger the memory loads for Q and K.
|
| 187 |
+
#pragma unroll
|
| 188 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
| 189 |
+
#pragma unroll
|
| 190 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
| 191 |
+
|
| 192 |
+
// For Q/K, each warp loads from various timesteps.
|
| 193 |
+
int ti = iter + warp*COLS_PER_THREAD;
|
| 194 |
+
if( GO_BACKWARD ) {
|
| 195 |
+
ti = params.L - 1 - ti;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
// Is it a valid access?
|
| 199 |
+
int valid;
|
| 200 |
+
if( GO_BACKWARD ) {
|
| 201 |
+
valid = valid_qk[ri] && ti - ci >= 0;
|
| 202 |
+
} else {
|
| 203 |
+
valid = valid_qk[ri] && ti + ci < params.L;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
// The extra offset to add.
|
| 207 |
+
if( GO_BACKWARD ) {
|
| 208 |
+
offset_q = ri*THREADS_PER_WARP - ci*params.q_stride_L;
|
| 209 |
+
offset_k = ri*THREADS_PER_WARP - ci*params.k_stride_L;
|
| 210 |
+
} else {
|
| 211 |
+
offset_q = ri*THREADS_PER_WARP + ci*params.q_stride_L;
|
| 212 |
+
offset_k = ri*THREADS_PER_WARP + ci*params.k_stride_L;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
// Load Q/K if they are valid.
|
| 216 |
+
q[ri][ci] = valid ? ptr_q[offset_q] : 0.f;
|
| 217 |
+
k[ri][ci] = valid ? ptr_k[offset_k] : 0.f;
|
| 218 |
+
}
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
// For the V tensor, we assign contiguous thread to different loads. So, ti is different.
|
| 222 |
+
int ti = iter + tidx;
|
| 223 |
+
if( GO_BACKWARD ) {
|
| 224 |
+
ti = params.L - 1 - ti;
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
// Is it a valid access?
|
| 228 |
+
int valid_vo = tidx < COLS_PER_ITER;
|
| 229 |
+
if( GO_BACKWARD ) {
|
| 230 |
+
valid_vo &= ti >= 0;
|
| 231 |
+
} else {
|
| 232 |
+
valid_vo &= ti < params.L;
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
// Trigger the loads for V.
|
| 236 |
+
float ldg_v = valid_vo ? *ptr_v : 0.f;
|
| 237 |
+
|
| 238 |
+
// Move the load pointers.
|
| 239 |
+
if( GO_BACKWARD ) {
|
| 240 |
+
ptr_q -= COLS_PER_ITER*params.q_stride_L;
|
| 241 |
+
ptr_k -= COLS_PER_ITER*params.k_stride_L;
|
| 242 |
+
ptr_v -= COLS_PER_ITER*params.v_stride_L;
|
| 243 |
+
} else {
|
| 244 |
+
ptr_q += COLS_PER_ITER*params.q_stride_L;
|
| 245 |
+
ptr_k += COLS_PER_ITER*params.k_stride_L;
|
| 246 |
+
ptr_v += COLS_PER_ITER*params.v_stride_L;
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
// Store to shared memory.
|
| 250 |
+
if( tidx < COLS_PER_ITER ) {
|
| 251 |
+
smem_v[tidx] = ldg_v;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
// Make sure V is in shared memory.
|
| 255 |
+
__syncthreads();
|
| 256 |
+
|
| 257 |
+
// Read V from shared memory.
|
| 258 |
+
float v[COLS_PER_THREAD];
|
| 259 |
+
#pragma unroll
|
| 260 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
| 261 |
+
v[ci] = smem_v[warp*COLS_PER_THREAD + ci];
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
// Each thread computes local K*V products.
|
| 265 |
+
float kv[ROWS_PER_THREAD][COLS_PER_THREAD];
|
| 266 |
+
#pragma unroll
|
| 267 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
| 268 |
+
#pragma unroll
|
| 269 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
| 270 |
+
kv[ri][ci] = 0.f;
|
| 271 |
+
}
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
// Update the K*V^T product.
|
| 275 |
+
#pragma unroll
|
| 276 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
| 277 |
+
#pragma unroll
|
| 278 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
| 279 |
+
kv[ri][ci] += k[ri][ci] * v[ci];
|
| 280 |
+
}
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
// We must perform the prefix sums within the thread-block. Start with the thread.
|
| 284 |
+
#pragma unroll
|
| 285 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
| 286 |
+
#pragma unroll
|
| 287 |
+
for( int ci = 1; ci < COLS_PER_THREAD; ++ci ) {
|
| 288 |
+
kv[ri][ci] += kv[ri][ci-1];
|
| 289 |
+
}
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
// Store the partial sums to shared memory. Unless we have no inter-warp reduction to perform.
|
| 293 |
+
#pragma unroll
|
| 294 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
| 295 |
+
smem_reds[warp*E + ri*THREADS_PER_WARP + lane] = kv[ri][COLS_PER_THREAD-1];
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
// Make sure the data is in shared memory.
|
| 299 |
+
__syncthreads();
|
| 300 |
+
|
| 301 |
+
// Each thread deals with one or more column(s) of the matrix.
|
| 302 |
+
constexpr int SUMS_PER_THREAD = (E + THREADS_PER_BLOCK-1) / THREADS_PER_BLOCK;
|
| 303 |
+
#pragma unroll
|
| 304 |
+
for( int ii = 0, idx = tidx; ii < SUMS_PER_THREAD; ++ii, idx += THREADS_PER_BLOCK ) {
|
| 305 |
+
if( idx < E ) {
|
| 306 |
+
float sum = smem_reds[idx];
|
| 307 |
+
#pragma unroll
|
| 308 |
+
for( int jj = 1; jj < WARPS; ++jj ) {
|
| 309 |
+
smem_reds[idx + jj*E] = sum += smem_reds[idx + jj*E];
|
| 310 |
+
}
|
| 311 |
+
}
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
// Make sure the reductions are stored in shared memory.
|
| 315 |
+
__syncthreads();
|
| 316 |
+
|
| 317 |
+
// Each thread updates his partial products.
|
| 318 |
+
#pragma unroll
|
| 319 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
| 320 |
+
float sum = running_kv[ri];
|
| 321 |
+
if( warp > 0 ) {
|
| 322 |
+
sum += smem_reds[(warp-1)*E + lane + ri*THREADS_PER_WARP];
|
| 323 |
+
}
|
| 324 |
+
#pragma unroll
|
| 325 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
| 326 |
+
kv[ri][ci] += sum;
|
| 327 |
+
}
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
// Compute the partial output values for that thread.
|
| 331 |
+
float sum[COLS_PER_THREAD];
|
| 332 |
+
#pragma unroll
|
| 333 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
| 334 |
+
sum[ci] = q[0][ci] * kv[0][ci];
|
| 335 |
+
#pragma unroll
|
| 336 |
+
for( int ri = 1; ri < ROWS_PER_THREAD; ++ri ) {
|
| 337 |
+
sum[ci] += q[ri][ci] * kv[ri][ci];
|
| 338 |
+
}
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
// Run the parallel reductions inside the warp.
|
| 342 |
+
#pragma unroll
|
| 343 |
+
for( int mask = THREADS_PER_WARP / 2; mask >= 1; mask /= 2 ) {
|
| 344 |
+
#pragma unroll
|
| 345 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
| 346 |
+
sum[ci] += __shfl_xor_sync(uint32_t(-1), sum[ci], mask);
|
| 347 |
+
}
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
// Store the final output to shared memory.
|
| 351 |
+
if( lane == 0 ) {
|
| 352 |
+
#pragma unroll
|
| 353 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
| 354 |
+
smem_o[warp*COLS_PER_THREAD + ci] = sum[ci];
|
| 355 |
+
}
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
// Make sure the data is in shared memory.
|
| 359 |
+
__syncthreads();
|
| 360 |
+
|
| 361 |
+
// Store the output.
|
| 362 |
+
if( valid_vo ) {
|
| 363 |
+
*ptr_o = smem_o[tidx];
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
// Each thread updates his running kv.
|
| 367 |
+
#pragma unroll
|
| 368 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
| 369 |
+
running_kv[ri] += smem_reds[(WARPS-1)*E + lane + ri*THREADS_PER_WARP];
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
// Move to next location.
|
| 373 |
+
if( GO_BACKWARD ) {
|
| 374 |
+
ptr_o -= COLS_PER_ITER*params.o_stride_L;
|
| 375 |
+
} else {
|
| 376 |
+
ptr_o += COLS_PER_ITER*params.o_stride_L;
|
| 377 |
+
}
|
| 378 |
+
}
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 382 |
+
|
| 383 |
+
template< int E, bool GO_BACKWARD, int WARPS >
|
| 384 |
+
int lmha_low_occupancy_(const Lmha_params<float> ¶ms) {
|
| 385 |
+
|
| 386 |
+
// Make sure we are not going to launch an invalid grid.
|
| 387 |
+
if( params.H > 65535 || params.B > 65535 ) {
|
| 388 |
+
return 1;
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
// Prepare the grid and trigger the CUDA kernel.
|
| 392 |
+
dim3 grid;
|
| 393 |
+
grid.x = params.M;
|
| 394 |
+
grid.y = params.H;
|
| 395 |
+
grid.z = params.B;
|
| 396 |
+
lmha_low_occupancy_kernel<E, GO_BACKWARD, WARPS><<<grid, WARPS*THREADS_PER_WARP>>>(params);
|
| 397 |
+
return 0;
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 401 |
+
|
| 402 |
+
template< int E, bool GO_BACKWARD >
|
| 403 |
+
int lmha_low_occupancy_(const Lmha_params<float> ¶ms, int blocks) {
|
| 404 |
+
if( params.M * blocks >= 8*LOW_OCCUPANCY_THRESHOLD ) {
|
| 405 |
+
return lmha_low_occupancy_<E, GO_BACKWARD, 4>(params);
|
| 406 |
+
} else if( params.M * blocks >= 4*LOW_OCCUPANCY_THRESHOLD ) {
|
| 407 |
+
return lmha_low_occupancy_<E, GO_BACKWARD, 8>(params);
|
| 408 |
+
} else {
|
| 409 |
+
return lmha_low_occupancy_<E, GO_BACKWARD, 16>(params);
|
| 410 |
+
}
|
| 411 |
+
return 1;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 415 |
+
|
| 416 |
+
template< int E, typename Params >
|
| 417 |
+
static inline __device__ __host__ int smem_buffer_elts_(const Params ¶ms) {
|
| 418 |
+
int M = round_up(params.M, 4);
|
| 419 |
+
return 2*E + 2*M;
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 423 |
+
|
| 424 |
+
template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD >
|
| 425 |
+
__global__
|
| 426 |
+
void lmha_kernel(Lmha_params<float> params) {
|
| 427 |
+
|
| 428 |
+
// Make sure E is a multiple of 4.
|
| 429 |
+
static_assert(E % 4 == 0, "");
|
| 430 |
+
|
| 431 |
+
// The amount of shared memory per buffer (2 buffers for double-buffering).
|
| 432 |
+
const int smem_buffer_elts = smem_buffer_elts_<E>(params);
|
| 433 |
+
// The M dimension for shared memory.
|
| 434 |
+
const int M = round_up(params.M, 4);
|
| 435 |
+
|
| 436 |
+
// Shared memory to store Q, K and V. Size is 2*smem_buffer_elts.
|
| 437 |
+
extern __shared__ float smem_[];
|
| 438 |
+
|
| 439 |
+
// The various shared memory buffers.
|
| 440 |
+
float *smem_q = &smem_[0*E];
|
| 441 |
+
float *smem_k = &smem_[1*E];
|
| 442 |
+
float *smem_v = &smem_[2*E];
|
| 443 |
+
float *smem_o = &smem_[2*E + M];
|
| 444 |
+
|
| 445 |
+
// The index of the shared memory buffer (for double-buffering).
|
| 446 |
+
int smem_curr = 0;
|
| 447 |
+
|
| 448 |
+
// The sequence processed by that block.
|
| 449 |
+
const int bi = blockIdx.y;
|
| 450 |
+
// The head processed by that block.
|
| 451 |
+
const int hi = blockIdx.x;
|
| 452 |
+
|
| 453 |
+
// The linear index of the thread.
|
| 454 |
+
const int tidx = threadIdx.x;
|
| 455 |
+
|
| 456 |
+
// The offset to the position loaded by the thread in Q.
|
| 457 |
+
int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + tidx;
|
| 458 |
+
// The offset to the position loaded by the thread in K.
|
| 459 |
+
int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + tidx;
|
| 460 |
+
|
| 461 |
+
// If we walk backward, account for the extra offset.
|
| 462 |
+
if( GO_BACKWARD ) {
|
| 463 |
+
offset_q += (params.L-1)*params.q_stride_L;
|
| 464 |
+
offset_k += (params.L-1)*params.k_stride_L;
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
// Determine the base pointers for Q and K.
|
| 468 |
+
const float *ptr_q = ¶ms.q[offset_q];
|
| 469 |
+
const float *ptr_k = ¶ms.k[offset_k];
|
| 470 |
+
|
| 471 |
+
// The offset to the position loaded by the thread in V and O.
|
| 472 |
+
int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + tidx;
|
| 473 |
+
int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + tidx;
|
| 474 |
+
|
| 475 |
+
// If we walk backward, account for the extra offset.
|
| 476 |
+
if( GO_BACKWARD ) {
|
| 477 |
+
offset_v += (params.L-1)*params.v_stride_L;
|
| 478 |
+
offset_o += (params.L-1)*params.o_stride_L;
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
// Determine the base pointers for V.
|
| 482 |
+
const float *ptr_v = ¶ms.v[offset_v];
|
| 483 |
+
|
| 484 |
+
// Is it an active Q/K thread?
|
| 485 |
+
const int active_qk = tidx < params.E;
|
| 486 |
+
|
| 487 |
+
// Trigger the memory loads for Q and K.
|
| 488 |
+
float ldg_q = 0.f, ldg_k = 0.f;
|
| 489 |
+
if( active_qk ) {
|
| 490 |
+
ldg_q = *ptr_q;
|
| 491 |
+
ldg_k = *ptr_k;
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
// Is it an active V thread?
|
| 495 |
+
const int active_v = tidx < params.M;
|
| 496 |
+
|
| 497 |
+
// Trigger the memory loads for V.
|
| 498 |
+
float ldg_v = 0.f;
|
| 499 |
+
if( active_v ) {
|
| 500 |
+
ldg_v = *ptr_v;
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
+
// Move the load pointers.
|
| 504 |
+
if( GO_BACKWARD ) {
|
| 505 |
+
ptr_q -= params.q_stride_L;
|
| 506 |
+
ptr_k -= params.k_stride_L;
|
| 507 |
+
ptr_v -= params.v_stride_L;
|
| 508 |
+
} else {
|
| 509 |
+
ptr_q += params.q_stride_L;
|
| 510 |
+
ptr_k += params.k_stride_L;
|
| 511 |
+
ptr_v += params.v_stride_L;
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
// The number of FLOAT4s per head.
|
| 515 |
+
constexpr int FLOAT4s_PER_HEAD = E / 4;
|
| 516 |
+
// The number of FLOAT4s per thread.
|
| 517 |
+
constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD;
|
| 518 |
+
|
| 519 |
+
// The storage for the K*V^T values.
|
| 520 |
+
float4 kv[FLOAT4s_PER_THREAD];
|
| 521 |
+
#pragma unroll
|
| 522 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 523 |
+
kv[ii] = make_float4(0.f, 0.f, 0.f, 0.f);
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
// The output pointer.
|
| 527 |
+
float *out_ptr = ¶ms.out[offset_o];
|
| 528 |
+
|
| 529 |
+
// Store to shared memory Q and K.
|
| 530 |
+
if( tidx < E ) {
|
| 531 |
+
smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q;
|
| 532 |
+
smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k;
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
// Store to shared memory V. All threads store valid values.
|
| 536 |
+
if( tidx < M ) {
|
| 537 |
+
smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v;
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
// The position of the thread in the V dimension.
|
| 541 |
+
int vo = tidx / THREADS_PER_HEAD;
|
| 542 |
+
int vi = tidx % THREADS_PER_HEAD;
|
| 543 |
+
|
| 544 |
+
// Iterate over the timesteps.
|
| 545 |
+
for( int ti = 0; ti < params.L; ++ti ) {
|
| 546 |
+
|
| 547 |
+
// Is it the last iteration?
|
| 548 |
+
int is_last = ti == params.L - 1;
|
| 549 |
+
|
| 550 |
+
// Trigger the next loads for Q and K.
|
| 551 |
+
if( !is_last && active_qk ) {
|
| 552 |
+
ldg_q = *ptr_q;
|
| 553 |
+
ldg_k = *ptr_k;
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
// Trigger the next loads for V.
|
| 557 |
+
if( !is_last && active_v ) {
|
| 558 |
+
ldg_v = *ptr_v;
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
// Move the load pointers.
|
| 562 |
+
if( GO_BACKWARD ) {
|
| 563 |
+
ptr_q -= params.q_stride_L;
|
| 564 |
+
ptr_k -= params.k_stride_L;
|
| 565 |
+
ptr_v -= params.v_stride_L;
|
| 566 |
+
} else {
|
| 567 |
+
ptr_q += params.q_stride_L;
|
| 568 |
+
ptr_k += params.k_stride_L;
|
| 569 |
+
ptr_v += params.v_stride_L;
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
// Make sure the data is in shared memory.
|
| 573 |
+
__syncthreads();
|
| 574 |
+
|
| 575 |
+
// Each thread loads 4 values from K.
|
| 576 |
+
float4 k[FLOAT4s_PER_THREAD];
|
| 577 |
+
#pragma unroll
|
| 578 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 579 |
+
int ki = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4;
|
| 580 |
+
k[ii] = *reinterpret_cast<const float4*>(&smem_k[smem_curr*smem_buffer_elts + ki]);
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
// Each thread loads a single V value.
|
| 584 |
+
float v = 0.f;
|
| 585 |
+
if( vo < params.M ) {
|
| 586 |
+
v = *reinterpret_cast<const float *>(&smem_v[smem_curr*smem_buffer_elts + vo]);
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
// Update the K*V^T product.
|
| 590 |
+
#pragma unroll
|
| 591 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 592 |
+
kv[ii].x += k[ii].x * v;
|
| 593 |
+
kv[ii].y += k[ii].y * v;
|
| 594 |
+
kv[ii].z += k[ii].z * v;
|
| 595 |
+
kv[ii].w += k[ii].w * v;
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
// Load the Q values from shared memory.
|
| 599 |
+
float4 q[FLOAT4s_PER_THREAD];
|
| 600 |
+
#pragma unroll
|
| 601 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 602 |
+
int qi = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4;
|
| 603 |
+
q[ii] = *reinterpret_cast<const float4*>(&smem_q[smem_curr*smem_buffer_elts + qi]);
|
| 604 |
+
}
|
| 605 |
+
|
| 606 |
+
// Compute the partial output value for that thread.
|
| 607 |
+
float sum = 0.f;
|
| 608 |
+
#pragma unroll
|
| 609 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 610 |
+
sum += q[ii].x * kv[ii].x;
|
| 611 |
+
sum += q[ii].y * kv[ii].y;
|
| 612 |
+
sum += q[ii].z * kv[ii].z;
|
| 613 |
+
sum += q[ii].w * kv[ii].w;
|
| 614 |
+
}
|
| 615 |
+
|
| 616 |
+
// Finalize the computation of the sum (if we have more than 1 thread per head).
|
| 617 |
+
if( THREADS_PER_HEAD > 1 ) {
|
| 618 |
+
|
| 619 |
+
// Finalize the sum for each head.
|
| 620 |
+
#pragma unroll
|
| 621 |
+
for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) {
|
| 622 |
+
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
+
// Store to shared memory.
|
| 626 |
+
if( vo < M && vi == 0 ) {
|
| 627 |
+
smem_o[smem_curr*smem_buffer_elts + vo] = sum;
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
// Make sure the data is in shared memory.
|
| 631 |
+
__syncthreads();
|
| 632 |
+
|
| 633 |
+
// Active threads read the data to store.
|
| 634 |
+
if( active_v ) {
|
| 635 |
+
sum = smem_o[smem_curr*smem_buffer_elts + tidx];
|
| 636 |
+
}
|
| 637 |
+
|
| 638 |
+
} // THREADS_PER_HEAD > 1.
|
| 639 |
+
|
| 640 |
+
// Store the output. All the threads are active.
|
| 641 |
+
if( active_v ) {
|
| 642 |
+
*out_ptr = sum;
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
// Move to next location.
|
| 646 |
+
if( GO_BACKWARD ) {
|
| 647 |
+
out_ptr -= params.o_stride_L;
|
| 648 |
+
} else {
|
| 649 |
+
out_ptr += params.o_stride_L;
|
| 650 |
+
}
|
| 651 |
+
|
| 652 |
+
// Move the shared memory buffer.
|
| 653 |
+
smem_curr = (smem_curr + 1) % 2;
|
| 654 |
+
|
| 655 |
+
// Store to shared memory for Q and K.
|
| 656 |
+
if( !is_last && tidx < E ) {
|
| 657 |
+
smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q;
|
| 658 |
+
smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k;
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
// Store to shared memory for V.
|
| 662 |
+
if( !is_last && tidx < M ) {
|
| 663 |
+
smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v;
|
| 664 |
+
}
|
| 665 |
+
}
|
| 666 |
+
}
|
| 667 |
+
|
| 668 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 669 |
+
|
| 670 |
+
template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD >
|
| 671 |
+
int lmha_(const Lmha_params<float> ¶ms) {
|
| 672 |
+
// The M dimension rounded up to 4.
|
| 673 |
+
int M = round_up(params.M, 4);
|
| 674 |
+
|
| 675 |
+
// The number of threads in the block.
|
| 676 |
+
int block = round_up(max(E, M*THREADS_PER_HEAD), 32);
|
| 677 |
+
if( block > 512 || params.B > 65535 ) {
|
| 678 |
+
return 1;
|
| 679 |
+
}
|
| 680 |
+
|
| 681 |
+
// Prepare the kernel.
|
| 682 |
+
dim3 grid(params.H, params.B);
|
| 683 |
+
size_t smem = smem_buffer_elts_<E>(params)*2*sizeof(float);
|
| 684 |
+
lmha_kernel<E, THREADS_PER_HEAD, GO_BACKWARD><<<grid, block, smem>>>(params);
|
| 685 |
+
return 0;
|
| 686 |
+
}
|
| 687 |
+
|
| 688 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 689 |
+
|
| 690 |
+
template< bool GO_BACKWARD >
|
| 691 |
+
int lmha(const Lmha_params<float> ¶ms) {
|
| 692 |
+
int blocks = params.B * params.H;
|
| 693 |
+
int res = 1;
|
| 694 |
+
if( blocks < LOW_OCCUPANCY_THRESHOLD ) {
|
| 695 |
+
if( params.E <= 32 ) {
|
| 696 |
+
res = lmha_low_occupancy_< 32, GO_BACKWARD>(params, blocks);
|
| 697 |
+
} else if( params.E <= 64 ) {
|
| 698 |
+
res = lmha_low_occupancy_< 64, GO_BACKWARD>(params, blocks);
|
| 699 |
+
} else if( params.E <= 128 ) {
|
| 700 |
+
res = lmha_low_occupancy_<128, GO_BACKWARD>(params, blocks);
|
| 701 |
+
} else if( params.E <= 256 ) {
|
| 702 |
+
res = lmha_low_occupancy_<256, GO_BACKWARD>(params, blocks);
|
| 703 |
+
}
|
| 704 |
+
} else {
|
| 705 |
+
if( params.E <= 32 ) {
|
| 706 |
+
res = lmha_< 32, 1, GO_BACKWARD>(params);
|
| 707 |
+
} else if( params.E <= 48 ) {
|
| 708 |
+
res = lmha_< 48, 1, GO_BACKWARD>(params);
|
| 709 |
+
} else if( params.E <= 64 ) {
|
| 710 |
+
res = lmha_< 64, 1, GO_BACKWARD>(params);
|
| 711 |
+
} else if( params.E <= 128 ) {
|
| 712 |
+
res = lmha_<128, 2, GO_BACKWARD>(params);
|
| 713 |
+
} else if( params.E <= 256 ) {
|
| 714 |
+
res = lmha_<256, 4, GO_BACKWARD>(params);
|
| 715 |
+
}
|
| 716 |
+
}
|
| 717 |
+
return res;
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 721 |
+
|
| 722 |
+
template< typename T >
|
| 723 |
+
inline void set_params(Lmha_params<T> ¶ms,
|
| 724 |
+
const torch::Tensor q,
|
| 725 |
+
const torch::Tensor k,
|
| 726 |
+
const torch::Tensor v,
|
| 727 |
+
torch::Tensor o) {
|
| 728 |
+
|
| 729 |
+
// Define the pointers.
|
| 730 |
+
params.out = o.data_ptr<T>();
|
| 731 |
+
params.q = q.data_ptr<T>();
|
| 732 |
+
params.k = k.data_ptr<T>();
|
| 733 |
+
params.v = v.data_ptr<T>();
|
| 734 |
+
|
| 735 |
+
// Define the strides.
|
| 736 |
+
params.q_stride_B = (int) q.stride(0);
|
| 737 |
+
params.q_stride_H = (int) q.stride(1);
|
| 738 |
+
params.q_stride_L = (int) q.stride(2);
|
| 739 |
+
params.k_stride_B = (int) k.stride(0);
|
| 740 |
+
params.k_stride_H = (int) k.stride(1);
|
| 741 |
+
params.k_stride_L = (int) k.stride(2);
|
| 742 |
+
params.v_stride_B = (int) v.stride(0);
|
| 743 |
+
params.v_stride_H = (int) v.stride(1);
|
| 744 |
+
params.v_stride_L = (int) v.stride(2);
|
| 745 |
+
params.o_stride_B = (int) o.stride(0);
|
| 746 |
+
params.o_stride_H = (int) o.stride(1);
|
| 747 |
+
params.o_stride_L = (int) o.stride(2);
|
| 748 |
+
|
| 749 |
+
// Extract the dimensions.
|
| 750 |
+
int N = q.size(0);
|
| 751 |
+
int H = q.size(1);
|
| 752 |
+
int L = q.size(2);
|
| 753 |
+
int E = q.size(3);
|
| 754 |
+
int M = v.size(3);
|
| 755 |
+
|
| 756 |
+
params.B = N;
|
| 757 |
+
params.L = L;
|
| 758 |
+
params.H = H;
|
| 759 |
+
params.E = E;
|
| 760 |
+
params.M = M;
|
| 761 |
+
}
|
| 762 |
+
|
| 763 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 764 |
+
|
| 765 |
+
int lmha_fwd(const torch::Tensor queries,
|
| 766 |
+
const torch::Tensor keys,
|
| 767 |
+
const torch::Tensor values,
|
| 768 |
+
torch::Tensor product) {
|
| 769 |
+
|
| 770 |
+
// Make sure that we are using the correct GPU device
|
| 771 |
+
torch::DeviceGuard _guard(queries.device());
|
| 772 |
+
|
| 773 |
+
// Make sure the inner-most dimension of the tensors is packed.
|
| 774 |
+
assert(queries.stride(3) == 1);
|
| 775 |
+
assert(keys .stride(3) == 1);
|
| 776 |
+
assert(values .stride(3) == 1);
|
| 777 |
+
assert(product.stride(3) == 1);
|
| 778 |
+
|
| 779 |
+
// Extract the dimensions.
|
| 780 |
+
int N = queries.size(0);
|
| 781 |
+
int H = queries.size(1);
|
| 782 |
+
int L = queries.size(2);
|
| 783 |
+
int E = queries.size(3);
|
| 784 |
+
int M = values.size (3);
|
| 785 |
+
|
| 786 |
+
// The structure of params.
|
| 787 |
+
Lmha_params<float> params;
|
| 788 |
+
set_params(params, queries, keys, values, product);
|
| 789 |
+
|
| 790 |
+
// Launch the kernel.
|
| 791 |
+
return lmha<false>(params);
|
| 792 |
+
}
|
| 793 |
+
|
| 794 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 795 |
+
|
| 796 |
+
template< typename T >
|
| 797 |
+
struct Lmha_bwd_params {
|
| 798 |
+
|
| 799 |
+
// The output buffer for K. Dimensions [B, H, L, D].
|
| 800 |
+
T *out_k;
|
| 801 |
+
// The output buffer for V. Dimensions [B, H, L, D].
|
| 802 |
+
T *out_v;
|
| 803 |
+
|
| 804 |
+
// The input Qs. Dimensions [B, H, L, D].
|
| 805 |
+
const T *q;
|
| 806 |
+
// The input Ks. Dimensions [B, H, L, D].
|
| 807 |
+
const T *k;
|
| 808 |
+
// The input Vs. Dimensions [B, H, L, D].
|
| 809 |
+
const T *v;
|
| 810 |
+
// The input Gs. Dimensions [B, H, L, D].
|
| 811 |
+
const T *g;
|
| 812 |
+
|
| 813 |
+
// The dimensions.
|
| 814 |
+
int B, L, H, M, E;
|
| 815 |
+
|
| 816 |
+
// The strides for the input tensors.
|
| 817 |
+
int q_stride_B, q_stride_L, q_stride_H;
|
| 818 |
+
int k_stride_B, k_stride_L, k_stride_H;
|
| 819 |
+
int v_stride_B, v_stride_L, v_stride_H;
|
| 820 |
+
int g_stride_B, g_stride_L, g_stride_H;
|
| 821 |
+
|
| 822 |
+
// The strides for the outputs.
|
| 823 |
+
int out_k_stride_B, out_k_stride_L, out_k_stride_H;
|
| 824 |
+
int out_v_stride_B, out_v_stride_L, out_v_stride_H;
|
| 825 |
+
};
|
| 826 |
+
|
| 827 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 828 |
+
|
| 829 |
+
template< int D, int THREADS_PER_HEAD >
|
| 830 |
+
__global__ __launch_bounds__(D*THREADS_PER_HEAD*2)
|
| 831 |
+
void lmha_bwd_kernel(Lmha_bwd_params<float> params) {
|
| 832 |
+
|
| 833 |
+
// Make sure D is a multiple of 4.
|
| 834 |
+
static_assert(D % 4 == 0, "");
|
| 835 |
+
|
| 836 |
+
// The shared memory buffers.
|
| 837 |
+
__shared__ struct Smem { float qg[2*D], kv[2*D], out_kv[2*D]; } smem_[2];
|
| 838 |
+
|
| 839 |
+
// The index of the shared memory buffer (for double-buffering).
|
| 840 |
+
int smem_curr = 0;
|
| 841 |
+
|
| 842 |
+
// The sequence processed by that block.
|
| 843 |
+
const int bi = blockIdx.y;
|
| 844 |
+
// The head processed by that block.
|
| 845 |
+
const int hi = blockIdx.x;
|
| 846 |
+
|
| 847 |
+
// The linear index of the thread.
|
| 848 |
+
const int tidx = threadIdx.x;
|
| 849 |
+
|
| 850 |
+
// Split the threads into two slices.
|
| 851 |
+
int so = tidx / (D*THREADS_PER_HEAD);
|
| 852 |
+
int si = tidx % (D*THREADS_PER_HEAD);
|
| 853 |
+
|
| 854 |
+
// The strides for B/L/H for the Q/G tensors.
|
| 855 |
+
int qg_stride_B, qg_stride_L, qg_stride_H;
|
| 856 |
+
if( so == 0 ) {
|
| 857 |
+
qg_stride_B = params.q_stride_B;
|
| 858 |
+
qg_stride_L = params.q_stride_L;
|
| 859 |
+
qg_stride_H = params.q_stride_H;
|
| 860 |
+
} else {
|
| 861 |
+
qg_stride_B = params.g_stride_B;
|
| 862 |
+
qg_stride_L = params.g_stride_L;
|
| 863 |
+
qg_stride_H = params.g_stride_H;
|
| 864 |
+
}
|
| 865 |
+
|
| 866 |
+
// The strides for B/L/H for the K/V tensors.
|
| 867 |
+
int kv_stride_B, kv_stride_L, kv_stride_H;
|
| 868 |
+
if( so == 0 ) {
|
| 869 |
+
kv_stride_B = params.k_stride_B;
|
| 870 |
+
kv_stride_L = params.k_stride_L;
|
| 871 |
+
kv_stride_H = params.k_stride_H;
|
| 872 |
+
} else {
|
| 873 |
+
kv_stride_B = params.v_stride_B;
|
| 874 |
+
kv_stride_L = params.v_stride_L;
|
| 875 |
+
kv_stride_H = params.v_stride_H;
|
| 876 |
+
}
|
| 877 |
+
|
| 878 |
+
// The hidden size.
|
| 879 |
+
int hidden_size_per_head = 0;
|
| 880 |
+
if( so == 0 ) {
|
| 881 |
+
hidden_size_per_head = params.E;
|
| 882 |
+
} else {
|
| 883 |
+
hidden_size_per_head = params.M;
|
| 884 |
+
}
|
| 885 |
+
|
| 886 |
+
// Where to start reading from.
|
| 887 |
+
int offset_qg = bi*qg_stride_B + hi*qg_stride_H + si;
|
| 888 |
+
int offset_kv = bi*kv_stride_B + hi*kv_stride_H + si;
|
| 889 |
+
|
| 890 |
+
// We walk backward, account for the extra offset.
|
| 891 |
+
offset_qg += (params.L-1)*qg_stride_L;
|
| 892 |
+
offset_kv += (params.L-1)*kv_stride_L;
|
| 893 |
+
|
| 894 |
+
// Determine the base pointers for Q, K, V and G.
|
| 895 |
+
const float *ptr_qg = &(so == 0 ? params.q : params.g)[offset_qg];
|
| 896 |
+
const float *ptr_kv = &(so == 0 ? params.k : params.v)[offset_kv];
|
| 897 |
+
|
| 898 |
+
// Is it an active thread?
|
| 899 |
+
const int active = si < hidden_size_per_head;
|
| 900 |
+
|
| 901 |
+
// Trigger the memory loads for Q, K, V and G.
|
| 902 |
+
float ldg_qg = 0.f, ldg_kv = 0.f;
|
| 903 |
+
if( active ) {
|
| 904 |
+
ldg_qg = *ptr_qg;
|
| 905 |
+
ldg_kv = *ptr_kv;
|
| 906 |
+
}
|
| 907 |
+
|
| 908 |
+
// Move the load pointers (backward).
|
| 909 |
+
ptr_qg -= qg_stride_L;
|
| 910 |
+
ptr_kv -= kv_stride_L;
|
| 911 |
+
|
| 912 |
+
// The number of FLOAT4s per head.
|
| 913 |
+
constexpr int FLOAT4s_PER_HEAD = D / 4;
|
| 914 |
+
// The number of FLOAT4s per thread.
|
| 915 |
+
constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD;
|
| 916 |
+
|
| 917 |
+
// The storage for the G*Q^T or Q^T*G values.
|
| 918 |
+
float4 gq[FLOAT4s_PER_THREAD];
|
| 919 |
+
#pragma unroll
|
| 920 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 921 |
+
gq[ii] = make_float4(0.f, 0.f, 0.f, 0.f);
|
| 922 |
+
}
|
| 923 |
+
|
| 924 |
+
// The strides for B/L/H for the K/V tensors.
|
| 925 |
+
int out_kv_stride_B, out_kv_stride_L, out_kv_stride_H;
|
| 926 |
+
if( so == 0 ) {
|
| 927 |
+
out_kv_stride_B = params.out_k_stride_B;
|
| 928 |
+
out_kv_stride_L = params.out_k_stride_L;
|
| 929 |
+
out_kv_stride_H = params.out_k_stride_H;
|
| 930 |
+
} else {
|
| 931 |
+
out_kv_stride_B = params.out_v_stride_B;
|
| 932 |
+
out_kv_stride_L = params.out_v_stride_L;
|
| 933 |
+
out_kv_stride_H = params.out_v_stride_H;
|
| 934 |
+
}
|
| 935 |
+
|
| 936 |
+
// Where to start reading from.
|
| 937 |
+
int offset_out_kv = bi*out_kv_stride_B + hi*out_kv_stride_H + si;
|
| 938 |
+
|
| 939 |
+
// We walk backward, account for the extra offset.
|
| 940 |
+
offset_out_kv += (params.L-1)*out_kv_stride_L;
|
| 941 |
+
|
| 942 |
+
// The output pointer.
|
| 943 |
+
float *ptr_out_kv = &(so == 0 ? params.out_k : params.out_v)[offset_out_kv];
|
| 944 |
+
|
| 945 |
+
// Store to shared memory.
|
| 946 |
+
if( si < D ) {
|
| 947 |
+
smem_[smem_curr].qg[so*D + si] = ldg_qg;
|
| 948 |
+
smem_[smem_curr].kv[so*D + si] = ldg_kv;
|
| 949 |
+
}
|
| 950 |
+
|
| 951 |
+
// The position of the thread in the output dimension.
|
| 952 |
+
int oo = si / THREADS_PER_HEAD % D;
|
| 953 |
+
int oi = si % THREADS_PER_HEAD * 4;
|
| 954 |
+
|
| 955 |
+
// Iterate over the timesteps.
|
| 956 |
+
for( int ti = 0; ti < params.L; ++ti ) {
|
| 957 |
+
|
| 958 |
+
// Is it the last iteration?
|
| 959 |
+
int is_last = ti == params.L - 1;
|
| 960 |
+
|
| 961 |
+
// Trigger the next loads.
|
| 962 |
+
if( !is_last && active ) {
|
| 963 |
+
ldg_qg = *ptr_qg;
|
| 964 |
+
ldg_kv = *ptr_kv;
|
| 965 |
+
}
|
| 966 |
+
|
| 967 |
+
// Move the load pointers.
|
| 968 |
+
ptr_qg -= qg_stride_L;
|
| 969 |
+
ptr_kv -= kv_stride_L;
|
| 970 |
+
|
| 971 |
+
// Make sure the data is in shared memory.
|
| 972 |
+
__syncthreads();
|
| 973 |
+
|
| 974 |
+
// Each thread loads 4 values from G or Q.
|
| 975 |
+
float4 g[FLOAT4s_PER_THREAD];
|
| 976 |
+
#pragma unroll
|
| 977 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 978 |
+
float *smem_ptr = &smem_[smem_curr].qg[(so^1)*D + oi];
|
| 979 |
+
g[ii] = *reinterpret_cast<const float4*>(&smem_ptr[ii*THREADS_PER_HEAD*4]);
|
| 980 |
+
}
|
| 981 |
+
|
| 982 |
+
// Each thread loads a single from Q or G value.
|
| 983 |
+
float q = smem_[smem_curr].qg[so*D + oo];
|
| 984 |
+
|
| 985 |
+
// Update the G*Q^T or Q*G^T product.
|
| 986 |
+
#pragma unroll
|
| 987 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 988 |
+
gq[ii].x += g[ii].x * q;
|
| 989 |
+
gq[ii].y += g[ii].y * q;
|
| 990 |
+
gq[ii].z += g[ii].z * q;
|
| 991 |
+
gq[ii].w += g[ii].w * q;
|
| 992 |
+
}
|
| 993 |
+
|
| 994 |
+
// Load the V or K values from shared memory.
|
| 995 |
+
float4 v[FLOAT4s_PER_THREAD];
|
| 996 |
+
#pragma unroll
|
| 997 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 998 |
+
float *smem_ptr = &smem_[smem_curr].kv[(so^1)*D + oi];
|
| 999 |
+
v[ii] = *reinterpret_cast<const float4*>(&smem_ptr[ii*THREADS_PER_HEAD*4]);
|
| 1000 |
+
}
|
| 1001 |
+
|
| 1002 |
+
// Compute the partial output value for that thread.
|
| 1003 |
+
float sum = 0.f;
|
| 1004 |
+
#pragma unroll
|
| 1005 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 1006 |
+
sum += v[ii].x * gq[ii].x;
|
| 1007 |
+
sum += v[ii].y * gq[ii].y;
|
| 1008 |
+
sum += v[ii].z * gq[ii].z;
|
| 1009 |
+
sum += v[ii].w * gq[ii].w;
|
| 1010 |
+
}
|
| 1011 |
+
|
| 1012 |
+
// Finalize the computation of the sum (if we have more than 1 thread per head).
|
| 1013 |
+
if( THREADS_PER_HEAD > 1 ) {
|
| 1014 |
+
|
| 1015 |
+
// Finalize the sum for each head.
|
| 1016 |
+
#pragma unroll
|
| 1017 |
+
for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) {
|
| 1018 |
+
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
| 1019 |
+
}
|
| 1020 |
+
|
| 1021 |
+
// Store to shared memory.
|
| 1022 |
+
if( oi == 0 ) {
|
| 1023 |
+
smem_[smem_curr].out_kv[so*D + oo] = sum;
|
| 1024 |
+
}
|
| 1025 |
+
|
| 1026 |
+
// Make sure the data is in shared memory.
|
| 1027 |
+
__syncthreads();
|
| 1028 |
+
|
| 1029 |
+
// Active threads read the data to store.
|
| 1030 |
+
if( si < hidden_size_per_head ) {
|
| 1031 |
+
sum = smem_[smem_curr].out_kv[so*D + si];
|
| 1032 |
+
}
|
| 1033 |
+
|
| 1034 |
+
} // THREADS_PER_HEAD > 1.
|
| 1035 |
+
|
| 1036 |
+
// Store the output. All the threads are active.
|
| 1037 |
+
if( si < hidden_size_per_head ) {
|
| 1038 |
+
*ptr_out_kv = sum;
|
| 1039 |
+
}
|
| 1040 |
+
|
| 1041 |
+
// Move to next location.
|
| 1042 |
+
ptr_out_kv -= out_kv_stride_L;
|
| 1043 |
+
|
| 1044 |
+
// Move the shared memory buffer.
|
| 1045 |
+
smem_curr = (smem_curr + 1) % 2;
|
| 1046 |
+
|
| 1047 |
+
// Store to shared memory for Q and K.
|
| 1048 |
+
if( !is_last && si < D ) {
|
| 1049 |
+
smem_[smem_curr].qg[so*D + si] = ldg_qg;
|
| 1050 |
+
smem_[smem_curr].kv[so*D + si] = ldg_kv;
|
| 1051 |
+
}
|
| 1052 |
+
}
|
| 1053 |
+
}
|
| 1054 |
+
|
| 1055 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1056 |
+
|
| 1057 |
+
template< int D, int THREADS_PER_HEAD >
|
| 1058 |
+
int lmha_bwd_(const Lmha_bwd_params<float> ¶ms) {
|
| 1059 |
+
int block = D*THREADS_PER_HEAD*2;
|
| 1060 |
+
if( block >= 1024 || params.B > 65535 ) {
|
| 1061 |
+
return 1;
|
| 1062 |
+
}
|
| 1063 |
+
dim3 grid(params.H, params.B);
|
| 1064 |
+
lmha_bwd_kernel<D, THREADS_PER_HEAD><<<grid, block>>>(params);
|
| 1065 |
+
return 0;
|
| 1066 |
+
}
|
| 1067 |
+
|
| 1068 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1069 |
+
|
| 1070 |
+
int lmha_bwd(const Lmha_bwd_params<float> ¶ms) {
|
| 1071 |
+
int blocks = params.B * params.H;
|
| 1072 |
+
if( blocks < LOW_OCCUPANCY_THRESHOLD ) {
|
| 1073 |
+
return 1;
|
| 1074 |
+
}
|
| 1075 |
+
|
| 1076 |
+
int hidden_size_per_head = max(params.E, params.M);
|
| 1077 |
+
int res = 1;
|
| 1078 |
+
if( hidden_size_per_head <= 32 ) {
|
| 1079 |
+
res = lmha_bwd_< 32, 1>(params);
|
| 1080 |
+
} else if( hidden_size_per_head <= 64 ) {
|
| 1081 |
+
res = lmha_bwd_< 64, 1>(params);
|
| 1082 |
+
} else if( hidden_size_per_head <= 128 ) {
|
| 1083 |
+
res = lmha_bwd_<128, 2>(params);
|
| 1084 |
+
} else if( hidden_size_per_head <= 256 ) {
|
| 1085 |
+
res = lmha_bwd_<256, 4>(params);
|
| 1086 |
+
}
|
| 1087 |
+
return res;
|
| 1088 |
+
}
|
| 1089 |
+
|
| 1090 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1091 |
+
|
| 1092 |
+
int lmha_bwd(const torch::Tensor queries,
|
| 1093 |
+
const torch::Tensor keys,
|
| 1094 |
+
const torch::Tensor values,
|
| 1095 |
+
const torch::Tensor grad_out,
|
| 1096 |
+
torch::Tensor grad_queries,
|
| 1097 |
+
torch::Tensor grad_keys,
|
| 1098 |
+
torch::Tensor grad_values) {
|
| 1099 |
+
|
| 1100 |
+
// Make sure that we are using the correct GPU device
|
| 1101 |
+
torch::DeviceGuard _guard(queries.device());
|
| 1102 |
+
|
| 1103 |
+
// Make sure the inner-most dimension of the tensors is packed.
|
| 1104 |
+
assert(queries .stride(3) == 1);
|
| 1105 |
+
assert(keys .stride(3) == 1);
|
| 1106 |
+
assert(values .stride(3) == 1);
|
| 1107 |
+
assert(grad_out .stride(3) == 1);
|
| 1108 |
+
assert(grad_queries.stride(3) == 1);
|
| 1109 |
+
assert(grad_keys .stride(3) == 1);
|
| 1110 |
+
assert(grad_values .stride(3) == 1);
|
| 1111 |
+
|
| 1112 |
+
// Extract the dimensions.
|
| 1113 |
+
int N = queries.size(0);
|
| 1114 |
+
int H = queries.size(1);
|
| 1115 |
+
int L = queries.size(2);
|
| 1116 |
+
int E = queries.size(3);
|
| 1117 |
+
int M = values.size (3);
|
| 1118 |
+
|
| 1119 |
+
// Gradient on Q.
|
| 1120 |
+
|
| 1121 |
+
// The structure of params.
|
| 1122 |
+
Lmha_params<float> params;
|
| 1123 |
+
set_params(params, grad_out, values, keys, grad_queries);
|
| 1124 |
+
|
| 1125 |
+
// Launch the kernel.
|
| 1126 |
+
int res = lmha<false>(params);
|
| 1127 |
+
if( res ) {
|
| 1128 |
+
return res;
|
| 1129 |
+
}
|
| 1130 |
+
|
| 1131 |
+
// Gradient on K and V together.
|
| 1132 |
+
|
| 1133 |
+
Lmha_bwd_params<float> bwd_params;
|
| 1134 |
+
bwd_params.out_k = grad_keys.data_ptr<float>();
|
| 1135 |
+
bwd_params.out_v = grad_values.data_ptr<float>();
|
| 1136 |
+
bwd_params.q = queries.data_ptr<float>();
|
| 1137 |
+
bwd_params.k = keys.data_ptr<float>();
|
| 1138 |
+
bwd_params.v = values.data_ptr<float>();
|
| 1139 |
+
bwd_params.g = grad_out.data_ptr<float>();
|
| 1140 |
+
|
| 1141 |
+
bwd_params.B = N;
|
| 1142 |
+
bwd_params.L = L;
|
| 1143 |
+
bwd_params.H = H;
|
| 1144 |
+
bwd_params.E = E;
|
| 1145 |
+
bwd_params.M = M;
|
| 1146 |
+
|
| 1147 |
+
bwd_params.q_stride_B = queries.stride(0);
|
| 1148 |
+
bwd_params.q_stride_H = queries.stride(1);
|
| 1149 |
+
bwd_params.q_stride_L = queries.stride(2);
|
| 1150 |
+
bwd_params.k_stride_B = keys.stride(0);
|
| 1151 |
+
bwd_params.k_stride_H = keys.stride(1);
|
| 1152 |
+
bwd_params.k_stride_L = keys.stride(2);
|
| 1153 |
+
bwd_params.v_stride_B = values.stride(0);
|
| 1154 |
+
bwd_params.v_stride_H = values.stride(1);
|
| 1155 |
+
bwd_params.v_stride_L = values.stride(2);
|
| 1156 |
+
bwd_params.g_stride_B = grad_out.stride(0);
|
| 1157 |
+
bwd_params.g_stride_H = grad_out.stride(1);
|
| 1158 |
+
bwd_params.g_stride_L = grad_out.stride(2);
|
| 1159 |
+
|
| 1160 |
+
bwd_params.out_k_stride_B = grad_keys.stride(0);
|
| 1161 |
+
bwd_params.out_k_stride_H = grad_keys.stride(1);
|
| 1162 |
+
bwd_params.out_k_stride_L = grad_keys.stride(2);
|
| 1163 |
+
bwd_params.out_v_stride_B = grad_values.stride(0);
|
| 1164 |
+
bwd_params.out_v_stride_H = grad_values.stride(1);
|
| 1165 |
+
bwd_params.out_v_stride_L = grad_values.stride(2);
|
| 1166 |
+
|
| 1167 |
+
// Try to run the fused kernel.
|
| 1168 |
+
int fallback = lmha_bwd(bwd_params);
|
| 1169 |
+
|
| 1170 |
+
// If it failed, fallback on separate kernels for K and V.
|
| 1171 |
+
if( fallback ) {
|
| 1172 |
+
|
| 1173 |
+
// Gradient on K.
|
| 1174 |
+
|
| 1175 |
+
// Launch the kernel.
|
| 1176 |
+
set_params(params, values, grad_out, queries, grad_keys);
|
| 1177 |
+
res = lmha<true>(params);
|
| 1178 |
+
if( res ) {
|
| 1179 |
+
return res;
|
| 1180 |
+
}
|
| 1181 |
+
|
| 1182 |
+
// Gradient on V.
|
| 1183 |
+
|
| 1184 |
+
// Launch the kernel.
|
| 1185 |
+
set_params(params, keys, queries, grad_out, grad_values);
|
| 1186 |
+
return lmha<true>(params);
|
| 1187 |
+
}
|
| 1188 |
+
|
| 1189 |
+
// It worked...
|
| 1190 |
+
return 0;
|
| 1191 |
+
}
|
| 1192 |
+
|
| 1193 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1194 |
+
|
| 1195 |
+
} // namespace nvidia
|
| 1196 |
+
#endif // #ifdef ENABLE_NVIDIA_OPTIMIZATIONS
|
| 1197 |
+
|
| 1198 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1199 |
+
|
| 1200 |
+
typedef torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> float_accessor;
|
| 1201 |
+
|
| 1202 |
+
#define E_BLOCK_SIZE 8
|
| 1203 |
+
|
| 1204 |
+
__global__ void causal_dot_product_kernel(
|
| 1205 |
+
const float_accessor queries,
|
| 1206 |
+
const float_accessor keys,
|
| 1207 |
+
const float_accessor values,
|
| 1208 |
+
float_accessor result,
|
| 1209 |
+
const int N,
|
| 1210 |
+
const int H,
|
| 1211 |
+
const int L,
|
| 1212 |
+
const int E,
|
| 1213 |
+
const int M
|
| 1214 |
+
) {
|
| 1215 |
+
int n = blockIdx.y;
|
| 1216 |
+
int h = blockIdx.z;
|
| 1217 |
+
|
| 1218 |
+
int e_start = blockIdx.x * E_BLOCK_SIZE;
|
| 1219 |
+
int m = threadIdx.x % M;
|
| 1220 |
+
|
| 1221 |
+
extern __shared__ float shared_mem[];
|
| 1222 |
+
float* shared_kv = shared_mem;
|
| 1223 |
+
|
| 1224 |
+
for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
|
| 1225 |
+
shared_kv[m + e_local * M] = 0;
|
| 1226 |
+
}
|
| 1227 |
+
|
| 1228 |
+
for (int t=0; t<L; t++) {
|
| 1229 |
+
float res = 0;
|
| 1230 |
+
for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
|
| 1231 |
+
shared_kv[e_local*M + m] += keys[n][h][t][e_local + e_start] * values[n][h][t][m];
|
| 1232 |
+
res += queries[n][h][t][e_local + e_start] * shared_kv[e_local*M + m];
|
| 1233 |
+
}
|
| 1234 |
+
atomicAdd(
|
| 1235 |
+
&result[n][h][t][m],
|
| 1236 |
+
res
|
| 1237 |
+
);
|
| 1238 |
+
}
|
| 1239 |
+
}
|
| 1240 |
+
|
| 1241 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1242 |
+
|
| 1243 |
+
void causal_dot_product_(const torch::Tensor queries,
|
| 1244 |
+
const torch::Tensor keys,
|
| 1245 |
+
const torch::Tensor values,
|
| 1246 |
+
torch::Tensor product) {
|
| 1247 |
+
// Make sure that we are using the correct GPU device
|
| 1248 |
+
torch::DeviceGuard _guard(queries.device());
|
| 1249 |
+
|
| 1250 |
+
int N = queries.size(0);
|
| 1251 |
+
int H = queries.size(1);
|
| 1252 |
+
int L = queries.size(2);
|
| 1253 |
+
int E = queries.size(3);
|
| 1254 |
+
int M = values.size(3);
|
| 1255 |
+
|
| 1256 |
+
const int blocks_per_sequence = (E + E_BLOCK_SIZE - 1) / E_BLOCK_SIZE;
|
| 1257 |
+
|
| 1258 |
+
dim3 blockDim(M, 1, 1);
|
| 1259 |
+
dim3 gridDim(blocks_per_sequence, N, H);
|
| 1260 |
+
const int shared_mem_forward = E_BLOCK_SIZE * M * sizeof(float);
|
| 1261 |
+
|
| 1262 |
+
causal_dot_product_kernel<<<gridDim, blockDim, shared_mem_forward>>>(
|
| 1263 |
+
queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1264 |
+
keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1265 |
+
values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1266 |
+
product.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1267 |
+
N, H, L, E, M
|
| 1268 |
+
);
|
| 1269 |
+
}
|
| 1270 |
+
|
| 1271 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1272 |
+
|
| 1273 |
+
void causal_dot_product(const torch::Tensor queries,
|
| 1274 |
+
const torch::Tensor keys,
|
| 1275 |
+
const torch::Tensor values,
|
| 1276 |
+
torch::Tensor product) {
|
| 1277 |
+
#ifdef ENABLE_NVIDIA_OPTIMIZATIONS
|
| 1278 |
+
int fallback = nvidia::lmha_fwd(queries, keys, values, product);
|
| 1279 |
+
#else
|
| 1280 |
+
int fallback = 1;
|
| 1281 |
+
#endif
|
| 1282 |
+
if( fallback ) {
|
| 1283 |
+
causal_dot_product_(queries, keys, values, product);
|
| 1284 |
+
}
|
| 1285 |
+
}
|
| 1286 |
+
|
| 1287 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1288 |
+
|
| 1289 |
+
#define M_BLOCK_SIZE 4
|
| 1290 |
+
|
| 1291 |
+
// we need shared memory to store
|
| 1292 |
+
// kv
|
| 1293 |
+
// Backward direction
|
| 1294 |
+
// kv_backwards
|
| 1295 |
+
// Shared memory usage
|
| 1296 |
+
__global__ void causal_dot_backward_query_key_kernel(
|
| 1297 |
+
const float_accessor queries,
|
| 1298 |
+
const float_accessor keys,
|
| 1299 |
+
const float_accessor values,
|
| 1300 |
+
const float_accessor grad_out,
|
| 1301 |
+
float_accessor grad_queries,
|
| 1302 |
+
float_accessor grad_keys,
|
| 1303 |
+
int N,
|
| 1304 |
+
int H,
|
| 1305 |
+
int L,
|
| 1306 |
+
int E,
|
| 1307 |
+
int M
|
| 1308 |
+
) {
|
| 1309 |
+
int n = blockIdx.y;
|
| 1310 |
+
int h = blockIdx.z;
|
| 1311 |
+
|
| 1312 |
+
int m_start = blockIdx.x * M_BLOCK_SIZE;
|
| 1313 |
+
int e = threadIdx.x % E;
|
| 1314 |
+
|
| 1315 |
+
extern __shared__ float shared_mem[];
|
| 1316 |
+
const int shared_kv_size = M_BLOCK_SIZE * E;
|
| 1317 |
+
float* shared_kv = shared_mem;
|
| 1318 |
+
float* shared_kv_bw = shared_mem + shared_kv_size;
|
| 1319 |
+
|
| 1320 |
+
for (int m_local = 0; m_local < M_BLOCK_SIZE && m_local + m_start < M; m_local++) {
|
| 1321 |
+
shared_kv[m_local * E + e] = 0;
|
| 1322 |
+
shared_kv_bw[m_local * E + e] = 0;
|
| 1323 |
+
}
|
| 1324 |
+
|
| 1325 |
+
for (int l=0; l<L; l++) {
|
| 1326 |
+
float res = 0, res_bw = 0;
|
| 1327 |
+
int l_b = L - l - 1;
|
| 1328 |
+
for (int m_local = 0; m_local < M_BLOCK_SIZE && m_local + m_start < M; m_local++) {
|
| 1329 |
+
shared_kv[m_local*E + e] += keys[n][h][l][e] * values[n][h][l][m_start + m_local];
|
| 1330 |
+
shared_kv_bw[m_local*E + e] += queries[n][h][l_b][e] * grad_out[n][h][l_b][m_start + m_local];
|
| 1331 |
+
res += grad_out[n][h][l][m_start + m_local] * shared_kv[m_local*E + e];
|
| 1332 |
+
res_bw += values[n][h][l_b][m_start + m_local] * shared_kv_bw[m_local*E + e];
|
| 1333 |
+
}
|
| 1334 |
+
atomicAdd(
|
| 1335 |
+
&grad_queries[n][h][l][e],
|
| 1336 |
+
res
|
| 1337 |
+
);
|
| 1338 |
+
atomicAdd(
|
| 1339 |
+
&grad_keys[n][h][l_b][e],
|
| 1340 |
+
res_bw
|
| 1341 |
+
);
|
| 1342 |
+
}
|
| 1343 |
+
}
|
| 1344 |
+
|
| 1345 |
+
|
| 1346 |
+
__global__ void causal_dot_backward_value_kernel(
|
| 1347 |
+
const float_accessor queries,
|
| 1348 |
+
const float_accessor keys,
|
| 1349 |
+
const float_accessor values,
|
| 1350 |
+
const float_accessor grad_out,
|
| 1351 |
+
float_accessor grad_keys,
|
| 1352 |
+
float_accessor grad_values,
|
| 1353 |
+
int N,
|
| 1354 |
+
int H,
|
| 1355 |
+
int L,
|
| 1356 |
+
int E,
|
| 1357 |
+
int M
|
| 1358 |
+
) {
|
| 1359 |
+
int n = blockIdx.y;
|
| 1360 |
+
int h = blockIdx.z;
|
| 1361 |
+
|
| 1362 |
+
int e_start = blockIdx.x * E_BLOCK_SIZE;
|
| 1363 |
+
int m = threadIdx.x % M;
|
| 1364 |
+
|
| 1365 |
+
extern __shared__ float shared_mem[];
|
| 1366 |
+
float* shared_kv = shared_mem;
|
| 1367 |
+
for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
|
| 1368 |
+
shared_kv[m + e_local * M] = 0;
|
| 1369 |
+
}
|
| 1370 |
+
|
| 1371 |
+
for (int l = 0; l < L; l++) {
|
| 1372 |
+
int l_b = L - l -1;
|
| 1373 |
+
float res = 0;
|
| 1374 |
+
for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
|
| 1375 |
+
shared_kv[e_local*M + m] += queries[n][h][l_b][e_start + e_local] * grad_out[n][h][l_b][m];
|
| 1376 |
+
res += keys[n][h][l_b][e_start + e_local] * shared_kv[e_local*M + m];
|
| 1377 |
+
}
|
| 1378 |
+
atomicAdd(
|
| 1379 |
+
&grad_values[n][h][l_b][m],
|
| 1380 |
+
res
|
| 1381 |
+
);
|
| 1382 |
+
}
|
| 1383 |
+
}
|
| 1384 |
+
|
| 1385 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1386 |
+
|
| 1387 |
+
void causal_dot_backward_(const torch::Tensor queries,
|
| 1388 |
+
const torch::Tensor keys,
|
| 1389 |
+
const torch::Tensor values,
|
| 1390 |
+
const torch::Tensor grad_out,
|
| 1391 |
+
torch::Tensor grad_queries,
|
| 1392 |
+
torch::Tensor grad_keys,
|
| 1393 |
+
torch::Tensor grad_values) {
|
| 1394 |
+
|
| 1395 |
+
// Make sure that we are using the correct GPU device
|
| 1396 |
+
torch::DeviceGuard _guard(queries.device());
|
| 1397 |
+
|
| 1398 |
+
int N = queries.size(0);
|
| 1399 |
+
int H = queries.size(1);
|
| 1400 |
+
int L = queries.size(2);
|
| 1401 |
+
int E = queries.size(3);
|
| 1402 |
+
int M = values.size(3);
|
| 1403 |
+
|
| 1404 |
+
const int blocks_per_sequence = (M + M_BLOCK_SIZE - 1) / M_BLOCK_SIZE;
|
| 1405 |
+
|
| 1406 |
+
dim3 blockDim(E, 1, 1);
|
| 1407 |
+
dim3 gridDim(blocks_per_sequence, N, H);
|
| 1408 |
+
const int shared_mem_qk_backward = 2 * M_BLOCK_SIZE * E * sizeof(float);
|
| 1409 |
+
|
| 1410 |
+
causal_dot_backward_query_key_kernel<<<gridDim, blockDim, shared_mem_qk_backward>>>(
|
| 1411 |
+
queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1412 |
+
keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1413 |
+
values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1414 |
+
grad_out.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1415 |
+
grad_queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1416 |
+
grad_keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1417 |
+
N, H, L, E, M
|
| 1418 |
+
);
|
| 1419 |
+
|
| 1420 |
+
const int blocks_per_sequence_value = (E + E_BLOCK_SIZE - 1) / E_BLOCK_SIZE;
|
| 1421 |
+
|
| 1422 |
+
dim3 blockDimv(M, 1, 1);
|
| 1423 |
+
dim3 gridDimv(blocks_per_sequence_value, N, H);
|
| 1424 |
+
const int shared_mem_v_backward = E_BLOCK_SIZE * M * sizeof(float);
|
| 1425 |
+
causal_dot_backward_value_kernel<<<gridDimv, blockDimv, shared_mem_v_backward>>>(
|
| 1426 |
+
queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1427 |
+
keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1428 |
+
values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1429 |
+
grad_out.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1430 |
+
grad_keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1431 |
+
grad_values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1432 |
+
N, H, L, E, M
|
| 1433 |
+
);
|
| 1434 |
+
}
|
| 1435 |
+
|
| 1436 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1437 |
+
|
| 1438 |
+
void causal_dot_backward(const torch::Tensor queries,
|
| 1439 |
+
const torch::Tensor keys,
|
| 1440 |
+
const torch::Tensor values,
|
| 1441 |
+
const torch::Tensor grad_out,
|
| 1442 |
+
torch::Tensor grad_queries,
|
| 1443 |
+
torch::Tensor grad_keys,
|
| 1444 |
+
torch::Tensor grad_values) {
|
| 1445 |
+
#ifdef ENABLE_NVIDIA_OPTIMIZATIONS
|
| 1446 |
+
int fallback = nvidia::lmha_bwd(queries,
|
| 1447 |
+
keys,
|
| 1448 |
+
values,
|
| 1449 |
+
grad_out,
|
| 1450 |
+
grad_queries,
|
| 1451 |
+
grad_keys,
|
| 1452 |
+
grad_values);
|
| 1453 |
+
#else
|
| 1454 |
+
int fallback = 1;
|
| 1455 |
+
#endif
|
| 1456 |
+
if( fallback ) {
|
| 1457 |
+
// Make sure that the gradient tensors are 0. This is needed because the
|
| 1458 |
+
// bwd pass might have partially executed and filled in some values in
|
| 1459 |
+
// grad_queries or grad_keys.
|
| 1460 |
+
//
|
| 1461 |
+
// This adds a small overhead every time we have to fall back to the old
|
| 1462 |
+
// kernel for the backward pass.
|
| 1463 |
+
grad_queries.zero_();
|
| 1464 |
+
grad_keys.zero_();
|
| 1465 |
+
causal_dot_backward_(queries, keys, values, grad_out, grad_queries, grad_keys, grad_values);
|
| 1466 |
+
}
|
| 1467 |
+
}
|
| 1468 |
+
|
| 1469 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1470 |
+
|
| 1471 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 1472 |
+
m.def(
|
| 1473 |
+
"causal_dot_product",
|
| 1474 |
+
&causal_dot_product,
|
| 1475 |
+
"Compute the weighted sum of values but attending only to previous "
|
| 1476 |
+
"values."
|
| 1477 |
+
);
|
| 1478 |
+
m.def(
|
| 1479 |
+
"causal_dot_backward",
|
| 1480 |
+
&causal_dot_backward,
|
| 1481 |
+
"Compute the gradients for the causal dot product."
|
| 1482 |
+
);
|
| 1483 |
+
}
|
csrc/causal_attention_kv_cuda.cu
ADDED
|
@@ -0,0 +1,1483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//
|
| 2 |
+
// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
// Written by Angelos Katharopoulos <[email protected]>,
|
| 4 |
+
// Apoorv Vyas <[email protected]>
|
| 5 |
+
//
|
| 6 |
+
|
| 7 |
+
//
|
| 8 |
+
// For modifications made inside namespace nvidia (authored by jdemouth):
|
| 9 |
+
//
|
| 10 |
+
// Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved.
|
| 11 |
+
//
|
| 12 |
+
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
| 13 |
+
// this software and associated documentation files (the "Software"), to deal in
|
| 14 |
+
// the Software without restriction, including without limitation the rights to
|
| 15 |
+
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
| 16 |
+
// the Software, and to permit persons to whom the Software is furnished to do so,
|
| 17 |
+
// subject to the following conditions:
|
| 18 |
+
//
|
| 19 |
+
// The above copyright notice and this permission notice shall be included in all
|
| 20 |
+
// copies or substantial portions of the Software.
|
| 21 |
+
//
|
| 22 |
+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 23 |
+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
| 24 |
+
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
| 25 |
+
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
| 26 |
+
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
| 27 |
+
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 28 |
+
//
|
| 29 |
+
|
| 30 |
+
#include <torch/extension.h>
|
| 31 |
+
#include <assert.h>
|
| 32 |
+
#include <stdio.h>
|
| 33 |
+
|
| 34 |
+
#define ENABLE_NVIDIA_OPTIMIZATIONS
|
| 35 |
+
|
| 36 |
+
#ifdef ENABLE_NVIDIA_OPTIMIZATIONS
|
| 37 |
+
namespace nvidia {
|
| 38 |
+
|
| 39 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 40 |
+
|
| 41 |
+
constexpr int THREADS_PER_WARP = 32;
|
| 42 |
+
|
| 43 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
constexpr int LOW_OCCUPANCY_THRESHOLD = 40; // TODO: Make it HW specific (like 1/2 SMs).
|
| 46 |
+
|
| 47 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
static inline __device__ __host__ int div_up(int m, int n) {
|
| 50 |
+
return (m + n-1) / n;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 54 |
+
|
| 55 |
+
static inline __device__ __host__ int round_up(int m, int n) {
|
| 56 |
+
return div_up(m, n) * n;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 60 |
+
|
| 61 |
+
template< typename T >
|
| 62 |
+
struct Lmha_params {
|
| 63 |
+
|
| 64 |
+
// The output buffer. Dimensions [B, H, L, M].
|
| 65 |
+
T *out;
|
| 66 |
+
|
| 67 |
+
// The input Qs. Dimensions [B, H, L, E].
|
| 68 |
+
const T *q;
|
| 69 |
+
// The input Ks. Dimensions [B, H, L, E].
|
| 70 |
+
const T *k;
|
| 71 |
+
// The input Vs. Dimensions [B, H, L, M].
|
| 72 |
+
const T *v;
|
| 73 |
+
|
| 74 |
+
// The different dimensions.
|
| 75 |
+
int B, L, H, E, M;
|
| 76 |
+
|
| 77 |
+
// The strides for the different tensors.
|
| 78 |
+
int q_stride_B, q_stride_H, q_stride_L;
|
| 79 |
+
int k_stride_B, k_stride_H, k_stride_L;
|
| 80 |
+
int v_stride_B, v_stride_H, v_stride_L;
|
| 81 |
+
int o_stride_B, o_stride_H, o_stride_L;
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 85 |
+
|
| 86 |
+
template< int E, bool GO_BACKWARD, int WARPS, int COLS_PER_THREAD = 4 >
|
| 87 |
+
__global__ __launch_bounds__(WARPS * THREADS_PER_WARP)
|
| 88 |
+
void lmha_low_occupancy_kernel(Lmha_params<float> params) {
|
| 89 |
+
|
| 90 |
+
// The number of threads per block.
|
| 91 |
+
constexpr int THREADS_PER_BLOCK = WARPS * THREADS_PER_WARP;
|
| 92 |
+
// The number of rows per thread.
|
| 93 |
+
constexpr int ROWS_PER_THREAD = E / THREADS_PER_WARP;
|
| 94 |
+
// The number of steps per iteration.
|
| 95 |
+
constexpr int COLS_PER_ITER = WARPS * COLS_PER_THREAD;
|
| 96 |
+
|
| 97 |
+
// Make sure E is a multiple of the warp size.
|
| 98 |
+
static_assert(E % THREADS_PER_WARP == 0, "");
|
| 99 |
+
|
| 100 |
+
// Shared memory to store V/O.
|
| 101 |
+
__shared__ float smem_v[COLS_PER_ITER], smem_o[COLS_PER_ITER];
|
| 102 |
+
// Shared memory buffer to performance the reductions.
|
| 103 |
+
__shared__ float smem_reds[E * WARPS];
|
| 104 |
+
|
| 105 |
+
// The sequence processed by that block.
|
| 106 |
+
const int bi = blockIdx.z;
|
| 107 |
+
// The head processed by that block.
|
| 108 |
+
const int hi = blockIdx.y;
|
| 109 |
+
// The hidden cell in the V/output buffers.
|
| 110 |
+
const int vi = blockIdx.x;
|
| 111 |
+
|
| 112 |
+
// The linear index of the thread.
|
| 113 |
+
const int tidx = threadIdx.x;
|
| 114 |
+
|
| 115 |
+
// Decompose the block in warp/lane.
|
| 116 |
+
const int warp = tidx / THREADS_PER_WARP;
|
| 117 |
+
const int lane = tidx % THREADS_PER_WARP;
|
| 118 |
+
|
| 119 |
+
// The base offset loaded by the thread in Q and K.
|
| 120 |
+
int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + lane;
|
| 121 |
+
int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + lane;
|
| 122 |
+
|
| 123 |
+
// If we walk backward, account for the extra offset.
|
| 124 |
+
if( GO_BACKWARD ) {
|
| 125 |
+
offset_q += (params.L-1)*params.q_stride_L;
|
| 126 |
+
offset_k += (params.L-1)*params.k_stride_L;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
// Position the warp at the beginning of the proper timestep.
|
| 130 |
+
if( GO_BACKWARD ) {
|
| 131 |
+
offset_q -= warp*COLS_PER_THREAD*params.q_stride_L;
|
| 132 |
+
offset_k -= warp*COLS_PER_THREAD*params.k_stride_L;
|
| 133 |
+
} else {
|
| 134 |
+
offset_q += warp*COLS_PER_THREAD*params.q_stride_L;
|
| 135 |
+
offset_k += warp*COLS_PER_THREAD*params.k_stride_L;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
// Determine the base pointers for Q and K.
|
| 139 |
+
const float *ptr_q = ¶ms.q[offset_q];
|
| 140 |
+
const float *ptr_k = ¶ms.k[offset_k];
|
| 141 |
+
|
| 142 |
+
// Is a given row valid?
|
| 143 |
+
int valid_qk[ROWS_PER_THREAD];
|
| 144 |
+
#pragma unroll
|
| 145 |
+
for( int ii = 0; ii < ROWS_PER_THREAD; ++ii ) {
|
| 146 |
+
valid_qk[ii] = lane + ii*THREADS_PER_WARP < params.E;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
// The offset to the position loaded by the thread in V.
|
| 150 |
+
int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + vi;
|
| 151 |
+
int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + vi;
|
| 152 |
+
|
| 153 |
+
// If we walk backward, account for the extra offset.
|
| 154 |
+
if( GO_BACKWARD ) {
|
| 155 |
+
offset_v += (params.L-1)*params.v_stride_L;
|
| 156 |
+
offset_o += (params.L-1)*params.o_stride_L;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
// We load/store a strided matrix of COLS_PER_ITER x OUTPUTS_PER_BLOCK.
|
| 160 |
+
if( GO_BACKWARD ) {
|
| 161 |
+
offset_v -= tidx*params.v_stride_L;
|
| 162 |
+
offset_o -= tidx*params.o_stride_L;
|
| 163 |
+
} else {
|
| 164 |
+
offset_v += tidx*params.v_stride_L;
|
| 165 |
+
offset_o += tidx*params.o_stride_L;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
// Determine the base pointer for V.
|
| 169 |
+
const float *ptr_v = ¶ms.v[offset_v];
|
| 170 |
+
// The output pointer.
|
| 171 |
+
float *ptr_o = ¶ms.out[offset_o];
|
| 172 |
+
|
| 173 |
+
// The running KVs.
|
| 174 |
+
float running_kv[ROWS_PER_THREAD];
|
| 175 |
+
#pragma unroll
|
| 176 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
| 177 |
+
running_kv[ri] = 0.f;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
// Iterate over the timesteps. TODO: Use params.loop_count!!!
|
| 181 |
+
for( int iter = 0; iter < params.L; iter += COLS_PER_ITER ) {
|
| 182 |
+
|
| 183 |
+
// Each thread loads a matrix of elements.
|
| 184 |
+
float q[ROWS_PER_THREAD][COLS_PER_THREAD], k[ROWS_PER_THREAD][COLS_PER_THREAD];
|
| 185 |
+
|
| 186 |
+
// Trigger the memory loads for Q and K.
|
| 187 |
+
#pragma unroll
|
| 188 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
| 189 |
+
#pragma unroll
|
| 190 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
| 191 |
+
|
| 192 |
+
// For Q/K, each warp loads from various timesteps.
|
| 193 |
+
int ti = iter + warp*COLS_PER_THREAD;
|
| 194 |
+
if( GO_BACKWARD ) {
|
| 195 |
+
ti = params.L - 1 - ti;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
// Is it a valid access?
|
| 199 |
+
int valid;
|
| 200 |
+
if( GO_BACKWARD ) {
|
| 201 |
+
valid = valid_qk[ri] && ti - ci >= 0;
|
| 202 |
+
} else {
|
| 203 |
+
valid = valid_qk[ri] && ti + ci < params.L;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
// The extra offset to add.
|
| 207 |
+
if( GO_BACKWARD ) {
|
| 208 |
+
offset_q = ri*THREADS_PER_WARP - ci*params.q_stride_L;
|
| 209 |
+
offset_k = ri*THREADS_PER_WARP - ci*params.k_stride_L;
|
| 210 |
+
} else {
|
| 211 |
+
offset_q = ri*THREADS_PER_WARP + ci*params.q_stride_L;
|
| 212 |
+
offset_k = ri*THREADS_PER_WARP + ci*params.k_stride_L;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
// Load Q/K if they are valid.
|
| 216 |
+
q[ri][ci] = valid ? ptr_q[offset_q] : 0.f;
|
| 217 |
+
k[ri][ci] = valid ? ptr_k[offset_k] : 0.f;
|
| 218 |
+
}
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
// For the V tensor, we assign contiguous thread to different loads. So, ti is different.
|
| 222 |
+
int ti = iter + tidx;
|
| 223 |
+
if( GO_BACKWARD ) {
|
| 224 |
+
ti = params.L - 1 - ti;
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
// Is it a valid access?
|
| 228 |
+
int valid_vo = tidx < COLS_PER_ITER;
|
| 229 |
+
if( GO_BACKWARD ) {
|
| 230 |
+
valid_vo &= ti >= 0;
|
| 231 |
+
} else {
|
| 232 |
+
valid_vo &= ti < params.L;
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
// Trigger the loads for V.
|
| 236 |
+
float ldg_v = valid_vo ? *ptr_v : 0.f;
|
| 237 |
+
|
| 238 |
+
// Move the load pointers.
|
| 239 |
+
if( GO_BACKWARD ) {
|
| 240 |
+
ptr_q -= COLS_PER_ITER*params.q_stride_L;
|
| 241 |
+
ptr_k -= COLS_PER_ITER*params.k_stride_L;
|
| 242 |
+
ptr_v -= COLS_PER_ITER*params.v_stride_L;
|
| 243 |
+
} else {
|
| 244 |
+
ptr_q += COLS_PER_ITER*params.q_stride_L;
|
| 245 |
+
ptr_k += COLS_PER_ITER*params.k_stride_L;
|
| 246 |
+
ptr_v += COLS_PER_ITER*params.v_stride_L;
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
// Store to shared memory.
|
| 250 |
+
if( tidx < COLS_PER_ITER ) {
|
| 251 |
+
smem_v[tidx] = ldg_v;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
// Make sure V is in shared memory.
|
| 255 |
+
__syncthreads();
|
| 256 |
+
|
| 257 |
+
// Read V from shared memory.
|
| 258 |
+
float v[COLS_PER_THREAD];
|
| 259 |
+
#pragma unroll
|
| 260 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
| 261 |
+
v[ci] = smem_v[warp*COLS_PER_THREAD + ci];
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
// Each thread computes local K*V products.
|
| 265 |
+
float kv[ROWS_PER_THREAD][COLS_PER_THREAD];
|
| 266 |
+
#pragma unroll
|
| 267 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
| 268 |
+
#pragma unroll
|
| 269 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
| 270 |
+
kv[ri][ci] = 0.f;
|
| 271 |
+
}
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
// Update the K*V^T product.
|
| 275 |
+
#pragma unroll
|
| 276 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
| 277 |
+
#pragma unroll
|
| 278 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
| 279 |
+
kv[ri][ci] += k[ri][ci] * v[ci];
|
| 280 |
+
}
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
// We must perform the prefix sums within the thread-block. Start with the thread.
|
| 284 |
+
#pragma unroll
|
| 285 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
| 286 |
+
#pragma unroll
|
| 287 |
+
for( int ci = 1; ci < COLS_PER_THREAD; ++ci ) {
|
| 288 |
+
kv[ri][ci] += kv[ri][ci-1];
|
| 289 |
+
}
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
// Store the partial sums to shared memory. Unless we have no inter-warp reduction to perform.
|
| 293 |
+
#pragma unroll
|
| 294 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
| 295 |
+
smem_reds[warp*E + ri*THREADS_PER_WARP + lane] = kv[ri][COLS_PER_THREAD-1];
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
// Make sure the data is in shared memory.
|
| 299 |
+
__syncthreads();
|
| 300 |
+
|
| 301 |
+
// Each thread deals with one or more column(s) of the matrix.
|
| 302 |
+
constexpr int SUMS_PER_THREAD = (E + THREADS_PER_BLOCK-1) / THREADS_PER_BLOCK;
|
| 303 |
+
#pragma unroll
|
| 304 |
+
for( int ii = 0, idx = tidx; ii < SUMS_PER_THREAD; ++ii, idx += THREADS_PER_BLOCK ) {
|
| 305 |
+
if( idx < E ) {
|
| 306 |
+
float sum = smem_reds[idx];
|
| 307 |
+
#pragma unroll
|
| 308 |
+
for( int jj = 1; jj < WARPS; ++jj ) {
|
| 309 |
+
smem_reds[idx + jj*E] = sum += smem_reds[idx + jj*E];
|
| 310 |
+
}
|
| 311 |
+
}
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
// Make sure the reductions are stored in shared memory.
|
| 315 |
+
__syncthreads();
|
| 316 |
+
|
| 317 |
+
// Each thread updates his partial products.
|
| 318 |
+
#pragma unroll
|
| 319 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
| 320 |
+
float sum = running_kv[ri];
|
| 321 |
+
if( warp > 0 ) {
|
| 322 |
+
sum += smem_reds[(warp-1)*E + lane + ri*THREADS_PER_WARP];
|
| 323 |
+
}
|
| 324 |
+
#pragma unroll
|
| 325 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
| 326 |
+
kv[ri][ci] += sum;
|
| 327 |
+
}
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
// Compute the partial output values for that thread.
|
| 331 |
+
float sum[COLS_PER_THREAD];
|
| 332 |
+
#pragma unroll
|
| 333 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
| 334 |
+
sum[ci] = q[0][ci] * kv[0][ci];
|
| 335 |
+
#pragma unroll
|
| 336 |
+
for( int ri = 1; ri < ROWS_PER_THREAD; ++ri ) {
|
| 337 |
+
sum[ci] += q[ri][ci] * kv[ri][ci];
|
| 338 |
+
}
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
// Run the parallel reductions inside the warp.
|
| 342 |
+
#pragma unroll
|
| 343 |
+
for( int mask = THREADS_PER_WARP / 2; mask >= 1; mask /= 2 ) {
|
| 344 |
+
#pragma unroll
|
| 345 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
| 346 |
+
sum[ci] += __shfl_xor_sync(uint32_t(-1), sum[ci], mask);
|
| 347 |
+
}
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
// Store the final output to shared memory.
|
| 351 |
+
if( lane == 0 ) {
|
| 352 |
+
#pragma unroll
|
| 353 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
| 354 |
+
smem_o[warp*COLS_PER_THREAD + ci] = sum[ci];
|
| 355 |
+
}
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
// Make sure the data is in shared memory.
|
| 359 |
+
__syncthreads();
|
| 360 |
+
|
| 361 |
+
// Store the output.
|
| 362 |
+
if( valid_vo ) {
|
| 363 |
+
*ptr_o = smem_o[tidx];
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
// Each thread updates his running kv.
|
| 367 |
+
#pragma unroll
|
| 368 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
| 369 |
+
running_kv[ri] += smem_reds[(WARPS-1)*E + lane + ri*THREADS_PER_WARP];
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
// Move to next location.
|
| 373 |
+
if( GO_BACKWARD ) {
|
| 374 |
+
ptr_o -= COLS_PER_ITER*params.o_stride_L;
|
| 375 |
+
} else {
|
| 376 |
+
ptr_o += COLS_PER_ITER*params.o_stride_L;
|
| 377 |
+
}
|
| 378 |
+
}
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 382 |
+
|
| 383 |
+
template< int E, bool GO_BACKWARD, int WARPS >
|
| 384 |
+
int lmha_low_occupancy_(const Lmha_params<float> ¶ms) {
|
| 385 |
+
|
| 386 |
+
// Make sure we are not going to launch an invalid grid.
|
| 387 |
+
if( params.H > 65535 || params.B > 65535 ) {
|
| 388 |
+
return 1;
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
// Prepare the grid and trigger the CUDA kernel.
|
| 392 |
+
dim3 grid;
|
| 393 |
+
grid.x = params.M;
|
| 394 |
+
grid.y = params.H;
|
| 395 |
+
grid.z = params.B;
|
| 396 |
+
lmha_low_occupancy_kernel<E, GO_BACKWARD, WARPS><<<grid, WARPS*THREADS_PER_WARP>>>(params);
|
| 397 |
+
return 0;
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 401 |
+
|
| 402 |
+
template< int E, bool GO_BACKWARD >
|
| 403 |
+
int lmha_low_occupancy_(const Lmha_params<float> ¶ms, int blocks) {
|
| 404 |
+
if( params.M * blocks >= 8*LOW_OCCUPANCY_THRESHOLD ) {
|
| 405 |
+
return lmha_low_occupancy_<E, GO_BACKWARD, 4>(params);
|
| 406 |
+
} else if( params.M * blocks >= 4*LOW_OCCUPANCY_THRESHOLD ) {
|
| 407 |
+
return lmha_low_occupancy_<E, GO_BACKWARD, 8>(params);
|
| 408 |
+
} else {
|
| 409 |
+
return lmha_low_occupancy_<E, GO_BACKWARD, 16>(params);
|
| 410 |
+
}
|
| 411 |
+
return 1;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 415 |
+
|
| 416 |
+
template< int E, typename Params >
|
| 417 |
+
static inline __device__ __host__ int smem_buffer_elts_(const Params ¶ms) {
|
| 418 |
+
int M = round_up(params.M, 4);
|
| 419 |
+
return 2*E + 2*M;
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 423 |
+
|
| 424 |
+
template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD >
|
| 425 |
+
__global__
|
| 426 |
+
void lmha_kernel(Lmha_params<float> params) {
|
| 427 |
+
|
| 428 |
+
// Make sure E is a multiple of 4.
|
| 429 |
+
static_assert(E % 4 == 0, "");
|
| 430 |
+
|
| 431 |
+
// The amount of shared memory per buffer (2 buffers for double-buffering).
|
| 432 |
+
const int smem_buffer_elts = smem_buffer_elts_<E>(params);
|
| 433 |
+
// The M dimension for shared memory.
|
| 434 |
+
const int M = round_up(params.M, 4);
|
| 435 |
+
|
| 436 |
+
// Shared memory to store Q, K and V. Size is 2*smem_buffer_elts.
|
| 437 |
+
extern __shared__ float smem_[];
|
| 438 |
+
|
| 439 |
+
// The various shared memory buffers.
|
| 440 |
+
float *smem_q = &smem_[0*E];
|
| 441 |
+
float *smem_k = &smem_[1*E];
|
| 442 |
+
float *smem_v = &smem_[2*E];
|
| 443 |
+
float *smem_o = &smem_[2*E + M];
|
| 444 |
+
|
| 445 |
+
// The index of the shared memory buffer (for double-buffering).
|
| 446 |
+
int smem_curr = 0;
|
| 447 |
+
|
| 448 |
+
// The sequence processed by that block.
|
| 449 |
+
const int bi = blockIdx.y;
|
| 450 |
+
// The head processed by that block.
|
| 451 |
+
const int hi = blockIdx.x;
|
| 452 |
+
|
| 453 |
+
// The linear index of the thread.
|
| 454 |
+
const int tidx = threadIdx.x;
|
| 455 |
+
|
| 456 |
+
// The offset to the position loaded by the thread in Q.
|
| 457 |
+
int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + tidx;
|
| 458 |
+
// The offset to the position loaded by the thread in K.
|
| 459 |
+
int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + tidx;
|
| 460 |
+
|
| 461 |
+
// If we walk backward, account for the extra offset.
|
| 462 |
+
if( GO_BACKWARD ) {
|
| 463 |
+
offset_q += (params.L-1)*params.q_stride_L;
|
| 464 |
+
offset_k += (params.L-1)*params.k_stride_L;
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
// Determine the base pointers for Q and K.
|
| 468 |
+
const float *ptr_q = ¶ms.q[offset_q];
|
| 469 |
+
const float *ptr_k = ¶ms.k[offset_k];
|
| 470 |
+
|
| 471 |
+
// The offset to the position loaded by the thread in V and O.
|
| 472 |
+
int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + tidx;
|
| 473 |
+
int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + tidx;
|
| 474 |
+
|
| 475 |
+
// If we walk backward, account for the extra offset.
|
| 476 |
+
if( GO_BACKWARD ) {
|
| 477 |
+
offset_v += (params.L-1)*params.v_stride_L;
|
| 478 |
+
offset_o += (params.L-1)*params.o_stride_L;
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
// Determine the base pointers for V.
|
| 482 |
+
const float *ptr_v = ¶ms.v[offset_v];
|
| 483 |
+
|
| 484 |
+
// Is it an active Q/K thread?
|
| 485 |
+
const int active_qk = tidx < params.E;
|
| 486 |
+
|
| 487 |
+
// Trigger the memory loads for Q and K.
|
| 488 |
+
float ldg_q = 0.f, ldg_k = 0.f;
|
| 489 |
+
if( active_qk ) {
|
| 490 |
+
ldg_q = *ptr_q;
|
| 491 |
+
ldg_k = *ptr_k;
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
// Is it an active V thread?
|
| 495 |
+
const int active_v = tidx < params.M;
|
| 496 |
+
|
| 497 |
+
// Trigger the memory loads for V.
|
| 498 |
+
float ldg_v = 0.f;
|
| 499 |
+
if( active_v ) {
|
| 500 |
+
ldg_v = *ptr_v;
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
+
// Move the load pointers.
|
| 504 |
+
if( GO_BACKWARD ) {
|
| 505 |
+
ptr_q -= params.q_stride_L;
|
| 506 |
+
ptr_k -= params.k_stride_L;
|
| 507 |
+
ptr_v -= params.v_stride_L;
|
| 508 |
+
} else {
|
| 509 |
+
ptr_q += params.q_stride_L;
|
| 510 |
+
ptr_k += params.k_stride_L;
|
| 511 |
+
ptr_v += params.v_stride_L;
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
// The number of FLOAT4s per head.
|
| 515 |
+
constexpr int FLOAT4s_PER_HEAD = E / 4;
|
| 516 |
+
// The number of FLOAT4s per thread.
|
| 517 |
+
constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD;
|
| 518 |
+
|
| 519 |
+
// The storage for the K*V^T values.
|
| 520 |
+
float4 kv[FLOAT4s_PER_THREAD];
|
| 521 |
+
#pragma unroll
|
| 522 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 523 |
+
kv[ii] = make_float4(0.f, 0.f, 0.f, 0.f);
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
// The output pointer.
|
| 527 |
+
float *out_ptr = ¶ms.out[offset_o];
|
| 528 |
+
|
| 529 |
+
// Store to shared memory Q and K.
|
| 530 |
+
if( tidx < E ) {
|
| 531 |
+
smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q;
|
| 532 |
+
smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k;
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
// Store to shared memory V. All threads store valid values.
|
| 536 |
+
if( tidx < M ) {
|
| 537 |
+
smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v;
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
// The position of the thread in the V dimension.
|
| 541 |
+
int vo = tidx / THREADS_PER_HEAD;
|
| 542 |
+
int vi = tidx % THREADS_PER_HEAD;
|
| 543 |
+
|
| 544 |
+
// Iterate over the timesteps.
|
| 545 |
+
for( int ti = 0; ti < params.L; ++ti ) {
|
| 546 |
+
|
| 547 |
+
// Is it the last iteration?
|
| 548 |
+
int is_last = ti == params.L - 1;
|
| 549 |
+
|
| 550 |
+
// Trigger the next loads for Q and K.
|
| 551 |
+
if( !is_last && active_qk ) {
|
| 552 |
+
ldg_q = *ptr_q;
|
| 553 |
+
ldg_k = *ptr_k;
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
// Trigger the next loads for V.
|
| 557 |
+
if( !is_last && active_v ) {
|
| 558 |
+
ldg_v = *ptr_v;
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
// Move the load pointers.
|
| 562 |
+
if( GO_BACKWARD ) {
|
| 563 |
+
ptr_q -= params.q_stride_L;
|
| 564 |
+
ptr_k -= params.k_stride_L;
|
| 565 |
+
ptr_v -= params.v_stride_L;
|
| 566 |
+
} else {
|
| 567 |
+
ptr_q += params.q_stride_L;
|
| 568 |
+
ptr_k += params.k_stride_L;
|
| 569 |
+
ptr_v += params.v_stride_L;
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
// Make sure the data is in shared memory.
|
| 573 |
+
__syncthreads();
|
| 574 |
+
|
| 575 |
+
// Each thread loads 4 values from K.
|
| 576 |
+
float4 k[FLOAT4s_PER_THREAD];
|
| 577 |
+
#pragma unroll
|
| 578 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 579 |
+
int ki = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4;
|
| 580 |
+
k[ii] = *reinterpret_cast<const float4*>(&smem_k[smem_curr*smem_buffer_elts + ki]);
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
// Each thread loads a single V value.
|
| 584 |
+
float v = 0.f;
|
| 585 |
+
if( vo < params.M ) {
|
| 586 |
+
v = *reinterpret_cast<const float *>(&smem_v[smem_curr*smem_buffer_elts + vo]);
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
// Update the K*V^T product.
|
| 590 |
+
#pragma unroll
|
| 591 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 592 |
+
kv[ii].x += k[ii].x * v;
|
| 593 |
+
kv[ii].y += k[ii].y * v;
|
| 594 |
+
kv[ii].z += k[ii].z * v;
|
| 595 |
+
kv[ii].w += k[ii].w * v;
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
// Load the Q values from shared memory.
|
| 599 |
+
float4 q[FLOAT4s_PER_THREAD];
|
| 600 |
+
#pragma unroll
|
| 601 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 602 |
+
int qi = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4;
|
| 603 |
+
q[ii] = *reinterpret_cast<const float4*>(&smem_q[smem_curr*smem_buffer_elts + qi]);
|
| 604 |
+
}
|
| 605 |
+
|
| 606 |
+
// Compute the partial output value for that thread.
|
| 607 |
+
float sum = 0.f;
|
| 608 |
+
#pragma unroll
|
| 609 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 610 |
+
sum += q[ii].x * kv[ii].x;
|
| 611 |
+
sum += q[ii].y * kv[ii].y;
|
| 612 |
+
sum += q[ii].z * kv[ii].z;
|
| 613 |
+
sum += q[ii].w * kv[ii].w;
|
| 614 |
+
}
|
| 615 |
+
|
| 616 |
+
// Finalize the computation of the sum (if we have more than 1 thread per head).
|
| 617 |
+
if( THREADS_PER_HEAD > 1 ) {
|
| 618 |
+
|
| 619 |
+
// Finalize the sum for each head.
|
| 620 |
+
#pragma unroll
|
| 621 |
+
for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) {
|
| 622 |
+
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
+
// Store to shared memory.
|
| 626 |
+
if( vo < M && vi == 0 ) {
|
| 627 |
+
smem_o[smem_curr*smem_buffer_elts + vo] = sum;
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
// Make sure the data is in shared memory.
|
| 631 |
+
__syncthreads();
|
| 632 |
+
|
| 633 |
+
// Active threads read the data to store.
|
| 634 |
+
if( active_v ) {
|
| 635 |
+
sum = smem_o[smem_curr*smem_buffer_elts + tidx];
|
| 636 |
+
}
|
| 637 |
+
|
| 638 |
+
} // THREADS_PER_HEAD > 1.
|
| 639 |
+
|
| 640 |
+
// Store the output. All the threads are active.
|
| 641 |
+
if( active_v ) {
|
| 642 |
+
*out_ptr = sum;
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
// Move to next location.
|
| 646 |
+
if( GO_BACKWARD ) {
|
| 647 |
+
out_ptr -= params.o_stride_L;
|
| 648 |
+
} else {
|
| 649 |
+
out_ptr += params.o_stride_L;
|
| 650 |
+
}
|
| 651 |
+
|
| 652 |
+
// Move the shared memory buffer.
|
| 653 |
+
smem_curr = (smem_curr + 1) % 2;
|
| 654 |
+
|
| 655 |
+
// Store to shared memory for Q and K.
|
| 656 |
+
if( !is_last && tidx < E ) {
|
| 657 |
+
smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q;
|
| 658 |
+
smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k;
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
// Store to shared memory for V.
|
| 662 |
+
if( !is_last && tidx < M ) {
|
| 663 |
+
smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v;
|
| 664 |
+
}
|
| 665 |
+
}
|
| 666 |
+
}
|
| 667 |
+
|
| 668 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 669 |
+
|
| 670 |
+
template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD >
|
| 671 |
+
int lmha_(const Lmha_params<float> ¶ms) {
|
| 672 |
+
// The M dimension rounded up to 4.
|
| 673 |
+
int M = round_up(params.M, 4);
|
| 674 |
+
|
| 675 |
+
// The number of threads in the block.
|
| 676 |
+
int block = round_up(max(E, M*THREADS_PER_HEAD), 32);
|
| 677 |
+
if( block > 512 || params.B > 65535 ) {
|
| 678 |
+
return 1;
|
| 679 |
+
}
|
| 680 |
+
|
| 681 |
+
// Prepare the kernel.
|
| 682 |
+
dim3 grid(params.H, params.B);
|
| 683 |
+
size_t smem = smem_buffer_elts_<E>(params)*2*sizeof(float);
|
| 684 |
+
lmha_kernel<E, THREADS_PER_HEAD, GO_BACKWARD><<<grid, block, smem>>>(params);
|
| 685 |
+
return 0;
|
| 686 |
+
}
|
| 687 |
+
|
| 688 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 689 |
+
|
| 690 |
+
template< bool GO_BACKWARD >
|
| 691 |
+
int lmha(const Lmha_params<float> ¶ms) {
|
| 692 |
+
int blocks = params.B * params.H;
|
| 693 |
+
int res = 1;
|
| 694 |
+
if( blocks < LOW_OCCUPANCY_THRESHOLD ) {
|
| 695 |
+
if( params.E <= 32 ) {
|
| 696 |
+
res = lmha_low_occupancy_< 32, GO_BACKWARD>(params, blocks);
|
| 697 |
+
} else if( params.E <= 64 ) {
|
| 698 |
+
res = lmha_low_occupancy_< 64, GO_BACKWARD>(params, blocks);
|
| 699 |
+
} else if( params.E <= 128 ) {
|
| 700 |
+
res = lmha_low_occupancy_<128, GO_BACKWARD>(params, blocks);
|
| 701 |
+
} else if( params.E <= 256 ) {
|
| 702 |
+
res = lmha_low_occupancy_<256, GO_BACKWARD>(params, blocks);
|
| 703 |
+
}
|
| 704 |
+
} else {
|
| 705 |
+
if( params.E <= 32 ) {
|
| 706 |
+
res = lmha_< 32, 1, GO_BACKWARD>(params);
|
| 707 |
+
} else if( params.E <= 48 ) {
|
| 708 |
+
res = lmha_< 48, 1, GO_BACKWARD>(params);
|
| 709 |
+
} else if( params.E <= 64 ) {
|
| 710 |
+
res = lmha_< 64, 1, GO_BACKWARD>(params);
|
| 711 |
+
} else if( params.E <= 128 ) {
|
| 712 |
+
res = lmha_<128, 2, GO_BACKWARD>(params);
|
| 713 |
+
} else if( params.E <= 256 ) {
|
| 714 |
+
res = lmha_<256, 4, GO_BACKWARD>(params);
|
| 715 |
+
}
|
| 716 |
+
}
|
| 717 |
+
return res;
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 721 |
+
|
| 722 |
+
template< typename T >
|
| 723 |
+
inline void set_params(Lmha_params<T> ¶ms,
|
| 724 |
+
const torch::Tensor q,
|
| 725 |
+
const torch::Tensor k,
|
| 726 |
+
const torch::Tensor v,
|
| 727 |
+
torch::Tensor o) {
|
| 728 |
+
|
| 729 |
+
// Define the pointers.
|
| 730 |
+
params.out = o.data_ptr<T>();
|
| 731 |
+
params.q = q.data_ptr<T>();
|
| 732 |
+
params.k = k.data_ptr<T>();
|
| 733 |
+
params.v = v.data_ptr<T>();
|
| 734 |
+
|
| 735 |
+
// Define the strides.
|
| 736 |
+
params.q_stride_B = (int) q.stride(0);
|
| 737 |
+
params.q_stride_H = (int) q.stride(1);
|
| 738 |
+
params.q_stride_L = (int) q.stride(2);
|
| 739 |
+
params.k_stride_B = (int) k.stride(0);
|
| 740 |
+
params.k_stride_H = (int) k.stride(1);
|
| 741 |
+
params.k_stride_L = (int) k.stride(2);
|
| 742 |
+
params.v_stride_B = (int) v.stride(0);
|
| 743 |
+
params.v_stride_H = (int) v.stride(1);
|
| 744 |
+
params.v_stride_L = (int) v.stride(2);
|
| 745 |
+
params.o_stride_B = (int) o.stride(0);
|
| 746 |
+
params.o_stride_H = (int) o.stride(1);
|
| 747 |
+
params.o_stride_L = (int) o.stride(2);
|
| 748 |
+
|
| 749 |
+
// Extract the dimensions.
|
| 750 |
+
int N = q.size(0);
|
| 751 |
+
int H = q.size(1);
|
| 752 |
+
int L = q.size(2);
|
| 753 |
+
int E = q.size(3);
|
| 754 |
+
int M = v.size(3);
|
| 755 |
+
|
| 756 |
+
params.B = N;
|
| 757 |
+
params.L = L;
|
| 758 |
+
params.H = H;
|
| 759 |
+
params.E = E;
|
| 760 |
+
params.M = M;
|
| 761 |
+
}
|
| 762 |
+
|
| 763 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 764 |
+
|
| 765 |
+
int lmha_fwd(const torch::Tensor queries,
|
| 766 |
+
const torch::Tensor keys,
|
| 767 |
+
const torch::Tensor values,
|
| 768 |
+
torch::Tensor product) {
|
| 769 |
+
|
| 770 |
+
// Make sure that we are using the correct GPU device
|
| 771 |
+
torch::DeviceGuard _guard(queries.device());
|
| 772 |
+
|
| 773 |
+
// Make sure the inner-most dimension of the tensors is packed.
|
| 774 |
+
assert(queries.stride(3) == 1);
|
| 775 |
+
assert(keys .stride(3) == 1);
|
| 776 |
+
assert(values .stride(3) == 1);
|
| 777 |
+
assert(product.stride(3) == 1);
|
| 778 |
+
|
| 779 |
+
// Extract the dimensions.
|
| 780 |
+
int N = queries.size(0);
|
| 781 |
+
int H = queries.size(1);
|
| 782 |
+
int L = queries.size(2);
|
| 783 |
+
int E = queries.size(3);
|
| 784 |
+
int M = values.size (3);
|
| 785 |
+
|
| 786 |
+
// The structure of params.
|
| 787 |
+
Lmha_params<float> params;
|
| 788 |
+
set_params(params, queries, keys, values, product);
|
| 789 |
+
|
| 790 |
+
// Launch the kernel.
|
| 791 |
+
return lmha<false>(params);
|
| 792 |
+
}
|
| 793 |
+
|
| 794 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 795 |
+
|
| 796 |
+
template< typename T >
|
| 797 |
+
struct Lmha_bwd_params {
|
| 798 |
+
|
| 799 |
+
// The output buffer for K. Dimensions [B, H, L, D].
|
| 800 |
+
T *out_k;
|
| 801 |
+
// The output buffer for V. Dimensions [B, H, L, D].
|
| 802 |
+
T *out_v;
|
| 803 |
+
|
| 804 |
+
// The input Qs. Dimensions [B, H, L, D].
|
| 805 |
+
const T *q;
|
| 806 |
+
// The input Ks. Dimensions [B, H, L, D].
|
| 807 |
+
const T *k;
|
| 808 |
+
// The input Vs. Dimensions [B, H, L, D].
|
| 809 |
+
const T *v;
|
| 810 |
+
// The input Gs. Dimensions [B, H, L, D].
|
| 811 |
+
const T *g;
|
| 812 |
+
|
| 813 |
+
// The dimensions.
|
| 814 |
+
int B, L, H, M, E;
|
| 815 |
+
|
| 816 |
+
// The strides for the input tensors.
|
| 817 |
+
int q_stride_B, q_stride_L, q_stride_H;
|
| 818 |
+
int k_stride_B, k_stride_L, k_stride_H;
|
| 819 |
+
int v_stride_B, v_stride_L, v_stride_H;
|
| 820 |
+
int g_stride_B, g_stride_L, g_stride_H;
|
| 821 |
+
|
| 822 |
+
// The strides for the outputs.
|
| 823 |
+
int out_k_stride_B, out_k_stride_L, out_k_stride_H;
|
| 824 |
+
int out_v_stride_B, out_v_stride_L, out_v_stride_H;
|
| 825 |
+
};
|
| 826 |
+
|
| 827 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 828 |
+
|
| 829 |
+
template< int D, int THREADS_PER_HEAD >
|
| 830 |
+
__global__ __launch_bounds__(D*THREADS_PER_HEAD*2)
|
| 831 |
+
void lmha_bwd_kernel(Lmha_bwd_params<float> params) {
|
| 832 |
+
|
| 833 |
+
// Make sure D is a multiple of 4.
|
| 834 |
+
static_assert(D % 4 == 0, "");
|
| 835 |
+
|
| 836 |
+
// The shared memory buffers.
|
| 837 |
+
__shared__ struct Smem { float qg[2*D], kv[2*D], out_kv[2*D]; } smem_[2];
|
| 838 |
+
|
| 839 |
+
// The index of the shared memory buffer (for double-buffering).
|
| 840 |
+
int smem_curr = 0;
|
| 841 |
+
|
| 842 |
+
// The sequence processed by that block.
|
| 843 |
+
const int bi = blockIdx.y;
|
| 844 |
+
// The head processed by that block.
|
| 845 |
+
const int hi = blockIdx.x;
|
| 846 |
+
|
| 847 |
+
// The linear index of the thread.
|
| 848 |
+
const int tidx = threadIdx.x;
|
| 849 |
+
|
| 850 |
+
// Split the threads into two slices.
|
| 851 |
+
int so = tidx / (D*THREADS_PER_HEAD);
|
| 852 |
+
int si = tidx % (D*THREADS_PER_HEAD);
|
| 853 |
+
|
| 854 |
+
// The strides for B/L/H for the Q/G tensors.
|
| 855 |
+
int qg_stride_B, qg_stride_L, qg_stride_H;
|
| 856 |
+
if( so == 0 ) {
|
| 857 |
+
qg_stride_B = params.q_stride_B;
|
| 858 |
+
qg_stride_L = params.q_stride_L;
|
| 859 |
+
qg_stride_H = params.q_stride_H;
|
| 860 |
+
} else {
|
| 861 |
+
qg_stride_B = params.g_stride_B;
|
| 862 |
+
qg_stride_L = params.g_stride_L;
|
| 863 |
+
qg_stride_H = params.g_stride_H;
|
| 864 |
+
}
|
| 865 |
+
|
| 866 |
+
// The strides for B/L/H for the K/V tensors.
|
| 867 |
+
int kv_stride_B, kv_stride_L, kv_stride_H;
|
| 868 |
+
if( so == 0 ) {
|
| 869 |
+
kv_stride_B = params.k_stride_B;
|
| 870 |
+
kv_stride_L = params.k_stride_L;
|
| 871 |
+
kv_stride_H = params.k_stride_H;
|
| 872 |
+
} else {
|
| 873 |
+
kv_stride_B = params.v_stride_B;
|
| 874 |
+
kv_stride_L = params.v_stride_L;
|
| 875 |
+
kv_stride_H = params.v_stride_H;
|
| 876 |
+
}
|
| 877 |
+
|
| 878 |
+
// The hidden size.
|
| 879 |
+
int hidden_size_per_head = 0;
|
| 880 |
+
if( so == 0 ) {
|
| 881 |
+
hidden_size_per_head = params.E;
|
| 882 |
+
} else {
|
| 883 |
+
hidden_size_per_head = params.M;
|
| 884 |
+
}
|
| 885 |
+
|
| 886 |
+
// Where to start reading from.
|
| 887 |
+
int offset_qg = bi*qg_stride_B + hi*qg_stride_H + si;
|
| 888 |
+
int offset_kv = bi*kv_stride_B + hi*kv_stride_H + si;
|
| 889 |
+
|
| 890 |
+
// We walk backward, account for the extra offset.
|
| 891 |
+
offset_qg += (params.L-1)*qg_stride_L;
|
| 892 |
+
offset_kv += (params.L-1)*kv_stride_L;
|
| 893 |
+
|
| 894 |
+
// Determine the base pointers for Q, K, V and G.
|
| 895 |
+
const float *ptr_qg = &(so == 0 ? params.q : params.g)[offset_qg];
|
| 896 |
+
const float *ptr_kv = &(so == 0 ? params.k : params.v)[offset_kv];
|
| 897 |
+
|
| 898 |
+
// Is it an active thread?
|
| 899 |
+
const int active = si < hidden_size_per_head;
|
| 900 |
+
|
| 901 |
+
// Trigger the memory loads for Q, K, V and G.
|
| 902 |
+
float ldg_qg = 0.f, ldg_kv = 0.f;
|
| 903 |
+
if( active ) {
|
| 904 |
+
ldg_qg = *ptr_qg;
|
| 905 |
+
ldg_kv = *ptr_kv;
|
| 906 |
+
}
|
| 907 |
+
|
| 908 |
+
// Move the load pointers (backward).
|
| 909 |
+
ptr_qg -= qg_stride_L;
|
| 910 |
+
ptr_kv -= kv_stride_L;
|
| 911 |
+
|
| 912 |
+
// The number of FLOAT4s per head.
|
| 913 |
+
constexpr int FLOAT4s_PER_HEAD = D / 4;
|
| 914 |
+
// The number of FLOAT4s per thread.
|
| 915 |
+
constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD;
|
| 916 |
+
|
| 917 |
+
// The storage for the G*Q^T or Q^T*G values.
|
| 918 |
+
float4 gq[FLOAT4s_PER_THREAD];
|
| 919 |
+
#pragma unroll
|
| 920 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 921 |
+
gq[ii] = make_float4(0.f, 0.f, 0.f, 0.f);
|
| 922 |
+
}
|
| 923 |
+
|
| 924 |
+
// The strides for B/L/H for the K/V tensors.
|
| 925 |
+
int out_kv_stride_B, out_kv_stride_L, out_kv_stride_H;
|
| 926 |
+
if( so == 0 ) {
|
| 927 |
+
out_kv_stride_B = params.out_k_stride_B;
|
| 928 |
+
out_kv_stride_L = params.out_k_stride_L;
|
| 929 |
+
out_kv_stride_H = params.out_k_stride_H;
|
| 930 |
+
} else {
|
| 931 |
+
out_kv_stride_B = params.out_v_stride_B;
|
| 932 |
+
out_kv_stride_L = params.out_v_stride_L;
|
| 933 |
+
out_kv_stride_H = params.out_v_stride_H;
|
| 934 |
+
}
|
| 935 |
+
|
| 936 |
+
// Where to start reading from.
|
| 937 |
+
int offset_out_kv = bi*out_kv_stride_B + hi*out_kv_stride_H + si;
|
| 938 |
+
|
| 939 |
+
// We walk backward, account for the extra offset.
|
| 940 |
+
offset_out_kv += (params.L-1)*out_kv_stride_L;
|
| 941 |
+
|
| 942 |
+
// The output pointer.
|
| 943 |
+
float *ptr_out_kv = &(so == 0 ? params.out_k : params.out_v)[offset_out_kv];
|
| 944 |
+
|
| 945 |
+
// Store to shared memory.
|
| 946 |
+
if( si < D ) {
|
| 947 |
+
smem_[smem_curr].qg[so*D + si] = ldg_qg;
|
| 948 |
+
smem_[smem_curr].kv[so*D + si] = ldg_kv;
|
| 949 |
+
}
|
| 950 |
+
|
| 951 |
+
// The position of the thread in the output dimension.
|
| 952 |
+
int oo = si / THREADS_PER_HEAD % D;
|
| 953 |
+
int oi = si % THREADS_PER_HEAD * 4;
|
| 954 |
+
|
| 955 |
+
// Iterate over the timesteps.
|
| 956 |
+
for( int ti = 0; ti < params.L; ++ti ) {
|
| 957 |
+
|
| 958 |
+
// Is it the last iteration?
|
| 959 |
+
int is_last = ti == params.L - 1;
|
| 960 |
+
|
| 961 |
+
// Trigger the next loads.
|
| 962 |
+
if( !is_last && active ) {
|
| 963 |
+
ldg_qg = *ptr_qg;
|
| 964 |
+
ldg_kv = *ptr_kv;
|
| 965 |
+
}
|
| 966 |
+
|
| 967 |
+
// Move the load pointers.
|
| 968 |
+
ptr_qg -= qg_stride_L;
|
| 969 |
+
ptr_kv -= kv_stride_L;
|
| 970 |
+
|
| 971 |
+
// Make sure the data is in shared memory.
|
| 972 |
+
__syncthreads();
|
| 973 |
+
|
| 974 |
+
// Each thread loads 4 values from G or Q.
|
| 975 |
+
float4 g[FLOAT4s_PER_THREAD];
|
| 976 |
+
#pragma unroll
|
| 977 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 978 |
+
float *smem_ptr = &smem_[smem_curr].qg[(so^1)*D + oi];
|
| 979 |
+
g[ii] = *reinterpret_cast<const float4*>(&smem_ptr[ii*THREADS_PER_HEAD*4]);
|
| 980 |
+
}
|
| 981 |
+
|
| 982 |
+
// Each thread loads a single from Q or G value.
|
| 983 |
+
float q = smem_[smem_curr].qg[so*D + oo];
|
| 984 |
+
|
| 985 |
+
// Update the G*Q^T or Q*G^T product.
|
| 986 |
+
#pragma unroll
|
| 987 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 988 |
+
gq[ii].x += g[ii].x * q;
|
| 989 |
+
gq[ii].y += g[ii].y * q;
|
| 990 |
+
gq[ii].z += g[ii].z * q;
|
| 991 |
+
gq[ii].w += g[ii].w * q;
|
| 992 |
+
}
|
| 993 |
+
|
| 994 |
+
// Load the V or K values from shared memory.
|
| 995 |
+
float4 v[FLOAT4s_PER_THREAD];
|
| 996 |
+
#pragma unroll
|
| 997 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 998 |
+
float *smem_ptr = &smem_[smem_curr].kv[(so^1)*D + oi];
|
| 999 |
+
v[ii] = *reinterpret_cast<const float4*>(&smem_ptr[ii*THREADS_PER_HEAD*4]);
|
| 1000 |
+
}
|
| 1001 |
+
|
| 1002 |
+
// Compute the partial output value for that thread.
|
| 1003 |
+
float sum = 0.f;
|
| 1004 |
+
#pragma unroll
|
| 1005 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
| 1006 |
+
sum += v[ii].x * gq[ii].x;
|
| 1007 |
+
sum += v[ii].y * gq[ii].y;
|
| 1008 |
+
sum += v[ii].z * gq[ii].z;
|
| 1009 |
+
sum += v[ii].w * gq[ii].w;
|
| 1010 |
+
}
|
| 1011 |
+
|
| 1012 |
+
// Finalize the computation of the sum (if we have more than 1 thread per head).
|
| 1013 |
+
if( THREADS_PER_HEAD > 1 ) {
|
| 1014 |
+
|
| 1015 |
+
// Finalize the sum for each head.
|
| 1016 |
+
#pragma unroll
|
| 1017 |
+
for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) {
|
| 1018 |
+
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
| 1019 |
+
}
|
| 1020 |
+
|
| 1021 |
+
// Store to shared memory.
|
| 1022 |
+
if( oi == 0 ) {
|
| 1023 |
+
smem_[smem_curr].out_kv[so*D + oo] = sum;
|
| 1024 |
+
}
|
| 1025 |
+
|
| 1026 |
+
// Make sure the data is in shared memory.
|
| 1027 |
+
__syncthreads();
|
| 1028 |
+
|
| 1029 |
+
// Active threads read the data to store.
|
| 1030 |
+
if( si < hidden_size_per_head ) {
|
| 1031 |
+
sum = smem_[smem_curr].out_kv[so*D + si];
|
| 1032 |
+
}
|
| 1033 |
+
|
| 1034 |
+
} // THREADS_PER_HEAD > 1.
|
| 1035 |
+
|
| 1036 |
+
// Store the output. All the threads are active.
|
| 1037 |
+
if( si < hidden_size_per_head ) {
|
| 1038 |
+
*ptr_out_kv = sum;
|
| 1039 |
+
}
|
| 1040 |
+
|
| 1041 |
+
// Move to next location.
|
| 1042 |
+
ptr_out_kv -= out_kv_stride_L;
|
| 1043 |
+
|
| 1044 |
+
// Move the shared memory buffer.
|
| 1045 |
+
smem_curr = (smem_curr + 1) % 2;
|
| 1046 |
+
|
| 1047 |
+
// Store to shared memory for Q and K.
|
| 1048 |
+
if( !is_last && si < D ) {
|
| 1049 |
+
smem_[smem_curr].qg[so*D + si] = ldg_qg;
|
| 1050 |
+
smem_[smem_curr].kv[so*D + si] = ldg_kv;
|
| 1051 |
+
}
|
| 1052 |
+
}
|
| 1053 |
+
}
|
| 1054 |
+
|
| 1055 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1056 |
+
|
| 1057 |
+
template< int D, int THREADS_PER_HEAD >
|
| 1058 |
+
int lmha_bwd_(const Lmha_bwd_params<float> ¶ms) {
|
| 1059 |
+
int block = D*THREADS_PER_HEAD*2;
|
| 1060 |
+
if( block >= 1024 || params.B > 65535 ) {
|
| 1061 |
+
return 1;
|
| 1062 |
+
}
|
| 1063 |
+
dim3 grid(params.H, params.B);
|
| 1064 |
+
lmha_bwd_kernel<D, THREADS_PER_HEAD><<<grid, block>>>(params);
|
| 1065 |
+
return 0;
|
| 1066 |
+
}
|
| 1067 |
+
|
| 1068 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1069 |
+
|
| 1070 |
+
int lmha_bwd(const Lmha_bwd_params<float> ¶ms) {
|
| 1071 |
+
int blocks = params.B * params.H;
|
| 1072 |
+
if( blocks < LOW_OCCUPANCY_THRESHOLD ) {
|
| 1073 |
+
return 1;
|
| 1074 |
+
}
|
| 1075 |
+
|
| 1076 |
+
int hidden_size_per_head = max(params.E, params.M);
|
| 1077 |
+
int res = 1;
|
| 1078 |
+
if( hidden_size_per_head <= 32 ) {
|
| 1079 |
+
res = lmha_bwd_< 32, 1>(params);
|
| 1080 |
+
} else if( hidden_size_per_head <= 64 ) {
|
| 1081 |
+
res = lmha_bwd_< 64, 1>(params);
|
| 1082 |
+
} else if( hidden_size_per_head <= 128 ) {
|
| 1083 |
+
res = lmha_bwd_<128, 2>(params);
|
| 1084 |
+
} else if( hidden_size_per_head <= 256 ) {
|
| 1085 |
+
res = lmha_bwd_<256, 4>(params);
|
| 1086 |
+
}
|
| 1087 |
+
return res;
|
| 1088 |
+
}
|
| 1089 |
+
|
| 1090 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1091 |
+
|
| 1092 |
+
int lmha_bwd(const torch::Tensor queries,
|
| 1093 |
+
const torch::Tensor keys,
|
| 1094 |
+
const torch::Tensor values,
|
| 1095 |
+
const torch::Tensor grad_out,
|
| 1096 |
+
torch::Tensor grad_queries,
|
| 1097 |
+
torch::Tensor grad_keys,
|
| 1098 |
+
torch::Tensor grad_values) {
|
| 1099 |
+
|
| 1100 |
+
// Make sure that we are using the correct GPU device
|
| 1101 |
+
torch::DeviceGuard _guard(queries.device());
|
| 1102 |
+
|
| 1103 |
+
// Make sure the inner-most dimension of the tensors is packed.
|
| 1104 |
+
assert(queries .stride(3) == 1);
|
| 1105 |
+
assert(keys .stride(3) == 1);
|
| 1106 |
+
assert(values .stride(3) == 1);
|
| 1107 |
+
assert(grad_out .stride(3) == 1);
|
| 1108 |
+
assert(grad_queries.stride(3) == 1);
|
| 1109 |
+
assert(grad_keys .stride(3) == 1);
|
| 1110 |
+
assert(grad_values .stride(3) == 1);
|
| 1111 |
+
|
| 1112 |
+
// Extract the dimensions.
|
| 1113 |
+
int N = queries.size(0);
|
| 1114 |
+
int H = queries.size(1);
|
| 1115 |
+
int L = queries.size(2);
|
| 1116 |
+
int E = queries.size(3);
|
| 1117 |
+
int M = values.size (3);
|
| 1118 |
+
|
| 1119 |
+
// Gradient on Q.
|
| 1120 |
+
|
| 1121 |
+
// The structure of params.
|
| 1122 |
+
Lmha_params<float> params;
|
| 1123 |
+
set_params(params, grad_out, values, keys, grad_queries);
|
| 1124 |
+
|
| 1125 |
+
// Launch the kernel.
|
| 1126 |
+
int res = lmha<false>(params);
|
| 1127 |
+
if( res ) {
|
| 1128 |
+
return res;
|
| 1129 |
+
}
|
| 1130 |
+
|
| 1131 |
+
// Gradient on K and V together.
|
| 1132 |
+
|
| 1133 |
+
Lmha_bwd_params<float> bwd_params;
|
| 1134 |
+
bwd_params.out_k = grad_keys.data_ptr<float>();
|
| 1135 |
+
bwd_params.out_v = grad_values.data_ptr<float>();
|
| 1136 |
+
bwd_params.q = queries.data_ptr<float>();
|
| 1137 |
+
bwd_params.k = keys.data_ptr<float>();
|
| 1138 |
+
bwd_params.v = values.data_ptr<float>();
|
| 1139 |
+
bwd_params.g = grad_out.data_ptr<float>();
|
| 1140 |
+
|
| 1141 |
+
bwd_params.B = N;
|
| 1142 |
+
bwd_params.L = L;
|
| 1143 |
+
bwd_params.H = H;
|
| 1144 |
+
bwd_params.E = E;
|
| 1145 |
+
bwd_params.M = M;
|
| 1146 |
+
|
| 1147 |
+
bwd_params.q_stride_B = queries.stride(0);
|
| 1148 |
+
bwd_params.q_stride_H = queries.stride(1);
|
| 1149 |
+
bwd_params.q_stride_L = queries.stride(2);
|
| 1150 |
+
bwd_params.k_stride_B = keys.stride(0);
|
| 1151 |
+
bwd_params.k_stride_H = keys.stride(1);
|
| 1152 |
+
bwd_params.k_stride_L = keys.stride(2);
|
| 1153 |
+
bwd_params.v_stride_B = values.stride(0);
|
| 1154 |
+
bwd_params.v_stride_H = values.stride(1);
|
| 1155 |
+
bwd_params.v_stride_L = values.stride(2);
|
| 1156 |
+
bwd_params.g_stride_B = grad_out.stride(0);
|
| 1157 |
+
bwd_params.g_stride_H = grad_out.stride(1);
|
| 1158 |
+
bwd_params.g_stride_L = grad_out.stride(2);
|
| 1159 |
+
|
| 1160 |
+
bwd_params.out_k_stride_B = grad_keys.stride(0);
|
| 1161 |
+
bwd_params.out_k_stride_H = grad_keys.stride(1);
|
| 1162 |
+
bwd_params.out_k_stride_L = grad_keys.stride(2);
|
| 1163 |
+
bwd_params.out_v_stride_B = grad_values.stride(0);
|
| 1164 |
+
bwd_params.out_v_stride_H = grad_values.stride(1);
|
| 1165 |
+
bwd_params.out_v_stride_L = grad_values.stride(2);
|
| 1166 |
+
|
| 1167 |
+
// Try to run the fused kernel.
|
| 1168 |
+
int fallback = lmha_bwd(bwd_params);
|
| 1169 |
+
|
| 1170 |
+
// If it failed, fallback on separate kernels for K and V.
|
| 1171 |
+
if( fallback ) {
|
| 1172 |
+
|
| 1173 |
+
// Gradient on K.
|
| 1174 |
+
|
| 1175 |
+
// Launch the kernel.
|
| 1176 |
+
set_params(params, values, grad_out, queries, grad_keys);
|
| 1177 |
+
res = lmha<true>(params);
|
| 1178 |
+
if( res ) {
|
| 1179 |
+
return res;
|
| 1180 |
+
}
|
| 1181 |
+
|
| 1182 |
+
// Gradient on V.
|
| 1183 |
+
|
| 1184 |
+
// Launch the kernel.
|
| 1185 |
+
set_params(params, keys, queries, grad_out, grad_values);
|
| 1186 |
+
return lmha<true>(params);
|
| 1187 |
+
}
|
| 1188 |
+
|
| 1189 |
+
// It worked...
|
| 1190 |
+
return 0;
|
| 1191 |
+
}
|
| 1192 |
+
|
| 1193 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1194 |
+
|
| 1195 |
+
} // namespace nvidia
|
| 1196 |
+
#endif // #ifdef ENABLE_NVIDIA_OPTIMIZATIONS
|
| 1197 |
+
|
| 1198 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1199 |
+
|
| 1200 |
+
typedef torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> float_accessor;
|
| 1201 |
+
|
| 1202 |
+
#define E_BLOCK_SIZE 8
|
| 1203 |
+
|
| 1204 |
+
__global__ void causal_dot_product_kernel(
|
| 1205 |
+
const float_accessor queries,
|
| 1206 |
+
const float_accessor keys,
|
| 1207 |
+
const float_accessor values,
|
| 1208 |
+
float_accessor result,
|
| 1209 |
+
const int N,
|
| 1210 |
+
const int H,
|
| 1211 |
+
const int L,
|
| 1212 |
+
const int E,
|
| 1213 |
+
const int M
|
| 1214 |
+
) {
|
| 1215 |
+
int n = blockIdx.y;
|
| 1216 |
+
int h = blockIdx.z;
|
| 1217 |
+
|
| 1218 |
+
int e_start = blockIdx.x * E_BLOCK_SIZE;
|
| 1219 |
+
int m = threadIdx.x % M;
|
| 1220 |
+
|
| 1221 |
+
extern __shared__ float shared_mem[];
|
| 1222 |
+
float* shared_kv = shared_mem;
|
| 1223 |
+
|
| 1224 |
+
for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
|
| 1225 |
+
shared_kv[m + e_local * M] = 0;
|
| 1226 |
+
}
|
| 1227 |
+
|
| 1228 |
+
for (int t=0; t<L; t++) {
|
| 1229 |
+
float res = 0;
|
| 1230 |
+
for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
|
| 1231 |
+
shared_kv[e_local*M + m] += keys[n][h][t][e_local + e_start] * values[n][h][t][m];
|
| 1232 |
+
res += queries[n][h][t][e_local + e_start] * shared_kv[e_local*M + m];
|
| 1233 |
+
}
|
| 1234 |
+
atomicAdd(
|
| 1235 |
+
&result[n][h][t][m],
|
| 1236 |
+
res
|
| 1237 |
+
);
|
| 1238 |
+
}
|
| 1239 |
+
}
|
| 1240 |
+
|
| 1241 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1242 |
+
|
| 1243 |
+
void causal_dot_product_(const torch::Tensor queries,
|
| 1244 |
+
const torch::Tensor keys,
|
| 1245 |
+
const torch::Tensor values,
|
| 1246 |
+
torch::Tensor product) {
|
| 1247 |
+
// Make sure that we are using the correct GPU device
|
| 1248 |
+
torch::DeviceGuard _guard(queries.device());
|
| 1249 |
+
|
| 1250 |
+
int N = queries.size(0);
|
| 1251 |
+
int H = queries.size(1);
|
| 1252 |
+
int L = queries.size(2);
|
| 1253 |
+
int E = queries.size(3);
|
| 1254 |
+
int M = values.size(3);
|
| 1255 |
+
|
| 1256 |
+
const int blocks_per_sequence = (E + E_BLOCK_SIZE - 1) / E_BLOCK_SIZE;
|
| 1257 |
+
|
| 1258 |
+
dim3 blockDim(M, 1, 1);
|
| 1259 |
+
dim3 gridDim(blocks_per_sequence, N, H);
|
| 1260 |
+
const int shared_mem_forward = E_BLOCK_SIZE * M * sizeof(float);
|
| 1261 |
+
|
| 1262 |
+
causal_dot_product_kernel<<<gridDim, blockDim, shared_mem_forward>>>(
|
| 1263 |
+
queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1264 |
+
keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1265 |
+
values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1266 |
+
product.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1267 |
+
N, H, L, E, M
|
| 1268 |
+
);
|
| 1269 |
+
}
|
| 1270 |
+
|
| 1271 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1272 |
+
|
| 1273 |
+
void causal_dot_product(const torch::Tensor queries,
|
| 1274 |
+
const torch::Tensor keys,
|
| 1275 |
+
const torch::Tensor values,
|
| 1276 |
+
torch::Tensor product) {
|
| 1277 |
+
#ifdef ENABLE_NVIDIA_OPTIMIZATIONS
|
| 1278 |
+
int fallback = nvidia::lmha_fwd(queries, keys, values, product);
|
| 1279 |
+
#else
|
| 1280 |
+
int fallback = 1;
|
| 1281 |
+
#endif
|
| 1282 |
+
if( fallback ) {
|
| 1283 |
+
causal_dot_product_(queries, keys, values, product);
|
| 1284 |
+
}
|
| 1285 |
+
}
|
| 1286 |
+
|
| 1287 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1288 |
+
|
| 1289 |
+
#define M_BLOCK_SIZE 4
|
| 1290 |
+
|
| 1291 |
+
// we need shared memory to store
|
| 1292 |
+
// kv
|
| 1293 |
+
// Backward direction
|
| 1294 |
+
// kv_backwards
|
| 1295 |
+
// Shared memory usage
|
| 1296 |
+
__global__ void causal_dot_backward_query_key_kernel(
|
| 1297 |
+
const float_accessor queries,
|
| 1298 |
+
const float_accessor keys,
|
| 1299 |
+
const float_accessor values,
|
| 1300 |
+
const float_accessor grad_out,
|
| 1301 |
+
float_accessor grad_queries,
|
| 1302 |
+
float_accessor grad_keys,
|
| 1303 |
+
int N,
|
| 1304 |
+
int H,
|
| 1305 |
+
int L,
|
| 1306 |
+
int E,
|
| 1307 |
+
int M
|
| 1308 |
+
) {
|
| 1309 |
+
int n = blockIdx.y;
|
| 1310 |
+
int h = blockIdx.z;
|
| 1311 |
+
|
| 1312 |
+
int m_start = blockIdx.x * M_BLOCK_SIZE;
|
| 1313 |
+
int e = threadIdx.x % E;
|
| 1314 |
+
|
| 1315 |
+
extern __shared__ float shared_mem[];
|
| 1316 |
+
const int shared_kv_size = M_BLOCK_SIZE * E;
|
| 1317 |
+
float* shared_kv = shared_mem;
|
| 1318 |
+
float* shared_kv_bw = shared_mem + shared_kv_size;
|
| 1319 |
+
|
| 1320 |
+
for (int m_local = 0; m_local < M_BLOCK_SIZE && m_local + m_start < M; m_local++) {
|
| 1321 |
+
shared_kv[m_local * E + e] = 0;
|
| 1322 |
+
shared_kv_bw[m_local * E + e] = 0;
|
| 1323 |
+
}
|
| 1324 |
+
|
| 1325 |
+
for (int l=0; l<L; l++) {
|
| 1326 |
+
float res = 0, res_bw = 0;
|
| 1327 |
+
int l_b = L - l - 1;
|
| 1328 |
+
for (int m_local = 0; m_local < M_BLOCK_SIZE && m_local + m_start < M; m_local++) {
|
| 1329 |
+
shared_kv[m_local*E + e] += keys[n][h][l][e] * values[n][h][l][m_start + m_local];
|
| 1330 |
+
shared_kv_bw[m_local*E + e] += queries[n][h][l_b][e] * grad_out[n][h][l_b][m_start + m_local];
|
| 1331 |
+
res += grad_out[n][h][l][m_start + m_local] * shared_kv[m_local*E + e];
|
| 1332 |
+
res_bw += values[n][h][l_b][m_start + m_local] * shared_kv_bw[m_local*E + e];
|
| 1333 |
+
}
|
| 1334 |
+
atomicAdd(
|
| 1335 |
+
&grad_queries[n][h][l][e],
|
| 1336 |
+
res
|
| 1337 |
+
);
|
| 1338 |
+
atomicAdd(
|
| 1339 |
+
&grad_keys[n][h][l_b][e],
|
| 1340 |
+
res_bw
|
| 1341 |
+
);
|
| 1342 |
+
}
|
| 1343 |
+
}
|
| 1344 |
+
|
| 1345 |
+
|
| 1346 |
+
__global__ void causal_dot_backward_value_kernel(
|
| 1347 |
+
const float_accessor queries,
|
| 1348 |
+
const float_accessor keys,
|
| 1349 |
+
const float_accessor values,
|
| 1350 |
+
const float_accessor grad_out,
|
| 1351 |
+
float_accessor grad_keys,
|
| 1352 |
+
float_accessor grad_values,
|
| 1353 |
+
int N,
|
| 1354 |
+
int H,
|
| 1355 |
+
int L,
|
| 1356 |
+
int E,
|
| 1357 |
+
int M
|
| 1358 |
+
) {
|
| 1359 |
+
int n = blockIdx.y;
|
| 1360 |
+
int h = blockIdx.z;
|
| 1361 |
+
|
| 1362 |
+
int e_start = blockIdx.x * E_BLOCK_SIZE;
|
| 1363 |
+
int m = threadIdx.x % M;
|
| 1364 |
+
|
| 1365 |
+
extern __shared__ float shared_mem[];
|
| 1366 |
+
float* shared_kv = shared_mem;
|
| 1367 |
+
for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
|
| 1368 |
+
shared_kv[m + e_local * M] = 0;
|
| 1369 |
+
}
|
| 1370 |
+
|
| 1371 |
+
for (int l = 0; l < L; l++) {
|
| 1372 |
+
int l_b = L - l -1;
|
| 1373 |
+
float res = 0;
|
| 1374 |
+
for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
|
| 1375 |
+
shared_kv[e_local*M + m] += queries[n][h][l_b][e_start + e_local] * grad_out[n][h][l_b][m];
|
| 1376 |
+
res += keys[n][h][l_b][e_start + e_local] * shared_kv[e_local*M + m];
|
| 1377 |
+
}
|
| 1378 |
+
atomicAdd(
|
| 1379 |
+
&grad_values[n][h][l_b][m],
|
| 1380 |
+
res
|
| 1381 |
+
);
|
| 1382 |
+
}
|
| 1383 |
+
}
|
| 1384 |
+
|
| 1385 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1386 |
+
|
| 1387 |
+
void causal_dot_backward_(const torch::Tensor queries,
|
| 1388 |
+
const torch::Tensor keys,
|
| 1389 |
+
const torch::Tensor values,
|
| 1390 |
+
const torch::Tensor grad_out,
|
| 1391 |
+
torch::Tensor grad_queries,
|
| 1392 |
+
torch::Tensor grad_keys,
|
| 1393 |
+
torch::Tensor grad_values) {
|
| 1394 |
+
|
| 1395 |
+
// Make sure that we are using the correct GPU device
|
| 1396 |
+
torch::DeviceGuard _guard(queries.device());
|
| 1397 |
+
|
| 1398 |
+
int N = queries.size(0);
|
| 1399 |
+
int H = queries.size(1);
|
| 1400 |
+
int L = queries.size(2);
|
| 1401 |
+
int E = queries.size(3);
|
| 1402 |
+
int M = values.size(3);
|
| 1403 |
+
|
| 1404 |
+
const int blocks_per_sequence = (M + M_BLOCK_SIZE - 1) / M_BLOCK_SIZE;
|
| 1405 |
+
|
| 1406 |
+
dim3 blockDim(E, 1, 1);
|
| 1407 |
+
dim3 gridDim(blocks_per_sequence, N, H);
|
| 1408 |
+
const int shared_mem_qk_backward = 2 * M_BLOCK_SIZE * E * sizeof(float);
|
| 1409 |
+
|
| 1410 |
+
causal_dot_backward_query_key_kernel<<<gridDim, blockDim, shared_mem_qk_backward>>>(
|
| 1411 |
+
queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1412 |
+
keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1413 |
+
values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1414 |
+
grad_out.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1415 |
+
grad_queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1416 |
+
grad_keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1417 |
+
N, H, L, E, M
|
| 1418 |
+
);
|
| 1419 |
+
|
| 1420 |
+
const int blocks_per_sequence_value = (E + E_BLOCK_SIZE - 1) / E_BLOCK_SIZE;
|
| 1421 |
+
|
| 1422 |
+
dim3 blockDimv(M, 1, 1);
|
| 1423 |
+
dim3 gridDimv(blocks_per_sequence_value, N, H);
|
| 1424 |
+
const int shared_mem_v_backward = E_BLOCK_SIZE * M * sizeof(float);
|
| 1425 |
+
causal_dot_backward_value_kernel<<<gridDimv, blockDimv, shared_mem_v_backward>>>(
|
| 1426 |
+
queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1427 |
+
keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1428 |
+
values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1429 |
+
grad_out.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1430 |
+
grad_keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1431 |
+
grad_values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
| 1432 |
+
N, H, L, E, M
|
| 1433 |
+
);
|
| 1434 |
+
}
|
| 1435 |
+
|
| 1436 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1437 |
+
|
| 1438 |
+
void causal_dot_backward(const torch::Tensor queries,
|
| 1439 |
+
const torch::Tensor keys,
|
| 1440 |
+
const torch::Tensor values,
|
| 1441 |
+
const torch::Tensor grad_out,
|
| 1442 |
+
torch::Tensor grad_queries,
|
| 1443 |
+
torch::Tensor grad_keys,
|
| 1444 |
+
torch::Tensor grad_values) {
|
| 1445 |
+
#ifdef ENABLE_NVIDIA_OPTIMIZATIONS
|
| 1446 |
+
int fallback = nvidia::lmha_bwd(queries,
|
| 1447 |
+
keys,
|
| 1448 |
+
values,
|
| 1449 |
+
grad_out,
|
| 1450 |
+
grad_queries,
|
| 1451 |
+
grad_keys,
|
| 1452 |
+
grad_values);
|
| 1453 |
+
#else
|
| 1454 |
+
int fallback = 1;
|
| 1455 |
+
#endif
|
| 1456 |
+
if( fallback ) {
|
| 1457 |
+
// Make sure that the gradient tensors are 0. This is needed because the
|
| 1458 |
+
// bwd pass might have partially executed and filled in some values in
|
| 1459 |
+
// grad_queries or grad_keys.
|
| 1460 |
+
//
|
| 1461 |
+
// This adds a small overhead every time we have to fall back to the old
|
| 1462 |
+
// kernel for the backward pass.
|
| 1463 |
+
grad_queries.zero_();
|
| 1464 |
+
grad_keys.zero_();
|
| 1465 |
+
causal_dot_backward_(queries, keys, values, grad_out, grad_queries, grad_keys, grad_values);
|
| 1466 |
+
}
|
| 1467 |
+
}
|
| 1468 |
+
|
| 1469 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1470 |
+
|
| 1471 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 1472 |
+
m.def(
|
| 1473 |
+
"causal_dot_product",
|
| 1474 |
+
&causal_dot_product,
|
| 1475 |
+
"Compute the weighted sum of values but attending only to previous "
|
| 1476 |
+
"values."
|
| 1477 |
+
);
|
| 1478 |
+
m.def(
|
| 1479 |
+
"causal_dot_backward",
|
| 1480 |
+
&causal_dot_backward,
|
| 1481 |
+
"Compute the gradients for the causal dot product."
|
| 1482 |
+
);
|
| 1483 |
+
}
|
csrc/setup.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
| 3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
| 4 |
+
# Apoorv Vyas <[email protected]>
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from setuptools import setup
|
| 9 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
| 10 |
+
import subprocess
|
| 11 |
+
|
| 12 |
+
def get_last_arch_torch():
|
| 13 |
+
arch = torch.cuda.get_arch_list()[-1]
|
| 14 |
+
print(f"Found arch: {arch} from existing torch installation")
|
| 15 |
+
return arch
|
| 16 |
+
|
| 17 |
+
def get_cuda_bare_metal_version(cuda_dir):
|
| 18 |
+
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
| 19 |
+
output = raw_output.split()
|
| 20 |
+
release_idx = output.index("release") + 1
|
| 21 |
+
release = output[release_idx].split(".")
|
| 22 |
+
bare_metal_major = release[0]
|
| 23 |
+
bare_metal_minor = release[1][0]
|
| 24 |
+
return raw_output, bare_metal_major, bare_metal_minor
|
| 25 |
+
|
| 26 |
+
def append_nvcc_threads(nvcc_extra_args):
|
| 27 |
+
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
| 28 |
+
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
| 29 |
+
return nvcc_extra_args + ["--threads", "4"]
|
| 30 |
+
return nvcc_extra_args
|
| 31 |
+
|
| 32 |
+
arch = get_last_arch_torch()
|
| 33 |
+
sm_num = arch[-2:]
|
| 34 |
+
cc_flag = ['--generate-code=arch=compute_90,code=compute_90'] # for H100
|
| 35 |
+
# cc_flag = ['--generate-code=arch=compute_80,code=compute_80'] # for A100
|
| 36 |
+
# cc_flag = ['--generate-code=arch=compute_89,code=compute_89'] # for RTX 6000, 4090
|
| 37 |
+
# cc_flag = ['--generate-code=arch=compute_86,code=compute_86'] # for A6000, 3090
|
| 38 |
+
# cc_flag = ['--generate-code=arch=compute_75,code=compute_75']
|
| 39 |
+
|
| 40 |
+
setup(
|
| 41 |
+
name='causal_attention_cuda_cpp',
|
| 42 |
+
ext_modules=[
|
| 43 |
+
CUDAExtension('causal_attention_cuda', [
|
| 44 |
+
# 'causal_attention.cpp',
|
| 45 |
+
'causal_attention_cuda.cu',
|
| 46 |
+
],
|
| 47 |
+
extra_compile_args={'cxx': ['-O3'],
|
| 48 |
+
'nvcc': append_nvcc_threads(['-O3', '-lineinfo', '--use_fast_math', '-std=c++17'] + cc_flag)
|
| 49 |
+
})
|
| 50 |
+
],
|
| 51 |
+
cmdclass={
|
| 52 |
+
'build_ext': BuildExtension
|
| 53 |
+
})
|
src/__init__.py
ADDED
|
File without changes
|
src/dataloaders/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Load dataloaders
|
| 3 |
+
"""
|
| 4 |
+
import importlib
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def load_data(dataset_config: dict, dataloader_config: dict):
|
| 8 |
+
"""Return dataloaders from dataset_config"""
|
| 9 |
+
try:
|
| 10 |
+
dataset_module = importlib.import_module(f'dataloaders.{dataset_config["name"]}')
|
| 11 |
+
except Exception:
|
| 12 |
+
try:
|
| 13 |
+
dataset_module = importlib.import_module(f'src.dataloaders.{dataset_config["name"]}')
|
| 14 |
+
except Exception as e2:
|
| 15 |
+
print(e2)
|
| 16 |
+
try: # e.g., tasks like GLUE where name is benchmark and path specifies the dataset / task
|
| 17 |
+
dataset_module = importlib.import_module(f'dataloaders.{dataset_config["path"]}')
|
| 18 |
+
except Exception as e3:
|
| 19 |
+
print(f'Error from {dataset_config}')
|
| 20 |
+
raise e3
|
| 21 |
+
_load_data = getattr(dataset_module, 'load_data')
|
| 22 |
+
return _load_data(**dataset_config, **dataloader_config)
|
src/dataloaders/alpaca_clean.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Alpaca training dataloaders
|
| 3 |
+
|
| 4 |
+
We adopt the original prompt template; goes something like:
|
| 5 |
+
```
|
| 6 |
+
Below is an instruction that describes a task.
|
| 7 |
+
Write a response that appropriately completes the request.
|
| 8 |
+
### Instruction:
|
| 9 |
+
{instruction}
|
| 10 |
+
|
| 11 |
+
### Response:
|
| 12 |
+
{response}
|
| 13 |
+
```
|
| 14 |
+
See `PROMPT_DICT` for more.
|
| 15 |
+
"""
|
| 16 |
+
from functools import partial
|
| 17 |
+
from os.path import join
|
| 18 |
+
|
| 19 |
+
from datasets import load_metric, load_dataset
|
| 20 |
+
|
| 21 |
+
from .utils import (
|
| 22 |
+
get_lm_loader, get_seq2seq_loader,
|
| 23 |
+
convert_to_hf_dataset,
|
| 24 |
+
get_tokenizer_from_config,
|
| 25 |
+
download_scrolls_metric as download_metric
|
| 26 |
+
)
|
| 27 |
+
from .utils.packing import ConcatDataset
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
PROMPT_DICT = {
|
| 31 |
+
"prompt_input": (
|
| 32 |
+
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
| 33 |
+
"Write a response that appropriately completes the request.\n\n"
|
| 34 |
+
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
| 35 |
+
),
|
| 36 |
+
"prompt_no_input": (
|
| 37 |
+
"Below is an instruction that describes a task. "
|
| 38 |
+
"Write a response that appropriately completes the request.\n\n"
|
| 39 |
+
"### Instruction:\n{instruction}\n\n### Response:\n"
|
| 40 |
+
),
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_data(name: str, dataset_config: dict, pretrained_model_config: dict,
|
| 45 |
+
preprocess_config: dict, **loader_kwargs: any):
|
| 46 |
+
"""
|
| 47 |
+
Shared function to load dataset from experiment config
|
| 48 |
+
-> e.g., see configs/experiments/distill_alpaca_clean_lr1e-2.yaml
|
| 49 |
+
"""
|
| 50 |
+
# Misc. setup
|
| 51 |
+
cache_dir = dataset_config['cache_dir']
|
| 52 |
+
input_len = dataset_config['chunk_size']
|
| 53 |
+
concat_data = dataset_config['concat_data']
|
| 54 |
+
|
| 55 |
+
tokenizer_name = pretrained_model_config['pretrained_model_name_or_path']
|
| 56 |
+
tokenizer_name = tokenizer_name.split('/')[-1]
|
| 57 |
+
# save_path = join(cache_dir, f'{name}_{tokenizer_name}')
|
| 58 |
+
|
| 59 |
+
# Setup tokenizer
|
| 60 |
+
tokenizer = get_tokenizer_from_config(pretrained_model_config)
|
| 61 |
+
if tokenizer.pad_token is None:
|
| 62 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 63 |
+
print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}')
|
| 64 |
+
|
| 65 |
+
tokenizer.padding_side = 'left' # for decoder-only generation
|
| 66 |
+
# Get initial data
|
| 67 |
+
ignore_kwargs = ['concat_data', 'chunk_size', 'pose_kwargs']
|
| 68 |
+
dataset = load_dataset(
|
| 69 |
+
**{k: v for k, v in dataset_config.items() if k not in ignore_kwargs}
|
| 70 |
+
)
|
| 71 |
+
if dataset_config['name'] == 'samsum': # hack
|
| 72 |
+
dataset = dataset.rename_column('dialogue', 'input')
|
| 73 |
+
dataset = dataset.rename_column('summary', 'output')
|
| 74 |
+
_instruction = 'Summarize this dialogue.'
|
| 75 |
+
for split in dataset.keys():
|
| 76 |
+
dataset[split] = dataset[split].add_column(
|
| 77 |
+
'instruction', [_instruction] * len(dataset[split])
|
| 78 |
+
)
|
| 79 |
+
train_set, val_set, test_set = dataset['train'], dataset['validation'], dataset['test']
|
| 80 |
+
dataset = train_set # hack to work with below code
|
| 81 |
+
else:
|
| 82 |
+
dataset = dataset['train']
|
| 83 |
+
train_set = convert_to_hf_dataset([dataset[ix] for ix in range(200, len(dataset))], cache_dir)
|
| 84 |
+
val_set = convert_to_hf_dataset([dataset[ix] for ix in range(200)], cache_dir)
|
| 85 |
+
test_set = convert_to_hf_dataset([dataset[ix] for ix in range(200)], cache_dir)
|
| 86 |
+
|
| 87 |
+
# Convert to dicts of {input_ids, attention_mask, labels}
|
| 88 |
+
train_set = train_set.map(
|
| 89 |
+
partial(template_and_tokenize, tokenizer=tokenizer, include_label=True),
|
| 90 |
+
remove_columns=list(dataset.features),) # load_from_cache_file=False)
|
| 91 |
+
val_set = val_set.map(
|
| 92 |
+
partial(template_and_tokenize, tokenizer=tokenizer, include_label=True),
|
| 93 |
+
remove_columns=list(dataset.features),) # load_from_cache_file=False)
|
| 94 |
+
test_set = test_set.map(
|
| 95 |
+
partial(template_and_tokenize, tokenizer=tokenizer, include_label=False),
|
| 96 |
+
remove_columns=list(dataset.features),) # load_from_cache_file=False)
|
| 97 |
+
|
| 98 |
+
# Chunk together train and val sets
|
| 99 |
+
if concat_data:
|
| 100 |
+
train_set = ConcatDataset(train_set, chunk_size=input_len)
|
| 101 |
+
val_set = ConcatDataset(val_set, chunk_size=input_len)
|
| 102 |
+
|
| 103 |
+
# Get dataloaders
|
| 104 |
+
dataloaders = {
|
| 105 |
+
'train': get_lm_loader(train_set, tokenizer, 'train', input_len, **loader_kwargs),
|
| 106 |
+
'validation': get_lm_loader(val_set, tokenizer, 'validation', input_len, **loader_kwargs),
|
| 107 |
+
'test': get_seq2seq_loader(test_set, tokenizer, 'test', **loader_kwargs),
|
| 108 |
+
}
|
| 109 |
+
# Evaluation metric
|
| 110 |
+
try:
|
| 111 |
+
metric = load_metric(download_metric(), 'gov_report') # hack but we want rouge
|
| 112 |
+
except Exception as e:
|
| 113 |
+
print(f'Error loading metric: {e}')
|
| 114 |
+
metric = None
|
| 115 |
+
|
| 116 |
+
# Finishing touches
|
| 117 |
+
for k, v in dataloaders.items(): # Make tokenizer accessible
|
| 118 |
+
dataloaders[k].dataset.tokenizer = tokenizer
|
| 119 |
+
dataloaders[k].dataset.metric = metric
|
| 120 |
+
return dataloaders
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def template_and_tokenize(sample, tokenizer, include_label: bool = True):
|
| 124 |
+
"""
|
| 125 |
+
Format dataset context and answers into single-sequence prompts
|
| 126 |
+
"""
|
| 127 |
+
if sample.get('input', '') == '':
|
| 128 |
+
prompt = PROMPT_DICT["prompt_no_input"].format_map(sample)
|
| 129 |
+
else:
|
| 130 |
+
prompt = PROMPT_DICT["prompt_input"].format_map(sample)
|
| 131 |
+
|
| 132 |
+
prompt = tokenizer.encode(prompt, add_special_tokens=True)
|
| 133 |
+
if include_label:
|
| 134 |
+
answer = tokenizer.encode(f'{sample["output"]}{tokenizer.eos_token}',
|
| 135 |
+
add_special_tokens=False)
|
| 136 |
+
target = None
|
| 137 |
+
else:
|
| 138 |
+
answer = []
|
| 139 |
+
target = tokenizer.encode(f'{sample["output"]}{tokenizer.eos_token}',
|
| 140 |
+
add_special_tokens=False)
|
| 141 |
+
input_ids = prompt + answer
|
| 142 |
+
attn_mask = [1] * len(input_ids)
|
| 143 |
+
|
| 144 |
+
sample = {
|
| 145 |
+
"input_ids": input_ids,
|
| 146 |
+
"attention_mask" : attn_mask,
|
| 147 |
+
"labels": [-100] * len(prompt) + answer if include_label else target,
|
| 148 |
+
}
|
| 149 |
+
return sample
|
src/dataloaders/alpaca_clean_instruct.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Alpaca Clean dataset with Llama3-Instruct prompt formatting
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from functools import partial
|
| 6 |
+
from os.path import join
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch.utils.data import Dataset, DataLoader
|
| 13 |
+
|
| 14 |
+
from datasets import load_metric, load_dataset
|
| 15 |
+
from transformers import AutoTokenizer
|
| 16 |
+
from transformers import DataCollatorForSeq2Seq, DefaultDataCollator, DataCollatorWithPadding
|
| 17 |
+
|
| 18 |
+
from .utils import (
|
| 19 |
+
get_lm_loader, get_seq2seq_loader,
|
| 20 |
+
convert_to_hf_dataset,
|
| 21 |
+
get_tokenizer_from_config,
|
| 22 |
+
download_scrolls_metric as download_metric
|
| 23 |
+
)
|
| 24 |
+
from .utils.packing import ConcatDataset
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
SYSTEM_PROMPT = "You are a helpful AI assistant who always responds to appropriately complete a user's request."
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def encode_response(response: str, tokenizer) -> list[int]:
|
| 31 |
+
tokens = tokenizer.encode(response.strip(), add_special_tokens=False)
|
| 32 |
+
# For Llama 3 Instruct: tokens.append(tokenizer.get_added_vocab()["<|eot_id|>"])
|
| 33 |
+
tokens.append(tokenizer.eos_token_id)
|
| 34 |
+
try: # Llama 3 Instruct
|
| 35 |
+
tokens.append(tokenizer.get_added_vocab()["<|end_of_text|>"])
|
| 36 |
+
except KeyError:
|
| 37 |
+
pass
|
| 38 |
+
return tokens
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load_data(name: str, dataset_config: dict, pretrained_model_config: dict,
|
| 42 |
+
preprocess_config: dict, **loader_kwargs: any):
|
| 43 |
+
|
| 44 |
+
# Misc. setup
|
| 45 |
+
cache_dir = dataset_config['cache_dir']
|
| 46 |
+
input_len = dataset_config['chunk_size']
|
| 47 |
+
concat_data = dataset_config['concat_data']
|
| 48 |
+
load_from_cache_file = False # False if want to retokenize dataset
|
| 49 |
+
|
| 50 |
+
# Hard-code system prompt handling
|
| 51 |
+
if 'istral' in pretrained_model_config['pretrained_model_name_or_path']:
|
| 52 |
+
system_prompt = ''
|
| 53 |
+
else:
|
| 54 |
+
system_prompt = SYSTEM_PROMPT
|
| 55 |
+
|
| 56 |
+
tokenizer_name = pretrained_model_config['pretrained_model_name_or_path']
|
| 57 |
+
tokenizer_name = tokenizer_name.split('/')[-1]
|
| 58 |
+
save_path = join(cache_dir, f'{name}_{tokenizer_name}')
|
| 59 |
+
|
| 60 |
+
# Setup tokenizer
|
| 61 |
+
tokenizer = get_tokenizer_from_config(pretrained_model_config)
|
| 62 |
+
if tokenizer.pad_token is None:
|
| 63 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 64 |
+
print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}')
|
| 65 |
+
|
| 66 |
+
tokenizer.padding_side = 'left' # for decoder-only generation
|
| 67 |
+
|
| 68 |
+
# Get initial data
|
| 69 |
+
ignore_kwargs = ['concat_data', 'chunk_size', 'pose_kwargs', 'system_prompt', 'name']
|
| 70 |
+
train_set = load_dataset(
|
| 71 |
+
**{k: v for k, v in dataset_config.items() if k not in ignore_kwargs},
|
| 72 |
+
split='train[100:-100]',
|
| 73 |
+
)
|
| 74 |
+
val_set = load_dataset( # we just use this dataset as a validation set
|
| 75 |
+
**{k: v for k, v in dataset_config.items() if k not in ignore_kwargs},
|
| 76 |
+
split='train[:100]+train[-100:]',
|
| 77 |
+
)
|
| 78 |
+
test_set = load_dataset(
|
| 79 |
+
**{k: v for k, v in dataset_config.items() if k not in ignore_kwargs},
|
| 80 |
+
split='train[:100]+train[-100:]',
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Convert to dicts of {input_ids, attention_mask, labels}
|
| 84 |
+
train_set = train_set.map(partial(template_and_tokenize, tokenizer=tokenizer,
|
| 85 |
+
include_label=True, system_prompt=system_prompt),
|
| 86 |
+
remove_columns=list(train_set.features),
|
| 87 |
+
load_from_cache_file=load_from_cache_file)
|
| 88 |
+
val_set = val_set.map(partial(template_and_tokenize, tokenizer=tokenizer,
|
| 89 |
+
include_label=True, system_prompt=system_prompt),
|
| 90 |
+
remove_columns=list(val_set.features),
|
| 91 |
+
load_from_cache_file=load_from_cache_file)
|
| 92 |
+
test_set = test_set.map(partial(template_and_tokenize, tokenizer=tokenizer,
|
| 93 |
+
include_label=False, system_prompt=system_prompt),
|
| 94 |
+
remove_columns=list(test_set.features),
|
| 95 |
+
load_from_cache_file=load_from_cache_file)
|
| 96 |
+
|
| 97 |
+
# Chunk together train and val sets
|
| 98 |
+
if concat_data:
|
| 99 |
+
train_set = ConcatDataset(train_set, chunk_size=input_len)
|
| 100 |
+
val_set = ConcatDataset(val_set, chunk_size=input_len)
|
| 101 |
+
|
| 102 |
+
# Get dataloaders
|
| 103 |
+
dataloaders = {
|
| 104 |
+
'train': get_lm_loader(train_set, tokenizer, 'train', input_len, **loader_kwargs),
|
| 105 |
+
'validation': get_lm_loader(val_set, tokenizer, 'validation', input_len, **loader_kwargs),
|
| 106 |
+
'test': get_seq2seq_loader(test_set, tokenizer, 'test', **loader_kwargs),
|
| 107 |
+
}
|
| 108 |
+
# Evaluation metric
|
| 109 |
+
metric = load_metric(download_metric(), 'gov_report') # hack but we want rouge
|
| 110 |
+
|
| 111 |
+
# Finishing touches
|
| 112 |
+
for k, v in dataloaders.items(): # Make tokenizer accessible
|
| 113 |
+
dataloaders[k].dataset.tokenizer = tokenizer
|
| 114 |
+
dataloaders[k].dataset.metric = metric
|
| 115 |
+
return dataloaders
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def template_and_tokenize(sample, tokenizer, include_label: bool = True,
|
| 119 |
+
system_prompt: str = None):
|
| 120 |
+
if system_prompt is None:
|
| 121 |
+
system_prompt = SYSTEM_PROMPT
|
| 122 |
+
|
| 123 |
+
prompt = sample['instruction']
|
| 124 |
+
if sample['input'] != '':
|
| 125 |
+
prompt += f"\n\n{sample['input']}"
|
| 126 |
+
|
| 127 |
+
messages = [
|
| 128 |
+
{"role": "system", "content": system_prompt},
|
| 129 |
+
] if system_prompt != '' else []
|
| 130 |
+
messages.append({"role": "user", "content": prompt})
|
| 131 |
+
prompt_ids = tokenizer.apply_chat_template(
|
| 132 |
+
messages, tokenize=True, add_generation_prompt=True,
|
| 133 |
+
)
|
| 134 |
+
if include_label:
|
| 135 |
+
answer = encode_response(sample['output'], tokenizer)
|
| 136 |
+
else:
|
| 137 |
+
answer = []
|
| 138 |
+
target = encode_response(sample['output'], tokenizer)
|
| 139 |
+
|
| 140 |
+
input_ids = prompt_ids + answer
|
| 141 |
+
attn_mask = [1] * len(input_ids)
|
| 142 |
+
sample = {
|
| 143 |
+
"input_ids": input_ids,
|
| 144 |
+
"attention_mask" : attn_mask,
|
| 145 |
+
"labels": [-100] * len(prompt_ids) + answer if include_label else target,
|
| 146 |
+
}
|
| 147 |
+
return sample
|
| 148 |
+
|
src/dataloaders/utils/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Helper functions dataset setup and loading
|
| 3 |
+
"""
|
| 4 |
+
from .setup import *
|
src/dataloaders/utils/llama3.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data utils for Llama3
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
def encode_header(message: str, tokenizer) -> list[int]:
|
| 6 |
+
tokens = []
|
| 7 |
+
tokens.append(tokenizer.get_added_vocab()["<|start_header_id|>"])
|
| 8 |
+
tokens.extend(tokenizer.encode(message["role"], add_special_tokens=False))
|
| 9 |
+
tokens.append(tokenizer.get_added_vocab()["<|end_header_id|>"])
|
| 10 |
+
tokens.extend(tokenizer.encode("\n\n", add_special_tokens=False))
|
| 11 |
+
return tokens
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def encode_message(message: str, tokenizer, include_header: bool = True) -> list[int]:
|
| 15 |
+
tokens = encode_header(message, tokenizer) if include_header else []
|
| 16 |
+
tokens.extend(
|
| 17 |
+
tokenizer.encode(message["content"].strip(), add_special_tokens=False)
|
| 18 |
+
)
|
| 19 |
+
tokens.append(tokenizer.get_added_vocab()["<|eot_id|>"])
|
| 20 |
+
return tokens
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def template_and_tokenize(sample, tokenizer, include_label: bool = True,
|
| 24 |
+
system_prompt: str = None):
|
| 25 |
+
if system_prompt is not None:
|
| 26 |
+
dialog = [{'role': 'system', 'content': system_prompt}]
|
| 27 |
+
else:
|
| 28 |
+
dialog = []
|
| 29 |
+
|
| 30 |
+
chat = []
|
| 31 |
+
instruction = sample['instruction']
|
| 32 |
+
if sample['input'] != '':
|
| 33 |
+
instruction += f"\n\n{sample['input']}"
|
| 34 |
+
dialog.extend([
|
| 35 |
+
{'role': 'user', 'content': instruction},
|
| 36 |
+
{'role': 'assistant', 'content': sample['output']},
|
| 37 |
+
])
|
| 38 |
+
|
| 39 |
+
prompt = []
|
| 40 |
+
prompt.append(tokenizer.get_added_vocab()["<|begin_of_text|>"])
|
| 41 |
+
for message in dialog[:-1]:
|
| 42 |
+
prompt.extend(encode_message(message, tokenizer))
|
| 43 |
+
|
| 44 |
+
if include_label:
|
| 45 |
+
answer = encode_message(dialog[-1], tokenizer)
|
| 46 |
+
answer.append(tokenizer.get_added_vocab()["<|end_of_text|>"])
|
| 47 |
+
else:
|
| 48 |
+
answer = []
|
| 49 |
+
target = encode_message(dialog[-1], tokenizer, include_header=False)
|
| 50 |
+
target.append(tokenizer.get_added_vocab()["<|end_of_text|>"])
|
| 51 |
+
# Add the start of an assistant message for the model to complete.
|
| 52 |
+
prompt.extend(encode_header({"role": "assistant", "content": ""}, tokenizer))
|
| 53 |
+
|
| 54 |
+
input_ids = prompt + answer
|
| 55 |
+
attn_mask = [1] * len(input_ids)
|
| 56 |
+
|
| 57 |
+
sample = {
|
| 58 |
+
"input_ids": input_ids,
|
| 59 |
+
"attention_mask" : attn_mask,
|
| 60 |
+
"labels": [-100] * len(prompt) + answer if include_label else target,
|
| 61 |
+
}
|
| 62 |
+
return sample
|
src/dataloaders/utils/packing.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
| 3 |
+
"""
|
| 4 |
+
Copied from https://github.com/meta-llama/llama-recipes/blob/9b3dabcaac78980eae40005bbc8b1a8276c82af3/src/llama_recipes/data/concatenator.py#L1
|
| 5 |
+
"""
|
| 6 |
+
import random
|
| 7 |
+
from itertools import chain
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from torch.utils.data import Dataset
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Concatenator(object):
|
| 15 |
+
def __init__(self, chunk_size=2048):
|
| 16 |
+
self.chunk_size=chunk_size
|
| 17 |
+
self.residual = {"input_ids": [], "attention_mask": []}
|
| 18 |
+
|
| 19 |
+
def __call__(self, batch):
|
| 20 |
+
concatenated_samples = {
|
| 21 |
+
k: v + list(chain(*batch[k])) for k, v in self.residual.items()
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
total_length = len(concatenated_samples[list(concatenated_samples.keys())[0]])
|
| 25 |
+
|
| 26 |
+
if total_length >= self.chunk_size:
|
| 27 |
+
chunk_num = total_length // self.chunk_size
|
| 28 |
+
result = {
|
| 29 |
+
k: [
|
| 30 |
+
v[i : i + self.chunk_size]
|
| 31 |
+
for i in range(0, chunk_num * self.chunk_size, self.chunk_size)
|
| 32 |
+
]
|
| 33 |
+
for k, v in concatenated_samples.items()
|
| 34 |
+
}
|
| 35 |
+
self.residual = {
|
| 36 |
+
k: v[(chunk_num * self.chunk_size) :]
|
| 37 |
+
for k, v in concatenated_samples.items()
|
| 38 |
+
}
|
| 39 |
+
else:
|
| 40 |
+
result = concatenated_samples
|
| 41 |
+
self.residual = {k: [] for k in concatenated_samples.keys()}
|
| 42 |
+
|
| 43 |
+
result["labels"] = result["input_ids"].copy()
|
| 44 |
+
|
| 45 |
+
return result
|
| 46 |
+
|
| 47 |
+
class ConcatDataset(Dataset):
|
| 48 |
+
"""
|
| 49 |
+
Concatenates or packs samples of a dataset into chunks of size `chunk_size`
|
| 50 |
+
"""
|
| 51 |
+
def __init__(self, dataset, chunk_size: int = 1024, seed: int = 42,) -> None:
|
| 52 |
+
self.dataset = dataset
|
| 53 |
+
self.chunk_size = chunk_size
|
| 54 |
+
self.samples = []
|
| 55 |
+
buffer = {
|
| 56 |
+
"input_ids": [],
|
| 57 |
+
"attention_mask": [],
|
| 58 |
+
"labels": [],
|
| 59 |
+
}
|
| 60 |
+
random.seed(seed)
|
| 61 |
+
for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
|
| 62 |
+
buffer = {k: v + sample[k] for k,v in buffer.items()}
|
| 63 |
+
|
| 64 |
+
while len(next(iter(buffer.values()))) > self.chunk_size:
|
| 65 |
+
self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
|
| 66 |
+
buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
|
| 67 |
+
# Slow hack, but filter out any samples without valid labels (all -100)
|
| 68 |
+
self.filtered_samples = []
|
| 69 |
+
for s in self.samples:
|
| 70 |
+
if sum(s['labels']) != chunk_size * -100:
|
| 71 |
+
self.filtered_samples.append(s)
|
| 72 |
+
if len(self.filtered_samples) < len(self.samples):
|
| 73 |
+
print(f'OG dataset: {len(self.samples)} samples -> Filtered dataset: {len(self.filtered_samples)}')
|
| 74 |
+
print(f'-> Filtered out {len(self.samples) - len(self.filtered_samples)} samples')
|
| 75 |
+
|
| 76 |
+
def __getitem__(self, idx):
|
| 77 |
+
return self.filtered_samples[idx]
|
| 78 |
+
|
| 79 |
+
def __len__(self):
|
| 80 |
+
return len(self.filtered_samples)
|
src/dataloaders/utils/setup.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Helper functions dataset setup and loading
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from os.path import join
|
| 6 |
+
import shutil
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from torch.utils.data import Dataset, DataLoader
|
| 10 |
+
|
| 11 |
+
from datasets import Dataset as HFDataset
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
+
from transformers import AutoTokenizer, LlamaTokenizer
|
| 14 |
+
from transformers import DataCollatorForSeq2Seq
|
| 15 |
+
# from transformers import DefaultDataCollator, DataCollatorWithPadding
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_seq2seq_loader(dataset: Dataset, tokenizer: AutoTokenizer,
|
| 19 |
+
split: str, **loader_kwargs: any):
|
| 20 |
+
"""
|
| 21 |
+
Get dataloader for seq2seq tasks (evaluation)
|
| 22 |
+
"""
|
| 23 |
+
tokenizer.padding_side = 'right'
|
| 24 |
+
collate_fn = DataCollatorForSeq2Seq(
|
| 25 |
+
tokenizer, label_pad_token_id=-100, return_tensors='pt')
|
| 26 |
+
return DataLoader(
|
| 27 |
+
dataset, shuffle='train' in split, collate_fn=collate_fn, **loader_kwargs)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_lm_loader(dataset: Dataset, tokenizer: AutoTokenizer,
|
| 31 |
+
split: str, max_length: int = None, **loader_kwargs: any):
|
| 32 |
+
"""
|
| 33 |
+
Get dataloader for language modeling (training)
|
| 34 |
+
-> Currently this ends up being the same as get_seq2seq_loader
|
| 35 |
+
"""
|
| 36 |
+
# collate_fn = DefaultDataCollator(return_tensors='pt')
|
| 37 |
+
# collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, padding=True,
|
| 38 |
+
# max_length=max_length, return_tensors='pt')
|
| 39 |
+
collate_fn = DataCollatorForSeq2Seq(
|
| 40 |
+
tokenizer, label_pad_token_id=-100, return_tensors='pt')
|
| 41 |
+
return DataLoader(
|
| 42 |
+
dataset, shuffle='train' in split, collate_fn=collate_fn, **loader_kwargs)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def convert_to_hf_dataset(dataset, cache_dir: str):
|
| 46 |
+
"""
|
| 47 |
+
Convert iterable dataset to HuggingFace HFDataset object
|
| 48 |
+
"""
|
| 49 |
+
def gen():
|
| 50 |
+
for _, sample in enumerate(dataset):
|
| 51 |
+
yield sample # dataset[idx]
|
| 52 |
+
return HFDataset.from_generator(gen, cache_dir=cache_dir)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_tokenizer_from_config(model_config):
|
| 56 |
+
"""
|
| 57 |
+
Get pretrained tokenizer based on (pretrained) model config
|
| 58 |
+
"""
|
| 59 |
+
# Get tokenizer
|
| 60 |
+
if 'llama' in model_config['pretrained_model_name_or_path']:
|
| 61 |
+
try: # if we store locally
|
| 62 |
+
model_path = join(model_config['cache_dir'],
|
| 63 |
+
model_config['pretrained_model_name_or_path'])
|
| 64 |
+
tokenizer = LlamaTokenizer.from_pretrained(model_path)
|
| 65 |
+
except Exception as e:
|
| 66 |
+
try:
|
| 67 |
+
tokenizer = AutoTokenizer.from_pretrained(**model_config)
|
| 68 |
+
print("-> Bad LlamaTokenizer.from_pretrained(model_path)", e)
|
| 69 |
+
print("-> But resolved with: AutoTokenizer.from_pretrained(**model_config)")
|
| 70 |
+
except Exception as e2:
|
| 71 |
+
print("-> Error with AutoTokenizer.from_pretrained(**model_config)", e2)
|
| 72 |
+
# tokenizer = LlamaTokenizer.from_pretrained(**model_config) # v4.43 errors with `*** TypeError: not a string`
|
| 73 |
+
elif 'Mistral-7B-Instruct-v0.3' in model_config['pretrained_model_name_or_path']:
|
| 74 |
+
tokenizer = LlamaTokenizer.from_pretrained(**model_config) # hack where AutoTokenizer doesn't recognize
|
| 75 |
+
elif 'Mistral-7B' in model_config['pretrained_model_name_or_path']:
|
| 76 |
+
tokenizer = AutoTokenizer.from_pretrained(**model_config)
|
| 77 |
+
else:
|
| 78 |
+
tokenizer = AutoTokenizer.from_pretrained(**model_config)
|
| 79 |
+
return tokenizer
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def add_special_tokens_to_dataset(dataset, tokenizer):
|
| 83 |
+
"""
|
| 84 |
+
Add special tokens as attributes to a dataset object
|
| 85 |
+
"""
|
| 86 |
+
token_map = {k: v for k, v in tokenizer.special_tokens_map.items()}
|
| 87 |
+
special_ids = tokenizer.all_special_ids
|
| 88 |
+
for idx, k in enumerate(tokenizer.special_tokens_map.keys()):
|
| 89 |
+
token_map[f'{k}_id'] = special_ids[idx]
|
| 90 |
+
for k, v in token_map.items():
|
| 91 |
+
setattr(dataset, k, v)
|
| 92 |
+
return dataset
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def train_test_split(samples: any, train_size: int, test_size: int, seed: int):
|
| 96 |
+
"""
|
| 97 |
+
Split samples into train and test sets
|
| 98 |
+
"""
|
| 99 |
+
try:
|
| 100 |
+
assert len(samples) == train_size + test_size
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(len(samples), train_size + test_size)
|
| 103 |
+
raise e
|
| 104 |
+
arange = np.arange(len(samples))
|
| 105 |
+
np.random.seed(seed)
|
| 106 |
+
test_idx = np.random.choice(arange, size=test_size, replace=False)
|
| 107 |
+
train_idx = np.setdiff1d(arange, test_idx)
|
| 108 |
+
return samples[train_idx], samples[test_idx]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def download_scrolls_metric():
|
| 112 |
+
"""
|
| 113 |
+
Download ROUGE, F1, and other accuracy metrics included in the SCROLLS dataset
|
| 114 |
+
"""
|
| 115 |
+
scrolls_metric_path = hf_hub_download(
|
| 116 |
+
repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset"
|
| 117 |
+
)
|
| 118 |
+
updated_scrolls_metric_path = (
|
| 119 |
+
os.path.dirname(scrolls_metric_path) +
|
| 120 |
+
os.path.basename(scrolls_metric_path).replace(".", "_") + ".py"
|
| 121 |
+
)
|
| 122 |
+
shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
|
| 123 |
+
return updated_scrolls_metric_path
|
src/finetune.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Finetuning functions to do post-distillation
|
| 3 |
+
"""
|
| 4 |
+
from os.path import join
|
| 5 |
+
from omegaconf import OmegaConf
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.nn import Module
|
| 9 |
+
|
| 10 |
+
from src.utils.setup import update_config_from_args
|
| 11 |
+
from src.dataloaders import load_data
|
| 12 |
+
from src.trainer import get_trainer, get_optimizer, get_scheduler
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def prepare_finetune_configs(args, model_config: dict,
|
| 16 |
+
finetune_config_name: str = None,
|
| 17 |
+
finetune_checkpoint_name: str = None,
|
| 18 |
+
config_dir='./configs/experiment'):
|
| 19 |
+
"""
|
| 20 |
+
Prepare finetuning configs
|
| 21 |
+
"""
|
| 22 |
+
# Load finetuning config
|
| 23 |
+
finetune_config = (finetune_config_name if finetune_config_name is not None else
|
| 24 |
+
finetune_checkpoint_name.split('-f=')[-1].split('-')[0])
|
| 25 |
+
finetune_config_path = join(config_dir, f'{finetune_config}.yaml')
|
| 26 |
+
finetune_config = OmegaConf.load(finetune_config_path)
|
| 27 |
+
finetune_config = update_config_from_args(finetune_config, args,
|
| 28 |
+
ignore_args=['lr', 'weight_decay'])
|
| 29 |
+
# Update data tokenizer to match model
|
| 30 |
+
if getattr(finetune_config.dataset, 'pretrained_model_config', None) is not None:
|
| 31 |
+
for k in ['pretrained_model_name_or_path', 'cache_dir']:
|
| 32 |
+
finetune_config.dataset.pretrained_model_config[k] = model_config['model'][k]
|
| 33 |
+
# Set finetuning args
|
| 34 |
+
for arg, argv in finetune_config.trainer.items():
|
| 35 |
+
if arg != 'name':
|
| 36 |
+
setattr(args, arg, argv)
|
| 37 |
+
for _config in ['dataloader', 'optimizer', 'lr_scheduler']:
|
| 38 |
+
setattr(args, _config, OmegaConf.to_container(getattr(finetune_config, _config)))
|
| 39 |
+
return finetune_config, args
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_finetuner(model: Module, finetune_config: dict, device: torch.device,
|
| 43 |
+
args: any, wandb: any, initial_eval: bool = False):
|
| 44 |
+
"""
|
| 45 |
+
Initialize finetuning trainer
|
| 46 |
+
"""
|
| 47 |
+
model.to(device) # if using a fused optimizer
|
| 48 |
+
model.train()
|
| 49 |
+
|
| 50 |
+
# Initialize optimizer and scheduler
|
| 51 |
+
optimizer = get_optimizer(model=model, **finetune_config.optimizer)
|
| 52 |
+
scheduler = get_scheduler(optimizer=optimizer, **finetune_config.lr_scheduler)
|
| 53 |
+
|
| 54 |
+
dataloaders = load_data(finetune_config.dataset, finetune_config.dataloader)
|
| 55 |
+
train_loader = dataloaders[finetune_config.trainer.train_split]
|
| 56 |
+
eval_loader = dataloaders[finetune_config.trainer.val_split]
|
| 57 |
+
|
| 58 |
+
OurTrainer = get_trainer(finetune_config.trainer.name)
|
| 59 |
+
trainer = OurTrainer(model=model,
|
| 60 |
+
args=args,
|
| 61 |
+
train_loader=train_loader,
|
| 62 |
+
eval_loader=eval_loader,
|
| 63 |
+
optimizer_and_scheduler=(optimizer, scheduler),
|
| 64 |
+
device=device,
|
| 65 |
+
wandb=wandb,
|
| 66 |
+
checkpoint_suffix='_ft',
|
| 67 |
+
**finetune_config.trainer)
|
| 68 |
+
return trainer
|
src/model/__init__.py
ADDED
|
File without changes
|
src/model/convert_model.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Attention conversion helpers
|
| 3 |
+
"""
|
| 4 |
+
from functools import partial
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def convert_attention(model: nn.Module,
|
| 10 |
+
attention_config: dict,
|
| 11 |
+
train_attention: bool = False,
|
| 12 |
+
remove_base_attn: bool = True,):
|
| 13 |
+
"""
|
| 14 |
+
Call to convert all attention layers
|
| 15 |
+
"""
|
| 16 |
+
softmax_attns = []
|
| 17 |
+
if 'softmax_attentions' in attention_config:
|
| 18 |
+
softmax_attns = attention_config['softmax_attentions']
|
| 19 |
+
if attention_config.attention_type != 'softmax':
|
| 20 |
+
layers = traverse_layers(model)
|
| 21 |
+
for layer_idx, layer in enumerate(tqdm(layers, desc='Converting attentions...')):
|
| 22 |
+
if layer_idx not in softmax_attns:
|
| 23 |
+
layer.self_attn = convert_llama_attention(
|
| 24 |
+
layer, attention_config, layers, train_attention, remove_base_attn,
|
| 25 |
+
)
|
| 26 |
+
layer.self_attn.converted = True
|
| 27 |
+
else: # Freeze any preserved softmax attention layers
|
| 28 |
+
for p in layer.parameters():
|
| 29 |
+
p.requires_grad = False
|
| 30 |
+
else:
|
| 31 |
+
print(f'-> attention_config.attention_type is {attention_config.attention_type}; not converting attentions')
|
| 32 |
+
return model
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def toggle_attention(llama_model: nn.Module, train: bool = False):
|
| 36 |
+
"""
|
| 37 |
+
Make attentions trainable if train is True
|
| 38 |
+
-> Set train_attention = False when finetuning
|
| 39 |
+
"""
|
| 40 |
+
for layer in traverse_layers(llama_model):
|
| 41 |
+
layer.self_attn.train_attention = train
|
| 42 |
+
return llama_model
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def remove_base_attention(llama_model: nn.Module):
|
| 46 |
+
"""
|
| 47 |
+
Remove teacher attention after distillation (if we keep it)
|
| 48 |
+
"""
|
| 49 |
+
for layer in traverse_layers(llama_model):
|
| 50 |
+
if getattr(layer.self_attn, 'base_attn', False):
|
| 51 |
+
del layer.self_attn.base_attn
|
| 52 |
+
return llama_model
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def traverse_layers(model: nn.Module, verbose: bool = False):
|
| 56 |
+
"""
|
| 57 |
+
Return list of model layers
|
| 58 |
+
"""
|
| 59 |
+
try:
|
| 60 |
+
layers = model.model.layers
|
| 61 |
+
if verbose:
|
| 62 |
+
print('-> Loading from model.model.layers')
|
| 63 |
+
except AttributeError as e: # if base model
|
| 64 |
+
if verbose:
|
| 65 |
+
print(e)
|
| 66 |
+
try:
|
| 67 |
+
layers = model.layers
|
| 68 |
+
if verbose:
|
| 69 |
+
print('-> Loading from model.layers')
|
| 70 |
+
except AttributeError as e1: # If we make a PEFT model
|
| 71 |
+
if verbose:
|
| 72 |
+
print(e1)
|
| 73 |
+
layers = model.base_model.model.model.layers
|
| 74 |
+
if verbose:
|
| 75 |
+
print('-> Loading from model.base_model.model.model.layers')
|
| 76 |
+
return layers
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def convert_llama_attention(layer: nn.Module,
|
| 80 |
+
attention_config: dict,
|
| 81 |
+
layers: list[nn.Module], # list of layers
|
| 82 |
+
train_attention: bool = False,
|
| 83 |
+
remove_base_attn: bool = True):
|
| 84 |
+
"""
|
| 85 |
+
Converts a single layer's attention layer as specified by attention_config
|
| 86 |
+
"""
|
| 87 |
+
return get_attention(**attention_config)(
|
| 88 |
+
base_attn=layer.self_attn,
|
| 89 |
+
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
|
| 90 |
+
max_layer_idx=len(layers) - 1,
|
| 91 |
+
train_attention=train_attention,
|
| 92 |
+
remove_base_attn=remove_base_attn,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_attention(attention_type: str, **kwargs: any):
|
| 97 |
+
"""
|
| 98 |
+
Get the linear attention class; either purely linear or linear with sliding window
|
| 99 |
+
-> 'linear' == 'lolcats_llama'
|
| 100 |
+
-> 'linear and sliding_window' == 'lolcats_llama_window_*'
|
| 101 |
+
"""
|
| 102 |
+
kwargs['attention_type'] = attention_type
|
| 103 |
+
|
| 104 |
+
if attention_type == 'lolcats_llama':
|
| 105 |
+
from .linear_attention import LolcatsLinearAttention
|
| 106 |
+
return partial(LolcatsLinearAttention, **kwargs)
|
| 107 |
+
|
| 108 |
+
elif attention_type == 'lolcats_llama_window_tk':
|
| 109 |
+
from .linear_attention import LolcatsTKWindowAttention
|
| 110 |
+
return partial(LolcatsTKWindowAttention, **kwargs)
|
| 111 |
+
|
| 112 |
+
elif attention_type == 'lolcats_llama_window_sw':
|
| 113 |
+
from .linear_attention import LolcatsSlidingWindowAttention
|
| 114 |
+
return partial(LolcatsSlidingWindowAttention, **kwargs)
|
| 115 |
+
|
| 116 |
+
elif attention_type == 'lolcats_llama_window_sw_linear':
|
| 117 |
+
from .linear_attention.linear_window_attention_sw_linear import LolcatsLinearSlidingWindowAttention
|
| 118 |
+
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
|
| 119 |
+
|
| 120 |
+
## Experimental chunked linear attentions below
|
| 121 |
+
elif attention_type == 'lolcats_long_llama_window_tk':
|
| 122 |
+
from .linear_attention import LolcatsTKWindowLongAttention
|
| 123 |
+
return partial(LolcatsTKWindowLongAttention, **kwargs)
|
| 124 |
+
|
| 125 |
+
elif attention_type == 'lolcats_long_llama_window_sw':
|
| 126 |
+
from .linear_attention import LolcatsSlidingWindowLongAttention
|
| 127 |
+
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
|
| 128 |
+
|
| 129 |
+
## TK generation build (requires Thunderkittens)
|
| 130 |
+
elif attention_type == 'lolcats_llama_window_tk_gen':
|
| 131 |
+
from .linear_attention import LolcatsWindowAttentionTKGen
|
| 132 |
+
return partial(LolcatsWindowAttentionTKGen, **kwargs)
|
| 133 |
+
|
| 134 |
+
else:
|
| 135 |
+
print(f'-> attention_type {attention_type} not handled... returning None')
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_attention_cache(attention_type: str, past_key_values: any = None):
|
| 140 |
+
"""
|
| 141 |
+
Determine how we store past keys and values when generating
|
| 142 |
+
"""
|
| 143 |
+
if attention_type is None:
|
| 144 |
+
return past_key_values
|
| 145 |
+
|
| 146 |
+
# print(f'Returning attention cache based on attention_type == {attention_type}')
|
| 147 |
+
elif 'lolcats_llama_window_tk_gen' in attention_type:
|
| 148 |
+
from .linear_attention import LinearAttentionTKWindowGenerationCache
|
| 149 |
+
return LinearAttentionTKWindowGenerationCache()
|
| 150 |
+
|
| 151 |
+
elif 'llama_window_tk' in attention_type:
|
| 152 |
+
from .linear_attention import LinearAttentionTKWindowCache
|
| 153 |
+
return LinearAttentionTKWindowCache()
|
| 154 |
+
|
| 155 |
+
elif 'llama_window_sw' in attention_type:
|
| 156 |
+
from .linear_attention import LinearAttentionSlidingWindowCache
|
| 157 |
+
return LinearAttentionSlidingWindowCache()
|
| 158 |
+
|
| 159 |
+
elif 'llama_window_sw_linear' in attention_type:
|
| 160 |
+
from .linear_attention import LinearAttentionSlidingWindowCache
|
| 161 |
+
return LinearAttentionSlidingWindowCache()
|
| 162 |
+
|
| 163 |
+
## TK generation build (requires Thunderkittens)
|
| 164 |
+
elif attention_type == 'lolcats_llama_window_tk_gen':
|
| 165 |
+
from .linear_attention.linear_window_attention_tk_gen import LinearAttentionTKWindowGenerationCache
|
| 166 |
+
return LinearAttentionTKWindowGenerationCache()
|
| 167 |
+
|
| 168 |
+
elif 'softmax' in attention_type:
|
| 169 |
+
return past_key_values
|
| 170 |
+
|
| 171 |
+
else:
|
| 172 |
+
from .linear_attention import LinearAttentionState
|
| 173 |
+
return LinearAttentionState()
|
src/model/feature_map.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Learnable linear attention feature map classes and functions
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def init_feature_map(name: str, mlp: nn.Module, **kwargs: dict):
|
| 10 |
+
"""
|
| 11 |
+
Initialize feature map final activation for linear attention
|
| 12 |
+
"""
|
| 13 |
+
return FeatureMap(activation_name=name, mlp=mlp, **kwargs)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def init_feature_map_act(name: str, fullspace: bool = True, **kwargs):
|
| 17 |
+
"""
|
| 18 |
+
Initialize feature map final activation for linear attention
|
| 19 |
+
"""
|
| 20 |
+
if name == 'softmax_dim' and fullspace:
|
| 21 |
+
return SoftmaxDim(**kwargs)
|
| 22 |
+
elif name == 'softmax_dim' and not fullspace:
|
| 23 |
+
return SoftmaxDimHalfspace(**kwargs)
|
| 24 |
+
elif name == 'exp_dim' and fullspace:
|
| 25 |
+
return Exp(**kwargs)
|
| 26 |
+
elif name == 'exp_dim' and not fullspace:
|
| 27 |
+
return ExpHalfspace(**kwargs)
|
| 28 |
+
elif name == 'pos_elu':
|
| 29 |
+
return PosELU(**kwargs)
|
| 30 |
+
elif name == 'relu':
|
| 31 |
+
return ReLU(**kwargs)
|
| 32 |
+
|
| 33 |
+
else:
|
| 34 |
+
raise NotImplementedError
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def init_learned_kernel(name: str, **kwargs: any):
|
| 38 |
+
"""
|
| 39 |
+
Initialize feature map MLP for linear attention
|
| 40 |
+
"""
|
| 41 |
+
if name == 'untied_head_einsum':
|
| 42 |
+
return FeatureMapMLP(**kwargs)
|
| 43 |
+
elif name == 'untied_head_adapter':
|
| 44 |
+
return FeatureMapAdapter(**kwargs)
|
| 45 |
+
else:
|
| 46 |
+
raise NotImplementedError
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class FeatureMap(nn.Module):
|
| 50 |
+
"""
|
| 51 |
+
Final 'activation' of feature map. Can probably be combined with
|
| 52 |
+
`FeatureMapMLP` below
|
| 53 |
+
|
| 54 |
+
Full feature map is like f(xW + b)
|
| 55 |
+
-> This is the `f` part
|
| 56 |
+
"""
|
| 57 |
+
def __init__(self,
|
| 58 |
+
activation_name: str,
|
| 59 |
+
head_dim_idx: int = -1,
|
| 60 |
+
eps: float = 1e-12,
|
| 61 |
+
mlp: nn.Module = None,
|
| 62 |
+
fullspace: bool = True,):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.head_dim_idx = head_dim_idx
|
| 65 |
+
self.eps = eps
|
| 66 |
+
self.mlp = mlp if mlp is not None else nn.Identity()
|
| 67 |
+
self.activation = init_feature_map_act(activation_name, fullspace, eps=eps)
|
| 68 |
+
|
| 69 |
+
def forward(self, x: torch.Tensor, *mlp_args: any, **mlp_kwargs: any):
|
| 70 |
+
"""
|
| 71 |
+
Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
|
| 72 |
+
"""
|
| 73 |
+
return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x)
|
| 74 |
+
|
| 75 |
+
def q_map(self, *args: any, **kwargs: any):
|
| 76 |
+
"""
|
| 77 |
+
Use for inference in case q and k feature maps differ
|
| 78 |
+
"""
|
| 79 |
+
return self.forward(*args, **kwargs)
|
| 80 |
+
|
| 81 |
+
def k_map(self, *args: any, **kwargs: any):
|
| 82 |
+
"""
|
| 83 |
+
Use for inference in case q and k feature maps differ
|
| 84 |
+
"""
|
| 85 |
+
return self.forward(*args, **kwargs)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# -----------------------
|
| 89 |
+
# Feature map activations
|
| 90 |
+
# -----------------------
|
| 91 |
+
class FeatureMapAct(nn.Module):
|
| 92 |
+
"""
|
| 93 |
+
Base class for feature map activations
|
| 94 |
+
"""
|
| 95 |
+
def __init__(self, eps: float = 1e-12):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.eps = eps
|
| 98 |
+
|
| 99 |
+
def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
|
| 100 |
+
"""
|
| 101 |
+
x.shape is (batch_size, n_heads, seq_len, head_dim)
|
| 102 |
+
"""
|
| 103 |
+
return x
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class PosELU(FeatureMapAct):
|
| 107 |
+
"""
|
| 108 |
+
1 + ELU activation as in https://arxiv.org/abs/2006.16236
|
| 109 |
+
"""
|
| 110 |
+
def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
|
| 111 |
+
return (1 + F.elu(x)).clamp(min=self.eps)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class ReLU(FeatureMapAct):
|
| 115 |
+
"""
|
| 116 |
+
ReLU activation as in https://arxiv.org/abs/2103.13076
|
| 117 |
+
"""
|
| 118 |
+
def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
|
| 119 |
+
return F.relu(x).clamp(min=self.eps)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class SoftmaxDim(FeatureMapAct):
|
| 123 |
+
"""
|
| 124 |
+
Softmax activation as in https://arxiv.org/abs/2402.04347
|
| 125 |
+
"""
|
| 126 |
+
def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
|
| 127 |
+
return torch.cat([
|
| 128 |
+
torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1)
|
| 129 |
+
], dim=-1).clamp(min=self.eps)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class SoftmaxDimHalfspace(FeatureMapAct):
|
| 133 |
+
"""
|
| 134 |
+
Softmax activation as in https://arxiv.org/abs/2402.04347
|
| 135 |
+
"""
|
| 136 |
+
def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
|
| 137 |
+
return torch.softmax(x, dim=-1).clamp(min=self.eps)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class Exp(FeatureMapAct):
|
| 141 |
+
"""
|
| 142 |
+
Exp activation as in https://arxiv.org/abs/2402.04347
|
| 143 |
+
"""
|
| 144 |
+
def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
|
| 145 |
+
x_max = torch.amax(x, dim=-1, keepdim=True)
|
| 146 |
+
x_min = torch.amin(x, dim=-1, keepdim=True)
|
| 147 |
+
return torch.cat([
|
| 148 |
+
torch.exp(x - x_max), torch.exp(-x + x_min)
|
| 149 |
+
], dim=-1).clamp(min=self.eps)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class ExpHalfspace(FeatureMapAct):
|
| 153 |
+
"""
|
| 154 |
+
Exp activation as in https://arxiv.org/abs/2402.04347
|
| 155 |
+
"""
|
| 156 |
+
def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
|
| 157 |
+
x_max = torch.amax(x, dim=-1, keepdim=True)
|
| 158 |
+
return torch.exp(x - x_max).clamp(min=self.eps)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# ----------------
|
| 162 |
+
# Feature map MLPs
|
| 163 |
+
# ----------------
|
| 164 |
+
|
| 165 |
+
class FeatureMapMLP(nn.Module):
|
| 166 |
+
"""
|
| 167 |
+
Learnable MLP in feature map.
|
| 168 |
+
|
| 169 |
+
Full feature map is like f(xW + b)
|
| 170 |
+
-> This is the `W` and (optional) `b` part
|
| 171 |
+
"""
|
| 172 |
+
def __init__(self,
|
| 173 |
+
num_heads: int,
|
| 174 |
+
head_dim: int, # input dim
|
| 175 |
+
feature_dim: int, # output dim
|
| 176 |
+
dtype: torch.dtype,
|
| 177 |
+
device: torch.device,
|
| 178 |
+
skip_connection: bool = False,
|
| 179 |
+
bias: bool = False,
|
| 180 |
+
zero_init: bool = False,
|
| 181 |
+
normal_init: bool = False,):
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.num_heads = num_heads
|
| 184 |
+
self.head_dim = head_dim
|
| 185 |
+
self.feature_dim = feature_dim
|
| 186 |
+
self.dtype = dtype
|
| 187 |
+
self.device = device
|
| 188 |
+
self.skip_connection = skip_connection
|
| 189 |
+
self.bias = bias
|
| 190 |
+
self.zero_init = zero_init
|
| 191 |
+
self.normal_init = normal_init
|
| 192 |
+
self.init_weights_()
|
| 193 |
+
|
| 194 |
+
if self.zero_init: # Zero-out weights or set as identity post-initialization
|
| 195 |
+
self.zero_init_with_skip_() if self.skip_connection else self.zero_init_()
|
| 196 |
+
|
| 197 |
+
if self.normal_init:
|
| 198 |
+
with torch.no_grad():
|
| 199 |
+
nn.init.normal_(self.layer)
|
| 200 |
+
|
| 201 |
+
if self.skip_connection:
|
| 202 |
+
assertion_fail = f'If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}'
|
| 203 |
+
assert self.head_dim == self.feature_dim, assertion_fail
|
| 204 |
+
|
| 205 |
+
def init_weights_(self):
|
| 206 |
+
"""
|
| 207 |
+
Initialize (W)eights and (b)iases
|
| 208 |
+
"""
|
| 209 |
+
self.layer = nn.Parameter(torch.zeros(
|
| 210 |
+
(self.num_heads, self.head_dim, self.feature_dim),
|
| 211 |
+
dtype=self.dtype, device=self.device,
|
| 212 |
+
))
|
| 213 |
+
nn.init.kaiming_uniform_(self.layer)
|
| 214 |
+
|
| 215 |
+
if self.bias:
|
| 216 |
+
self.bias = nn.Parameter(torch.zeros(
|
| 217 |
+
(1, self.num_heads, 1, 1), # self.feature_dim),
|
| 218 |
+
dtype=self.dtype, device=self.device,
|
| 219 |
+
))
|
| 220 |
+
nn.init.kaiming_uniform_(self.bias)
|
| 221 |
+
else:
|
| 222 |
+
self.bias = 0. # hack
|
| 223 |
+
|
| 224 |
+
def zero_init_with_skip_(self):
|
| 225 |
+
"""
|
| 226 |
+
Initialize weights to zero matrix if skip connection
|
| 227 |
+
"""
|
| 228 |
+
with torch.no_grad():
|
| 229 |
+
nn.init.zeros_(self.layer)
|
| 230 |
+
|
| 231 |
+
def zero_init_(self):
|
| 232 |
+
"""
|
| 233 |
+
Initialize weights to identity matrix if no skip connection
|
| 234 |
+
"""
|
| 235 |
+
with torch.no_grad():
|
| 236 |
+
for i in range(self.layer.shape[0]):
|
| 237 |
+
try:
|
| 238 |
+
nn.init.eye_(self.layer[i])
|
| 239 |
+
except RuntimeError:
|
| 240 |
+
with torch.no_grad():
|
| 241 |
+
dtype = self.layer[i].dtype
|
| 242 |
+
weight = torch.eye(*self.layer[i].shape,
|
| 243 |
+
requires_grad=self.layer[i].requires_grad,
|
| 244 |
+
device=self.layer[i].device)
|
| 245 |
+
self.layer[i] = weight.to(dtype=dtype)
|
| 246 |
+
|
| 247 |
+
def forward(self, x: torch.Tensor):
|
| 248 |
+
"""
|
| 249 |
+
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
| 250 |
+
"""
|
| 251 |
+
_x = torch.einsum('hdf,bhld->bhlf', self.layer, x) + self.bias
|
| 252 |
+
return x + _x if self.skip_connection else _x
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class FeatureMapAdapter(FeatureMapMLP):
|
| 256 |
+
"""
|
| 257 |
+
Learnable Feature map with bottleneck adapter
|
| 258 |
+
as in https://arxiv.org/abs/1902.00751
|
| 259 |
+
|
| 260 |
+
We don't use but could be fun to try
|
| 261 |
+
"""
|
| 262 |
+
def __init__(self, hidden_dim: int, *args, **kwargs):
|
| 263 |
+
kwargs['skip_connection'] = True
|
| 264 |
+
kwargs['bias'] = True
|
| 265 |
+
kwargs['zero_init'] = True
|
| 266 |
+
self.hidden_dim = hidden_dim
|
| 267 |
+
super().__init__(*args, **kwargs)
|
| 268 |
+
|
| 269 |
+
def init_weights_(self):
|
| 270 |
+
"""
|
| 271 |
+
Initialize (W)eights and (b)iases
|
| 272 |
+
"""
|
| 273 |
+
kwargs = {'dtype': self.dtype, 'device': self.device}
|
| 274 |
+
self.layer0 = nn.Parameter(
|
| 275 |
+
torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs)
|
| 276 |
+
)
|
| 277 |
+
self.layer1 = nn.Parameter(
|
| 278 |
+
torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs)
|
| 279 |
+
)
|
| 280 |
+
nn.init.kaiming_uniform_(self.layer0)
|
| 281 |
+
nn.init.kaiming_uniform_(self.layer1)
|
| 282 |
+
|
| 283 |
+
self.bias0 = nn.Parameter(torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs))
|
| 284 |
+
self.bias1 = nn.Parameter(torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs))
|
| 285 |
+
nn.init.kaiming_uniform_(self.bias0)
|
| 286 |
+
nn.init.kaiming_uniform_(self.bias1)
|
| 287 |
+
|
| 288 |
+
def zero_init_with_skip_(self):
|
| 289 |
+
with torch.no_grad():
|
| 290 |
+
nn.init.zeros_(self.layer0)
|
| 291 |
+
nn.init.zeros_(self.layer1)
|
| 292 |
+
nn.init.zeros_(self.bias0)
|
| 293 |
+
nn.init.zeros_(self.bias1)
|
| 294 |
+
|
| 295 |
+
def zero_init_(self):
|
| 296 |
+
assert NotImplementedError
|
| 297 |
+
|
| 298 |
+
def forward(self, x: torch.Tensor):
|
| 299 |
+
"""
|
| 300 |
+
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
| 301 |
+
-> Down-project, apply nonlinearity, up-project; add skip connection
|
| 302 |
+
"""
|
| 303 |
+
_x = torch.einsum('hde,bhld->bhle', self.layer0, x) + self.bias0
|
| 304 |
+
_x = F.relu(_x)
|
| 305 |
+
_x = torch.einsum('hef,bhle->bhlf', self.layer1, _x) + self.bias1
|
| 306 |
+
return x + _x if self.skip_connection else _x
|
src/model/linear_attention/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Linear and linear attention + sliding window classes
|
| 3 |
+
"""
|
| 4 |
+
from .linear_attention import (
|
| 5 |
+
LolcatsLinearAttention, LinearAttentionState
|
| 6 |
+
)
|
| 7 |
+
from .linear_window_attention_tk import (
|
| 8 |
+
LolcatsTKWindowAttention, LinearAttentionTKWindowCache
|
| 9 |
+
)
|
| 10 |
+
from .linear_window_attention_sw import (
|
| 11 |
+
LolcatsSlidingWindowAttention, LinearAttentionSlidingWindowCache
|
| 12 |
+
)
|
| 13 |
+
# Experimental chunk linear attentions
|
| 14 |
+
from .linear_window_attention_tk_long import (
|
| 15 |
+
LolcatsTKWindowLongAttention,
|
| 16 |
+
)
|
| 17 |
+
from .linear_window_attention_sw_long import (
|
| 18 |
+
LolcatsSlidingWindowLongAttention,
|
| 19 |
+
)
|
| 20 |
+
from .linear_window_attention_tk_gen import (
|
| 21 |
+
LolcatsWindowAttentionTKGen,
|
| 22 |
+
LinearAttentionTKWindowGenerationCache
|
| 23 |
+
)
|
src/model/linear_attention/linear_attention.py
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Linear attention classes
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Tuple, Optional
|
| 5 |
+
import copy
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from omegaconf import OmegaConf, DictConfig
|
| 9 |
+
|
| 10 |
+
from transformers.cache_utils import Cache # starting at Transformers v4.36
|
| 11 |
+
|
| 12 |
+
# Causal linear attention dot product CUDA kernel from fast-transformers
|
| 13 |
+
try:
|
| 14 |
+
from csrc import causal_dot_product as fast_causal_dot_product
|
| 15 |
+
except ImportError:
|
| 16 |
+
fast_causal_dot_product = None
|
| 17 |
+
|
| 18 |
+
from src.model.feature_map import init_feature_map, init_learned_kernel
|
| 19 |
+
from src.model.rotary import get_rotary_embeddings, apply_rotary_pos_emb
|
| 20 |
+
from .utils import repeat_kv
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# -------------------
|
| 24 |
+
# Attention functions
|
| 25 |
+
# -------------------
|
| 26 |
+
|
| 27 |
+
def causal_dot_product(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
| 28 |
+
"""
|
| 29 |
+
Causal linear attention dot product
|
| 30 |
+
- If available, use CUDA kernel from fast-transformers
|
| 31 |
+
"""
|
| 32 |
+
if fast_causal_dot_product is None:
|
| 33 |
+
kv = torch.einsum('bhlf,bhld->bhlfd', k, v)
|
| 34 |
+
return torch.einsum('bhlf,bhlfd->bhld', q, kv.cumsum(dim=2))
|
| 35 |
+
return fast_causal_dot_product(q, k, v)
|
| 36 |
+
|
| 37 |
+
def linear_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
| 38 |
+
fp32_attention: bool = False, eps: float = 1e-12,
|
| 39 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 40 |
+
"""
|
| 41 |
+
Compute linear attention with CUDA kernel implementation from fast-transformers
|
| 42 |
+
- https://github.com/idiap/fast-transformers
|
| 43 |
+
- Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim);
|
| 44 |
+
v is shape (b, h, l, head_dim)
|
| 45 |
+
"""
|
| 46 |
+
dtype = q.dtype
|
| 47 |
+
# Causal mask already applied
|
| 48 |
+
y = causal_dot_product(q.contiguous().to(dtype=torch.float32),
|
| 49 |
+
k.contiguous().to(dtype=torch.float32),
|
| 50 |
+
v.contiguous().to(dtype=torch.float32))
|
| 51 |
+
if fp32_attention:
|
| 52 |
+
y = (y / (torch.einsum(
|
| 53 |
+
"bhld,bhld->bhl", q.float(), k.float().cumsum(dim=2)
|
| 54 |
+
) + eps)[..., None]).to(dtype=dtype)
|
| 55 |
+
else:
|
| 56 |
+
y = y.to(dtype=dtype)
|
| 57 |
+
k = k.float().cumsum(dim=2).to(dtype=dtype)
|
| 58 |
+
y = y / (torch.einsum("bhld,bhld->bhl", q, k) + eps)[..., None]
|
| 59 |
+
return y, None, None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def softmax_attention(q: torch.Tensor, k: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 63 |
+
causal: bool = True, fp32_attention: bool = True,
|
| 64 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 65 |
+
"""
|
| 66 |
+
Standard softmax attention; only compute outputs if v is not None
|
| 67 |
+
-> Assume q, k, v are shape (batch_size, num_heads, seq_len, head_dim)
|
| 68 |
+
"""
|
| 69 |
+
y = None
|
| 70 |
+
a = torch.einsum('bhmd,bhnd->bhmn', q, k) * (k.shape[-1] ** -0.5)
|
| 71 |
+
if causal: # Apply causal mask
|
| 72 |
+
m, n = a.shape[-2:]
|
| 73 |
+
causal_mask = torch.ones((m, n), device = a.device, dtype = torch.bool).triu(n - m + 1)
|
| 74 |
+
a = a.masked_fill(causal_mask, -torch.finfo(a.dtype).max)
|
| 75 |
+
if fp32_attention:
|
| 76 |
+
a = torch.softmax(a, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 77 |
+
else:
|
| 78 |
+
a = torch.softmax(a, dim=-1)
|
| 79 |
+
if v is not None:
|
| 80 |
+
y = torch.einsum('bhmn,bhnd->bhmd', a, v)
|
| 81 |
+
return y, a, None
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def quadratic_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor = None,
|
| 85 |
+
causal: bool = True, fp32_attention: bool = False, eps: float = 1e-12,
|
| 86 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 87 |
+
"""
|
| 88 |
+
Compute attention with feature maps by instantiating L x L matrix of attention weights
|
| 89 |
+
-> Use for attention distillation
|
| 90 |
+
-> Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim); v is shape (b, h, l, head_dim)
|
| 91 |
+
"""
|
| 92 |
+
y = None
|
| 93 |
+
dtype = q.dtype
|
| 94 |
+
if fp32_attention:
|
| 95 |
+
q, k = q.float(), k.float()
|
| 96 |
+
a = torch.einsum('bhmd,bhnd->bhmn', q, k) # note we don't scale, tho we could
|
| 97 |
+
if causal: # Apply causal mask
|
| 98 |
+
m, n = a.shape[-2:]
|
| 99 |
+
causal_mask = torch.ones((m, n), device = a.device, dtype = torch.bool).triu(n - m + 1)
|
| 100 |
+
a = a.masked_fill(causal_mask, 0)
|
| 101 |
+
# Normalize to compute attention
|
| 102 |
+
a = a / (a.sum(dim=-1, keepdim=True) + eps)
|
| 103 |
+
a = a.to(dtype=dtype) if fp32_attention else a
|
| 104 |
+
if torch.isnan(a).sum() > 0:
|
| 105 |
+
breakpoint()
|
| 106 |
+
if v is not None:
|
| 107 |
+
y = torch.einsum('bhmn,bhnd->bhmd', a, v)
|
| 108 |
+
return y, a, None
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# ---------------------
|
| 112 |
+
# Attention layer class
|
| 113 |
+
# ---------------------
|
| 114 |
+
|
| 115 |
+
class LolcatsLinearAttention(nn.Module):
|
| 116 |
+
"""
|
| 117 |
+
LoLCATs attention implementation initialized from a
|
| 118 |
+
`LlamaAttention` or `MistralAttention` object (base_attn)
|
| 119 |
+
|
| 120 |
+
Most of the arguments are directly tied to argparse args
|
| 121 |
+
- For now we don't support padding.
|
| 122 |
+
"""
|
| 123 |
+
def __init__(self,
|
| 124 |
+
base_attn: nn.Module, # like LlamaAttention
|
| 125 |
+
feature_map: str,
|
| 126 |
+
feature_map_kwargs: dict,
|
| 127 |
+
layer_idx: Optional[int] = None,
|
| 128 |
+
max_layer_idx: Optional[int] = None,
|
| 129 |
+
learned_kernel: Optional[str] = None,
|
| 130 |
+
learned_kernel_kwargs: Optional[dict] = None,
|
| 131 |
+
tie_qk_kernels: Optional[bool] = False,
|
| 132 |
+
rotary_config: Optional[dict] = None,
|
| 133 |
+
train_attention: Optional[bool] = False,
|
| 134 |
+
remove_base_attn: Optional[bool] = True,
|
| 135 |
+
attention_type: Optional[str] = 'lolcats_llama',
|
| 136 |
+
mask_value: int = 0,
|
| 137 |
+
eps: float = 1e-12,
|
| 138 |
+
fp32_attention: bool = False,
|
| 139 |
+
track_state_grads: bool = False,
|
| 140 |
+
rank: Optional[int] = 0,
|
| 141 |
+
**kwargs: any) -> None:
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.base_config = getattr(base_attn, 'config', None)
|
| 144 |
+
if self.base_config is not None:
|
| 145 |
+
self.base_config = self.base_config.to_dict()
|
| 146 |
+
self.attention_type = attention_type
|
| 147 |
+
self.mask_value = mask_value
|
| 148 |
+
self.eps = eps
|
| 149 |
+
self.layer_idx = (layer_idx if layer_idx is not None else base_attn.layer_idx)
|
| 150 |
+
self.max_layer_idx = max_layer_idx
|
| 151 |
+
self.tie_qk_kernels = tie_qk_kernels
|
| 152 |
+
self.train_attention = train_attention
|
| 153 |
+
self.base_inference = False
|
| 154 |
+
self.fp32_attention = fp32_attention
|
| 155 |
+
self.track_state_grads = track_state_grads
|
| 156 |
+
if rank == 0: # multi-gpu
|
| 157 |
+
if fp32_attention and layer_idx == 0:
|
| 158 |
+
print(f'-> fp32_attention is {fp32_attention}')
|
| 159 |
+
if layer_idx == 0 and feature_map_kwargs is not None:
|
| 160 |
+
for k, v in feature_map_kwargs.items():
|
| 161 |
+
print(f'-> {k}: {v}')
|
| 162 |
+
if layer_idx == 0 and learned_kernel_kwargs is not None:
|
| 163 |
+
for k, v in learned_kernel_kwargs.items():
|
| 164 |
+
print(f'-> {k}: {v}')
|
| 165 |
+
|
| 166 |
+
self.remove_base_attn = remove_base_attn
|
| 167 |
+
|
| 168 |
+
# Rotary embeddings (patch for Llama 3.1, Transformer v4.43.0)
|
| 169 |
+
self.rotary_config = rotary_config
|
| 170 |
+
if isinstance(self.rotary_config, DictConfig): # ensure dict
|
| 171 |
+
self.rotary_config = OmegaConf.to_container(self.rotary_config)
|
| 172 |
+
|
| 173 |
+
self.rotary_emb = None
|
| 174 |
+
if self.base_config is not None and self.rotary_config is None:
|
| 175 |
+
self.rotary_emb = base_attn.rotary_emb
|
| 176 |
+
|
| 177 |
+
self.init_weights_(base_attn, remove_base_attn)
|
| 178 |
+
self.init_feature_map_(feature_map, feature_map_kwargs,
|
| 179 |
+
learned_kernel, learned_kernel_kwargs)
|
| 180 |
+
|
| 181 |
+
def init_feature_map_(self,
|
| 182 |
+
feature_map: str,
|
| 183 |
+
feature_map_kwargs: dict,
|
| 184 |
+
learned_kernel: str = None,
|
| 185 |
+
learned_kernel_kwargs: dict = None):
|
| 186 |
+
"""
|
| 187 |
+
Initialize MLP-based feature map
|
| 188 |
+
"""
|
| 189 |
+
self.fmap_gqa = False # Turn True if specified below
|
| 190 |
+
if learned_kernel is not None:
|
| 191 |
+
# Ensure dict
|
| 192 |
+
learned_kernel_kwargs = {k: v for k, v in learned_kernel_kwargs.items()}
|
| 193 |
+
learned_kernel_kwargs['num_heads'] = self.num_heads
|
| 194 |
+
learned_kernel_kwargs['head_dim'] = self.head_dim
|
| 195 |
+
learned_kernel_kwargs['dtype'] = self.q_proj.weight.dtype
|
| 196 |
+
learned_kernel_kwargs['device'] = self.q_proj.weight.device
|
| 197 |
+
# Create MLP
|
| 198 |
+
mlp_learned_kernel = init_learned_kernel(learned_kernel, **learned_kernel_kwargs)
|
| 199 |
+
# Add "activation"; see src.models.feature_map.py
|
| 200 |
+
self.feature_map_q = init_feature_map(name=feature_map,
|
| 201 |
+
mlp=mlp_learned_kernel,
|
| 202 |
+
**feature_map_kwargs)
|
| 203 |
+
if self.tie_qk_kernels: # tie mlp weights for query and key feature maps
|
| 204 |
+
self.feature_map_k = self.feature_map_q
|
| 205 |
+
else:
|
| 206 |
+
self.feature_map_k = copy.deepcopy(self.feature_map_q)
|
| 207 |
+
|
| 208 |
+
def init_weights_(self, base_attn: nn.Module, remove_base_attn: bool = True):
|
| 209 |
+
"""
|
| 210 |
+
Initialize module layers, weights, positional dependencies, etc.
|
| 211 |
+
from original softmax attention layer (base_attn)
|
| 212 |
+
"""
|
| 213 |
+
# Make other attributes accessible
|
| 214 |
+
self.attention_dropout = 0 # We don't use dropout
|
| 215 |
+
self.hidden_size = base_attn.hidden_size
|
| 216 |
+
self.num_heads = base_attn.num_heads
|
| 217 |
+
self.head_dim = base_attn.head_dim
|
| 218 |
+
self.num_key_value_heads = base_attn.num_key_value_heads
|
| 219 |
+
self.num_key_value_groups = base_attn.num_key_value_groups
|
| 220 |
+
|
| 221 |
+
self.q_shape = [self.num_heads, self.head_dim]
|
| 222 |
+
self.k_shape = [self.num_key_value_heads, self.head_dim]
|
| 223 |
+
self.v_shape = [self.num_key_value_heads, self.head_dim]
|
| 224 |
+
device = base_attn.q_proj.weight.device
|
| 225 |
+
# Rotary embeddings
|
| 226 |
+
if self.rotary_emb is None:
|
| 227 |
+
self.max_position_embeddings = base_attn.max_position_embeddings
|
| 228 |
+
scaling_factor = getattr(base_attn.rotary_emb, 'scaling_factor', 1.)
|
| 229 |
+
if self.rotary_config is None:
|
| 230 |
+
self.rotary_emb = get_rotary_embeddings(
|
| 231 |
+
rope_scaling_type=None,
|
| 232 |
+
head_dim=self.head_dim,
|
| 233 |
+
max_position_embeddings=self.max_position_embeddings, # base_attn.rotary_emb.max_position_embeddings,
|
| 234 |
+
rope_theta=base_attn.rotary_emb.base,
|
| 235 |
+
rope_scaling_factor=scaling_factor, # base_attn.rotary_emb.scaling_factor,
|
| 236 |
+
device=device,
|
| 237 |
+
)
|
| 238 |
+
else:
|
| 239 |
+
if 'device' not in self.rotary_config:
|
| 240 |
+
self.rotary_config['device'] = device
|
| 241 |
+
self.rotary_emb = get_rotary_embeddings(**self.rotary_config)
|
| 242 |
+
|
| 243 |
+
# Copy original model projection layers
|
| 244 |
+
self.q_proj = base_attn.q_proj
|
| 245 |
+
self.k_proj = base_attn.k_proj
|
| 246 |
+
self.v_proj = base_attn.v_proj
|
| 247 |
+
self.o_proj = base_attn.o_proj
|
| 248 |
+
try: # If wanting to use FA2 for ground-truth inference
|
| 249 |
+
self._flash_attn_uses_top_left_mask = base_attn._flash_attn_uses_top_left_mask
|
| 250 |
+
except AttributeError:
|
| 251 |
+
pass
|
| 252 |
+
|
| 253 |
+
if self.remove_base_attn or remove_base_attn:
|
| 254 |
+
del base_attn # We don't need to keep these around
|
| 255 |
+
else:
|
| 256 |
+
self.base_attn = base_attn # For some training runs helpful to just call
|
| 257 |
+
|
| 258 |
+
def process_qkv(self,
|
| 259 |
+
hidden_states: torch.Tensor,
|
| 260 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 261 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 262 |
+
past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None,): # "legacy" cache approach
|
| 263 |
+
"""
|
| 264 |
+
Compute queries, keys, and values
|
| 265 |
+
"""
|
| 266 |
+
b, l, _ = hidden_states.size()
|
| 267 |
+
q = self.q_proj(hidden_states)
|
| 268 |
+
k = self.k_proj(hidden_states)
|
| 269 |
+
v = self.v_proj(hidden_states)
|
| 270 |
+
kv_seq_len = k.shape[-2]
|
| 271 |
+
|
| 272 |
+
# Shape is (batch_size, seq_len, num_heads, head_dim)
|
| 273 |
+
q = q.view(b, l, *self.q_shape).transpose(1, 2)
|
| 274 |
+
k = k.view(b, l, *self.k_shape).transpose(1, 2)
|
| 275 |
+
v = v.view(b, l, *self.v_shape).transpose(1, 2)
|
| 276 |
+
|
| 277 |
+
if past_key_value is not None: # and k.shape[2] > q.shape[2]: # e.g., when generating
|
| 278 |
+
past_key_value.window_size = getattr(self, 'decode_window_size', None) # self.decode_window_size
|
| 279 |
+
if isinstance(past_key_value, Cache): # In Transformers v4.36+ this is a DynamicCache object
|
| 280 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 281 |
+
else:
|
| 282 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
| 283 |
+
|
| 284 |
+
# Apply rotary embeddings and repeat for GQA
|
| 285 |
+
if position_ids is not None and kv_seq_len <= position_ids[0, -1]:
|
| 286 |
+
kv_seq_len = position_ids[0, -1] + 1 # hack for adjusting position ids
|
| 287 |
+
try: # As in Transformers v4.36
|
| 288 |
+
cos, sin = self.rotary_emb(k, seq_len=kv_seq_len)
|
| 289 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
|
| 290 |
+
except TypeError: # As in Transformers v4.39+
|
| 291 |
+
cos, sin = self.rotary_emb(v, position_ids)
|
| 292 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
| 293 |
+
|
| 294 |
+
k = repeat_kv(k, self.num_key_value_groups)
|
| 295 |
+
v = repeat_kv(v, self.num_key_value_groups)
|
| 296 |
+
return q, k, v, kv_seq_len
|
| 297 |
+
|
| 298 |
+
def forward(self,
|
| 299 |
+
hidden_states: torch.Tensor,
|
| 300 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 301 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 302 |
+
past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None, # "legacy" cache approach
|
| 303 |
+
output_attentions: bool = False,
|
| 304 |
+
use_cache: bool = False,
|
| 305 |
+
**kwargs,
|
| 306 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 307 |
+
"""
|
| 308 |
+
Forward pass modified from transformers.models.mistral.modeling_mistral (v4.36)
|
| 309 |
+
- Consistent with HuggingFace Transformers for easy use with their pretrained models
|
| 310 |
+
"""
|
| 311 |
+
b, l, _ = hidden_states.size()
|
| 312 |
+
q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask,
|
| 313 |
+
position_ids, past_key_value)
|
| 314 |
+
if self.base_inference:
|
| 315 |
+
with torch.no_grad():
|
| 316 |
+
# 1. Compute "ground-truth" attention output and weights
|
| 317 |
+
y_true, _, _ = softmax_attention(q, k, v, causal=True)
|
| 318 |
+
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
| 319 |
+
y_true = self.o_proj(y_true)
|
| 320 |
+
attn_weights = (None, None)
|
| 321 |
+
|
| 322 |
+
elif self.train_attention: # Distilling / learning attentions
|
| 323 |
+
# Note for now we assume no padding when distilling; attention masks only enforce causality
|
| 324 |
+
assert output_attentions is True, f'When training feature maps, output_attentions should be True but is {output_attentions}'
|
| 325 |
+
with torch.no_grad():
|
| 326 |
+
# 1. Compute "ground-truth" attention output and weights
|
| 327 |
+
_y_true, attn_true, _ = softmax_attention(q, k, v, causal=True)
|
| 328 |
+
y_true = _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
| 329 |
+
y_true = self.o_proj(y_true)
|
| 330 |
+
|
| 331 |
+
# 2. Compute "predicted" attention (just weights)
|
| 332 |
+
q, k = self.feature_map_q.q_map(q), self.feature_map_k.k_map(k)
|
| 333 |
+
y_pred, attn_pred, _ = quadratic_attention(q, k, v, causal=True)
|
| 334 |
+
attn_weights = ((attn_pred, attn_true), (y_pred, _y_true)) # Save both attention weights so we can supervise.
|
| 335 |
+
|
| 336 |
+
else: # Finetuning
|
| 337 |
+
q, k = self.feature_map_q(q), self.feature_map_k(k)
|
| 338 |
+
# Apply prefill mask
|
| 339 |
+
if attention_mask is not None and q.shape[2] > 1:
|
| 340 |
+
if len(attention_mask.shape) == 4:
|
| 341 |
+
lin_attn_mask = (attention_mask == 0)[:, :1, -1, :l][..., None] # b, 1, k_len, 1
|
| 342 |
+
else:
|
| 343 |
+
lin_attn_mask = attention_mask[:, None, :, None] # b, 1, k_len, 1
|
| 344 |
+
k = k.masked_fill(~lin_attn_mask, 0)
|
| 345 |
+
|
| 346 |
+
if past_key_value is not None: # Initialize states
|
| 347 |
+
if len(past_key_value.kv_states) == self.layer_idx:
|
| 348 |
+
b, h, _, f = k.shape
|
| 349 |
+
past_key_value.kv_states.append(
|
| 350 |
+
torch.zeros(b, h, f, self.head_dim, dtype=q.dtype, device=q.device)
|
| 351 |
+
)
|
| 352 |
+
past_key_value.k_states.append(
|
| 353 |
+
torch.zeros(b, h, 1, f, dtype=q.dtype, device=q.device)
|
| 354 |
+
)
|
| 355 |
+
# Generating
|
| 356 |
+
if q.shape[2] == 1 and kv_seq_len > 1 and past_key_value is not None:
|
| 357 |
+
assert use_cache is True
|
| 358 |
+
kv_state, k_state = past_key_value.update(k, v, self.layer_idx,
|
| 359 |
+
accumulate_in_fp32=self.fp32_attention)
|
| 360 |
+
if self.fp32_attention:
|
| 361 |
+
q = q.float()
|
| 362 |
+
y_true = (torch.einsum('bhlf,bhfd->bhld', q, kv_state.float()) /
|
| 363 |
+
torch.einsum('bhlf,bhlf->bhl', q, k_state.float())[..., None]).to(dtype=k.dtype)
|
| 364 |
+
else:
|
| 365 |
+
y_true = (torch.einsum('bhlf,bhfd->bhld', q, kv_state) /
|
| 366 |
+
torch.einsum('bhlf,bhlf->bhl', q, k_state)[..., None])
|
| 367 |
+
else:
|
| 368 |
+
kv_state = past_key_value.kv_states[self.layer_idx]
|
| 369 |
+
k_state = past_key_value.k_states[self.layer_idx]
|
| 370 |
+
y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps) # Ordinarily the states are ignored
|
| 371 |
+
past_key_value.update(k.detach(), v.detach(), self.layer_idx,
|
| 372 |
+
accumulate_in_fp32=self.fp32_attention)
|
| 373 |
+
# doing some unnecessary recomputation here
|
| 374 |
+
else:
|
| 375 |
+
y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps)
|
| 376 |
+
|
| 377 |
+
# Concatenate heads and apply output projection
|
| 378 |
+
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
| 379 |
+
y_true = self.o_proj(y_true)
|
| 380 |
+
attn_weights = None
|
| 381 |
+
|
| 382 |
+
return y_true, attn_weights, past_key_value
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
class LinearAttentionState(Cache):
|
| 386 |
+
"""
|
| 387 |
+
Handle the KV and K states for linear attention
|
| 388 |
+
- Adopts HF Transformers `past_key_values` convention
|
| 389 |
+
- Inherits from `Cache` class
|
| 390 |
+
- Modified from transformers.cache_utils.DynamicCache (v4.36)
|
| 391 |
+
"""
|
| 392 |
+
def __init__(self) -> None:
|
| 393 |
+
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
| 394 |
+
self._seen_tokens_by_layer: List[int] = []
|
| 395 |
+
self.kv_states: List[torch.Tensor] = []
|
| 396 |
+
self.k_states: List[torch.Tensor] = []
|
| 397 |
+
|
| 398 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
| 399 |
+
"""
|
| 400 |
+
Returns the sequence length of the cached states. A layer index can be optionally passed.
|
| 401 |
+
"""
|
| 402 |
+
if len(self._seen_tokens_by_layer) <= layer_idx: # Initializing kv and k states
|
| 403 |
+
self._seen_tokens_by_layer.append(0)
|
| 404 |
+
return self._seen_tokens_by_layer[layer_idx]
|
| 405 |
+
|
| 406 |
+
def get_max_length(self) -> Optional[int]:
|
| 407 |
+
"""
|
| 408 |
+
Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
|
| 409 |
+
"""
|
| 410 |
+
return None
|
| 411 |
+
|
| 412 |
+
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
|
| 413 |
+
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
| 414 |
+
# Cache without size limit -> all cache is usable
|
| 415 |
+
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
| 416 |
+
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
| 417 |
+
max_length = self.get_max_length()
|
| 418 |
+
previous_seq_length = self.get_seq_length(layer_idx)
|
| 419 |
+
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
| 420 |
+
return max_length - new_seq_length
|
| 421 |
+
return previous_seq_length
|
| 422 |
+
|
| 423 |
+
def update(self, key_states: torch.Tensor, value_states: torch.Tensor,
|
| 424 |
+
layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None,
|
| 425 |
+
accumulate_in_fp32: bool = True, **kwargs: any,
|
| 426 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 427 |
+
|
| 428 |
+
with torch.no_grad ():
|
| 429 |
+
if layer_idx == 0:
|
| 430 |
+
self._seen_tokens += key_states.shape[-2]
|
| 431 |
+
dtype = key_states.dtype
|
| 432 |
+
if accumulate_in_fp32:
|
| 433 |
+
key_states, value_states = key_states.float(), value_states.float()
|
| 434 |
+
|
| 435 |
+
kv_state = torch.einsum('bhlf,bhld->bhfd', key_states, value_states).detach()
|
| 436 |
+
k_state = key_states.sum(dim=-2, keepdim=True).detach() # b, h, 1, f; note the 1
|
| 437 |
+
# Update the cache
|
| 438 |
+
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
| 439 |
+
print('if len(self.k_states) <= layer_idx: # Initializing kv and k states')
|
| 440 |
+
self.kv_states.append(kv_state.to(dtype))
|
| 441 |
+
self.k_states.append(k_state.to(dtype))
|
| 442 |
+
else:
|
| 443 |
+
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype)
|
| 444 |
+
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype)
|
| 445 |
+
self.kv_states[layer_idx] = kv_state
|
| 446 |
+
self.k_states[layer_idx] = k_state
|
| 447 |
+
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
| 448 |
+
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
| 449 |
+
|
| 450 |
+
def to_legacy_cache(self):
|
| 451 |
+
"""Hack, but just return self"""
|
| 452 |
+
return self
|
| 453 |
+
|
| 454 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
| 455 |
+
"""
|
| 456 |
+
Reorders the cache for beam search, given the selected beam indices.
|
| 457 |
+
-> Copied from transformers/src/transformers/cache_utils.py
|
| 458 |
+
"""
|
| 459 |
+
raise NotImplementedError('Reordering cache not implemented for LinearAttentionState')
|
src/model/linear_attention/linear_window_attention_sw.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Subquadratic attention combining sliding window and linear attentions
|
| 3 |
+
- Using "standard" sliding windows
|
| 4 |
+
- Didactically computes outputs with n^2 attention weights for now
|
| 5 |
+
- Copied + adapted from linear_window_attention_tk.py for single-file reference
|
| 6 |
+
|
| 7 |
+
For each layer:
|
| 8 |
+
- We first compute (softmax) attention over sliding windows
|
| 9 |
+
- We then compute standard linear attention to "fill in" the earlier parts
|
| 10 |
+
- We combine to model the entire sequence
|
| 11 |
+
"""
|
| 12 |
+
from typing import List, Tuple, Optional, Callable
|
| 13 |
+
import math
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
from transformers.cache_utils import Cache
|
| 19 |
+
|
| 20 |
+
from .linear_attention import (
|
| 21 |
+
LolcatsLinearAttention, LinearAttentionState,
|
| 22 |
+
softmax_attention
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# ----------------------
|
| 26 |
+
# Sliding window helpers
|
| 27 |
+
# ----------------------
|
| 28 |
+
def get_masks(window_size: int, q_len: int, k_len: int,
|
| 29 |
+
device: torch.device) -> tuple[torch.Tensor]:
|
| 30 |
+
"""
|
| 31 |
+
Return masks for softmax and linear attention terms
|
| 32 |
+
-> 1 is include, 0 is ignore
|
| 33 |
+
"""
|
| 34 |
+
kwargs = {'device': device, 'dtype': int}
|
| 35 |
+
causal_mask = torch.ones((q_len, k_len), **kwargs).tril(k_len - q_len)
|
| 36 |
+
linear_mask = torch.ones((q_len, k_len), **kwargs).tril(k_len - q_len - window_size)
|
| 37 |
+
window_mask = causal_mask - linear_mask
|
| 38 |
+
# Return softmax mask (window), linear attention mask
|
| 39 |
+
# -> shapes broadcast over (b, h, q_len, k_len)
|
| 40 |
+
return window_mask[None, None, ...], linear_mask[None, None, ...]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def hybrid_attention_quadratic(q: torch.Tensor, k: torch.Tensor,
|
| 44 |
+
f_q: torch.Tensor, f_k: torch.Tensor,
|
| 45 |
+
v: torch.Tensor,
|
| 46 |
+
window_factor: torch.Tensor,
|
| 47 |
+
linear_factor: torch.Tensor,
|
| 48 |
+
window_size: int,
|
| 49 |
+
kv_state: torch.Tensor = None,
|
| 50 |
+
k_state: torch.Tensor = None,
|
| 51 |
+
eps: float = 1e-12,
|
| 52 |
+
mask_value: float=-1e8):
|
| 53 |
+
"""
|
| 54 |
+
Hybrid attention combining sliding window and linear attentions
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
mask_window, mask_linear = get_masks(window_size, q.shape[-2], k.shape[-2], q.device)
|
| 58 |
+
|
| 59 |
+
# 1. Sliding window (softmax attention)
|
| 60 |
+
a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
| 61 |
+
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
| 62 |
+
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
| 63 |
+
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
| 64 |
+
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
| 65 |
+
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
| 66 |
+
|
| 67 |
+
# 2. Under window (linear attention)
|
| 68 |
+
a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float())
|
| 69 |
+
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
| 70 |
+
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
| 71 |
+
|
| 72 |
+
# 3. Combine
|
| 73 |
+
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
| 74 |
+
# Allow outputs to also depend on prior kv_state and k_state
|
| 75 |
+
y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v.float())
|
| 76 |
+
if kv_state is not None: # Combine with prior kv_state and k_state
|
| 77 |
+
y += linear_factor * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float())
|
| 78 |
+
sum_ln += linear_factor * torch.einsum(
|
| 79 |
+
'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None]
|
| 80 |
+
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
| 81 |
+
return y, a # attention weights only for the last chunk
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ---------------------
|
| 85 |
+
# Attention layer class
|
| 86 |
+
# ---------------------
|
| 87 |
+
class LolcatsSlidingWindowAttention(LolcatsLinearAttention):
|
| 88 |
+
"""
|
| 89 |
+
Lolcats attention combining sliding window and linear attention
|
| 90 |
+
"""
|
| 91 |
+
def __init__(self,
|
| 92 |
+
window_size: int = 64,
|
| 93 |
+
decode_window_size: int = None,
|
| 94 |
+
affine_attention_factors: bool = False,
|
| 95 |
+
init_window_factor: float = 0,
|
| 96 |
+
train_window_factor: bool = True,
|
| 97 |
+
state_grad_enabled: bool = False,
|
| 98 |
+
**kwargs):
|
| 99 |
+
self.window_size = window_size
|
| 100 |
+
self.decode_window_size = (
|
| 101 |
+
decode_window_size if decode_window_size is not None else window_size
|
| 102 |
+
)
|
| 103 |
+
self.window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1}
|
| 104 |
+
super().__init__(**kwargs)
|
| 105 |
+
self.attention_type = kwargs['attention_type'] # 'hedgehog_llama_window_sw'
|
| 106 |
+
# Determine how we compute attentions
|
| 107 |
+
self.quadratic_attention = hybrid_attention_quadratic
|
| 108 |
+
self.attention_type = kwargs['attention_type'] # 'hedgehog_long_llama_window_sw'
|
| 109 |
+
# Learnable factor for combining attentions
|
| 110 |
+
self.affine_attention_factors = affine_attention_factors
|
| 111 |
+
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
| 112 |
+
if train_window_factor:
|
| 113 |
+
self.window_factors = nn.Parameter(
|
| 114 |
+
init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype))
|
| 115 |
+
else:
|
| 116 |
+
self.register_buffer(
|
| 117 |
+
"window_factors", init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
| 118 |
+
)
|
| 119 |
+
# Whether we use original flash attention 2 inference (use during attention transfer)
|
| 120 |
+
self.base_inference = False
|
| 121 |
+
self.state_grad_enabled = state_grad_enabled
|
| 122 |
+
|
| 123 |
+
def forward(self,
|
| 124 |
+
hidden_states: torch.Tensor,
|
| 125 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 126 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 127 |
+
past_key_value: Optional[Cache] = None,
|
| 128 |
+
output_attentions: bool = False,
|
| 129 |
+
use_cache: bool = False,
|
| 130 |
+
**kwargs,
|
| 131 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 132 |
+
"""
|
| 133 |
+
Forward pass with the option to compute attention weights multiple ways
|
| 134 |
+
if self.train_attention is True
|
| 135 |
+
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
| 136 |
+
"""
|
| 137 |
+
b, l, _ = hidden_states.size()
|
| 138 |
+
q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask,
|
| 139 |
+
position_ids, past_key_value)
|
| 140 |
+
f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) # Have to do after repeat for grouped-query attn if we use same fmap
|
| 141 |
+
|
| 142 |
+
if self.train_attention:
|
| 143 |
+
# 1. Compute "ground-truth" attention output and weights
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
_y_true, a_true = softmax_attention(q, k, v)[:2]
|
| 146 |
+
y_true = _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
| 147 |
+
y_true = self.o_proj(y_true)
|
| 148 |
+
|
| 149 |
+
# 2. Compute "predicted" attention outputs
|
| 150 |
+
# compute attn weights under sliding window
|
| 151 |
+
window_factors = F.sigmoid(self.window_factors)
|
| 152 |
+
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
| 153 |
+
y_pred, a_pred = self.quadratic_attention(q, k, f_q, f_k, v,
|
| 154 |
+
window_factors, linear_factors,
|
| 155 |
+
window_size=self.window_size)
|
| 156 |
+
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
|
| 157 |
+
else:
|
| 158 |
+
attn_weights = None
|
| 159 |
+
# attention_mask = None # For now this is always True
|
| 160 |
+
if past_key_value is None: # Regular training
|
| 161 |
+
window_factors = F.sigmoid(self.window_factors)
|
| 162 |
+
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
| 163 |
+
y_true, a_pred = self.quadratic_attention(q, k, f_q, f_k, v,
|
| 164 |
+
window_factors, linear_factors,
|
| 165 |
+
window_size=self.window_size)
|
| 166 |
+
attn_weights = a_pred
|
| 167 |
+
else:
|
| 168 |
+
past_key_value.window_size = self.decode_window_size
|
| 169 |
+
if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating
|
| 170 |
+
assert use_cache is True
|
| 171 |
+
_kv = past_key_value.update_for_decoding(k, v, self.layer_idx,
|
| 172 |
+
self.feature_map_k,
|
| 173 |
+
dtype=q.dtype)
|
| 174 |
+
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
| 175 |
+
|
| 176 |
+
# Sliding window + linear attention decode
|
| 177 |
+
window_factors = F.sigmoid(self.window_factors)
|
| 178 |
+
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
| 179 |
+
|
| 180 |
+
# Softmax attention terms
|
| 181 |
+
a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5)
|
| 182 |
+
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
| 183 |
+
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
| 184 |
+
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
| 185 |
+
|
| 186 |
+
# Combine with linear attention terms
|
| 187 |
+
y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float())
|
| 188 |
+
+ linear_factors * torch.einsum('bhlf,bhfd->bhld', f_q.float(), f_kv_state.float()))
|
| 189 |
+
sum_ln = linear_factors * torch.einsum(
|
| 190 |
+
'bhlf,bhnf->bhl', f_q.float(), f_k_state.float())[..., None]
|
| 191 |
+
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
| 192 |
+
|
| 193 |
+
else: # Stateful training
|
| 194 |
+
try:
|
| 195 |
+
kv_state = past_key_value.kv_states[self.layer_idx]
|
| 196 |
+
k_state = past_key_value.k_states[self.layer_idx]
|
| 197 |
+
except IndexError:
|
| 198 |
+
kv_state, k_state = None, None
|
| 199 |
+
window_factors = F.sigmoid(self.window_factors)
|
| 200 |
+
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
| 201 |
+
y_true, _ = self.quadratic_attention(q, k, f_q, f_k, v,
|
| 202 |
+
window_factors, linear_factors,
|
| 203 |
+
window_size=self.window_size,
|
| 204 |
+
kv_state=kv_state,
|
| 205 |
+
k_state=k_state)
|
| 206 |
+
# Save and update KV cache and states
|
| 207 |
+
# past_key_value.update(k, v.detach(), self.layer_idx,
|
| 208 |
+
# fmap_key_states=f_k.detach(),
|
| 209 |
+
# accumulate_in_fp32=True)
|
| 210 |
+
past_key_value.update(k, v, self.layer_idx,
|
| 211 |
+
fmap_key_states=f_k,
|
| 212 |
+
accumulate_in_fp32=True)
|
| 213 |
+
# Concatenate heads and apply output projection
|
| 214 |
+
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
| 215 |
+
y_true = self.o_proj(y_true)
|
| 216 |
+
return y_true, attn_weights, past_key_value
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class LinearAttentionSlidingWindowCache(LinearAttentionState):
|
| 220 |
+
"""
|
| 221 |
+
Class for `past_key_values`
|
| 222 |
+
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
| 223 |
+
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
| 224 |
+
"""
|
| 225 |
+
def __init__(self, window_size: int = 64) -> None:
|
| 226 |
+
super().__init__()
|
| 227 |
+
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
| 228 |
+
self._seen_tokens_by_layer: List[int] = []
|
| 229 |
+
self.kv_states: List[torch.Tensor] = []
|
| 230 |
+
self.k_states: List[torch.Tensor] = []
|
| 231 |
+
|
| 232 |
+
# Account for sliding windows
|
| 233 |
+
self.decode_kv_states: List[torch.Tensor] = []
|
| 234 |
+
self.decode_k_states: List[torch.Tensor] = []
|
| 235 |
+
self.k_cache: List[torch.Tensor] = []
|
| 236 |
+
self.v_cache: List[torch.Tensor] = []
|
| 237 |
+
self.window_size = window_size
|
| 238 |
+
|
| 239 |
+
def update(self, key_states: torch.Tensor, value_states: torch.Tensor,
|
| 240 |
+
layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None,
|
| 241 |
+
accumulate_in_fp32: bool = False,
|
| 242 |
+
fmap_key_states: torch.Tensor = None, # should not be None
|
| 243 |
+
grad_enabled: bool = False,
|
| 244 |
+
**kwargs: any,
|
| 245 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 246 |
+
"""
|
| 247 |
+
Update KV, K states; and KV cache during training
|
| 248 |
+
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
| 249 |
+
up to sliding window terms
|
| 250 |
+
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
| 251 |
+
up to end of sequence
|
| 252 |
+
- Likewise for `self.decode_k_states` and `self.k_states`
|
| 253 |
+
"""
|
| 254 |
+
with torch.set_grad_enabled(grad_enabled):
|
| 255 |
+
if layer_idx == 0:
|
| 256 |
+
self._seen_tokens += key_states.shape[-2]
|
| 257 |
+
|
| 258 |
+
dtype = key_states.dtype
|
| 259 |
+
if accumulate_in_fp32:
|
| 260 |
+
# key_states = key_states.float()
|
| 261 |
+
fmap_key_states = fmap_key_states.float()
|
| 262 |
+
value_states = value_states.float()
|
| 263 |
+
|
| 264 |
+
# Decoding KV state (KV terms up to last window_size)
|
| 265 |
+
decode_kv_state = torch.einsum(
|
| 266 |
+
'bhlf,bhld->bhfd', fmap_key_states[:, :, :-self.window_size], value_states[:, :, :-self.window_size]
|
| 267 |
+
)
|
| 268 |
+
# KV state
|
| 269 |
+
kv_state = decode_kv_state + torch.einsum(
|
| 270 |
+
'bhlf,bhld->bhfd', fmap_key_states[:, :, -self.window_size:], value_states[:, :, -self.window_size:]
|
| 271 |
+
)
|
| 272 |
+
# shape is b, h, 1, f; note the 1
|
| 273 |
+
decode_k_state = fmap_key_states[:, :, :-self.window_size].sum(dim=-2, keepdim=True)
|
| 274 |
+
k_state = (decode_k_state + fmap_key_states[:, :, -self.window_size:].sum(dim=-2, keepdim=True))
|
| 275 |
+
|
| 276 |
+
# Update the cache
|
| 277 |
+
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
| 278 |
+
self.kv_states.append(kv_state.to(dtype))
|
| 279 |
+
self.k_states.append(k_state.to(dtype))
|
| 280 |
+
|
| 281 |
+
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
| 282 |
+
self.decode_k_states.append(decode_k_state.to(dtype))
|
| 283 |
+
|
| 284 |
+
self.k_cache.append(key_states[:, :, -self.window_size:, :])
|
| 285 |
+
self.v_cache.append(value_states[:, :, -self.window_size:, :].to(dtype))
|
| 286 |
+
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
| 287 |
+
else:
|
| 288 |
+
# Update kv and k states recurrently
|
| 289 |
+
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype)
|
| 290 |
+
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype)
|
| 291 |
+
self.kv_states[layer_idx] = kv_state
|
| 292 |
+
self.k_states[layer_idx] = k_state
|
| 293 |
+
|
| 294 |
+
decode_kv_state = (self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
| 295 |
+
+ decode_kv_state).to(dtype)
|
| 296 |
+
decode_k_state = (self.decode_k_states[layer_idx].to(kv_state.dtype)
|
| 297 |
+
+ decode_k_state).to(dtype)
|
| 298 |
+
self.decode_kv_states[layer_idx] = decode_kv_state
|
| 299 |
+
self.decode_k_states[layer_idx] = decode_k_state
|
| 300 |
+
|
| 301 |
+
self.k_cache[layer_idx] = key_states[:, :, -self.window_size:, :]
|
| 302 |
+
self.v_cache[layer_idx] = value_states[:, :, -self.window_size:, :]
|
| 303 |
+
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
| 304 |
+
|
| 305 |
+
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
| 306 |
+
|
| 307 |
+
def update_for_decoding(self, keys: torch.Tensor, values: torch.Tensor,
|
| 308 |
+
layer_idx: int, feature_map_k: Callable, dtype: torch.dtype):
|
| 309 |
+
"""
|
| 310 |
+
Update the decoding KV and K states, and KV cache, during decodeing
|
| 311 |
+
"""
|
| 312 |
+
with torch.no_grad():
|
| 313 |
+
k_cache = self.k_cache[layer_idx]
|
| 314 |
+
v_cache = self.v_cache[layer_idx]
|
| 315 |
+
|
| 316 |
+
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
| 317 |
+
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
| 318 |
+
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
| 319 |
+
else:
|
| 320 |
+
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
|
| 321 |
+
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
|
| 322 |
+
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
|
| 323 |
+
# else:
|
| 324 |
+
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
|
| 325 |
+
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
|
| 326 |
+
k_state = feature_map_k(k_cache[:, :, :1, :])
|
| 327 |
+
v_state = v_cache[:, :, :1, :]
|
| 328 |
+
kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d
|
| 329 |
+
self.decode_kv_states[layer_idx] += kv_state
|
| 330 |
+
self.decode_k_states[layer_idx] += k_state
|
| 331 |
+
|
| 332 |
+
self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2)
|
| 333 |
+
self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], values], dim=-2)
|
| 334 |
+
|
| 335 |
+
if layer_idx == 0:
|
| 336 |
+
self._seen_tokens += keys.shape[-2]
|
| 337 |
+
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
| 338 |
+
return (self.k_cache[layer_idx], self.v_cache[layer_idx],
|
| 339 |
+
self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx])
|
src/model/linear_attention/linear_window_attention_sw_linear.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Subquadratic attention combining sliding window and linear attentions
|
| 3 |
+
- Using "standard" sliding windows
|
| 4 |
+
- Didactically computes outputs with n^2 attention weights for now
|
| 5 |
+
- Copied + adapted from linear_window_attention_tk.py for single-file reference
|
| 6 |
+
|
| 7 |
+
For each layer:
|
| 8 |
+
- We first compute (softmax) attention over sliding windows
|
| 9 |
+
- We then compute standard linear attention to "fill in" the earlier parts
|
| 10 |
+
- We combine to model the entire sequence
|
| 11 |
+
"""
|
| 12 |
+
from typing import List, Tuple, Optional, Callable
|
| 13 |
+
import math
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
from transformers.cache_utils import Cache
|
| 19 |
+
try:
|
| 20 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 21 |
+
except ModuleNotFoundError:
|
| 22 |
+
_flash_attention_forward = None # Transformers v4.36
|
| 23 |
+
|
| 24 |
+
# Causal linear attention dot product CUDA kernel from fast-transformers
|
| 25 |
+
from csrc import causal_dot_product
|
| 26 |
+
|
| 27 |
+
from src.model.rotary import apply_rotary_pos_emb
|
| 28 |
+
from .linear_attention import (
|
| 29 |
+
LolcatsLinearAttention, LinearAttentionState,
|
| 30 |
+
softmax_attention
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# ----------------------
|
| 34 |
+
# Sliding window helpers
|
| 35 |
+
# ----------------------
|
| 36 |
+
def get_masks(window_size: int, q_len: int, k_len: int,
|
| 37 |
+
device: torch.device) -> tuple[torch.Tensor]:
|
| 38 |
+
"""
|
| 39 |
+
Return masks for softmax and linear attention terms
|
| 40 |
+
-> 1 is include, 0 is ignore
|
| 41 |
+
"""
|
| 42 |
+
kwargs = {'device': device, 'dtype': int}
|
| 43 |
+
causal_mask = torch.ones((q_len, k_len), **kwargs).tril(max(k_len - q_len, 0))
|
| 44 |
+
linear_mask = torch.ones((q_len, k_len), **kwargs).tril(max(k_len - q_len, 0) - window_size)
|
| 45 |
+
window_mask = causal_mask - linear_mask
|
| 46 |
+
# Return softmax mask (window), linear attention mask
|
| 47 |
+
# -> shapes broadcast over (b, h, q_len, k_len)
|
| 48 |
+
return window_mask[None, None, ...], linear_mask[None, None, ...]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def hybrid_attention_quadratic(q: torch.Tensor, k: torch.Tensor,
|
| 52 |
+
f_q: torch.Tensor, f_k: torch.Tensor,
|
| 53 |
+
v: torch.Tensor,
|
| 54 |
+
window_factor: torch.Tensor,
|
| 55 |
+
linear_factor: torch.Tensor,
|
| 56 |
+
window_size: int,
|
| 57 |
+
kv_state: torch.Tensor = None,
|
| 58 |
+
k_state: torch.Tensor = None,
|
| 59 |
+
eps: float = 1e-12,
|
| 60 |
+
mask_value: float=-1e8):
|
| 61 |
+
"""
|
| 62 |
+
Hybrid attention combining sliding window and linear attentions
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
mask_window, mask_linear = get_masks(window_size, q.shape[-2], k.shape[-2], q.device)
|
| 66 |
+
|
| 67 |
+
# 1. Sliding window (softmax attention)
|
| 68 |
+
a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
| 69 |
+
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
| 70 |
+
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
| 71 |
+
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
| 72 |
+
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
| 73 |
+
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
| 74 |
+
|
| 75 |
+
# 2. Under window (linear attention)
|
| 76 |
+
a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float())
|
| 77 |
+
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
| 78 |
+
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
| 79 |
+
|
| 80 |
+
# 3. Combine
|
| 81 |
+
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
| 82 |
+
# Allow outputs to also depend on prior kv_state and k_state
|
| 83 |
+
y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v.float())
|
| 84 |
+
if kv_state is not None: # Combine with prior kv_state and k_state
|
| 85 |
+
y += linear_factor * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float())
|
| 86 |
+
sum_ln += linear_factor * torch.einsum(
|
| 87 |
+
'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None]
|
| 88 |
+
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
| 89 |
+
return y, a # attention weights only for the last chunk
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# ------------------------------
|
| 93 |
+
# Hybrid window attention linear
|
| 94 |
+
# ------------------------------
|
| 95 |
+
def under_window_linear_attention(f_q: torch.Tensor, f_k: torch.Tensor, v: torch.Tensor,
|
| 96 |
+
window_size: int, linear_factor: float, eps: float=1e-12):
|
| 97 |
+
"""Compute hybrid window attention dot product with linear complexity in q_len"""
|
| 98 |
+
dtype = f_q.dtype
|
| 99 |
+
w = window_size
|
| 100 |
+
f_k = F.pad(f_k, (0, 0, w, 0), value=0)[:, :, :-w, :]
|
| 101 |
+
v = F.pad(v, (0, 0, w, 0), value=0)[:, :, :-w, :]
|
| 102 |
+
qkv = linear_factor * causal_dot_product(f_q.contiguous().to(dtype=torch.float32),
|
| 103 |
+
f_k.contiguous().to(dtype=torch.float32),
|
| 104 |
+
v.contiguous().to(dtype=torch.float32)).to(dtype=dtype)
|
| 105 |
+
sum_f_k = f_k.float().cumsum(dim=2).to(dtype=dtype)
|
| 106 |
+
sum_qk = linear_factor * torch.einsum("bhld,bhld->bhl", f_q, sum_f_k)[..., None]
|
| 107 |
+
sum_qk[sum_qk == 0] += eps
|
| 108 |
+
return qkv, sum_qk
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def sliding_window_softmax_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
| 112 |
+
window_size: int, window_factor: float, mask_value: float=-1e8):
|
| 113 |
+
"""
|
| 114 |
+
Compute sliding window softmax attention without materializing
|
| 115 |
+
O(seq_len^2) attention weights
|
| 116 |
+
"""
|
| 117 |
+
d = q.shape[-1]
|
| 118 |
+
# Compute windows for keys
|
| 119 |
+
window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1}
|
| 120 |
+
k = F.pad(k, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
|
| 121 |
+
v = F.pad(v, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
|
| 122 |
+
|
| 123 |
+
# Compute windowed_softmax(qk); causal in its construction
|
| 124 |
+
a_sm = torch.einsum('bhld,bhldw->bhlw', q, k) * (d ** -0.5)
|
| 125 |
+
a_sm[a_sm == 0] = -torch.finfo(q.dtype).max # heuristic for zeroing out padding above
|
| 126 |
+
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
| 127 |
+
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
| 128 |
+
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
| 129 |
+
return torch.einsum('bhlw,bhldw->bhld', a_sm, v), sum_sm
|
| 130 |
+
# return torch.einsum('bhlw,bhldw->bhld', torch.softmax(qk, dim=-1), v)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def hybrid_attention_linear(q: torch.Tensor, k: torch.Tensor,
|
| 134 |
+
f_q: torch.Tensor, f_k: torch.Tensor,
|
| 135 |
+
v: torch.Tensor,
|
| 136 |
+
window_factor: torch.Tensor = None,
|
| 137 |
+
linear_factor: torch.Tensor = None,
|
| 138 |
+
window_size: int = 64,
|
| 139 |
+
kv_state: torch.Tensor = None,
|
| 140 |
+
k_state: torch.Tensor = None,
|
| 141 |
+
eps: float = 1e-12,
|
| 142 |
+
mask_value: float=-1e8):
|
| 143 |
+
"""
|
| 144 |
+
Alternative hybrid attention combining sliding window and linear attentions
|
| 145 |
+
-> Uses O(n) memory if n is sequence length by padding and unfolding windows
|
| 146 |
+
"""
|
| 147 |
+
window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1}
|
| 148 |
+
# 1. Sliding window (softmax attention)
|
| 149 |
+
with torch.no_grad():
|
| 150 |
+
qkv_sm, sum_qk_sm = sliding_window_softmax_attention(q, k, v, window_size, window_factor, mask_value)
|
| 151 |
+
|
| 152 |
+
# 2. Under window (linear attention)
|
| 153 |
+
qkv_ln, sum_qk_ln = under_window_linear_attention(f_q, f_k, v, window_size, linear_factor, eps)
|
| 154 |
+
|
| 155 |
+
# 3. Combine
|
| 156 |
+
y = (qkv_sm + qkv_ln) / (sum_qk_sm + sum_qk_ln)
|
| 157 |
+
return y, None
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# ---------------------
|
| 161 |
+
# Attention layer class
|
| 162 |
+
# ---------------------
|
| 163 |
+
class LolcatsLinearSlidingWindowAttention(LolcatsLinearAttention):
|
| 164 |
+
"""
|
| 165 |
+
Lolcats attention combining sliding window and linear attention
|
| 166 |
+
"""
|
| 167 |
+
def __init__(self,
|
| 168 |
+
window_size: int = 64,
|
| 169 |
+
decode_window_size: int = None,
|
| 170 |
+
affine_attention_factors: bool = False,
|
| 171 |
+
init_window_factor: float = 0,
|
| 172 |
+
train_window_factor: bool = True,
|
| 173 |
+
state_grad_enabled: bool = False,
|
| 174 |
+
**kwargs):
|
| 175 |
+
self.window_size = window_size
|
| 176 |
+
self.decode_window_size = (
|
| 177 |
+
decode_window_size if decode_window_size is not None else window_size
|
| 178 |
+
)
|
| 179 |
+
self.window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1}
|
| 180 |
+
super().__init__(**kwargs)
|
| 181 |
+
# Determine how we compute attentions
|
| 182 |
+
self.linear_attention = hybrid_attention_linear
|
| 183 |
+
self.attention_type = 'lolcats_llama_window_sw'
|
| 184 |
+
# Learnable factor for combining attentions
|
| 185 |
+
self.affine_attention_factors = affine_attention_factors
|
| 186 |
+
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
| 187 |
+
if train_window_factor:
|
| 188 |
+
self.window_factors = nn.Parameter(
|
| 189 |
+
init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype))
|
| 190 |
+
else:
|
| 191 |
+
self.register_buffer(
|
| 192 |
+
"window_factors", init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
| 193 |
+
)
|
| 194 |
+
# Whether we use original flash attention 2 inference (use during attention transfer)
|
| 195 |
+
self.base_inference = False
|
| 196 |
+
self.state_grad_enabled = state_grad_enabled
|
| 197 |
+
|
| 198 |
+
def forward(self,
|
| 199 |
+
hidden_states: torch.Tensor,
|
| 200 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 201 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 202 |
+
past_key_value: Optional[Cache] = None,
|
| 203 |
+
output_attentions: bool = False,
|
| 204 |
+
use_cache: bool = False,
|
| 205 |
+
**kwargs,
|
| 206 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 207 |
+
"""
|
| 208 |
+
Forward pass with the option to compute attention weights multiple ways
|
| 209 |
+
if self.train_attention is True
|
| 210 |
+
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
| 211 |
+
"""
|
| 212 |
+
b, l, _ = hidden_states.size()
|
| 213 |
+
|
| 214 |
+
if self.train_attention and self.base_inference:
|
| 215 |
+
with torch.no_grad():
|
| 216 |
+
_y_true = flash_attention_2(self, # self.base_attn,
|
| 217 |
+
hidden_states=hidden_states,
|
| 218 |
+
attention_mask=None,
|
| 219 |
+
position_ids=position_ids,
|
| 220 |
+
past_key_value=None,
|
| 221 |
+
output_attentions=False,
|
| 222 |
+
use_cache=False)[0]
|
| 223 |
+
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
|
| 224 |
+
y_true = _y_true.reshape(b, l, -1).contiguous()
|
| 225 |
+
y_true = self.o_proj(y_true)
|
| 226 |
+
# layer_io = (hidden_states, _y_true) # hack
|
| 227 |
+
layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
|
| 228 |
+
return y_true, layer_io, None
|
| 229 |
+
|
| 230 |
+
else:
|
| 231 |
+
q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask,
|
| 232 |
+
position_ids, past_key_value)
|
| 233 |
+
f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) # Have to do after repeat for grouped-query attn if we use same fmap
|
| 234 |
+
|
| 235 |
+
attn_weights = None
|
| 236 |
+
# attention_mask = None # For now this is always True
|
| 237 |
+
if past_key_value is None: # Regular training
|
| 238 |
+
window_factors = F.sigmoid(self.window_factors)
|
| 239 |
+
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
| 240 |
+
y_true, a_pred = self.linear_attention(q, k, f_q, f_k, v,
|
| 241 |
+
window_factors, linear_factors,
|
| 242 |
+
window_size=self.window_size)
|
| 243 |
+
attn_weights = a_pred
|
| 244 |
+
else:
|
| 245 |
+
past_key_value.window_size = self.decode_window_size
|
| 246 |
+
if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating
|
| 247 |
+
assert use_cache is True
|
| 248 |
+
_kv = past_key_value.update_for_decoding(k, v, self.layer_idx,
|
| 249 |
+
self.feature_map_k,
|
| 250 |
+
dtype=q.dtype)
|
| 251 |
+
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
| 252 |
+
|
| 253 |
+
# Sliding window + linear attention decode
|
| 254 |
+
window_factors = F.sigmoid(self.window_factors)
|
| 255 |
+
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
| 256 |
+
|
| 257 |
+
# Softmax attention terms
|
| 258 |
+
a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5)
|
| 259 |
+
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
| 260 |
+
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
| 261 |
+
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
| 262 |
+
|
| 263 |
+
# Combine with linear attention terms
|
| 264 |
+
y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float())
|
| 265 |
+
+ linear_factors * torch.einsum('bhlf,bhfd->bhld', f_q.float(), f_kv_state.float()))
|
| 266 |
+
sum_ln = linear_factors * torch.einsum(
|
| 267 |
+
'bhlf,bhnf->bhl', f_q.float(), f_k_state.float())[..., None]
|
| 268 |
+
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
| 269 |
+
|
| 270 |
+
else: # Stateful training
|
| 271 |
+
try:
|
| 272 |
+
kv_state = past_key_value.kv_states[self.layer_idx]
|
| 273 |
+
k_state = past_key_value.k_states[self.layer_idx]
|
| 274 |
+
except IndexError:
|
| 275 |
+
kv_state, k_state = None, None
|
| 276 |
+
window_factors = F.sigmoid(self.window_factors)
|
| 277 |
+
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
| 278 |
+
y_true, _ = self.linear_attention(q, k, f_q, f_k, v,
|
| 279 |
+
window_factors, linear_factors,
|
| 280 |
+
window_size=self.window_size,
|
| 281 |
+
kv_state=kv_state,
|
| 282 |
+
k_state=k_state)
|
| 283 |
+
# Save and update KV cache and states
|
| 284 |
+
# past_key_value.update(k, v.detach(), self.layer_idx,
|
| 285 |
+
# fmap_key_states=f_k.detach(),
|
| 286 |
+
# accumulate_in_fp32=True)
|
| 287 |
+
past_key_value.update(k, v, self.layer_idx,
|
| 288 |
+
fmap_key_states=f_k,
|
| 289 |
+
accumulate_in_fp32=True)
|
| 290 |
+
# Concatenate heads and apply output projection
|
| 291 |
+
_y_true = y_true.transpose(1, 2).contiguous()
|
| 292 |
+
y_true = self.o_proj(_y_true.view(b, l, self.hidden_size))
|
| 293 |
+
|
| 294 |
+
if self.train_attention:
|
| 295 |
+
attn_weights = _y_true # flash_attn outputs are shape (b, l, h, d)
|
| 296 |
+
return y_true, attn_weights, past_key_value
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class LinearAttentionSlidingWindowCache(LinearAttentionState):
|
| 300 |
+
"""
|
| 301 |
+
Class for `past_key_values`
|
| 302 |
+
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
| 303 |
+
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
| 304 |
+
"""
|
| 305 |
+
def __init__(self, window_size: int = 64) -> None:
|
| 306 |
+
super().__init__()
|
| 307 |
+
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
| 308 |
+
self._seen_tokens_by_layer: List[int] = []
|
| 309 |
+
self.kv_states: List[torch.Tensor] = []
|
| 310 |
+
self.k_states: List[torch.Tensor] = []
|
| 311 |
+
|
| 312 |
+
# Account for sliding windows
|
| 313 |
+
self.decode_kv_states: List[torch.Tensor] = []
|
| 314 |
+
self.decode_k_states: List[torch.Tensor] = []
|
| 315 |
+
self.k_cache: List[torch.Tensor] = []
|
| 316 |
+
self.v_cache: List[torch.Tensor] = []
|
| 317 |
+
self.window_size = window_size
|
| 318 |
+
|
| 319 |
+
def update(self, key_states: torch.Tensor, value_states: torch.Tensor,
|
| 320 |
+
layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None,
|
| 321 |
+
accumulate_in_fp32: bool = False,
|
| 322 |
+
fmap_key_states: torch.Tensor = None, # should not be None
|
| 323 |
+
grad_enabled: bool = False,
|
| 324 |
+
**kwargs: any,
|
| 325 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 326 |
+
"""
|
| 327 |
+
Update KV, K states; and KV cache during training
|
| 328 |
+
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
| 329 |
+
up to sliding window terms
|
| 330 |
+
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
| 331 |
+
up to end of sequence
|
| 332 |
+
- Likewise for `self.decode_k_states` and `self.k_states`
|
| 333 |
+
"""
|
| 334 |
+
with torch.set_grad_enabled(grad_enabled):
|
| 335 |
+
if layer_idx == 0:
|
| 336 |
+
self._seen_tokens += key_states.shape[-2]
|
| 337 |
+
|
| 338 |
+
dtype = key_states.dtype
|
| 339 |
+
if accumulate_in_fp32:
|
| 340 |
+
# key_states = key_states.float()
|
| 341 |
+
fmap_key_states = fmap_key_states.float()
|
| 342 |
+
value_states = value_states.float()
|
| 343 |
+
|
| 344 |
+
# Decoding KV state (KV terms up to last window_size)
|
| 345 |
+
decode_kv_state = torch.einsum(
|
| 346 |
+
'bhlf,bhld->bhfd', fmap_key_states[:, :, :-self.window_size], value_states[:, :, :-self.window_size]
|
| 347 |
+
)
|
| 348 |
+
# KV state
|
| 349 |
+
kv_state = decode_kv_state + torch.einsum(
|
| 350 |
+
'bhlf,bhld->bhfd', fmap_key_states[:, :, -self.window_size:], value_states[:, :, -self.window_size:]
|
| 351 |
+
)
|
| 352 |
+
# shape is b, h, 1, f; note the 1
|
| 353 |
+
decode_k_state = fmap_key_states[:, :, :-self.window_size].sum(dim=-2, keepdim=True)
|
| 354 |
+
k_state = (decode_k_state + fmap_key_states[:, :, -self.window_size:].sum(dim=-2, keepdim=True))
|
| 355 |
+
|
| 356 |
+
# Update the cache
|
| 357 |
+
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
| 358 |
+
self.kv_states.append(kv_state.to(dtype))
|
| 359 |
+
self.k_states.append(k_state.to(dtype))
|
| 360 |
+
|
| 361 |
+
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
| 362 |
+
self.decode_k_states.append(decode_k_state.to(dtype))
|
| 363 |
+
|
| 364 |
+
self.k_cache.append(key_states[:, :, -self.window_size:, :])
|
| 365 |
+
self.v_cache.append(value_states[:, :, -self.window_size:, :].to(dtype))
|
| 366 |
+
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
| 367 |
+
else:
|
| 368 |
+
# Update kv and k states recurrently
|
| 369 |
+
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype)
|
| 370 |
+
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype)
|
| 371 |
+
self.kv_states[layer_idx] = kv_state
|
| 372 |
+
self.k_states[layer_idx] = k_state
|
| 373 |
+
|
| 374 |
+
decode_kv_state = (self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
| 375 |
+
+ decode_kv_state).to(dtype)
|
| 376 |
+
decode_k_state = (self.decode_k_states[layer_idx].to(kv_state.dtype)
|
| 377 |
+
+ decode_k_state).to(dtype)
|
| 378 |
+
self.decode_kv_states[layer_idx] = decode_kv_state
|
| 379 |
+
self.decode_k_states[layer_idx] = decode_k_state
|
| 380 |
+
|
| 381 |
+
self.k_cache[layer_idx] = key_states[:, :, -self.window_size:, :]
|
| 382 |
+
self.v_cache[layer_idx] = value_states[:, :, -self.window_size:, :]
|
| 383 |
+
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
| 384 |
+
|
| 385 |
+
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
| 386 |
+
|
| 387 |
+
def update_for_decoding(self, keys: torch.Tensor, values: torch.Tensor,
|
| 388 |
+
layer_idx: int, feature_map_k: Callable, dtype: torch.dtype):
|
| 389 |
+
"""
|
| 390 |
+
Update the decoding KV and K states, and KV cache, during decodeing
|
| 391 |
+
"""
|
| 392 |
+
with torch.no_grad():
|
| 393 |
+
k_cache = self.k_cache[layer_idx]
|
| 394 |
+
v_cache = self.v_cache[layer_idx]
|
| 395 |
+
|
| 396 |
+
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
| 397 |
+
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
| 398 |
+
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
| 399 |
+
else:
|
| 400 |
+
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
|
| 401 |
+
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
|
| 402 |
+
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
|
| 403 |
+
# else:
|
| 404 |
+
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
|
| 405 |
+
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
|
| 406 |
+
k_state = feature_map_k(k_cache[:, :, :1, :])
|
| 407 |
+
v_state = v_cache[:, :, :1, :]
|
| 408 |
+
kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d
|
| 409 |
+
self.decode_kv_states[layer_idx] += kv_state
|
| 410 |
+
self.decode_k_states[layer_idx] += k_state
|
| 411 |
+
|
| 412 |
+
self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2)
|
| 413 |
+
self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], values], dim=-2)
|
| 414 |
+
|
| 415 |
+
if layer_idx == 0:
|
| 416 |
+
self._seen_tokens += keys.shape[-2]
|
| 417 |
+
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
| 418 |
+
return (self.k_cache[layer_idx], self.v_cache[layer_idx],
|
| 419 |
+
self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx])
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
# -----------------
|
| 423 |
+
# Flash Attention 2
|
| 424 |
+
# -----------------
|
| 425 |
+
|
| 426 |
+
def flash_attention_2(self,
|
| 427 |
+
hidden_states: torch.Tensor,
|
| 428 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 429 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 430 |
+
past_key_value: Optional[Cache] = None,
|
| 431 |
+
output_attentions: bool = False,
|
| 432 |
+
use_cache: bool = False,
|
| 433 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 434 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 435 |
+
"""
|
| 436 |
+
Wrapper for LlamaFlashAttention2
|
| 437 |
+
Copied and modified from HF Transformers v4.36 and v4.43 implementations
|
| 438 |
+
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
|
| 439 |
+
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
|
| 440 |
+
"""
|
| 441 |
+
output_attentions = False
|
| 442 |
+
|
| 443 |
+
bsz, q_len, _ = hidden_states.size()
|
| 444 |
+
|
| 445 |
+
query_states = self.q_proj(hidden_states)
|
| 446 |
+
key_states = self.k_proj(hidden_states)
|
| 447 |
+
value_states = self.v_proj(hidden_states)
|
| 448 |
+
|
| 449 |
+
# Flash attention requires the input to have the shape
|
| 450 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
| 451 |
+
# therefore we just need to keep the original shape
|
| 452 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 453 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 454 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 455 |
+
|
| 456 |
+
try: # As in Transformers v4.36
|
| 457 |
+
kv_seq_len = key_states.shape[-2]
|
| 458 |
+
if past_key_value is not None:
|
| 459 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 460 |
+
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
|
| 461 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 462 |
+
except: # As in Transformers v4.39
|
| 463 |
+
cos, sin = self.rotary_emb(key_states, position_ids)
|
| 464 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 465 |
+
|
| 466 |
+
if past_key_value is not None:
|
| 467 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 468 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 469 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 470 |
+
|
| 471 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
| 472 |
+
# to be able to avoid many of these transpose/reshape/view.
|
| 473 |
+
query_states = query_states.transpose(1, 2)
|
| 474 |
+
key_states = key_states.transpose(1, 2)
|
| 475 |
+
value_states = value_states.transpose(1, 2)
|
| 476 |
+
|
| 477 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
| 478 |
+
|
| 479 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 480 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 481 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
| 482 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
| 483 |
+
# in fp32. (LlamaRMSNorm handles it correctly)
|
| 484 |
+
|
| 485 |
+
input_dtype = query_states.dtype
|
| 486 |
+
if input_dtype == torch.float32:
|
| 487 |
+
if torch.is_autocast_enabled():
|
| 488 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 489 |
+
# Handle the case where the model is quantized
|
| 490 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 491 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 492 |
+
else:
|
| 493 |
+
target_dtype = self.q_proj.weight.dtype
|
| 494 |
+
|
| 495 |
+
logger.warning_once(
|
| 496 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
| 497 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
| 498 |
+
f" {target_dtype}."
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
query_states = query_states.to(target_dtype)
|
| 502 |
+
key_states = key_states.to(target_dtype)
|
| 503 |
+
value_states = value_states.to(target_dtype)
|
| 504 |
+
|
| 505 |
+
if getattr(self, '_flash_attention_forward', False):
|
| 506 |
+
attn_output = self._flash_attention_forward(
|
| 507 |
+
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate,
|
| 508 |
+
is_causal=True,
|
| 509 |
+
)
|
| 510 |
+
else:
|
| 511 |
+
attn_output = _flash_attention_forward(
|
| 512 |
+
query_states,
|
| 513 |
+
key_states,
|
| 514 |
+
value_states,
|
| 515 |
+
attention_mask,
|
| 516 |
+
q_len,
|
| 517 |
+
dropout=0, # dropout_rate,
|
| 518 |
+
sliding_window=getattr(self, "sliding_window", None),
|
| 519 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
| 520 |
+
is_causal=True,
|
| 521 |
+
)
|
| 522 |
+
return attn_output, past_key_value
|
src/model/linear_attention/linear_window_attention_sw_long.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LoLCATs attention combining sliding window and linear attentions
|
| 3 |
+
- Using standard sliding window arrangement
|
| 4 |
+
- Training over long sequences with fixed memory with recurrent view
|
| 5 |
+
- During attention transfer, use Flash Attention to compute softmax attention outputs
|
| 6 |
+
|
| 7 |
+
For each layer:
|
| 8 |
+
- We first compute (softmax) attention over sliding windows
|
| 9 |
+
- We then compute standard linear attention to "fill in" the earlier parts
|
| 10 |
+
- We combine to model the entire sequence
|
| 11 |
+
"""
|
| 12 |
+
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
| 13 |
+
from .linear_window_attention_sw import hybrid_attention_quadratic
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class LolcatsSlidingWindowLongAttention(LolcatsTKWindowLongAttention):
|
| 17 |
+
"""
|
| 18 |
+
Lolcats attention combining sliding window and linear attention
|
| 19 |
+
"""
|
| 20 |
+
def __init__(self, remove_base_attn=True, **kwargs):
|
| 21 |
+
# keep self.base_attn for Flash Attention inference
|
| 22 |
+
super().__init__(remove_base_attn=True, **kwargs)
|
| 23 |
+
self.quadratic_attention = hybrid_attention_quadratic
|