Update vision_tower_builder.py
Browse files- vision_tower_builder.py +3 -6
vision_tower_builder.py
CHANGED
|
@@ -2,9 +2,6 @@ from typing import Optional, Tuple, Union, Dict
|
|
| 2 |
from dataclasses import dataclass
|
| 3 |
from functools import partial, reduce
|
| 4 |
from PIL import Image
|
| 5 |
-
import torch
|
| 6 |
-
import torch.utils.checkpoint
|
| 7 |
-
from torch import nn
|
| 8 |
import os
|
| 9 |
from transformers.image_processing_utils import BatchFeature, get_size_dict
|
| 10 |
from transformers.image_transforms import (
|
|
@@ -29,7 +26,7 @@ try:
|
|
| 29 |
from flash_attn import flash_attn_qkvpacked_func
|
| 30 |
except:
|
| 31 |
print("You need to install flash_attn")
|
| 32 |
-
from timm.
|
| 33 |
|
| 34 |
|
| 35 |
|
|
@@ -516,7 +513,7 @@ def build_vit(config, pt_type='origin'):
|
|
| 516 |
drop_path_rate=0.,
|
| 517 |
num_frames=config.num_frames,
|
| 518 |
tubelet_size=1,
|
| 519 |
-
use_checkpoint=
|
| 520 |
checkpoint_num=24,
|
| 521 |
return_index=config.return_idx,
|
| 522 |
with_ln=True, # merge vision_layernorm in it
|
|
@@ -619,4 +616,4 @@ def build_vision_tower(vision_tower_cfg, **kwargs):
|
|
| 619 |
elif "umt" in vision_tower:
|
| 620 |
return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
|
| 621 |
|
| 622 |
-
raise ValueError(f"Unknown vision tower: {vision_tower}")
|
|
|
|
| 2 |
from dataclasses import dataclass
|
| 3 |
from functools import partial, reduce
|
| 4 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
| 5 |
import os
|
| 6 |
from transformers.image_processing_utils import BatchFeature, get_size_dict
|
| 7 |
from transformers.image_transforms import (
|
|
|
|
| 26 |
from flash_attn import flash_attn_qkvpacked_func
|
| 27 |
except:
|
| 28 |
print("You need to install flash_attn")
|
| 29 |
+
from timm.layers import drop_path, to_2tuple, trunc_normal_
|
| 30 |
|
| 31 |
|
| 32 |
|
|
|
|
| 513 |
drop_path_rate=0.,
|
| 514 |
num_frames=config.num_frames,
|
| 515 |
tubelet_size=1,
|
| 516 |
+
use_checkpoint=False,
|
| 517 |
checkpoint_num=24,
|
| 518 |
return_index=config.return_idx,
|
| 519 |
with_ln=True, # merge vision_layernorm in it
|
|
|
|
| 616 |
elif "umt" in vision_tower:
|
| 617 |
return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
|
| 618 |
|
| 619 |
+
raise ValueError(f"Unknown vision tower: {vision_tower}")
|