dung-vpt-uney commited on
Commit
9233720
·
1 Parent(s): 8f661e8

Update Visual-CoT demo - 2025-10-12 22:18:00

Browse files

Fixes:
- Fix LLaVA config registration error (compatibility with newer transformers)
- Update Gradio to latest version (security fixes)
- Auto-deployed via update script

llava/model/language_model/modeling_llamantk.py CHANGED
@@ -43,29 +43,115 @@ from transformers.utils import (
43
  from .configuration_llamantk import LlamaNTKConfig
44
 
45
  # Make flash_attn optional for Hugging Face Spaces compatibility
 
 
 
 
 
 
 
 
 
46
  try:
47
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
48
- from flash_attn.bert_padding import unpad_input, pad_input
49
- HAS_FLASH_ATTN = True
50
- except ModuleNotFoundError:
 
 
 
 
 
51
  try:
52
- from flash_attn.flash_attn_interface import (
53
- flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
54
- )
55
  from flash_attn.bert_padding import unpad_input, pad_input
56
  HAS_FLASH_ATTN = True
 
57
  except ModuleNotFoundError:
58
- # Flash attention not available - will use standard attention
59
- HAS_FLASH_ATTN = False
60
- flash_attn_varlen_qkvpacked_func = None
61
- unpad_input = None
62
- pad_input = None
 
 
 
 
 
63
 
64
  from einops import rearrange
65
 
66
 
67
  logger = logging.get_logger(__name__)
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  _CONFIG_FOR_DOC = "LlamaNTKConfig"
70
 
71
 
@@ -561,30 +647,40 @@ class LlamaAttention(nn.Module):
561
  dtype=torch.int32,
562
  device=qkv.device,
563
  )
564
- output = flash_attn_varlen_qkvpacked_func(
565
  qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
566
  )
567
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
568
  else:
569
  nheads = qkv.shape[-2]
570
  x = rearrange(qkv, "b s three h d -> b s (three h d)")
571
- x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
 
 
 
 
572
  x_unpad = rearrange(
573
  x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
574
  )
575
- output_unpad = flash_attn_varlen_qkvpacked_func(
576
  x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
577
  )
578
- output = rearrange(
579
- pad_input(
 
580
  rearrange(output_unpad, "nnz h d -> nnz (h d)"),
581
  indices,
582
  bsz,
583
  q_len,
584
- ),
585
- "b s (h d) -> b s h d",
586
- h=nheads,
587
- )
 
 
 
 
 
588
  attn_output = self.o_proj(rearrange(output, "b s h d -> b s (h d)"))
589
  else:
590
  attn_weights = torch.matmul(
@@ -788,7 +884,7 @@ class LlamaAttention(nn.Module):
788
  use_cache: bool = False,
789
  ):
790
  # Use flash attention only if both config enables it AND flash_attn is available
