performance optimization
Browse files- modeling_ernie_45t_vl.py +18 -24
modeling_ernie_45t_vl.py
CHANGED
|
@@ -3457,33 +3457,27 @@ class VisionAttention(nn.Module):
|
|
| 3457 |
k = apply_rotary_pos_emb_vision(k.unsqueeze(dim=0), rotary_pos_emb).squeeze(
|
| 3458 |
dim=0
|
| 3459 |
)
|
| 3460 |
-
|
| 3461 |
-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
| 3462 |
-
|
| 3463 |
-
attention_mask = torch.full(
|
| 3464 |
-
[1, seq_length, seq_length],
|
| 3465 |
-
torch.finfo(q.dtype).min,
|
| 3466 |
-
device=q.device,
|
| 3467 |
-
dtype=q.dtype,
|
| 3468 |
-
)
|
| 3469 |
-
for i in range(1, len(cu_seqlens)):
|
| 3470 |
-
attention_mask[
|
| 3471 |
-
...,
|
| 3472 |
-
cu_seqlens[i - 1] : cu_seqlens[i],
|
| 3473 |
-
cu_seqlens[i - 1] : cu_seqlens[i],
|
| 3474 |
-
] = 0
|
| 3475 |
-
|
| 3476 |
q = q.transpose(0, 1)
|
| 3477 |
k = k.transpose(0, 1)
|
| 3478 |
v = v.transpose(0, 1)
|
| 3479 |
-
|
| 3480 |
-
|
| 3481 |
-
|
| 3482 |
-
|
| 3483 |
-
|
| 3484 |
-
|
| 3485 |
-
attn_output =
|
| 3486 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3487 |
attn_output = self.proj(attn_output)
|
| 3488 |
return attn_output
|
| 3489 |
|
|
|
|
| 3457 |
k = apply_rotary_pos_emb_vision(k.unsqueeze(dim=0), rotary_pos_emb).squeeze(
|
| 3458 |
dim=0
|
| 3459 |
)
|
| 3460 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3461 |
q = q.transpose(0, 1)
|
| 3462 |
k = k.transpose(0, 1)
|
| 3463 |
v = v.transpose(0, 1)
|
| 3464 |
+
|
| 3465 |
+
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
| 3466 |
+
splits = [
|
| 3467 |
+
torch.split(tensor, lengths.tolist(), dim=1) for tensor in (q, k, v)
|
| 3468 |
+
]
|
| 3469 |
+
|
| 3470 |
+
attn_output = []
|
| 3471 |
+
for q, k, v in zip(*splits):
|
| 3472 |
+
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
| 3473 |
+
attn_weights = nn.functional.softmax(
|
| 3474 |
+
attn_weights, dim=-1, dtype=torch.float32
|
| 3475 |
+
).to(q.dtype)
|
| 3476 |
+
attn_output_splited = torch.matmul(attn_weights, v)
|
| 3477 |
+
attn_output_splited = attn_output_splited.transpose(0, 1)
|
| 3478 |
+
attn_output.append(attn_output_splited)
|
| 3479 |
+
attn_output = torch.cat(attn_output, dim=0)
|
| 3480 |
+
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
| 3481 |
attn_output = self.proj(attn_output)
|
| 3482 |
return attn_output
|
| 3483 |
|