SFLY5 commited on
Commit
ce3b626
·
1 Parent(s): ba47981

performance optimization

Browse files
Files changed (1) hide show
  1. modeling_ernie_45t_vl.py +18 -24
modeling_ernie_45t_vl.py CHANGED
@@ -3457,33 +3457,27 @@ class VisionAttention(nn.Module):
3457
  k = apply_rotary_pos_emb_vision(k.unsqueeze(dim=0), rotary_pos_emb).squeeze(
3458
  dim=0
3459
  )
3460
-
3461
- max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
3462
-
3463
- attention_mask = torch.full(
3464
- [1, seq_length, seq_length],
3465
- torch.finfo(q.dtype).min,
3466
- device=q.device,
3467
- dtype=q.dtype,
3468
- )
3469
- for i in range(1, len(cu_seqlens)):
3470
- attention_mask[
3471
- ...,
3472
- cu_seqlens[i - 1] : cu_seqlens[i],
3473
- cu_seqlens[i - 1] : cu_seqlens[i],
3474
- ] = 0
3475
-
3476
  q = q.transpose(0, 1)
3477
  k = k.transpose(0, 1)
3478
  v = v.transpose(0, 1)
3479
- attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
3480
- attn_weights = attn_weights + attention_mask
3481
- attn_weights = nn.functional.softmax(
3482
- attn_weights, dim=-1, dtype=torch.float32
3483
- ).to(q.dtype)
3484
- attn_output = torch.matmul(attn_weights, v)
3485
- attn_output = attn_output.transpose(0, 1)
3486
- attn_output = attn_output.reshape(seq_length, -1)
 
 
 
 
 
 
 
 
 
3487
  attn_output = self.proj(attn_output)
3488
  return attn_output
3489
 
 
3457
  k = apply_rotary_pos_emb_vision(k.unsqueeze(dim=0), rotary_pos_emb).squeeze(
3458
  dim=0
3459
  )
3460
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3461
  q = q.transpose(0, 1)
3462
  k = k.transpose(0, 1)
3463
  v = v.transpose(0, 1)
3464
+
3465
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
3466
+ splits = [
3467
+ torch.split(tensor, lengths.tolist(), dim=1) for tensor in (q, k, v)
3468
+ ]
3469
+
3470
+ attn_output = []
3471
+ for q, k, v in zip(*splits):
3472
+ attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
3473
+ attn_weights = nn.functional.softmax(
3474
+ attn_weights, dim=-1, dtype=torch.float32
3475
+ ).to(q.dtype)
3476
+ attn_output_splited = torch.matmul(attn_weights, v)
3477
+ attn_output_splited = attn_output_splited.transpose(0, 1)
3478
+ attn_output.append(attn_output_splited)
3479
+ attn_output = torch.cat(attn_output, dim=0)
3480
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
3481
  attn_output = self.proj(attn_output)
3482
  return attn_output
3483