Spaces:
Runtime error
Runtime error
LeoXing1996
commited on
Commit
·
9121982
1
Parent(s):
00a4e2b
update memory efficient attention
Browse files- animatediff/models/motion_module.py +14 -1
- requirements.txt +1 -0
animatediff/models/motion_module.py
CHANGED
|
@@ -467,6 +467,14 @@ class CrossAttention(nn.Module):
|
|
| 467 |
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
| 468 |
return hidden_states
|
| 469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
|
| 471 |
class VersatileAttention(CrossAttention):
|
| 472 |
def __init__(
|
|
@@ -532,7 +540,12 @@ class VersatileAttention(CrossAttention):
|
|
| 532 |
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
| 533 |
|
| 534 |
# attention, what we cannot get enough of
|
| 535 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
| 537 |
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
| 538 |
hidden_states = hidden_states.to(query.dtype)
|
|
|
|
| 467 |
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
| 468 |
return hidden_states
|
| 469 |
|
| 470 |
+
def _memory_efficient_attention_pt20(self, query, key, value, attention_mask):
|
| 471 |
+
query = query.contiguous()
|
| 472 |
+
key = key.contiguous()
|
| 473 |
+
value = value.contiguous()
|
| 474 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0, is_causal=False)
|
| 475 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
| 476 |
+
return hidden_states
|
| 477 |
+
|
| 478 |
|
| 479 |
class VersatileAttention(CrossAttention):
|
| 480 |
def __init__(
|
|
|
|
| 540 |
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
| 541 |
|
| 542 |
# attention, what we cannot get enough of
|
| 543 |
+
if hasattr(F, 'scaled_dot_product_attention'):
|
| 544 |
+
# NOTE: pt20's scaled_dot_product_attention seems more memory efficient than
|
| 545 |
+
# xformers' memory_efficient_attention, set it as the first class citizen
|
| 546 |
+
hidden_states = self._memory_efficient_attention_pt20(query, key, value, attention_mask)
|
| 547 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 548 |
+
elif self._use_memory_efficient_attention_xformers:
|
| 549 |
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
| 550 |
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
| 551 |
hidden_states = hidden_states.to(query.dtype)
|
requirements.txt
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
opencv-python
|
|
|
|
| 2 |
torchvision==0.14.1
|
| 3 |
diffusers==0.24.0
|
| 4 |
transformers==4.25.1
|
|
|
|
| 1 |
opencv-python
|
| 2 |
+
torch>=2.0.0
|
| 3 |
torchvision==0.14.1
|
| 4 |
diffusers==0.24.0
|
| 5 |
transformers==4.25.1
|