791
- if self.config.use_flash_attn and HAS_FLASH_ATTN:
792
  return self.forward_flash_attn(
793
  hidden_states,
794
  attention_mask,
 
43
  from .configuration_llamantk import LlamaNTKConfig
44
 
45
  # Make flash_attn optional for Hugging Face Spaces compatibility
46
+ # Support both original flash_attn and kernels-community version
47
+ HAS_FLASH_ATTN = False
48
+ HAS_KERNELS_FLASH_ATTN = False
49
+ flash_attn_varlen_qkvpacked_func = None
50
+ unpad_input = None
51
+ pad_input = None
52
+ kernels_flash_attn = None
53
+
54
+ # Try kernels-community flash-attn first (pre-built, Spaces-compatible)
55
  try:
56
+ from kernels import get_kernel
57
+ kernels_flash_attn = get_kernel("kernels-community/flash-attn")
58
+ HAS_KERNELS_FLASH_ATTN = True
59
+ print("✓ Using kernels-community/flash-attn (pre-built)")
60
+ except Exception as e:
61
+ pass
62
+
63
+ # Fallback to original flash_attn if kernels not available
64
+ if not HAS_KERNELS_FLASH_ATTN:
65
  try:
66
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
 
 
67
  from flash_attn.bert_padding import unpad_input, pad_input
68
  HAS_FLASH_ATTN = True
69
+ print("✓ Using original flash-attn")
70
  except ModuleNotFoundError:
71
+ try:
72
+ from flash_attn.flash_attn_interface import (
73
+ flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
74
+ )
75
+ from flash_attn.bert_padding import unpad_input, pad_input
76
+ HAS_FLASH_ATTN = True
77
+ print("✓ Using original flash-attn (legacy API)")
78
+ except ModuleNotFoundError:
79
+ # Flash attention not available - will use standard attention
80
+ print("⚠ Flash attention not available, using standard attention")
81
 
82
  from einops import rearrange
83
 
84
 
85
  logger = logging.get_logger(__name__)
86
 
87
+
88
+ # Helper functions for padding/unpadding when using kernels (fallback if bert_padding not available)
89
+ def simple_unpad_input(hidden_states, attention_mask):
90
+ """
91
+ Simple unpad implementation when flash_attn.bert_padding is not available
92
+ """
93
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
94
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
95
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
96
+ cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
97
+ return (
98
+ hidden_states.flatten(0, 1)[indices],
99
+ indices,
100
+ cu_seqlens,
101
+ max_seqlen_in_batch,
102
+ )
103
+
104
+
105
+ def simple_pad_input(hidden_states, indices, batch, seqlen):
106
+ """
107
+ Simple pad implementation when flash_attn.bert_padding is not available
108
+ """
109
+ output = torch.zeros(
110
+ batch * seqlen, *hidden_states.shape[1:], dtype=hidden_states.dtype, device=hidden_states.device
111
+ )
112
+ output[indices] = hidden_states
113
+ return output.view(batch, seqlen, *hidden_states.shape[1:])
114
+
115
+
116
+ # Helper function to call flash attention with unified API
117
+ def call_flash_attn_qkvpacked(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None, causal=True):
118
+ """
119
+ Unified wrapper for flash attention that supports both:
120
+ - kernels-community/flash-attn (pre-built)
121
+ - original flash-attn
122
+
123
+ Args:
124
+ qkv: [total_seq_len, 3, num_heads, head_dim] packed Q, K, V
125
+ cu_seqlens: cumulative sequence lengths
126
+ max_seqlen: maximum sequence length
127
+ dropout_p: dropout probability
128
+ softmax_scale: softmax scale
129
+ causal: whether to use causal mask
130
+
131
+ Returns:
132
+ output: [total_seq_len, num_heads, head_dim]
133
+ """
134
+ if HAS_KERNELS_FLASH_ATTN:
135
+ # Kernels API: separate q, k, v and use varlen_fwd()
136
+ q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2] # [seq_len, num_heads, head_dim]
137
+ output = kernels_flash_attn.varlen_fwd(
138
+ q=q,
139
+ k=k,
140
+ v=v,
141
+ cu_seqlens_q=cu_seqlens,
142
+ cu_seqlens_k=cu_seqlens,
143
+ max_seqlen_q=max_seqlen,
144
+ max_seqlen_k=max_seqlen,
145
+ )[0]
146
+ return output
147
+ elif HAS_FLASH_ATTN:
148
+ # Original flash_attn API
149
+ return flash_attn_varlen_qkvpacked_func(
150
+ qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=softmax_scale, causal=causal
151
+ )
152
+ else:
153
+ raise RuntimeError("Flash attention is not available")
154
+
155
  _CONFIG_FOR_DOC = "LlamaNTKConfig"
156
 
157
 
 
647
  dtype=torch.int32,
648
  device=qkv.device,
649
  )
650
+ output = call_flash_attn_qkvpacked(
651
  qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
652
  )
653
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
654
  else:
655
  nheads = qkv.shape[-2]
656
  x = rearrange(qkv, "b s three h d -> b s (three h d)")
657
+ # Use appropriate unpad function based on available backend
658
+ if unpad_input is not None:
659
+ x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
660
+ else:
661
+ x_unpad, indices, cu_q_lens, max_s = simple_unpad_input(x, key_padding_mask)
662
  x_unpad = rearrange(
663
  x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
664
  )
665
+ output_unpad = call_flash_attn_qkvpacked(
666
  x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
667
  )
668
+ # Use appropriate pad function based on available backend
669
+ if pad_input is not None:
670
+ padded = pad_input(
671
  rearrange(output_unpad, "nnz h d -> nnz (h d)"),
672
  indices,
673
  bsz,
674
  q_len,
675
+ )
676
+ else:
677
+ padded = simple_pad_input(
678
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"),
679
+ indices,
680
+ bsz,
681
+ q_len,
682
+ )
683
+ output = rearrange(padded, "b s (h d) -> b s h d", h=nheads)
684
  attn_output = self.o_proj(rearrange(output, "b s h d -> b s (h d)"))
685
  else:
686
  attn_weights = torch.matmul(
 
884
  use_cache: bool = False,
885
  ):
886
  # Use flash attention only if both config enables it AND flash_attn is available
887
+ if self.config.use_flash_attn and (HAS_FLASH_ATTN or HAS_KERNELS_FLASH_ATTN):
888
  return self.forward_flash_attn(
889
  hidden_states,
890
  attention_mask,
requirements.txt CHANGED
@@ -18,6 +18,9 @@ einops==0.6.1
18
  einops-exts==0.0.4
19
  timm==0.6.13
20
 
 
 
 
21
  # Utilities
22
  Pillow>=10.0.0
23
  numpy>=1.24.0
 
18
  einops-exts==0.0.4
19
  timm==0.6.13
20
 
21
+ # Flash Attention via Kernels (pre-built, Spaces-compatible)
22
+ kernels>=0.0.1 # For fast attention without compilation
23
+
24
  # Utilities
25
  Pillow>=10.0.0
26
  numpy>=1.24.0