Spaces:
Runtime error
Runtime error
dung-vpt-uney
commited on
Commit
·
9233720
1
Parent(s):
8f661e8
Update Visual-CoT demo - 2025-10-12 22:18:00
Browse filesFixes:
- Fix LLaVA config registration error (compatibility with newer transformers)
- Update Gradio to latest version (security fixes)
- Auto-deployed via update script
- llava/model/language_model/modeling_llamantk.py +118 -22
- requirements.txt +3 -0
llava/model/language_model/modeling_llamantk.py
CHANGED
|
@@ -43,29 +43,115 @@ from transformers.utils import (
|
|
| 43 |
from .configuration_llamantk import LlamaNTKConfig
|
| 44 |
|
| 45 |
# Make flash_attn optional for Hugging Face Spaces compatibility
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
try:
|
| 47 |
-
from
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
try:
|
| 52 |
-
from flash_attn.flash_attn_interface import
|
| 53 |
-
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
| 54 |
-
)
|
| 55 |
from flash_attn.bert_padding import unpad_input, pad_input
|
| 56 |
HAS_FLASH_ATTN = True
|
|
|
|
| 57 |
except ModuleNotFoundError:
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
from einops import rearrange
|
| 65 |
|
| 66 |
|
| 67 |
logger = logging.get_logger(__name__)
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
_CONFIG_FOR_DOC = "LlamaNTKConfig"
|
| 70 |
|
| 71 |
|
|
@@ -561,30 +647,40 @@ class LlamaAttention(nn.Module):
|
|
| 561 |
dtype=torch.int32,
|
| 562 |
device=qkv.device,
|
| 563 |
)
|
| 564 |
-
output =
|
| 565 |
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
| 566 |
)
|
| 567 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
| 568 |
else:
|
| 569 |
nheads = qkv.shape[-2]
|
| 570 |
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
| 571 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
x_unpad = rearrange(
|
| 573 |
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
|
| 574 |
)
|
| 575 |
-
output_unpad =
|
| 576 |
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
| 577 |
)
|
| 578 |
-
|
| 579 |
-
|
|
|
|
| 580 |
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
| 581 |
indices,
|
| 582 |
bsz,
|
| 583 |
q_len,
|
| 584 |
-
)
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 588 |
attn_output = self.o_proj(rearrange(output, "b s h d -> b s (h d)"))
|
| 589 |
else:
|
| 590 |
attn_weights = torch.matmul(
|
|
@@ -788,7 +884,7 @@ class LlamaAttention(nn.Module):
|
|
| 788 |
use_cache: bool = False,
|
| 789 |
):
|
| 790 |
# Use flash attention only if both config enables it AND flash_attn is available
|
| 791 |
-
if self.config.use_flash_attn and HAS_FLASH_ATTN:
|
| 792 |
return self.forward_flash_attn(
|
| 793 |
hidden_states,
|
| 794 |
attention_mask,
|
|
|
|
| 43 |
from .configuration_llamantk import LlamaNTKConfig
|
| 44 |
|
| 45 |
# Make flash_attn optional for Hugging Face Spaces compatibility
|
| 46 |
+
# Support both original flash_attn and kernels-community version
|
| 47 |
+
HAS_FLASH_ATTN = False
|
| 48 |
+
HAS_KERNELS_FLASH_ATTN = False
|
| 49 |
+
flash_attn_varlen_qkvpacked_func = None
|
| 50 |
+
unpad_input = None
|
| 51 |
+
pad_input = None
|
| 52 |
+
kernels_flash_attn = None
|
| 53 |
+
|
| 54 |
+
# Try kernels-community flash-attn first (pre-built, Spaces-compatible)
|
| 55 |
try:
|
| 56 |
+
from kernels import get_kernel
|
| 57 |
+
kernels_flash_attn = get_kernel("kernels-community/flash-attn")
|
| 58 |
+
HAS_KERNELS_FLASH_ATTN = True
|
| 59 |
+
print("✓ Using kernels-community/flash-attn (pre-built)")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
# Fallback to original flash_attn if kernels not available
|
| 64 |
+
if not HAS_KERNELS_FLASH_ATTN:
|
| 65 |
try:
|
| 66 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
|
|
|
|
|
|
| 67 |
from flash_attn.bert_padding import unpad_input, pad_input
|
| 68 |
HAS_FLASH_ATTN = True
|
| 69 |
+
print("✓ Using original flash-attn")
|
| 70 |
except ModuleNotFoundError:
|
| 71 |
+
try:
|
| 72 |
+
from flash_attn.flash_attn_interface import (
|
| 73 |
+
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
| 74 |
+
)
|
| 75 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
| 76 |
+
HAS_FLASH_ATTN = True
|
| 77 |
+
print("✓ Using original flash-attn (legacy API)")
|
| 78 |
+
except ModuleNotFoundError:
|
| 79 |
+
# Flash attention not available - will use standard attention
|
| 80 |
+
print("⚠ Flash attention not available, using standard attention")
|
| 81 |
|
| 82 |
from einops import rearrange
|
| 83 |
|
| 84 |
|
| 85 |
logger = logging.get_logger(__name__)
|
| 86 |
|
| 87 |
+
|
| 88 |
+
# Helper functions for padding/unpadding when using kernels (fallback if bert_padding not available)
|
| 89 |
+
def simple_unpad_input(hidden_states, attention_mask):
|
| 90 |
+
"""
|
| 91 |
+
Simple unpad implementation when flash_attn.bert_padding is not available
|
| 92 |
+
"""
|
| 93 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 94 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 95 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 96 |
+
cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
| 97 |
+
return (
|
| 98 |
+
hidden_states.flatten(0, 1)[indices],
|
| 99 |
+
indices,
|
| 100 |
+
cu_seqlens,
|
| 101 |
+
max_seqlen_in_batch,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def simple_pad_input(hidden_states, indices, batch, seqlen):
|
| 106 |
+
"""
|
| 107 |
+
Simple pad implementation when flash_attn.bert_padding is not available
|
| 108 |
+
"""
|
| 109 |
+
output = torch.zeros(
|
| 110 |
+
batch * seqlen, *hidden_states.shape[1:], dtype=hidden_states.dtype, device=hidden_states.device
|
| 111 |
+
)
|
| 112 |
+
output[indices] = hidden_states
|
| 113 |
+
return output.view(batch, seqlen, *hidden_states.shape[1:])
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Helper function to call flash attention with unified API
|
| 117 |
+
def call_flash_attn_qkvpacked(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None, causal=True):
|
| 118 |
+
"""
|
| 119 |
+
Unified wrapper for flash attention that supports both:
|
| 120 |
+
- kernels-community/flash-attn (pre-built)
|
| 121 |
+
- original flash-attn
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
qkv: [total_seq_len, 3, num_heads, head_dim] packed Q, K, V
|
| 125 |
+
cu_seqlens: cumulative sequence lengths
|
| 126 |
+
max_seqlen: maximum sequence length
|
| 127 |
+
dropout_p: dropout probability
|
| 128 |
+
softmax_scale: softmax scale
|
| 129 |
+
causal: whether to use causal mask
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
output: [total_seq_len, num_heads, head_dim]
|
| 133 |
+
"""
|
| 134 |
+
if HAS_KERNELS_FLASH_ATTN:
|
| 135 |
+
# Kernels API: separate q, k, v and use varlen_fwd()
|
| 136 |
+
q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2] # [seq_len, num_heads, head_dim]
|
| 137 |
+
output = kernels_flash_attn.varlen_fwd(
|
| 138 |
+
q=q,
|
| 139 |
+
k=k,
|
| 140 |
+
v=v,
|
| 141 |
+
cu_seqlens_q=cu_seqlens,
|
| 142 |
+
cu_seqlens_k=cu_seqlens,
|
| 143 |
+
max_seqlen_q=max_seqlen,
|
| 144 |
+
max_seqlen_k=max_seqlen,
|
| 145 |
+
)[0]
|
| 146 |
+
return output
|
| 147 |
+
elif HAS_FLASH_ATTN:
|
| 148 |
+
# Original flash_attn API
|
| 149 |
+
return flash_attn_varlen_qkvpacked_func(
|
| 150 |
+
qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=softmax_scale, causal=causal
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
raise RuntimeError("Flash attention is not available")
|
| 154 |
+
|
| 155 |
_CONFIG_FOR_DOC = "LlamaNTKConfig"
|
| 156 |
|
| 157 |
|
|
|
|
| 647 |
dtype=torch.int32,
|
| 648 |
device=qkv.device,
|
| 649 |
)
|
| 650 |
+
output = call_flash_attn_qkvpacked(
|
| 651 |
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
| 652 |
)
|
| 653 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
| 654 |
else:
|
| 655 |
nheads = qkv.shape[-2]
|
| 656 |
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
| 657 |
+
# Use appropriate unpad function based on available backend
|
| 658 |
+
if unpad_input is not None:
|
| 659 |
+
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
| 660 |
+
else:
|
| 661 |
+
x_unpad, indices, cu_q_lens, max_s = simple_unpad_input(x, key_padding_mask)
|
| 662 |
x_unpad = rearrange(
|
| 663 |
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
|
| 664 |
)
|
| 665 |
+
output_unpad = call_flash_attn_qkvpacked(
|
| 666 |
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
| 667 |
)
|
| 668 |
+
# Use appropriate pad function based on available backend
|
| 669 |
+
if pad_input is not None:
|
| 670 |
+
padded = pad_input(
|
| 671 |
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
| 672 |
indices,
|
| 673 |
bsz,
|
| 674 |
q_len,
|
| 675 |
+
)
|
| 676 |
+
else:
|
| 677 |
+
padded = simple_pad_input(
|
| 678 |
+
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
| 679 |
+
indices,
|
| 680 |
+
bsz,
|
| 681 |
+
q_len,
|
| 682 |
+
)
|
| 683 |
+
output = rearrange(padded, "b s (h d) -> b s h d", h=nheads)
|
| 684 |
attn_output = self.o_proj(rearrange(output, "b s h d -> b s (h d)"))
|
| 685 |
else:
|
| 686 |
attn_weights = torch.matmul(
|
|
|
|
| 884 |
use_cache: bool = False,
|
| 885 |
):
|
| 886 |
# Use flash attention only if both config enables it AND flash_attn is available
|
| 887 |
+
if self.config.use_flash_attn and (HAS_FLASH_ATTN or HAS_KERNELS_FLASH_ATTN):
|
| 888 |
return self.forward_flash_attn(
|
| 889 |
hidden_states,
|
| 890 |
attention_mask,
|
requirements.txt
CHANGED
|
@@ -18,6 +18,9 @@ einops==0.6.1
|
|
| 18 |
einops-exts==0.0.4
|
| 19 |
timm==0.6.13
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
# Utilities
|
| 22 |
Pillow>=10.0.0
|
| 23 |
numpy>=1.24.0
|
|
|
|
| 18 |
einops-exts==0.0.4
|
| 19 |
timm==0.6.13
|
| 20 |
|
| 21 |
+
# Flash Attention via Kernels (pre-built, Spaces-compatible)
|
| 22 |
+
kernels>=0.0.1 # For fast attention without compilation
|
| 23 |
+
|
| 24 |
# Utilities
|
| 25 |
Pillow>=10.0.0
|
| 26 |
numpy>=1.24.0
|