|
|
import collections |
|
|
import functools |
|
|
import inspect |
|
|
import logging |
|
|
import math |
|
|
import warnings |
|
|
from typing import ( |
|
|
Any, |
|
|
Callable, |
|
|
Dict, |
|
|
List, |
|
|
Literal, |
|
|
Optional, |
|
|
Protocol, |
|
|
Set, |
|
|
Tuple, |
|
|
Type, |
|
|
) |
|
|
|
|
|
import einops |
|
|
import numpy as np |
|
|
import roma |
|
|
import timm |
|
|
import torch |
|
|
import torch.distributed.fsdp |
|
|
import torch.distributed.tensor |
|
|
import transformers |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
from .common_spear import ( |
|
|
Configurable, |
|
|
DiffusionInput, |
|
|
FlowInput, |
|
|
LLMOutput, |
|
|
RoboticsFlowInput, |
|
|
RoboticsInput, |
|
|
RoboticsOutput, |
|
|
RotationFormat, |
|
|
VLMOutput, |
|
|
expand_dims, |
|
|
is_quaternion, |
|
|
is_rotmat, |
|
|
is_rotmat_3x3, |
|
|
is_rotmat_9, |
|
|
quaternion_half_cover, |
|
|
rotmat_as_3x3, |
|
|
rotmat_as_9, |
|
|
) |
|
|
from .configuration_spear import ( |
|
|
FourierFeaturesConfig, |
|
|
NoisedControlProjectorConfig, |
|
|
PaliGemmaVLMConfig, |
|
|
PiZeroFlowMatchingDecoderBlockConfig, |
|
|
PiZeroFlowMatchingDecoderConfig, |
|
|
PiZeroFlowMatchingModuleConfig, |
|
|
RobotStateProjectorConfig, |
|
|
RotaryPositionalEncodingConfig, |
|
|
SPEAR1Config, |
|
|
) |
|
|
from .processing_spear import ( |
|
|
EmptyTokenizer, |
|
|
PaliGemmaDepthProcessor, |
|
|
PiZeroFlowMatchingProcessor, |
|
|
) |
|
|
|
|
|
|
|
|
class ConfigurableModule(torch.nn.Module, Configurable): |
|
|
def __init__(self, config): |
|
|
Configurable.__init__(self, config) |
|
|
torch.nn.Module.__init__(self) |
|
|
|
|
|
|
|
|
class GemmaRMSNorm(torch.nn.Module): |
|
|
def __init__(self, dim: int, eps: float = 1e-06): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = torch.nn.Parameter(torch.zeros(dim)) |
|
|
|
|
|
def _norm(self, x): |
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
|
|
def forward(self, x): |
|
|
output = self._norm(x.float()) |
|
|
output = output * (1.0 + self.weight.float()) |
|
|
return output.type_as(x) |
|
|
|
|
|
def extra_repr(self): |
|
|
return f"{tuple(self.weight.shape)}, eps={self.eps}" |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def make_module_params_non_trainable( |
|
|
module: torch.nn.Module, recursive: bool = True, dtype: Optional[torch.dtype] = None |
|
|
): |
|
|
""" |
|
|
NOTE: dtype is applied only to module parameters, not buffers. This is different from the |
|
|
default torch.nn.Module.to(dtype=dtype) behavior, which applies to both parameters and buffers |
|
|
""" |
|
|
for param in module.parameters(recurse=recursive): |
|
|
param.requires_grad = False |
|
|
if dtype is not None: |
|
|
param.data = param.to(dtype=dtype) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def make_module_non_trainable(module: torch.nn.Module, dtype: Optional[torch.dtype] = None): |
|
|
make_module_params_non_trainable(module, dtype=None) |
|
|
if dtype is not None: |
|
|
module.to(dtype=dtype) |
|
|
module.eval() |
|
|
|
|
|
|
|
|
class ResidualConvBlock(torch.nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
out_channels: int | None = None, |
|
|
hidden_channels: int | None = None, |
|
|
padding_mode: str = "replicate", |
|
|
activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu", |
|
|
norm: Literal["group_norm", "layer_norm"] = "group_norm", |
|
|
): |
|
|
super().__init__() |
|
|
if out_channels is None: |
|
|
out_channels = in_channels |
|
|
if hidden_channels is None: |
|
|
hidden_channels = in_channels |
|
|
if activation == "relu": |
|
|
activation_cls = functools.partial(torch.nn.ReLU, inplace=True) |
|
|
elif activation == "leaky_relu": |
|
|
activation_cls = functools.partial(torch.nn.LeakyReLU, negative_slope=0.2, inplace=True) |
|
|
elif activation == "silu": |
|
|
activation_cls = functools.partial(torch.nn.SiLU, inplace=True) |
|
|
elif activation == "elu": |
|
|
activation_cls = functools.partial(torch.nn.ELU, inplace=True) |
|
|
else: |
|
|
raise ValueError(f"Unsupported activation function: {activation}") |
|
|
self.layers = torch.nn.Sequential( |
|
|
torch.nn.GroupNorm(1, in_channels), |
|
|
activation_cls(), |
|
|
torch.nn.Conv2d( |
|
|
in_channels, |
|
|
hidden_channels, |
|
|
kernel_size=3, |
|
|
padding=1, |
|
|
padding_mode=padding_mode, |
|
|
), |
|
|
torch.nn.GroupNorm(hidden_channels // 32 if norm == "group_norm" else 1, hidden_channels), |
|
|
activation_cls(), |
|
|
torch.nn.Conv2d( |
|
|
hidden_channels, |
|
|
out_channels, |
|
|
kernel_size=3, |
|
|
padding=1, |
|
|
padding_mode=padding_mode, |
|
|
), |
|
|
) |
|
|
self.skip_connection = ( |
|
|
torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) |
|
|
if in_channels != out_channels |
|
|
else torch.nn.Identity() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
skip = self.skip_connection(x) |
|
|
x = self.layers(x) |
|
|
x = x + skip |
|
|
return x |
|
|
|
|
|
|
|
|
def normalized_view_plane_uv( |
|
|
width: int, |
|
|
height: int, |
|
|
aspect_ratio: float | None = None, |
|
|
dtype: torch.dtype = None, |
|
|
device: torch.device = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal) |
|
|
""" |
|
|
if aspect_ratio is None: |
|
|
aspect_ratio = width / height |
|
|
span_x = aspect_ratio / (1 + aspect_ratio**2) ** 0.5 |
|
|
span_y = 1 / (1 + aspect_ratio**2) ** 0.5 |
|
|
u = torch.linspace( |
|
|
-span_x * (width - 1) / width, |
|
|
span_x * (width - 1) / width, |
|
|
width, |
|
|
dtype=dtype, |
|
|
device=device, |
|
|
) |
|
|
v = torch.linspace( |
|
|
-span_y * (height - 1) / height, |
|
|
span_y * (height - 1) / height, |
|
|
height, |
|
|
dtype=dtype, |
|
|
device=device, |
|
|
) |
|
|
(u, v) = torch.meshgrid(u, v, indexing="xy") |
|
|
uv = torch.stack([u, v], dim=-1) |
|
|
return uv |
|
|
|
|
|
|
|
|
class Head(torch.nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
num_features: int, |
|
|
dim_in: int, |
|
|
dim_out: List[int], |
|
|
dim_proj: int = 512, |
|
|
dim_upsample: List[int] = [256, 128, 128], |
|
|
dim_times_res_block_hidden: int = 1, |
|
|
num_res_blocks: int = 1, |
|
|
res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm", |
|
|
last_res_blocks: int = 0, |
|
|
last_conv_channels: int = 32, |
|
|
last_conv_size: int = 1, |
|
|
): |
|
|
super().__init__() |
|
|
self.projects = torch.nn.ModuleList( |
|
|
[ |
|
|
torch.nn.Conv2d( |
|
|
in_channels=dim_in, |
|
|
out_channels=dim_proj, |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
padding=0, |
|
|
) |
|
|
for _ in range(num_features) |
|
|
] |
|
|
) |
|
|
self.upsample_blocks = torch.nn.ModuleList( |
|
|
[ |
|
|
torch.nn.Sequential( |
|
|
self._make_upsampler(in_ch + 2, out_ch), |
|
|
*( |
|
|
ResidualConvBlock( |
|
|
out_ch, |
|
|
out_ch, |
|
|
dim_times_res_block_hidden * out_ch, |
|
|
activation="relu", |
|
|
norm=res_block_norm, |
|
|
) |
|
|
for _ in range(num_res_blocks) |
|
|
), |
|
|
) |
|
|
for (in_ch, out_ch) in zip([dim_proj] + dim_upsample[:-1], dim_upsample, strict=True) |
|
|
] |
|
|
) |
|
|
self.output_block = torch.nn.ModuleList( |
|
|
[ |
|
|
self._make_output_block( |
|
|
dim_upsample[-1] + 2, |
|
|
dim_out_, |
|
|
dim_times_res_block_hidden, |
|
|
last_res_blocks, |
|
|
last_conv_channels, |
|
|
last_conv_size, |
|
|
res_block_norm, |
|
|
) |
|
|
for dim_out_ in dim_out |
|
|
] |
|
|
) |
|
|
|
|
|
def _make_upsampler(self, in_channels: int, out_channels: int): |
|
|
upsampler = torch.nn.Sequential( |
|
|
torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), |
|
|
torch.nn.Conv2d( |
|
|
out_channels, |
|
|
out_channels, |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=1, |
|
|
padding_mode="replicate", |
|
|
), |
|
|
) |
|
|
upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1] |
|
|
return upsampler |
|
|
|
|
|
def _make_output_block( |
|
|
self, |
|
|
dim_in: int, |
|
|
dim_out: int, |
|
|
dim_times_res_block_hidden: int, |
|
|
last_res_blocks: int, |
|
|
last_conv_channels: int, |
|
|
last_conv_size: int, |
|
|
res_block_norm: Literal["group_norm", "layer_norm"], |
|
|
): |
|
|
return torch.nn.Sequential( |
|
|
torch.nn.Conv2d( |
|
|
dim_in, |
|
|
last_conv_channels, |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=1, |
|
|
padding_mode="replicate", |
|
|
), |
|
|
*( |
|
|
ResidualConvBlock( |
|
|
last_conv_channels, |
|
|
last_conv_channels, |
|
|
dim_times_res_block_hidden * last_conv_channels, |
|
|
activation="relu", |
|
|
norm=res_block_norm, |
|
|
) |
|
|
for _ in range(last_res_blocks) |
|
|
), |
|
|
torch.nn.ReLU(inplace=True), |
|
|
torch.nn.Conv2d( |
|
|
last_conv_channels, |
|
|
dim_out, |
|
|
kernel_size=last_conv_size, |
|
|
stride=1, |
|
|
padding=last_conv_size // 2, |
|
|
padding_mode="replicate", |
|
|
), |
|
|
) |
|
|
|
|
|
def forward(self, hidden_states: List[torch.Tensor], image: torch.Tensor): |
|
|
(img_h, img_w) = image.shape[-2:] |
|
|
(patch_h, patch_w) = (img_h // 14, img_w // 14) |
|
|
x = torch.stack( |
|
|
[ |
|
|
proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous()) |
|
|
for (proj, feat) in zip(self.projects, hidden_states, strict=True) |
|
|
], |
|
|
dim=1, |
|
|
).sum(dim=1) |
|
|
for _, block in enumerate(self.upsample_blocks): |
|
|
uv = normalized_view_plane_uv( |
|
|
width=x.shape[-1], |
|
|
height=x.shape[-2], |
|
|
aspect_ratio=img_w / img_h, |
|
|
dtype=x.dtype, |
|
|
device=x.device, |
|
|
) |
|
|
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) |
|
|
x = torch.cat([x, uv], dim=1) |
|
|
for layer in block: |
|
|
x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False) |
|
|
x = torch.nn.functional.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False) |
|
|
uv = normalized_view_plane_uv( |
|
|
width=x.shape[-1], |
|
|
height=x.shape[-2], |
|
|
aspect_ratio=img_w / img_h, |
|
|
dtype=x.dtype, |
|
|
device=x.device, |
|
|
) |
|
|
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) |
|
|
x = torch.cat([x, uv], dim=1) |
|
|
if isinstance(self.output_block, torch.nn.ModuleList): |
|
|
output = [ |
|
|
torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) |
|
|
for block in self.output_block |
|
|
] |
|
|
else: |
|
|
output = torch.utils.checkpoint.checkpoint(self.output_block, x, use_reentrant=False) |
|
|
return output |
|
|
|
|
|
|
|
|
def _is_single_image_size(image_sizes: Dict[str, Dict[str, int]]) -> bool: |
|
|
return ( |
|
|
len(image_sizes) == 1 |
|
|
or len(set(((image_size["height"], image_size["width"]) for image_size in image_sizes.values()))) == 1 |
|
|
) |
|
|
|
|
|
|
|
|
class MoGe(torch.nn.Module): |
|
|
""" |
|
|
Implementation of MoGe taken from https://github.com/microsoft/MoGe/blob/main/moge/model/v1.py |
|
|
Simplified and stripped down such that: |
|
|
- It doesn't rely on MoGe codebase |
|
|
- Uses timm for ViT backbone |
|
|
- Only predicts points and mask |
|
|
- Currently does NOT infer depth or intrinsics (but this could be added) |
|
|
- Requires image to be resized to the expected resolution. Note that this resolution must be |
|
|
in the range of num_tokens_range, (for square images, pixel sizes in the range ~[36*14, 50*14]) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
image_sizes: Dict[str, Dict[str, int]] = {}, |
|
|
backbone_id: str = "vit_large_patch14_dinov2.lvd142m", |
|
|
intermediate_layers: int | List[int] = 4, |
|
|
dim_proj: int = 512, |
|
|
dim_upsample: List[int] = [256, 128, 64], |
|
|
dim_times_res_block_hidden: int = 2, |
|
|
num_res_blocks: int = 2, |
|
|
remap_output: Literal[False, True, "linear", "sinh", "exp", "sinh_exp"] = "exp", |
|
|
res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm", |
|
|
num_tokens_range: Tuple[int, int] = [1275, 2551], |
|
|
last_res_blocks: int = 0, |
|
|
last_conv_channels: int = 32, |
|
|
last_conv_size: int = 1, |
|
|
mask_threshold: float = 0.5, |
|
|
output_dtype: torch.dtype = torch.bfloat16, |
|
|
): |
|
|
""" |
|
|
NOTE: |
|
|
- All defaults were taken from the checkpoint config for 'Ruicheng/moge-vitl' |
|
|
- output_dtype - by default was float32, changed to bfloat16 for model training |
|
|
""" |
|
|
super().__init__() |
|
|
self.remap_output = remap_output |
|
|
self.intermediate_layers = intermediate_layers |
|
|
self.num_tokens_range = num_tokens_range |
|
|
self.mask_threshold = mask_threshold |
|
|
self.output_dtype = output_dtype |
|
|
self.image_sizes = dict(image_sizes) |
|
|
self.backbone: timm.models.vision_transformer.VisionTransformer = self._make_vit_backbone(backbone_id) |
|
|
token_size: int = self.backbone.embed_dim |
|
|
self.head = Head( |
|
|
num_features=( |
|
|
intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers) |
|
|
), |
|
|
dim_in=token_size, |
|
|
dim_out=[3, 1], |
|
|
dim_proj=dim_proj, |
|
|
dim_upsample=dim_upsample, |
|
|
dim_times_res_block_hidden=dim_times_res_block_hidden, |
|
|
num_res_blocks=num_res_blocks, |
|
|
res_block_norm=res_block_norm, |
|
|
last_res_blocks=last_res_blocks, |
|
|
last_conv_channels=last_conv_channels, |
|
|
last_conv_size=last_conv_size, |
|
|
) |
|
|
|
|
|
def _make_vit_backbone(self, backbone_id: str) -> timm.models.vision_transformer.VisionTransformer: |
|
|
if _is_single_image_size(self.image_sizes): |
|
|
kwargs = { |
|
|
"img_size": ( |
|
|
self.image_sizes["main"]["height"], |
|
|
self.image_sizes["main"]["width"], |
|
|
), |
|
|
"dynamic_img_size": False, |
|
|
} |
|
|
else: |
|
|
kwargs = {"img_size": (224, 224), "dynamic_img_size": True} |
|
|
vit_backbone: timm.models.vision_transformer.VisionTransformer = timm.create_model( |
|
|
backbone_id, pretrained=False, num_classes=0, **kwargs |
|
|
) |
|
|
vit_backbone.forward = functools.partial( |
|
|
vit_backbone.forward_intermediates, |
|
|
indices=4, |
|
|
return_prefix_tokens=False, |
|
|
norm=True, |
|
|
stop_early=True, |
|
|
output_fmt="NLC", |
|
|
intermediates_only=True, |
|
|
) |
|
|
return vit_backbone |
|
|
|
|
|
def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Args: |
|
|
image: torch.Tensor of shape [B, 3, H, W] containing the preprocessed image, resized to the |
|
|
size expected by the model. |
|
|
Returns: |
|
|
A dictionary containing: |
|
|
- `points`: torch.Tensor of shape [B, 3, H, W] containing the predicted points. |
|
|
- `mask`: torch.Tensor of shape [B, 1, H, W] containing the predicted mask. |
|
|
""" |
|
|
(height, width) = image.shape[-2:] |
|
|
assert (height, width) in [ |
|
|
(image_size["height"], image_size["width"]) for image_size in self.image_sizes.values() |
|
|
], f"{(height, width)} not in {self.image_sizes}" |
|
|
features: List[torch.Tensor] = self.backbone(image) |
|
|
output = self.head(features, image) |
|
|
(points, mask) = output |
|
|
with torch.autocast( |
|
|
device_type=image.device.type, |
|
|
dtype=torch.float32, |
|
|
enabled=self.output_dtype == torch.float32, |
|
|
): |
|
|
points = torch.nn.functional.interpolate( |
|
|
points, |
|
|
(height, width), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
antialias=False, |
|
|
) |
|
|
mask = torch.nn.functional.interpolate( |
|
|
mask, |
|
|
(height, width), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
antialias=False, |
|
|
) |
|
|
points = self._remap_points(points, dim=1) |
|
|
output = {"points": points, "mask": mask} |
|
|
return output |
|
|
|
|
|
def _remap_points(self, points: torch.Tensor, dim: int = 1) -> torch.Tensor: |
|
|
if self.remap_output == "linear": |
|
|
pass |
|
|
elif self.remap_output == "sinh": |
|
|
points = torch.sinh(points) |
|
|
elif self.remap_output == "exp": |
|
|
(xy, z) = points.split([2, 1], dim=dim) |
|
|
z = torch.exp(z) |
|
|
points = torch.cat([xy * z, z], dim=dim) |
|
|
elif self.remap_output == "sinh_exp": |
|
|
(xy, z) = points.split([2, 1], dim=dim) |
|
|
points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=dim) |
|
|
else: |
|
|
raise ValueError(f"Invalid remap output type: {self.remap_output}") |
|
|
return points |
|
|
|
|
|
|
|
|
class DepthBackboneConfig(transformers.PretrainedConfig): |
|
|
def __init__( |
|
|
self, |
|
|
hf_hub_repo: str = "", |
|
|
hf_filename: str = "", |
|
|
image_sizes: Dict[str, Dict[str, int]] = {}, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.hf_hub_repo = hf_hub_repo |
|
|
self.hf_filename = hf_filename |
|
|
self.image_sizes = dict(image_sizes) |
|
|
|
|
|
|
|
|
class PaliGemma3DConfig(transformers.models.paligemma.PaliGemmaConfig): |
|
|
sub_configs = { |
|
|
"text_config": transformers.AutoConfig, |
|
|
"vision_config": transformers.AutoConfig, |
|
|
"depth_config": DepthBackboneConfig, |
|
|
} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
depth_config={}, |
|
|
depth_only: bool = False, |
|
|
mask_prob: float = 0.0, |
|
|
projection: str = "", |
|
|
depth_layers: int = 4, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
if isinstance(depth_config, dict): |
|
|
self.depth_config = DepthBackboneConfig(**depth_config) |
|
|
else: |
|
|
self.depth_config = depth_config |
|
|
self.mask_prob = mask_prob |
|
|
self.depth_only = depth_only |
|
|
self.projection = projection |
|
|
self.depth_layers = depth_layers |
|
|
|
|
|
@property |
|
|
def is_single_image_size(self) -> bool: |
|
|
return ( |
|
|
len(self.depth_config.image_sizes) == 1 |
|
|
or len( |
|
|
set( |
|
|
( |
|
|
(image_size["height"], image_size["width"]) |
|
|
for image_size in self.depth_config.image_sizes.values() |
|
|
) |
|
|
) |
|
|
) |
|
|
== 1 |
|
|
) |
|
|
|
|
|
@property |
|
|
def camera_names(self) -> List[str]: |
|
|
return list(self.depth_config.image_sizes.keys()) |
|
|
|
|
|
|
|
|
class NeRFPositionalEmbedding(torch.nn.Module): |
|
|
def __init__(self, n_frequencies: int, log_scale: bool = True, scale: float = 1.0): |
|
|
""" |
|
|
Args: |
|
|
n_frequencies: Dimension size, same as L parameter in the NeRF paper |
|
|
scale: Scale factor for the frequencies. To match the formula from the paper |
|
|
[sin(2^k * pi * x), cos(2^k * pi * x)], set scale to math.pi. In practice, the paper |
|
|
authors don't multiply by pi and use scale=1.0. |
|
|
See https://github.com/bmild/nerf/issues/12 |
|
|
""" |
|
|
super().__init__() |
|
|
self.n_frequencies = n_frequencies |
|
|
if log_scale: |
|
|
freq = 2 ** torch.arange(self.n_frequencies, dtype=torch.float32) * scale |
|
|
else: |
|
|
freq = torch.linspace(1, 2 ** (self.n_frequencies - 1), self.n_frequencies) * scale |
|
|
self.register_buffer("freq", freq) |
|
|
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Maps features from dimensionality N to dimensionality N*(2*L + 1), i.e. `2L + 1` output |
|
|
features are produced for each input feature. |
|
|
|
|
|
Embeds x to (x, sin(2^k x), cos(2^k x), ...). NOTE: `x` is also in the output. This is |
|
|
different from the equation in the paper, but matches the actual author implementation. |
|
|
See https://github.com/bmild/nerf/issues/12 |
|
|
|
|
|
Args: |
|
|
inputs: torch.Tensor of shape [B, ..., N]; input values to be transformed |
|
|
Returns: torch.Tensor of shape [B, ..., N*(2*L + 1) = embedding_dim], encoded input values |
|
|
""" |
|
|
freq = expand_dims(self.freq, ndim=inputs.ndim + 1, order=[-1, 1]) |
|
|
spectrum = freq * inputs.unsqueeze(-1) |
|
|
sin = torch.sin(spectrum) |
|
|
cos = torch.cos(spectrum) |
|
|
encoding = torch.stack([sin, cos], dim=-1) |
|
|
encoding = encoding.view(*inputs.shape[:-1], -1) |
|
|
encoding = torch.cat([inputs, encoding], dim=-1) |
|
|
return encoding |
|
|
|
|
|
|
|
|
def make_mlp( |
|
|
layer_sizes: List[int], |
|
|
activation: str | Type[torch.nn.Module], |
|
|
norm: str | Type[torch.nn.Module] | None = torch.nn.LayerNorm, |
|
|
activate_final: bool = False, |
|
|
bias: bool = True, |
|
|
) -> torch.nn.Sequential: |
|
|
""" |
|
|
Args: |
|
|
layer_sizes: List of layer sizes. The first value is the number of input features and the last |
|
|
value is the number of output features |
|
|
activation: str or the class of the activation. If str, it should be the exact name of |
|
|
the activation module under torch.nn, e.g. 'ReLU', 'SiLU', 'GeLU'. Use 'Identity' if |
|
|
no activation wanted |
|
|
norm: type of normalization. Same type as `activation`. Ex: `torch.nn.LayerNorm`, 'LayerNorm', etc |
|
|
""" |
|
|
if len(layer_sizes) == 0: |
|
|
return torch.nn.Identity() |
|
|
assert len(layer_sizes) > 1, "Need to provide input and output layer sizes at least" |
|
|
if isinstance(activation, str): |
|
|
TorchActivation: Type[torch.nn.Module] = getattr(torch.nn, activation) |
|
|
else: |
|
|
TorchActivation: Type[torch.nn.Module] = activation |
|
|
assert issubclass(TorchActivation, torch.nn.Module), TorchActivation |
|
|
if isinstance(norm, str): |
|
|
TorchNorm: Type[torch.nn.Module] = getattr(torch.nn, norm) |
|
|
elif norm is None: |
|
|
TorchNorm: Type[torch.nn.Module] = torch.nn.Identity |
|
|
else: |
|
|
TorchNorm: Type[torch.nn.Module] = norm |
|
|
assert issubclass(TorchNorm, torch.nn.Module), TorchNorm |
|
|
|
|
|
def make_norm_act(modules: dict[str, torch.nn.Module], empty: bool): |
|
|
return {} if empty else modules |
|
|
|
|
|
module = torch.nn.Sequential( |
|
|
*[ |
|
|
torch.nn.Sequential( |
|
|
collections.OrderedDict( |
|
|
{ |
|
|
"linear": torch.nn.Linear(in_features, out_features, bias=bias), |
|
|
**make_norm_act( |
|
|
{"norm": TorchNorm(out_features), "act": TorchActivation()}, |
|
|
empty=i == len(layer_sizes) - 2 and not activate_final, |
|
|
), |
|
|
} |
|
|
) |
|
|
) |
|
|
for (i, (in_features, out_features)) in enumerate( |
|
|
zip(layer_sizes[:-1], layer_sizes[1:], strict=True) |
|
|
) |
|
|
] |
|
|
) |
|
|
return module |
|
|
|
|
|
|
|
|
class PaliGemma3D(transformers.models.paligemma.PaliGemmaForConditionalGeneration): |
|
|
""" |
|
|
Transformers-like implementation of PaliGemma with additional depth encoder |
|
|
""" |
|
|
|
|
|
config_class = PaliGemma3DConfig |
|
|
|
|
|
def __init__(self, config: PaliGemma3DConfig): |
|
|
super().__init__(config) |
|
|
assert self.config.projection in ["features_add"] |
|
|
if self.config.projection in ["features_add"]: |
|
|
self.depth_tower = self._make_depth_encoder() |
|
|
self.depth_projector = torch.nn.Linear( |
|
|
in_features=self.depth_tower.embed_dim * self.config.depth_layers, |
|
|
out_features=self.config.text_config.hidden_size, |
|
|
) |
|
|
self.generator: torch.Generator = torch.Generator() |
|
|
if self.config.depth_only: |
|
|
make_module_non_trainable(self.vision_tower) |
|
|
make_module_non_trainable(self.siglip_projector) |
|
|
self.projector = None |
|
|
else: |
|
|
raise ValueError(f"Projection type `{self.config.projection}` not supported!") |
|
|
|
|
|
@property |
|
|
def projectors(self) -> Set[torch.nn.Module]: |
|
|
modules = set( |
|
|
( |
|
|
module |
|
|
for module in [ |
|
|
self.projector, |
|
|
self.depth_projector, |
|
|
self.siglip_projector, |
|
|
] |
|
|
if module is not None |
|
|
) |
|
|
) |
|
|
if isinstance(self.depth_tower, MoGe): |
|
|
modules = modules | { |
|
|
module |
|
|
for module in self.depth_tower.modules() |
|
|
if isinstance(module, (ResidualConvBlock, Head)) |
|
|
} |
|
|
return modules |
|
|
|
|
|
@property |
|
|
def siglip_projector(self) -> torch.nn.Linear: |
|
|
return self.multi_modal_projector.linear |
|
|
|
|
|
def _make_depth_encoder(self) -> torch.nn.Module: |
|
|
if self.config.is_single_image_size: |
|
|
kwargs = { |
|
|
"img_size": ( |
|
|
self.config.depth_config.image_sizes["main"]["height"], |
|
|
self.config.depth_config.image_sizes["main"]["width"], |
|
|
), |
|
|
"dynamic_img_size": False, |
|
|
} |
|
|
else: |
|
|
kwargs = { |
|
|
"img_size": ( |
|
|
self.config.depth_config.image_sizes["main"]["height"], |
|
|
self.config.depth_config.image_sizes["main"]["width"], |
|
|
), |
|
|
"dynamic_img_size": True, |
|
|
} |
|
|
model: timm.models.vision_transformer.VisionTransformer = timm.create_model( |
|
|
"vit_large_patch14_dinov2.lvd142m", |
|
|
pretrained=False, |
|
|
num_classes=0, |
|
|
**kwargs, |
|
|
) |
|
|
model.forward = functools.partial( |
|
|
model.forward_intermediates, |
|
|
indices=self.config.depth_layers, |
|
|
return_prefix_tokens=False, |
|
|
norm=True, |
|
|
stop_early=True, |
|
|
output_fmt="NLC", |
|
|
intermediates_only=True, |
|
|
) |
|
|
return model |
|
|
|
|
|
def _load_depth_model_state_dict(self, depth_model: torch.nn.Module): |
|
|
logging.info( |
|
|
f"Loading depth model from {self.config.depth_config.hf_hub_repo}/{self.config.depth_config.hf_filename}" |
|
|
) |
|
|
state_dict = torch.load( |
|
|
hf_hub_download( |
|
|
repo_id=self.config.depth_config.hf_hub_repo, |
|
|
filename=self.config.depth_config.hf_filename, |
|
|
), |
|
|
map_location="cpu", |
|
|
mmap=True, |
|
|
weights_only=False, |
|
|
) |
|
|
if self.config.projection in ["spatial_add", "spatial_concat"]: |
|
|
pos_embed_state_dict = {"pos_embed": state_dict["backbone.pos_embed"]} |
|
|
pos_embed_state_dict = timm.models.vision_transformer.checkpoint_filter_fn( |
|
|
pos_embed_state_dict, depth_model.backbone |
|
|
) |
|
|
state_dict["backbone.pos_embed"] = pos_embed_state_dict["pos_embed"] |
|
|
else: |
|
|
state_dict = timm.models.vision_transformer.checkpoint_filter_fn(state_dict, depth_model) |
|
|
depth_model.load_state_dict(state_dict) |
|
|
|
|
|
def get_image_features(self, pixel_values: dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
if self.config.projection == "features_add": |
|
|
images_forward = self._get_image_features_add |
|
|
elif self.config.projection in ["spatial_add", "spatial_concat"]: |
|
|
images_forward = self._get_image_features_spatial |
|
|
else: |
|
|
raise ValueError(f"Project type `{self.config.projection}` not supported!") |
|
|
camera_names = self.config.camera_names |
|
|
if self.config.is_single_image_size: |
|
|
inputs = { |
|
|
"siglip": einops.rearrange( |
|
|
torch.stack( |
|
|
[pixel_values[f"{camera_name}.siglip"] for camera_name in camera_names], |
|
|
dim=1, |
|
|
), |
|
|
"B N C H W -> (B N) C H W", |
|
|
), |
|
|
"depth": einops.rearrange( |
|
|
torch.stack( |
|
|
[pixel_values[f"{camera_name}.depth"] for camera_name in camera_names], |
|
|
dim=1, |
|
|
), |
|
|
"B N C H W -> (B N) C H W", |
|
|
), |
|
|
} |
|
|
image_tokens = images_forward(inputs) |
|
|
else: |
|
|
camera_tokens: List[torch.Tensor] = [ |
|
|
images_forward( |
|
|
{ |
|
|
"siglip": pixel_values[f"{camera_name}.siglip"], |
|
|
"depth": pixel_values[f"{camera_name}.depth"], |
|
|
} |
|
|
) |
|
|
for camera_name in camera_names |
|
|
] |
|
|
image_tokens = torch.cat(camera_tokens, dim=-2) |
|
|
return image_tokens |
|
|
|
|
|
def _get_image_features_add(self, pixel_values: dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
""" |
|
|
Obtains image last hidden states from the vision tower and apply multimodal projection. |
|
|
|
|
|
Args: |
|
|
pixel_values: images of shape `[B * num_images, C, H, W]`. |
|
|
Keys in the dict correspond to the specific vision encoder to use |
|
|
Returns: |
|
|
image_features of shape [B * num_images, num_tokens, token_size = 2048] |
|
|
""" |
|
|
siglip_input = pixel_values["siglip"] |
|
|
depth_input = pixel_values["depth"] |
|
|
siglip_output: ViTOutput = self.vision_tower(siglip_input) |
|
|
siglip_features = siglip_output.last_hidden_state |
|
|
depth_output: list[torch.Tensor] = self.depth_tower(depth_input) |
|
|
depth_features = torch.cat(depth_output, dim=-1) if len(depth_output) > 1 else depth_output[0] |
|
|
siglip_features = self.siglip_projector(siglip_features) |
|
|
depth_features = self.depth_projector(depth_features) |
|
|
if self.config.depth_only: |
|
|
image_features = depth_features |
|
|
elif self.training and torch.bernoulli(torch.tensor(self.config.mask_prob), generator=self.generator): |
|
|
num_tokens = depth_features.shape[1] |
|
|
device = depth_features.device |
|
|
ones = torch.ones((depth_features.shape[0], num_tokens), device=device) |
|
|
indices = torch.multinomial(ones / num_tokens, num_samples=num_tokens // 2) |
|
|
mask = ones.to(dtype=torch.bool).scatter_(dim=-1, index=indices, value=0) |
|
|
mask = mask.unsqueeze(-1) |
|
|
image_features = siglip_features * mask + depth_features * ~mask |
|
|
else: |
|
|
image_features = (siglip_features + depth_features) / 2 |
|
|
image_features = image_features / self.config.text_config.hidden_size**0.5 |
|
|
return image_features |
|
|
|
|
|
def _get_image_features_spatial(self, pixel_values: dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
""" |
|
|
Obtains image last hidden states from the vision tower and apply multimodal projection. |
|
|
|
|
|
Args: |
|
|
pixel_values: images of shape `[B * num_images, C, H, W]`. |
|
|
Keys in the dict correspond to the specific vision encoder to use |
|
|
Returns: |
|
|
image_features of shape [B * num_images, num_tokens, token_size = 2048] |
|
|
""" |
|
|
siglip_input = pixel_values["siglip"] |
|
|
depth_input = pixel_values["depth"] |
|
|
siglip_output: ViTOutput = self.vision_tower(siglip_input) |
|
|
siglip_features = siglip_output.last_hidden_state |
|
|
depth_output: dict[str, torch.Tensor] = self.depth_tower(depth_input) |
|
|
points = depth_output["points"] |
|
|
mask = depth_output["mask"] |
|
|
mask_binary = mask > self.depth_tower.mask_threshold |
|
|
points = torch.where(mask_binary, points, 0) |
|
|
points_embed: torch.Tensor = self.depth_projector(points) |
|
|
if self.config.projection == "spatial_concat": |
|
|
features = torch.cat([siglip_features, points_embed], dim=-1) |
|
|
image_features = self.projector(features) |
|
|
else: |
|
|
features = siglip_features + points_embed |
|
|
image_features = self.siglip_projector(features) |
|
|
image_features = image_features / self.config.text_config.hidden_size**0.5 |
|
|
return image_features |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, *args, **kwargs) -> "PaliGemma3D": |
|
|
model = super().from_pretrained(*args, **kwargs) |
|
|
model._load_depth_model_state_dict(model.depth_tower) |
|
|
return model |
|
|
|
|
|
|
|
|
class RobotStateProjector(ConfigurableModule): |
|
|
"""Pack robot state and project to a single token per timestep""" |
|
|
|
|
|
def __init__(self, config: RobotStateProjectorConfig): |
|
|
super().__init__(config) |
|
|
if self.config.fourier: |
|
|
raise NotImplementedError("Fourier robot state projector is not implemented yet") |
|
|
|
|
|
self.robot_state_tokens_proj = make_mlp( |
|
|
layer_sizes=self.config.layers, |
|
|
activation=self.config.activation, |
|
|
norm=torch.nn.LayerNorm, |
|
|
) |
|
|
|
|
|
def forward(self, inputs: RoboticsInput) -> Optional[torch.Tensor]: |
|
|
""" |
|
|
Returns: |
|
|
torch.Tensor of shape [B, num_past_steps, token_size] or None (if mode == 'none') |
|
|
""" |
|
|
if self.config.mode == "ee_pose": |
|
|
robot_state = torch.cat([inputs.ee_pose_translation, inputs.ee_pose_rotation], dim=-1) |
|
|
elif self.config.mode == "ee_pose_gripper": |
|
|
robot_state = torch.cat( |
|
|
[inputs.ee_pose_translation, inputs.ee_pose_rotation, inputs.gripper], |
|
|
dim=-1, |
|
|
) |
|
|
elif self.config.mode == "ee_pose_joints": |
|
|
robot_state = torch.cat( |
|
|
[inputs.ee_pose_translation, inputs.ee_pose_rotation, inputs.joints], |
|
|
dim=-1, |
|
|
) |
|
|
elif self.config.mode == "joints": |
|
|
robot_state = inputs.joints |
|
|
elif self.config.mode == "all": |
|
|
robot_state = torch.cat( |
|
|
[ |
|
|
inputs.ee_pose_translation, |
|
|
inputs.ee_pose_rotation, |
|
|
inputs.gripper, |
|
|
inputs.joints, |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
elif self.config.mode == "none": |
|
|
robot_state = torch.tensor([], device=inputs.ee_pose_translation.device).view( |
|
|
inputs.ee_pose_translation.shape[0], |
|
|
0, |
|
|
self.config.layers[0] if len(self.config.layers) > 0 else 0, |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError(f"Unknown image tokens mode {self.config.mode}") |
|
|
output = self.robot_state_tokens_proj(robot_state) |
|
|
return output |
|
|
|
|
|
|
|
|
class FourierFeatures(ConfigurableModule): |
|
|
def __init__(self, config: FourierFeaturesConfig): |
|
|
super().__init__(config) |
|
|
if self.config.learnable_features: |
|
|
self.linear = torch.nn.Linear( |
|
|
in_features=1, out_features=self.config.num_features // 2, bias=False |
|
|
) |
|
|
else: |
|
|
half_dim = self.config.num_features // 2 |
|
|
freqs = torch.log(torch.tensor(self.config.max_period)) / (half_dim - 1) |
|
|
freqs = torch.exp(-freqs * torch.arange(half_dim)) |
|
|
self.register_buffer("freqs", freqs) |
|
|
self.layers: torch.nn.Sequential = make_mlp( |
|
|
self.config.layers, |
|
|
activation=self.config.activation, |
|
|
norm=self.config.norm, |
|
|
activate_final=False, |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute Fourier features and project them via MLP |
|
|
Args: |
|
|
x: Input tensor of shape [..., 1] |
|
|
Returns: |
|
|
torch.Tensor: Fourier features of shape [..., num_features] or [..., layers[-1]] |
|
|
""" |
|
|
assert x.shape[-1] == 1 and x.ndim > 1, x.shape |
|
|
if self.config.learnable_features: |
|
|
frequencies = 2 * math.pi * self.linear(x) |
|
|
else: |
|
|
frequencies = x * expand_dims(self.freqs, x.ndim, [-1, 1]) |
|
|
output = torch.cat([torch.cos(frequencies), torch.sin(frequencies)], dim=-1) |
|
|
output = self.layers(output) |
|
|
return output |
|
|
|
|
|
|
|
|
class NoisedControlProjector(ConfigurableModule): |
|
|
"""Pack noised control (translation, rotation, gripper) and project to a single token per timestep""" |
|
|
|
|
|
def __init__(self, config: NoisedControlProjectorConfig): |
|
|
super().__init__(config) |
|
|
self.input_projector = torch.nn.Linear( |
|
|
in_features=self.config.layers[0], |
|
|
out_features=self.config.layers[1] // 2, |
|
|
bias=False, |
|
|
) |
|
|
self.time_embed = FourierFeatures(self.config.time_embed) |
|
|
self.layers = make_mlp( |
|
|
self.config.layers[1:], |
|
|
activation=self.config.activation, |
|
|
norm=self.config.norm, |
|
|
activate_final=False, |
|
|
bias=False, |
|
|
) |
|
|
|
|
|
def forward(self, inputs: FlowInput | DiffusionInput) -> Optional[torch.Tensor]: |
|
|
""" |
|
|
Returns: |
|
|
torch.Tensor of shape [B, num_control_timesteps, token_size] |
|
|
""" |
|
|
noised_controls = torch.cat([inputs.translation_t, inputs.rotation_t, inputs.gripper_t], dim=-1) |
|
|
noised_controls = self.input_projector(noised_controls) |
|
|
timestep = self.time_embed(inputs.timestep) |
|
|
timestep = timestep.expand(-1, noised_controls.shape[1], -1) |
|
|
features = torch.cat([timestep, noised_controls], dim=-1) |
|
|
output = self.layers(features) |
|
|
return output |
|
|
|
|
|
|
|
|
def unmask_unattended(attn_mask: torch.Tensor, mask_value: Optional[float] = None) -> torch.Tensor: |
|
|
""" |
|
|
Copy-pased from `transformers.modeling_attn_mask_utils.AttentionMaskConverter._unmask_unattended` |
|
|
|
|
|
Attend to all tokens in fully-masked rows. This is required by F.scaled_dot_product_attention |
|
|
memory-efficient attention path. Otherwise, results are NaN |
|
|
Details: https://github.com/pytorch/pytorch/issues/110213 |
|
|
|
|
|
Args: |
|
|
attn_mask: [B, 1 | num_heads, query_seq_len, kv_seq_len] or [B, query_seq_len, kv_seq_len], float dtype |
|
|
mask_value: The value inside `attn_mask` that corresponds to masked elements |
|
|
Returns: |
|
|
|
|
|
For example, if `attn_mask` is (e.g. here left-padding case) |
|
|
``` |
|
|
[ |
|
|
[[ |
|
|
[0, 0, 0], |
|
|
[0, 0, 0], |
|
|
[0, 0, 1] |
|
|
]], |
|
|
[[ |
|
|
[1, 0, 0], |
|
|
[1, 1, 0], |
|
|
[1, 1, 1] |
|
|
]], |
|
|
[[ |
|
|
[0, 0, 0], |
|
|
[0, 1, 0], |
|
|
[0, 1, 1] |
|
|
]] |
|
|
] |
|
|
``` |
|
|
then the modified `attn_mask` will be |
|
|
``` |
|
|
[ |
|
|
[[ |
|
|
[1, 1, 1], <-- modified |
|
|
[1, 1, 1], <-- modified |
|
|
[0, 0, 1] |
|
|
]], |
|
|
[[ |
|
|
[1, 0, 0], |
|
|
[1, 1, 0], |
|
|
[1, 1, 1] |
|
|
]], |
|
|
[[ |
|
|
[1, 1, 1], <-- modified |
|
|
[0, 1, 0], |
|
|
[0, 1, 1] |
|
|
]] |
|
|
] |
|
|
``` |
|
|
""" |
|
|
assert attn_mask.dtype.is_floating_point, attn_mask.dtype |
|
|
if mask_value is None: |
|
|
mask_value = torch.finfo(attn_mask.dtype).min |
|
|
return attn_mask * ~torch.all(attn_mask == mask_value, dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def attn_mask_to_float(attn_mask: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor: |
|
|
""" |
|
|
Convert a 4D mask of type bool to `dtype`. If the attn_mask isn't 4D or isn't bool, raise error |
|
|
""" |
|
|
assert attn_mask.ndim == 4, attn_mask.shape |
|
|
assert attn_mask.dtype == torch.bool, attn_mask.dtype |
|
|
if dtype is None: |
|
|
dtype = torch.get_autocast_dtype(attn_mask.device.type) |
|
|
mask_value = torch.finfo(dtype).min |
|
|
attn_mask = torch.zeros(attn_mask.shape, dtype=dtype, device=attn_mask.device).masked_fill( |
|
|
~attn_mask, mask_value |
|
|
) |
|
|
attn_mask = unmask_unattended(attn_mask, mask_value) |
|
|
return attn_mask |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def make_4d_float_attn_mask( |
|
|
attn_mask: Optional[torch.Tensor], |
|
|
query_seq_length: int, |
|
|
kv_seq_length: int, |
|
|
dtype: torch.dtype, |
|
|
device: torch.device, |
|
|
batch_size: int, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Creates a 4D mask of shape [B | 1, 1, query_length, kv_seq_length] from a 2D mask of shape [B, kv_seq_length]. |
|
|
If the input `attn_mask` is already 4D: if dtype=torch.bool, convert to dtype, else do nothing |
|
|
If the input is None, output is a full bi-directional attn_mask |
|
|
|
|
|
Args: |
|
|
attn_mask: A 2D attention mask of shape [B, kv_seq_length] or [B, 1, query_length, kv_seq_length] |
|
|
and dtype bool. False values indicate masked out positions |
|
|
query_seq_length: The query sequence length (L) |
|
|
kv_seq_length: The key-value sequence length (S). When `transformers.StaticCache` is used, this should |
|
|
equal the cache size to account for zero-padding the part of the cache that is not yet filled. |
|
|
dtype: Output dtype |
|
|
device: Output device |
|
|
batch_size: Batch size |
|
|
Returns: |
|
|
torch.Tensor of shape [B | 1, 1, query_length, kv_seq_length] (i.e. [B | 1, 1, L, S]). |
|
|
Contains zero at unmasked positions and `torch.finfo(dtype).min` at masked positions |
|
|
""" |
|
|
if attn_mask is not None and attn_mask.ndim == 4: |
|
|
if attn_mask.dtype == torch.bool: |
|
|
attn_mask = attn_mask_to_float(attn_mask, dtype=dtype) |
|
|
elif attn_mask.dtype != dtype: |
|
|
raise TypeError(f"Expected attn_mask.dtype={dtype}, but got {attn_mask.dtype}") |
|
|
return attn_mask |
|
|
mask_value = torch.finfo(dtype).min |
|
|
output_mask = torch.zeros([batch_size, 1, query_seq_length, kv_seq_length], dtype=dtype, device=device) |
|
|
if attn_mask is not None: |
|
|
assert attn_mask.dtype == torch.bool, f"Unsupported dtype {attn_mask.dtype}" |
|
|
mask_length = attn_mask.shape[-1] |
|
|
if mask_length != kv_seq_length: |
|
|
raise NotImplementedError(f"{mask_length} != {kv_seq_length} not properly supported yet") |
|
|
inverted_mask = ~attn_mask.view(batch_size, 1, 1, mask_length) |
|
|
output_mask[..., :mask_length] = output_mask[..., :mask_length].masked_fill(inverted_mask, mask_value) |
|
|
return output_mask |
|
|
|
|
|
|
|
|
class VLMInput(Protocol): |
|
|
input_ids: torch.Tensor |
|
|
attn_mask: torch.Tensor |
|
|
images: Dict[str, torch.Tensor] |
|
|
multimodal_indices: torch.Tensor |
|
|
unimodal_indices: torch.Tensor |
|
|
|
|
|
@property |
|
|
def inputs_embeds(self) -> Optional[torch.Tensor]: |
|
|
return None |
|
|
|
|
|
@property |
|
|
def past_key_values(self) -> Optional[List[torch.Tensor]]: |
|
|
return None |
|
|
|
|
|
|
|
|
def zero_out_param_pretrained_grad( |
|
|
param: torch.nn.Parameter | torch.distributed.tensor.DTensor, |
|
|
module: torch.nn.Module, |
|
|
) -> None: |
|
|
"""Zero out the gradients of pretrained embeddings""" |
|
|
module.mask = module.mask.to(param.device) |
|
|
if isinstance(param, torch.distributed.tensor.DTensor) and not isinstance( |
|
|
module.mask, torch.distributed.tensor.DTensor |
|
|
): |
|
|
module.mask = torch.distributed.tensor.distribute_tensor( |
|
|
module.mask, device_mesh=param.device_mesh, placements=param.placements |
|
|
) |
|
|
mask = module.mask |
|
|
if type(param) is torch.distributed.tensor.DTensor and type(mask) is torch.distributed.tensor.DTensor: |
|
|
assert param.grad.shape == mask.shape, f"{param.grad.shape} != {mask.shape}" |
|
|
param.grad._local_tensor = torch.where(mask._local_tensor, param.grad._local_tensor, 0) |
|
|
elif type(param) in (torch.Tensor, torch.nn.Parameter) and type(mask) is torch.Tensor: |
|
|
assert param.grad.shape == mask.shape, f"{param.grad.shape} != {mask.shape}" |
|
|
param.grad = torch.where(mask, param.grad, 0) |
|
|
elif type(param) in (torch.Tensor, torch.nn.Parameter) and type(mask) is torch.distributed.tensor.DTensor: |
|
|
mask = mask.full_tensor() |
|
|
assert param.grad.shape == mask.shape, f"{param.grad.shape} != {mask.shape}" |
|
|
param.grad = torch.where(mask, param.grad, 0) |
|
|
else: |
|
|
raise ValueError(f"Unsupported parameter type: {type(param)} and mask type: {type(mask)}") |
|
|
|
|
|
|
|
|
class PaliGemmaVLM(ConfigurableModule): |
|
|
"""Wraps PaliGemma to make compatible with VLM API""" |
|
|
|
|
|
def __init__(self, config: PaliGemmaVLMConfig): |
|
|
super().__init__(config) |
|
|
if self.config.with_depth: |
|
|
config = PaliGemma3DConfig.from_pretrained( |
|
|
self.config.model_id, **self.config.paligemma_3d_config_dict |
|
|
) |
|
|
self.model = PaliGemma3D.from_pretrained( |
|
|
self.config.model_id, |
|
|
config=config, |
|
|
attn_implementation=self.config.attn_implementation, |
|
|
) |
|
|
else: |
|
|
self.model = transformers.AutoModelForVision2Seq.from_pretrained( |
|
|
self.config.model_id, |
|
|
attn_implementation=self.config.attn_implementation, |
|
|
) |
|
|
hf_processor = transformers.AutoProcessor.from_pretrained(self.config.model_id) |
|
|
self.processor = PaliGemmaDepthProcessor( |
|
|
config=self.config.processor_config, |
|
|
hf_processor=hf_processor, |
|
|
depth_tokens=self.config.depth_tokens, |
|
|
) |
|
|
self._resize_siglip_image_input() |
|
|
self._maybe_override_get_image_features() |
|
|
if self.config.depth_tokens > 0: |
|
|
self._resize_llm_token_embeddings(hf_processor.tokenizer) |
|
|
if not self.config.lm_head: |
|
|
self.model.language_model.lm_head = torch.nn.Identity() |
|
|
self.model.train(True) |
|
|
for decoder in self.model.language_model.model.layers: |
|
|
decoder.self_attn.is_causal = False |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
inputs: VLMInput, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
**kwargs, |
|
|
) -> VLMOutput: |
|
|
del kwargs |
|
|
self._maybe_register_zero_out_grad_hooks() |
|
|
cache = transformers.DynamicCache() |
|
|
if inputs.attn_mask.ndim == 4: |
|
|
attn_mask = attn_mask_to_float(inputs.attn_mask) |
|
|
else: |
|
|
attn_mask = inputs.attn_mask |
|
|
images = { |
|
|
encoder_camera_name: camera_images.view(-1, *camera_images.shape[2:]) |
|
|
for (encoder_camera_name, camera_images) in inputs.images.items() |
|
|
} |
|
|
llm_output: transformers.models.paligemma.modeling_paligemma.PaliGemmaCausalLMOutputWithPast = ( |
|
|
self.model( |
|
|
input_ids=inputs.input_ids, |
|
|
pixel_values=images, |
|
|
attention_mask=attn_mask, |
|
|
use_cache=use_cache, |
|
|
past_key_values=cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=True, |
|
|
) |
|
|
) |
|
|
indices = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device=inputs.input_ids.device) |
|
|
image_indices = indices[inputs.input_ids[0] == self.processor.hf_processor.image_token_id] |
|
|
text_indices = indices[inputs.input_ids[0] != self.processor.hf_processor.image_token_id] |
|
|
output = VLMOutput( |
|
|
llm_output=LLMOutput.from_transformers( |
|
|
input_ids=inputs.input_ids, |
|
|
llm_output=llm_output, |
|
|
text_indices=text_indices, |
|
|
image_indices=image_indices, |
|
|
), |
|
|
vit_tokens=llm_output.image_hidden_states, |
|
|
attn_mask=inputs.attn_mask, |
|
|
) |
|
|
return output |
|
|
|
|
|
@property |
|
|
def fsdp_wrap_modules(self) -> Set[torch.nn.Module]: |
|
|
transformer_modules = { |
|
|
module |
|
|
for module in self.modules() |
|
|
if isinstance( |
|
|
module, |
|
|
( |
|
|
transformers.models.siglip.modeling_siglip.SiglipEncoderLayer, |
|
|
transformers.models.siglip.modeling_siglip.SiglipVisionTransformer, |
|
|
timm.models.vision_transformer.Block, |
|
|
timm.models.vision_transformer.VisionTransformer, |
|
|
transformers.models.gemma.modeling_gemma.GemmaDecoderLayer, |
|
|
), |
|
|
) |
|
|
or module |
|
|
in ( |
|
|
self.model.language_model.model.embed_tokens, |
|
|
self.model.language_model.model.norm, |
|
|
) |
|
|
} |
|
|
if self.config.with_depth: |
|
|
projectors = self.model.projectors |
|
|
else: |
|
|
projectors = {self.model.multi_modal_projector} |
|
|
return projectors | transformer_modules |
|
|
|
|
|
def _resize_siglip_image_input(self) -> None: |
|
|
""" |
|
|
Enables resizing SigLIP positional embeddings to a new image size. |
|
|
""" |
|
|
num_image_tokens: int = self.config.processor_config.num_image_tokens["main"] |
|
|
image_size: Dict[str, int] = self.config.processor_config.image_sizes["main"].as_json() |
|
|
siglip_embeddings: transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings = ( |
|
|
self.model.vision_tower.vision_model.embeddings |
|
|
) |
|
|
embedding_weight: torch.Tensor = timm.layers.pos_embed.resample_abs_pos_embed( |
|
|
posemb=siglip_embeddings.position_embedding.weight.unsqueeze(0), |
|
|
new_size=(image_size["height"] // 14, image_size["width"] // 14), |
|
|
old_size=( |
|
|
siglip_embeddings.image_size // 14, |
|
|
siglip_embeddings.image_size // 14, |
|
|
), |
|
|
num_prefix_tokens=0, |
|
|
interpolation="bicubic", |
|
|
antialias=True, |
|
|
verbose=False, |
|
|
).squeeze(0) |
|
|
with torch.no_grad(): |
|
|
siglip_embeddings.position_embedding.weight.data = embedding_weight |
|
|
siglip_embeddings.num_patches = siglip_embeddings.num_positions = num_image_tokens |
|
|
siglip_embeddings.image_size = dict(image_size) |
|
|
siglip_embeddings.register_buffer( |
|
|
"position_ids", |
|
|
torch.arange(siglip_embeddings.num_positions).expand((1, -1)), |
|
|
persistent=False, |
|
|
) |
|
|
|
|
|
def interpolate_pos_encoding(embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: |
|
|
del embeddings |
|
|
new_size = (height // 14, width // 14) |
|
|
old_size = (image_size["height"] // 14, image_size["width"] // 14) |
|
|
if old_size == new_size: |
|
|
patch_pos_embedding = siglip_embeddings.position_embedding.weight |
|
|
else: |
|
|
patch_pos_embedding: torch.Tensor = timm.layers.pos_embed.resample_abs_pos_embed( |
|
|
posemb=siglip_embeddings.position_embedding.weight.unsqueeze(0), |
|
|
new_size=new_size, |
|
|
old_size=( |
|
|
image_size["height"] // 14, |
|
|
image_size["width"] // 14, |
|
|
), |
|
|
num_prefix_tokens=0, |
|
|
interpolation="bicubic", |
|
|
antialias=True, |
|
|
verbose=False, |
|
|
).squeeze(0) |
|
|
return patch_pos_embedding |
|
|
|
|
|
siglip_embeddings.interpolate_pos_encoding = interpolate_pos_encoding |
|
|
if not self.processor.config.is_single_image_size: |
|
|
self.model.vision_tower.forward = functools.partial( |
|
|
self.model.vision_tower.forward, interpolate_pos_encoding=True |
|
|
) |
|
|
|
|
|
def _maybe_override_get_image_features(self) -> None: |
|
|
""" |
|
|
Override PaliGemmaForConditionalGeneration.get_image_features() from transformers |
|
|
such that it can handle multiple cameras with different resolutions. |
|
|
""" |
|
|
if self.config.with_depth: |
|
|
return |
|
|
images_forward = self.model.get_image_features |
|
|
camera_names: List[str] = self.processor.config.camera_names |
|
|
|
|
|
def get_image_features(pixel_values: torch.Tensor) -> torch.Tensor: |
|
|
if self.processor.config.is_single_image_size: |
|
|
inputs = einops.rearrange( |
|
|
torch.stack( |
|
|
[pixel_values[f"{camera_name}.siglip"] for camera_name in camera_names], |
|
|
dim=1, |
|
|
), |
|
|
"B N C H W -> (B N) C H W", |
|
|
) |
|
|
image_tokens = images_forward(inputs) |
|
|
else: |
|
|
camera_tokens: List[torch.Tensor] = [ |
|
|
images_forward(pixel_values[f"{camera_name}.siglip"]) for camera_name in camera_names |
|
|
] |
|
|
image_tokens = torch.cat(camera_tokens, dim=-2) |
|
|
return image_tokens |
|
|
|
|
|
self.model.get_image_features = get_image_features |
|
|
|
|
|
def _resize_llm_token_embeddings(self, tokenizer: transformers.PreTrainedTokenizer) -> None: |
|
|
assert self.config.depth_tokens > 0, self.config.depth_tokens |
|
|
tokenizer.add_tokens([f"<dist{i:04d}>" for i in range(self.config.depth_tokens)]) |
|
|
total_num_tokens = len(tokenizer) |
|
|
vocab_size = tokenizer.vocab_size |
|
|
llm = self.model.language_model |
|
|
(_, hidden_size) = llm.lm_head.weight.shape |
|
|
self.model.resize_token_embeddings( |
|
|
total_num_tokens, |
|
|
pad_to_multiple_of=64, |
|
|
mean_resizing=self.config.mean_resizing, |
|
|
) |
|
|
if self.config.train_only_depth_tokens: |
|
|
assert len(self.model.language_model._tied_weights_keys) > 0 |
|
|
(weight_size, hidden_size) = llm.lm_head.weight.shape |
|
|
mask = torch.cat( |
|
|
[ |
|
|
torch.zeros([vocab_size, hidden_size], dtype=torch.bool), |
|
|
torch.ones([weight_size - vocab_size, hidden_size], dtype=torch.bool), |
|
|
], |
|
|
dim=0, |
|
|
) |
|
|
self.mask = mask |
|
|
self.embed_handle = None |
|
|
self.lm_head_handle = None |
|
|
self._maybe_register_zero_out_grad_hooks() |
|
|
|
|
|
def _maybe_register_zero_out_grad_hooks(self) -> None: |
|
|
""" |
|
|
Register hooks to zero out the gradients of pretrained embeddings and LM head. |
|
|
Skips registering the hooks if they already exist. This runs at every step as wrapping |
|
|
in FSDP removes any hooks that were registered on the *parameters* of the original module |
|
|
and the only way to run this reliably is to check if the hooks exist. |
|
|
""" |
|
|
if not self.config.train_only_depth_tokens: |
|
|
return |
|
|
llm = self.model.language_model |
|
|
if ( |
|
|
self.embed_handle is None |
|
|
or llm.model.embed_tokens.weight._post_accumulate_grad_hooks is None |
|
|
or self.embed_handle.id not in llm.model.embed_tokens.weight._post_accumulate_grad_hooks |
|
|
): |
|
|
self.embed_handle = llm.model.embed_tokens.weight.register_post_accumulate_grad_hook( |
|
|
functools.partial(zero_out_param_pretrained_grad, module=self) |
|
|
) |
|
|
if llm.model.embed_tokens.weight is not llm.lm_head.weight and ( |
|
|
self.lm_head_handle is None |
|
|
or llm.lm_head.weight._post_accumulate_grad_hooks is None |
|
|
or self.lm_head_handle.id not in llm.lm_head.weight._post_accumulate_grad_hooks |
|
|
): |
|
|
self.lm_head_handle = llm.lm_head.weight.register_post_accumulate_grad_hook( |
|
|
functools.partial(zero_out_param_pretrained_grad, module=self) |
|
|
) |
|
|
|
|
|
|
|
|
def make_position_indices( |
|
|
position_indices: Optional[torch.Tensor], |
|
|
seq_length: int, |
|
|
device: torch.device, |
|
|
max_seq_length: Optional[int], |
|
|
) -> torch.Tensor: |
|
|
if position_indices is not None: |
|
|
position_indices = position_indices.to(dtype=torch.int64) |
|
|
else: |
|
|
position_indices = torch.arange(seq_length, dtype=torch.int64, device=device).view(1, -1) |
|
|
if not torch.max(position_indices) < max_seq_length: |
|
|
raise IndexError( |
|
|
f"position_indices={position_indices} contains index out of bounds of num_embeddings={max_seq_length}" |
|
|
) |
|
|
return position_indices |
|
|
|
|
|
|
|
|
class RotaryPositionalEncoding(ConfigurableModule): |
|
|
""" |
|
|
Rotary Positional Embeddings (RoPE) from https://arxiv.org/abs/2104.09864 |
|
|
Reference implementations: |
|
|
- https://github.com/meta-llama/llama/blob/main/llama/model.py#L80 |
|
|
- transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding |
|
|
- transformers.models.llama.modeling_llama.LlamaRotaryEmbedding |
|
|
|
|
|
If cached=True, we cache the embeddings for each position up to `num_embeddings` |
|
|
""" |
|
|
|
|
|
def __init__(self, config: RotaryPositionalEncodingConfig): |
|
|
super().__init__(config) |
|
|
inv_freq = 1.0 / self.config.base ** ( |
|
|
torch.arange(0, self.config.embedding_dim, 2, dtype=torch.float32) / self.config.embedding_dim |
|
|
) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
self._build_cache() |
|
|
|
|
|
def _build_cache(self) -> None: |
|
|
if not self.config.cached: |
|
|
return |
|
|
position_indices = torch.arange(self.config.num_embeddings, dtype=torch.float32) |
|
|
indices_inv_freq = torch.einsum("i, j -> ij", position_indices, self.inv_freq) |
|
|
sin = torch.sin(indices_inv_freq) |
|
|
cos = torch.cos(indices_inv_freq) |
|
|
self.register_buffer("sin_cache", sin, persistent=False) |
|
|
self.register_buffer("cos_cache", cos, persistent=False) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
tokens: torch.Tensor, |
|
|
position_indices: Optional[torch.Tensor] = None, |
|
|
apply: bool = True, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
tokens: torch.Tensor of shape [B, ..., S, head_dim], where `...` might be any number of dims |
|
|
position_indices: torch.Tensor of shape [B | 1, S]. The indices of tokens within the sequence |
|
|
apply: If True, apply the positional embedding on tokens and return the result |
|
|
Returns: |
|
|
torch.Tensor of the same shape as `tokens` with positional embedding applied on tokens |
|
|
""" |
|
|
assert apply, f"{self.__class__} does not support applying embeddings externally" |
|
|
position_indices = make_position_indices( |
|
|
position_indices, |
|
|
seq_length=tokens.shape[-2], |
|
|
device=tokens.device, |
|
|
max_seq_length=self.config.num_embeddings, |
|
|
) |
|
|
if self.config.cached: |
|
|
sin = self.sin_cache[position_indices] |
|
|
cos = self.cos_cache[position_indices] |
|
|
sin = torch.cat([sin, sin], dim=-1) |
|
|
cos = torch.cat([cos, cos], dim=-1) |
|
|
else: |
|
|
inv_freq = self.inv_freq.view(1, -1, 1).to(dtype=torch.float32) |
|
|
position_indices = position_indices.to(dtype=torch.float32).unsqueeze(1) |
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings( |
|
|
"ignore", |
|
|
message="In CPU autocast, but the target dtype is not supported. Disabling autocast.", |
|
|
) |
|
|
with torch.autocast(device_type=tokens.device.type, dtype=torch.float32): |
|
|
freqs = (inv_freq @ position_indices).transpose(1, 2) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
(sin, cos) = (torch.sin(emb), torch.cos(emb)) |
|
|
(sin, cos) = (sin.to(dtype=tokens.dtype), cos.to(dtype=tokens.dtype)) |
|
|
sin = expand_dims(sin, tokens.ndim, order=[1, -1, 1, 1]) |
|
|
cos = expand_dims(cos, tokens.ndim, order=[1, -1, 1, 1]) |
|
|
tokens = tokens * cos + self._rotate_invert_half(tokens) * sin |
|
|
return tokens |
|
|
|
|
|
@staticmethod |
|
|
def _rotate_invert_half(x: torch.Tensor) -> torch.Tensor: |
|
|
x1 = x[..., : x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2 :] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
|
|
|
EAGER_ATTN = "eager" |
|
|
|
|
|
SDPA_ATTN = "sdpa" |
|
|
|
|
|
FLASH_ATTN = "flash_attention_2" |
|
|
|
|
|
|
|
|
def is_full_attn(attn_mask: Optional[torch.Tensor]) -> bool: |
|
|
""" |
|
|
Return True if attn_mask doesn't contain any masked out positions, False otherwise |
|
|
""" |
|
|
if attn_mask is None: |
|
|
return True |
|
|
if attn_mask.dtype == torch.bool: |
|
|
return torch.all(attn_mask == 1).item() |
|
|
if attn_mask.dtype.is_floating_point: |
|
|
return torch.all(attn_mask == 0).item() |
|
|
raise TypeError(f"Unrecognized dtype {attn_mask.dtype}") |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def make_attn_mask_causal(attn_mask: torch.Tensor, cache_position: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
attn_mask: 4D tensor of shape [B | 1, 1, query_seq_len, kv_seq_len] (i.e. [B | 1, 1, L, S]) of float |
|
|
dtype (NOT bool!). Masked positions contain the value `torch.finfo(dtype).min` |
|
|
cache_position: torch.Tensor of type torch.int64 and shape [query_seq_len]. Contained values |
|
|
are index positions of the query tokens in the sequence. During training, this would usually |
|
|
be torch.arange(query_seq_len), but during generate, this would usually be a tensor sequence |
|
|
with 1 element indicating the position of the token currently being generated |
|
|
Returns: |
|
|
torch.Tensor of the same shape as attn_mask. Contains zero at unmasked positions and |
|
|
`torch.finfo(dtype).min` at masked positions |
|
|
""" |
|
|
if attn_mask.dtype.is_floating_point: |
|
|
mask_value = torch.finfo(attn_mask.dtype).min |
|
|
elif attn_mask.dtype == torch.bool: |
|
|
mask_value = 0 |
|
|
else: |
|
|
raise TypeError(f"Unsupported mask type {attn_mask.dtype}") |
|
|
(_, _, query_seq_length, kv_seq_length) = attn_mask.shape |
|
|
causal_mask = torch.ones(attn_mask.shape, dtype=torch.bool, device=attn_mask.device) |
|
|
causal_mask = torch.triu(causal_mask, diagonal=1) |
|
|
causal_mask = causal_mask * ( |
|
|
torch.arange(kv_seq_length, device=cache_position.device).view(1, -1) > cache_position.view(-1, 1) |
|
|
).view(*[1] * (causal_mask.ndim - 2), query_seq_length, kv_seq_length) |
|
|
causal_attn_mask = attn_mask.masked_fill_(causal_mask, mask_value) |
|
|
return causal_attn_mask |
|
|
|
|
|
|
|
|
def update_attn_mask( |
|
|
attn_mask: Optional[torch.Tensor], |
|
|
attn_implementation: str, |
|
|
query_seq_length: int, |
|
|
kv_seq_length: int, |
|
|
cache_position: Optional[torch.Tensor], |
|
|
cache: Optional[transformers.Cache], |
|
|
batch_size: int, |
|
|
causal: bool, |
|
|
dtype: torch.dtype, |
|
|
device: torch.device, |
|
|
output_attentions: bool = False, |
|
|
) -> Optional[torch.Tensor]: |
|
|
""" |
|
|
Update attn_mask such that it's compatible with the attention implementation. |
|
|
Meant to be used with barrel.components.nn.layers.attention.MultiheadAttention and its derivatives |
|
|
|
|
|
Args: |
|
|
attn_mask: dtype torch.bool, torch.float32, torch.float16 or torch.bfloat16 and shape one of: |
|
|
- [B, kv_seq_length] (i.e. [B, S]) |
|
|
- [B, 1, query_seq_length, kv_seq_length] (i.e. [B, 1, L, S]) |
|
|
- [1, 1, query_seq_length, kv_seq_length] (i.e. [L, S]) |
|
|
If bool, False values indicate masked positions. |
|
|
If float, must contain only 0.0 and torch.finfo(dtype).min |
|
|
If attn_mask is None, full-bidirectional attention is assumed. The output might be None or |
|
|
a tensor. Refer to the return value documentation |
|
|
attn_implementation: One of [FLASH_ATTN, SDPA_ATTN, EAGER_ATTN] |
|
|
query_seq_length: The query sequence length (L) |
|
|
kv_seq_length: The key-value sequence length (S) |
|
|
cache_position: dtype torch.int64, shape [query_seq_len]. Used only when causal=True. |
|
|
Contained values are index positions of the query tokens in the sequence. During training, |
|
|
this would usually be torch.arange(query_seq_len), but during generate, this would usually be |
|
|
a tensor sequence with 1 element indicating the position of the token currently being generated. |
|
|
If None, default `cache_positions` are autocomputed from `query_seq_length` and cache size |
|
|
cache: Optional cache. Usually not None when running generate at inference. |
|
|
batch_size: Batch size of the generated attention mask |
|
|
causal: If True, make the attn_mask causal -> all non-causal positions are masked out, regardless |
|
|
of their attn_mask values. When using flash attention or SDPA and `causal == False`, make sure |
|
|
to pass `causal` to the attention operation, in case this function delegates causal masking |
|
|
dtype: dtype of the output attention mask. Must be the dtype of the attn computation |
|
|
device: device of the output attention mask |
|
|
output_attentions: If True, the attention operation is required to output attention weights |
|
|
Returns: |
|
|
- `None` in either of these cases: |
|
|
- `attn_mask` doesn't contain any masked out positions and causal=False |
|
|
- `attn_implementation in [FLASH_ATTN, SDPA_ATTN]` and `attn_mask` doesn't contain any |
|
|
masked out positions. If causal=True, we instead rely on the causal argument to |
|
|
flash attention or `torch.nn.functional.scaled_dot_product_attention`. This happens |
|
|
only if the cache is empty and cache_position is None |
|
|
- `attn_mask` if `attn_implementation == FLASH_ATTN` and `attn_mask` can't be ignored TODO(FLASH) |
|
|
- torch.Tensor of shape [B, 1, query_length, kv_seq_length] (i.e. [B, 1, L, S]) and type `dtype`. |
|
|
Contains zero at unmasked positions and `torch.finfo(dtype).min` at masked positions. |
|
|
""" |
|
|
assert attn_implementation in [FLASH_ATTN, SDPA_ATTN, EAGER_ATTN] |
|
|
assert dtype.is_floating_point, dtype |
|
|
if torch.jit.is_tracing() or torch.jit.is_scripting() or torch.compiler.is_compiling(): |
|
|
raise NotImplementedError("Complete correctness not confirmed yet") |
|
|
if isinstance(cache, transformers.StaticCache): |
|
|
if attn_mask is not None and attn_mask.shape[-1] != cache.get_max_cache_shape(): |
|
|
raise NotImplementedError("Complete correctness not confirmed yet") |
|
|
full_attn = is_full_attn(attn_mask) |
|
|
past_seen_tokens = cache.get_seq_length() if cache is not None else 0 |
|
|
if full_attn and not causal: |
|
|
return None |
|
|
if ( |
|
|
full_attn |
|
|
and causal |
|
|
and attn_implementation in [SDPA_ATTN, FLASH_ATTN] |
|
|
and past_seen_tokens == 0 |
|
|
and cache_position is None |
|
|
): |
|
|
return None |
|
|
past_seen_tokens = cache.get_seq_length() if cache is not None else 0 |
|
|
static_cache = isinstance(cache, transformers.StaticCache) |
|
|
if static_cache and kv_seq_length < cache.get_max_cache_shape(): |
|
|
kv_seq_length = cache.get_max_cache_shape() |
|
|
elif attn_mask is not None: |
|
|
assert kv_seq_length == attn_mask.shape[-1], f"{kv_seq_length}, {attn_mask.shape}" |
|
|
output_mask = make_4d_float_attn_mask( |
|
|
attn_mask=attn_mask, |
|
|
query_seq_length=query_seq_length, |
|
|
kv_seq_length=kv_seq_length, |
|
|
dtype=dtype, |
|
|
device=device, |
|
|
batch_size=batch_size, |
|
|
) |
|
|
if causal: |
|
|
cache_position = ( |
|
|
torch.arange(past_seen_tokens, past_seen_tokens + query_seq_length, device=device) |
|
|
if cache_position is None |
|
|
else cache_position |
|
|
) |
|
|
output_mask = make_attn_mask_causal(output_mask, cache_position) |
|
|
if ( |
|
|
attn_implementation == SDPA_ATTN |
|
|
and attn_mask is not None |
|
|
and attn_mask.device.type == "cuda" |
|
|
and not output_attentions |
|
|
): |
|
|
output_mask = unmask_unattended(output_mask, mask_value=torch.finfo(dtype).min) |
|
|
return output_mask |
|
|
|
|
|
|
|
|
def expand_kv_heads(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
|
""" |
|
|
The equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). Convert hidden_states from |
|
|
[batch, num_kv_heads, seqlen, head_dim] -> [batch, num_attention_heads, seqlen, head_dim] |
|
|
""" |
|
|
(batch, num_kv_heads, slen, head_dim) = hidden_states.shape |
|
|
if n_rep == 1: |
|
|
return hidden_states |
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim) |
|
|
return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim) |
|
|
|
|
|
|
|
|
class MultiheadAttention(torch.nn.Module): |
|
|
""" |
|
|
Multi-headed attention from 'Attention Is All You Need' paper |
|
|
|
|
|
Different implementation from torch.nn.MultiheadAttention to support: |
|
|
- Easy switch between EAGER_ATTN, SDPA_ATTN and FLASH_ATTN |
|
|
- Number of key-value heads different from query heads |
|
|
- Key-value cache during forward, in the same way as transformers. Useful for generation or |
|
|
cross-attention to projected keys and values |
|
|
- Ability to apply positional encodings to key and value after input linear projection |
|
|
- Different linear projection output size |
|
|
|
|
|
Adapted from transformers.models.llama.modeling_llama.LlamaAttention |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_features: int, |
|
|
num_heads: int, |
|
|
head_dim: Optional[int] = None, |
|
|
out_features: Optional[int] = None, |
|
|
key_features: Optional[int] = None, |
|
|
value_features: Optional[int] = None, |
|
|
num_kv_heads: Optional[int] = None, |
|
|
bias: bool = False, |
|
|
dropout: float = 0.0, |
|
|
cache_layer: Optional[int] = None, |
|
|
query_position_embed: Optional[Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]] = None, |
|
|
key_position_embed: Optional[Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]] = None, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
in_features: Input dimension for query linear projection. |
|
|
num_heads: Number of heads for query |
|
|
head_dim: Head dimension. If None, defaults to `in_features // num_heads` |
|
|
out_features: Output dimension for the output linear layer. If None, defaults to `in_features` |
|
|
key_features: Input dimension for key linear projection. If None, defaults to `in_features` |
|
|
value_features: Input dimension for value linear projection. If None, defaults to `in_features` |
|
|
num_kv_heads: Number of heads for keys and values. If None, defaults to `num_heads` |
|
|
cache_layer: Index of the layer in the cache. Needed only when `cache` is passed to |
|
|
the `forward()` call, usually during generation or when the projected keys and values need |
|
|
to be cached during training. Can be omitted when `cache_layer` is passed to `forward` |
|
|
position_embed: Callable that takes as input linearly projected query and key and a tuple of |
|
|
positional embeddings and returns query and key with positional embeddings applied. Note |
|
|
these embeddings are applied after linear projection. If you want to apply embeddings before |
|
|
the linear projection, do so before calling the forward method and use the default value |
|
|
for `position_embed`, which is a simple pass-through. Note you can also pass torch.nn.Module |
|
|
key_position_embed: Callable that takes as input linearly projected key and optional positional |
|
|
index in the sequence and returns key with positional embeddings applied. |
|
|
positional embeddings and returns query and key with positional embeddings applied. Note |
|
|
these embeddings are applied after linear projection. If you want to apply embeddings before |
|
|
the linear projection, do so before calling the forward method and use the default value |
|
|
for `position_embed`, which is a simple pass-through. Note you can also pass torch.nn.Module |
|
|
""" |
|
|
super().__init__() |
|
|
self.in_features = in_features |
|
|
self.key_features = key_features or in_features |
|
|
self.value_features = value_features or in_features |
|
|
self.bias = bias |
|
|
self.out_features = out_features or in_features |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = head_dim or in_features // num_heads |
|
|
self.num_kv_heads = num_kv_heads or num_heads |
|
|
self.dropout = dropout |
|
|
self.query_position_embed = query_position_embed |
|
|
self.key_position_embed = key_position_embed |
|
|
self.cache_layer = cache_layer |
|
|
self.q_proj = torch.nn.Linear(self.in_features, self.num_heads * self.head_dim, bias=self.bias) |
|
|
self.k_proj = torch.nn.Linear(self.key_features, self.num_kv_heads * self.head_dim, bias=self.bias) |
|
|
self.v_proj = torch.nn.Linear(self.value_features, self.num_kv_heads * self.head_dim, bias=self.bias) |
|
|
self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.out_features, bias=self.bias) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
query: torch.Tensor, |
|
|
key: torch.Tensor, |
|
|
value: torch.Tensor, |
|
|
attn_mask: Optional[torch.Tensor] = None, |
|
|
is_causal: bool = False, |
|
|
query_position_indices: Optional[torch.Tensor] = None, |
|
|
key_position_indices: Optional[torch.Tensor] = None, |
|
|
cache: Optional[transformers.Cache] = None, |
|
|
cache_layer: Optional[int] = None, |
|
|
output_attentions: bool = False, |
|
|
cache_kwargs: Dict[str, Any] = {}, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
""" |
|
|
Args: |
|
|
query: Query embedding of shape [B, L, in_features] |
|
|
key: Key embedding of shape [B, S, key_features] |
|
|
value: Value embedding of shape [B, S, value_features] |
|
|
attn_mask: dtype torch.bool or same dtype as query/key/value and shape one of: |
|
|
- [B, S] |
|
|
- [B | 1, 1 | num_heads, L, S] |
|
|
If bool, False values indicate masked positions (opposite of torch.nn.MultiheadAttention) |
|
|
If float, must contain only 0.0 and torch.finfo(dtype).min |
|
|
If attn_mask is None, full-bidirectional attention or causal attention is used depdening |
|
|
on the value of `is_causal`. |
|
|
is_causal: If True, apply additional causal masking to `attn_mask` |
|
|
query_position_indices: torch.Tensor of shape [1 | B, L] containing the indices of the `query` |
|
|
tokens within the entire sequence. Passed through to query_position_embed. If None and `cache` |
|
|
is not None, indices are autogenerated [0, 1, ..., L] and offset by `cache_size` |
|
|
key_position_indices: Same as `query_position_indices`, but applied to key |
|
|
cache: transformers.Cache containing cached key-value pairs. The linearly projected |
|
|
`key` and `value` passed to this function get added to the cache and concatenated after the |
|
|
key-value pairs in the cache and then attention is computed on the concatenated sequence. |
|
|
This is most commonly used at inference when generating auto-regressively or when one needs |
|
|
to cross attend to the keys and values outside this module forward pass. |
|
|
cache_layer: Index of the layer in the cache. Needed only when `cache` is passed to |
|
|
the `forward()` call, usually during generation or when the projected keys and values need |
|
|
to be cached during training. Can be omitted when `cache_layer` was passed to `__init__` |
|
|
output_attentions: If True, output also the attention weights. Otherwise output None. |
|
|
Note that only the eager implementation of MultiheadAttention supports this. |
|
|
cache_kwargs: kwargs directly passed to `cache.update()` |
|
|
Returns: |
|
|
Tuple with entries: |
|
|
- Attention block output: torch.Tensor of shape [B, L, out_features] |
|
|
- Optional attention weights if `output_attentions=True`, shape [B, num_heads, L, S] |
|
|
""" |
|
|
batch_size = query.shape[0] |
|
|
query_states = self.q_proj(query) |
|
|
key_states = self.k_proj(key) |
|
|
value_states = self.v_proj(value) |
|
|
query_states = query_states.view( |
|
|
batch_size, query_states.shape[1], self.num_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
key_states = key_states.view( |
|
|
batch_size, key_states.shape[1], self.num_kv_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
value_states = value_states.view( |
|
|
batch_size, value_states.shape[1], self.num_kv_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
(query_states, key_states) = self._maybe_apply_positional_embeddings( |
|
|
query_states=query_states, |
|
|
key_states=key_states, |
|
|
query_position_indices=query_position_indices, |
|
|
key_position_indices=key_position_indices, |
|
|
cache=cache, |
|
|
) |
|
|
(key_states, value_states) = self._maybe_update_cache( |
|
|
key_states, |
|
|
value_states, |
|
|
cache_layer=cache_layer, |
|
|
cache=cache, |
|
|
cache_kwargs=cache_kwargs, |
|
|
) |
|
|
key_states = expand_kv_heads(key_states, self.num_heads // self.num_kv_heads) |
|
|
value_states = expand_kv_heads(value_states, self.num_heads // self.num_kv_heads) |
|
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
|
|
attn_mask = update_attn_mask( |
|
|
attn_mask, |
|
|
attn_implementation=EAGER_ATTN, |
|
|
query_seq_length=query_states.shape[2], |
|
|
kv_seq_length=value_states.shape[2], |
|
|
cache_position=query_position_indices, |
|
|
cache=cache, |
|
|
batch_size=batch_size, |
|
|
causal=is_causal, |
|
|
dtype=query_states.dtype, |
|
|
device=query_states.device, |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
if attn_mask is not None: |
|
|
attn_mask = attn_mask[:, :, :, : key_states.shape[-2]] |
|
|
attn_weights = attn_weights + attn_mask |
|
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( |
|
|
query_states.dtype |
|
|
) |
|
|
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
|
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
shape = (batch_size, self.num_heads, query.shape[1], self.head_dim) |
|
|
assert attn_output.shape == shape, f"{attn_output.shape} != {shape}" |
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim) |
|
|
attn_output = self.o_proj(attn_output) |
|
|
if not output_attentions: |
|
|
attn_weights = None |
|
|
return attn_output, attn_weights |
|
|
|
|
|
def _maybe_apply_positional_embeddings( |
|
|
self, |
|
|
query_states: torch.Tensor, |
|
|
key_states: torch.Tensor, |
|
|
query_position_indices: Optional[torch.Tensor], |
|
|
key_position_indices: Optional[torch.Tensor], |
|
|
cache: Optional[transformers.Cache], |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
device = query_states.device |
|
|
if self.query_position_embed is not None: |
|
|
if query_position_indices is None and cache is not None: |
|
|
query_position_indices = ( |
|
|
torch.arange(query_states.shape[-2], dtype=torch.int64, device=device).view(1, -1) |
|
|
+ cache.get_seq_length() |
|
|
) |
|
|
query_states = self.query_position_embed(query_states, position_indices=query_position_indices) |
|
|
if self.key_position_embed is not None: |
|
|
if key_position_indices is None and cache is not None: |
|
|
key_position_indices = ( |
|
|
torch.arange(key_states.shape[-2], dtype=torch.int64, device=device).view(1, -1) |
|
|
+ cache.get_seq_length() |
|
|
) |
|
|
key_states = self.key_position_embed(key_states, position_indices=key_position_indices) |
|
|
return query_states, key_states |
|
|
|
|
|
def _maybe_update_cache( |
|
|
self, |
|
|
key_states: torch.Tensor, |
|
|
value_states: torch.Tensor, |
|
|
cache_layer: Optional[int], |
|
|
cache: Optional[transformers.Cache], |
|
|
cache_kwargs: Dict[str, Any], |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
if cache is not None: |
|
|
if cache_layer is None and self.cache_layer is None: |
|
|
raise RuntimeError("When cache != None, cache_layer has to be set") |
|
|
cache_layer = cache_layer if cache_layer is not None else self.cache_layer |
|
|
(key_states, value_states) = cache.update(key_states, value_states, cache_layer, cache_kwargs) |
|
|
return key_states, value_states |
|
|
|
|
|
|
|
|
class MultiheadFlashAttention2(MultiheadAttention): |
|
|
""" |
|
|
MultiheadAttention implemented using flash attention module. Inherits `MultiheadAttention` as the weights |
|
|
of the module stay untouched. The only change is on the forward pass where we call flash attention. |
|
|
""" |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
query: torch.Tensor, |
|
|
key: torch.Tensor, |
|
|
value: torch.Tensor, |
|
|
attn_mask: Optional[torch.Tensor] = None, |
|
|
is_causal: bool = False, |
|
|
query_position_indices: Optional[torch.Tensor] = None, |
|
|
key_position_indices: Optional[torch.Tensor] = None, |
|
|
cache: Optional[transformers.Cache] = None, |
|
|
cache_layer: Optional[int] = None, |
|
|
output_attentions: bool = False, |
|
|
cache_kwargs: Dict[str, Any] = {}, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
""" |
|
|
Args: |
|
|
query: Query embedding of shape [B, L, in_features] |
|
|
key: Key embedding of shape [B, S, key_features] |
|
|
value: Value embedding of shape [B, S, value_features] |
|
|
attn_mask: dtype torch.bool and shape [B, S]. |
|
|
If bool, False values indicate masked positions (opposite of torch.nn.MultiheadAttention) |
|
|
If attn_mask is None, full-bidirectional attention or causal attention is used depdening |
|
|
on the value of `is_causal`. |
|
|
NOTE: Doesn't support 4D attn_mask, unlike MultiheadAttention |
|
|
is_causal: If True, apply additional causal masking to `attn_mask` |
|
|
query_position_indices: torch.Tensor of shape [1 | B, L] containing the indices of the `query` |
|
|
tokens within the entire sequence. Passed through to query_position_embed. If None and `cache` |
|
|
is not None, indices are autogenerated [0, 1, ..., L] and offset by `cache_size` |
|
|
key_position_indices: Same as `query_position_indices`, but applied to key |
|
|
cache: transformers.Cache containing cached key-value pairs. The linearly projected |
|
|
`key` and `value` passed to this function get added to the cache and concatenated after the |
|
|
key-value pairs in the cache and then attention is computed on the concatenated sequence. |
|
|
This is most commonly used at inference when generating auto-regressively or when one needs |
|
|
to cross attend to the keys and values outside this module forward pass. |
|
|
cache_layer: Index of the layer in the cache. Needed only when `cache` is passed to |
|
|
the `forward()` call, usually during generation or when the projected keys and values need |
|
|
to be cached during training. Can be omitted when `cache_layer` was passed to `__init__` |
|
|
output_attentions: If True, output also the attention weights. Otherwise output None. |
|
|
Note that only the eager implementation of MultiheadAttention supports this. |
|
|
cache_kwargs: kwargs directly passed to `cache.update()` |
|
|
Returns: |
|
|
Tuple with entries: |
|
|
- Attention block output: torch.Tensor of shape [B, L, out_features] |
|
|
- Optional attention weights if `output_attentions=True`, shape [B, num_heads, L, S] |
|
|
""" |
|
|
if isinstance(cache, transformers.StaticCache): |
|
|
raise ValueError( |
|
|
"transformers.StaticCache not compatible with flash attention. Use `sdpa` instead (for now)." |
|
|
) |
|
|
assert output_attentions is False, f"{self.__class__} doesn't support output_attentions=True" |
|
|
batch_size = query.shape[0] |
|
|
query_states = self.q_proj(query) |
|
|
key_states = self.k_proj(key) |
|
|
value_states = self.v_proj(value) |
|
|
query_states = query_states.view( |
|
|
batch_size, query_states.shape[1], self.num_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
key_states = key_states.view( |
|
|
batch_size, key_states.shape[1], self.num_kv_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
value_states = value_states.view( |
|
|
batch_size, value_states.shape[1], self.num_kv_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
(query_states, key_states) = self._maybe_apply_positional_embeddings( |
|
|
query_states=query_states, |
|
|
key_states=key_states, |
|
|
query_position_indices=query_position_indices, |
|
|
key_position_indices=key_position_indices, |
|
|
cache=cache, |
|
|
) |
|
|
(key_states, value_states) = self._maybe_update_cache( |
|
|
key_states, |
|
|
value_states, |
|
|
cache_layer=cache_layer, |
|
|
cache=cache, |
|
|
cache_kwargs=cache_kwargs, |
|
|
) |
|
|
query_states = query_states.transpose(1, 2) |
|
|
key_states = key_states.transpose(1, 2) |
|
|
value_states = value_states.transpose(1, 2) |
|
|
attn_mask = update_attn_mask( |
|
|
attn_mask, |
|
|
attn_implementation=FLASH_ATTN, |
|
|
query_seq_length=query_states.shape[2], |
|
|
kv_seq_length=value_states.shape[2], |
|
|
cache_position=query_position_indices, |
|
|
cache=cache, |
|
|
batch_size=batch_size, |
|
|
causal=is_causal, |
|
|
dtype=query_states.dtype, |
|
|
device=query_states.device, |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
raise NotImplementedError("Correctness not yet confirmed") |
|
|
attn_output = transformers.modeling_flash_attention_utils._flash_attention_forward( |
|
|
query_states=query_states, |
|
|
key_states=key_states, |
|
|
value_states=value_states, |
|
|
attention_mask=attn_mask, |
|
|
query_length=query.shape[1], |
|
|
position_ids=None, |
|
|
dropout=self.dropout if self.training else 0.0, |
|
|
sliding_window=None, |
|
|
use_top_left_mask=False, |
|
|
is_causal=is_causal, |
|
|
deterministic=True, |
|
|
) |
|
|
size = (batch_size, self.num_heads, query.shape[1], self.head_dim) |
|
|
if attn_output.size() != size: |
|
|
raise ValueError(f"`attn_output` should be of size {size}, but is {attn_output.size()}") |
|
|
shape = (batch_size, self.num_heads, query.shape[1], self.head_dim) |
|
|
assert attn_output.shape == shape, f"{attn_output.shape} != {shape}" |
|
|
attn_output = attn_output.reshape(batch_size, query.shape[1], -1).contiguous() |
|
|
attn_output = self.o_proj(attn_output) |
|
|
return attn_output, None |
|
|
|
|
|
|
|
|
class MultiheadSdpaAttention(MultiheadAttention): |
|
|
""" |
|
|
MultiheadAttention SDPA attention. Inherits `MultiheadAttention` as the weights of the module stay untouched. |
|
|
The only change is on the forward pass where we call `torch.nn.functional.scaled_dot_product_attention` |
|
|
""" |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
query: torch.Tensor, |
|
|
key: torch.Tensor, |
|
|
value: torch.Tensor, |
|
|
attn_mask: Optional[torch.Tensor] = None, |
|
|
is_causal: bool = False, |
|
|
query_position_indices: Optional[torch.Tensor] = None, |
|
|
key_position_indices: Optional[torch.Tensor] = None, |
|
|
cache: Optional[transformers.Cache] = None, |
|
|
cache_layer: Optional[int] = None, |
|
|
output_attentions: bool = False, |
|
|
cache_kwargs: Dict[str, Any] = {}, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
""" |
|
|
Args: |
|
|
query: Query embedding of shape [B, L, in_features] |
|
|
key: Key embedding of shape [B, S, key_features] |
|
|
value: Value embedding of shape [B, S, value_features] |
|
|
attn_mask: dtype torch.bool or same dtype as query/key/value and shape one of: |
|
|
- [B, S] |
|
|
- [B | 1, 1 | num_heads, L, S] |
|
|
If bool, False values indicate masked positions (opposite of torch.nn.MultiheadAttention) |
|
|
If float, must contain only 0.0 and torch.finfo(dtype).min |
|
|
If attn_mask is None, full-bidirectional attention or causal attention is used depdening |
|
|
on the value of `is_causal`. |
|
|
is_causal: If True, apply additional causal masking to `attn_mask` |
|
|
query_position_indices: torch.Tensor of shape [1 | B, L] containing the indices of the `query` |
|
|
tokens within the entire sequence. Passed through to query_position_embed. If None and `cache` |
|
|
is not None, indices are autogenerated [0, 1, ..., L] and offset by `cache_size` |
|
|
key_position_indices: Same as `query_position_indices`, but applied to key |
|
|
cache: transformers.Cache containing cached key-value pairs. The linearly projected |
|
|
`key` and `value` passed to this function get added to the cache and concatenated after the |
|
|
key-value pairs in the cache and then attention is computed on the concatenated sequence. |
|
|
This is most commonly used at inference when generating auto-regressively or when one needs |
|
|
to cross attend to the keys and values outside this module forward pass. |
|
|
cache_layer: Index of the layer in the cache. Needed only when `cache` is passed to |
|
|
the `forward()` call, usually during generation or when the projected keys and values need |
|
|
to be cached during training. Can be omitted when `cache_layer` was passed to `__init__` |
|
|
output_attentions: If True, output also the attention weights. Otherwise output None. |
|
|
Note that only the eager implementation of MultiheadAttention supports this. |
|
|
cache_kwargs: kwargs directly passed to `cache.update()` |
|
|
Returns: |
|
|
Tuple with entries: |
|
|
- Attention block output: torch.Tensor of shape [B, L, out_features] |
|
|
- Optional attention weights if `output_attentions=True`, shape [B, num_heads, L, S] |
|
|
""" |
|
|
assert output_attentions is False, f"{self.__class__} doesn't support output_attentions=True" |
|
|
batch_size = query.shape[0] |
|
|
query_states = self.q_proj(query) |
|
|
key_states = self.k_proj(key) |
|
|
value_states = self.v_proj(value) |
|
|
query_states = query_states.view( |
|
|
batch_size, query_states.shape[1], self.num_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
key_states = key_states.view( |
|
|
batch_size, key_states.shape[1], self.num_kv_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
value_states = value_states.view( |
|
|
batch_size, value_states.shape[1], self.num_kv_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
(query_states, key_states) = self._maybe_apply_positional_embeddings( |
|
|
query_states=query_states, |
|
|
key_states=key_states, |
|
|
query_position_indices=query_position_indices, |
|
|
key_position_indices=key_position_indices, |
|
|
cache=cache, |
|
|
) |
|
|
(key_states, value_states) = self._maybe_update_cache( |
|
|
key_states, |
|
|
value_states, |
|
|
cache_layer=cache_layer, |
|
|
cache=cache, |
|
|
cache_kwargs=cache_kwargs, |
|
|
) |
|
|
key_states = expand_kv_heads(key_states, self.num_heads // self.num_kv_heads) |
|
|
value_states = expand_kv_heads(value_states, self.num_heads // self.num_kv_heads) |
|
|
attn_mask = update_attn_mask( |
|
|
attn_mask, |
|
|
attn_implementation=SDPA_ATTN, |
|
|
query_seq_length=query_states.shape[2], |
|
|
kv_seq_length=value_states.shape[2], |
|
|
cache_position=query_position_indices, |
|
|
cache=cache, |
|
|
batch_size=batch_size, |
|
|
causal=is_causal, |
|
|
dtype=query_states.dtype, |
|
|
device=query_states.device, |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
if attn_mask is not None: |
|
|
attn_mask = attn_mask[:, :, :, : key_states.shape[-2]] |
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attn_mask=attn_mask, |
|
|
dropout_p=self.dropout if self.training else 0.0, |
|
|
is_causal=is_causal, |
|
|
) |
|
|
shape = (batch_size, self.num_heads, query.shape[1], self.head_dim) |
|
|
assert attn_output.shape == shape, f"{attn_output.shape} != {shape}" |
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
attn_output = attn_output.view(batch_size, query.shape[1], self.num_heads * self.head_dim) |
|
|
attn_output = self.o_proj(attn_output) |
|
|
return attn_output, None |
|
|
|
|
|
|
|
|
ATTN_TYPES = { |
|
|
EAGER_ATTN: MultiheadAttention, |
|
|
SDPA_ATTN: MultiheadSdpaAttention, |
|
|
FLASH_ATTN: MultiheadFlashAttention2, |
|
|
} |
|
|
|
|
|
|
|
|
def make_activation(activation: str | Type[torch.nn.Module], **kwargs) -> torch.nn.Module: |
|
|
if isinstance(activation, str): |
|
|
TorchActivation: Type[torch.nn.Module] = getattr(torch.nn, activation) |
|
|
else: |
|
|
TorchActivation: Type[torch.nn.Module] = activation |
|
|
assert issubclass(TorchActivation, torch.nn.Module), TorchActivation |
|
|
return TorchActivation(**kwargs) |
|
|
|
|
|
|
|
|
class PiZeroMLP(torch.nn.Module): |
|
|
def __init__(self, feature_size: int, hidden_size: int, activation: str): |
|
|
super().__init__() |
|
|
self.gate_proj = torch.nn.Linear(feature_size, hidden_size, bias=False) |
|
|
self.up_proj = torch.nn.Linear(feature_size, hidden_size, bias=False) |
|
|
self.down_proj = torch.nn.Linear(hidden_size, feature_size, bias=False) |
|
|
self.activation = make_activation(activation, approximate="tanh") |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
|
|
|
|
class PiZeroFlowMatchingDecoderBlock(ConfigurableModule): |
|
|
def __init__(self, config: PiZeroFlowMatchingDecoderBlockConfig, **attn_kwargs): |
|
|
super().__init__(config) |
|
|
self.norm_in = GemmaRMSNorm(self.config.feature_size, eps=1e-06) |
|
|
self.self_attn = ATTN_TYPES[self.config.attn_implementation]( |
|
|
in_features=self.config.feature_size, |
|
|
num_heads=self.config.num_heads, |
|
|
head_dim=self.config.head_dim, |
|
|
num_kv_heads=self.config.num_kv_heads, |
|
|
**attn_kwargs, |
|
|
) |
|
|
self.mlp = PiZeroMLP( |
|
|
feature_size=self.config.feature_size, |
|
|
hidden_size=self.config.hidden_size, |
|
|
activation=self.config.activation, |
|
|
) |
|
|
self.norm_out = GemmaRMSNorm(self.config.feature_size, eps=1e-06) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
query: torch.Tensor, |
|
|
attn_mask: torch.Tensor, |
|
|
cache: transformers.Cache, |
|
|
attn_kwargs: Dict[str, Any], |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
query: torch.Tensor of shape [B, L, token_size]. The query seqence in the order: |
|
|
[noised query tokens, condition token, robot state tokens] |
|
|
timestep: torch.Tensor of shape [B, 1, token_size]. Timestep token |
|
|
attn_mask: torch.Tensor of shape [B, 1, L, L+S] and dtype torch.bool, where S is the VLM |
|
|
sequence length |
|
|
cache: Cache that contains only the VLM tokens during training and VLM + past query tokens |
|
|
during generation |
|
|
num_noised_tokens: Number of noised tokens in `query` |
|
|
num_condition_tokens: Number of condition tokens in `query` |
|
|
Returns: |
|
|
torch.Tensor of same shape as query [B, L, token_size] |
|
|
""" |
|
|
residual = x = query |
|
|
x = self.norm_in(x) |
|
|
(x, _) = self.self_attn( |
|
|
query=x, |
|
|
key=x, |
|
|
value=x, |
|
|
attn_mask=attn_mask, |
|
|
is_causal=False, |
|
|
cache=cache, |
|
|
**attn_kwargs, |
|
|
) |
|
|
x = residual + x |
|
|
residual = x |
|
|
x = self.norm_out(x) |
|
|
x = self.mlp(x) |
|
|
x = residual + x |
|
|
return x |
|
|
|
|
|
|
|
|
class PiZeroFlowMatchingDecoder(ConfigurableModule): |
|
|
"""PiZero Flow Matching control decoder""" |
|
|
|
|
|
def __init__(self, config: PiZeroFlowMatchingDecoderConfig): |
|
|
super().__init__(config) |
|
|
query_position_embed = RotaryPositionalEncoding(config=self.config.block_config.position_embed_config) |
|
|
key_position_embed = RotaryPositionalEncoding(config=self.config.block_config.position_embed_config) |
|
|
self.blocks = torch.nn.ModuleList( |
|
|
[ |
|
|
PiZeroFlowMatchingDecoderBlock( |
|
|
self.config.block_config, |
|
|
query_position_embed=query_position_embed, |
|
|
key_position_embed=key_position_embed, |
|
|
cache_layer=i, |
|
|
) |
|
|
for i in range(self.config.num_blocks) |
|
|
] |
|
|
) |
|
|
self.norm = GemmaRMSNorm(self.config.block_config.feature_size, eps=1e-06) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
control_tokens: torch.Tensor, |
|
|
robot_state_tokens: torch.Tensor, |
|
|
llm_kv_tokens: List[Tuple[torch.Tensor, torch.Tensor]], |
|
|
attn_mask: Optional[torch.Tensor], |
|
|
cache: Optional[transformers.Cache] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
control_tokens: torch.Tensor of shape [B, N, token_size], contains sequence of controls |
|
|
robot_state_tokens: torch.Tensor of shape [B, num_state_tokens, token_size] |
|
|
llm_kv_tokens: List of linearly projected key-value pairs from LLM, right before attention |
|
|
operation. Each tensor is of the shape [B, num_kv_heads, S, head_dim] |
|
|
attn_mask: One of |
|
|
- shape [B, S], dtype torch.bool -> padding attention mask for LLM tokens |
|
|
- shape [B, 1, L, S], dtype torch.bool -> full attention mask for LLM tokens |
|
|
Returns: |
|
|
torch.Tensor, shape [B, N, token_size] |
|
|
""" |
|
|
assert ( |
|
|
len(llm_kv_tokens) == self.config.num_blocks |
|
|
), f"{len(llm_kv_tokens)} != {self.config.num_blocks}" |
|
|
is_step_zero = cache.get_seq_length() == 0 if cache is not None else True |
|
|
vlm_seq_len = attn_mask.shape[-1] |
|
|
device = attn_mask.device |
|
|
if cache is None: |
|
|
cache = transformers.DynamicCache() |
|
|
if is_step_zero: |
|
|
position_indices = torch.arange(vlm_seq_len, dtype=torch.int64, device=device) |
|
|
for block_index, kv_tokens in enumerate(llm_kv_tokens): |
|
|
(key_states, value_states) = kv_tokens |
|
|
cache.update( |
|
|
key_states, |
|
|
value_states, |
|
|
block_index, |
|
|
cache_kwargs={"cache_position": position_indices}, |
|
|
) |
|
|
num_control_tokens = control_tokens.shape[1] |
|
|
num_robot_state_tokens = robot_state_tokens.shape[1] |
|
|
attn_mask = self._build_attn_mask( |
|
|
num_control_tokens=num_control_tokens, |
|
|
num_robot_state_tokens=num_robot_state_tokens, |
|
|
attn_mask=attn_mask, |
|
|
) |
|
|
if is_step_zero: |
|
|
tokens = torch.cat([robot_state_tokens, control_tokens], axis=1) |
|
|
query_position_indices = key_position_indices = vlm_seq_len + torch.arange( |
|
|
tokens.shape[1], dtype=torch.int64, device=device |
|
|
).view(1, -1) |
|
|
else: |
|
|
tokens = control_tokens |
|
|
attn_mask = attn_mask[:, :, -control_tokens.shape[1] :] |
|
|
query_position_indices = key_position_indices = ( |
|
|
vlm_seq_len |
|
|
+ num_robot_state_tokens |
|
|
+ torch.arange(tokens.shape[1], dtype=torch.int64, device=device).view(1, -1) |
|
|
) |
|
|
for block in self.blocks: |
|
|
tokens = block( |
|
|
query=tokens, |
|
|
attn_mask=attn_mask, |
|
|
cache=cache, |
|
|
attn_kwargs={ |
|
|
"query_position_indices": query_position_indices, |
|
|
"key_position_indices": key_position_indices, |
|
|
"cache_kwargs": {"cache_position": key_position_indices.view(-1)}, |
|
|
}, |
|
|
) |
|
|
if is_step_zero: |
|
|
(_, control_tokens) = torch.split(tokens, [num_robot_state_tokens, num_control_tokens], dim=1) |
|
|
else: |
|
|
control_tokens = tokens |
|
|
control_tokens = self.norm(control_tokens) |
|
|
return control_tokens |
|
|
|
|
|
@torch.no_grad() |
|
|
def _build_attn_mask( |
|
|
self, |
|
|
num_control_tokens: int, |
|
|
num_robot_state_tokens: int, |
|
|
attn_mask: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Expand `attn_mask` (which is effectively a padding mask) to 4D such that: |
|
|
- robot state tokens and control tokens can't attend to padding tokens |
|
|
- robot state tokens can't attend to control tokens |
|
|
Note: We can't keep the mask in 2D as it doesn't allow masking of padding tokens from the |
|
|
VLM sequence. Furthermore, in a 2D mask you can't disable attention from robot state tokens |
|
|
to control tokens |
|
|
""" |
|
|
assert attn_mask.dtype == torch.bool, attn_mask.dtype |
|
|
assert attn_mask.ndim in [2, 4], attn_mask.shape |
|
|
device = attn_mask.device |
|
|
batch_size = attn_mask.shape[0] |
|
|
query_seq_len = num_robot_state_tokens + num_control_tokens |
|
|
vlm_seq_len = attn_mask.shape[-1] |
|
|
kv_seq_len = query_seq_len + vlm_seq_len |
|
|
cross_attn_mask = torch.ones( |
|
|
[batch_size, 1, query_seq_len, kv_seq_len], dtype=torch.bool, device=device |
|
|
) |
|
|
if attn_mask.ndim == 2: |
|
|
attn_mask = attn_mask.view(batch_size, 1, 1, vlm_seq_len) |
|
|
else: |
|
|
attn_mask = torch.any(attn_mask, dim=-2, keepdims=True) |
|
|
cross_attn_mask[..., :vlm_seq_len] = attn_mask |
|
|
robot_state_query_indices = torch.arange( |
|
|
num_robot_state_tokens, dtype=torch.int64, device=device |
|
|
).view(-1, 1) |
|
|
control_key_indices = ( |
|
|
torch.arange(num_control_tokens, dtype=torch.int64, device=device).view(-1, 1) |
|
|
+ vlm_seq_len |
|
|
+ num_robot_state_tokens |
|
|
) |
|
|
cross_attn_mask[:, :, robot_state_query_indices, control_key_indices] = 0 |
|
|
return cross_attn_mask |
|
|
|
|
|
@property |
|
|
def fsdp_wrap_modules(self) -> Set[torch.nn.Module]: |
|
|
return {module for module in self.modules() if isinstance(module, type(self.blocks[0]))} | {self.norm} |
|
|
|
|
|
|
|
|
def integrate_unitquat( |
|
|
qt: torch.Tensor, |
|
|
dq_dt: torch.Tensor, |
|
|
dt: float | torch.Tensor, |
|
|
body_frame: bool = True, |
|
|
half_cover: bool = True, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Integrate a unit quaternion `qt` by the derivative `dq_dt` over the time interval `dt`. |
|
|
Args: |
|
|
qt: Unit quaternion, shape [..., 4] |
|
|
dq_dt: Derivative of the unit quaternion, shape [..., 4] |
|
|
dt: Time interval to integrate over, scalar or a tensor of shape () or [..., 1] |
|
|
half_cover: If True, the result is guaranteed to lie in the half space |
|
|
body_frame: If True, the integration is done in the body frame (post-multiply), |
|
|
otherwise in the inertial frame (pre-multiply). |
|
|
Returns: |
|
|
Integrated unit quaternion, shape [..., 4] |
|
|
""" |
|
|
assert qt.shape == dq_dt.shape, f"{qt.shape} != {dq_dt.shape}" |
|
|
assert is_quaternion(qt), f"{qt.shape} not a quaternion" |
|
|
if isinstance(dt, torch.Tensor): |
|
|
assert dt.ndim in (0, qt.ndim), f"dt.ndim = {dt.ndim} | {qt.ndim}" |
|
|
if body_frame: |
|
|
omega_q = 2.0 * roma.quat_product(roma.quat_conjugation(qt), dq_dt) |
|
|
else: |
|
|
omega_q = 2.0 * roma.quat_product(dq_dt, roma.quat_conjugation(qt)) |
|
|
omega = omega_q[..., :-1] |
|
|
dq = roma.rotvec_to_unitquat(omega * dt) |
|
|
if body_frame: |
|
|
qt = roma.quat_product(qt, dq) |
|
|
else: |
|
|
qt = roma.quat_product(dq, qt) |
|
|
if half_cover: |
|
|
qt = quaternion_half_cover(qt) |
|
|
return qt |
|
|
|
|
|
|
|
|
def rotmat_inverse(rotation: torch.Tensor) -> torch.Tensor: |
|
|
assert is_rotmat(rotation), f"Expected a rotation matrix, but got shape {rotation.shape}" |
|
|
rotmat = rotmat_as_3x3(rotation) |
|
|
rotmat = rotmat.transpose(-1, -2) |
|
|
if is_rotmat_9(rotation): |
|
|
rotmat = rotmat_as_9(rotmat) |
|
|
return rotmat |
|
|
|
|
|
|
|
|
def skew_symmetric_to_rotvec(skew_symmetric: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Convert a skew-symmetric matrix to a rotation vector in a differentiable way |
|
|
[ |
|
|
[ 0, -z, y], |
|
|
[ z, 0, -x], |
|
|
[-y, x, 0], |
|
|
] |
|
|
Args: |
|
|
skew_symmetric: Skew-symmetric matrix of shape [..., 3, 3] |
|
|
Returns: |
|
|
torch.Tensor of shape [..., 3] |
|
|
""" |
|
|
assert is_rotmat(skew_symmetric), skew_symmetric.shape |
|
|
rotvec = torch.stack( |
|
|
( |
|
|
skew_symmetric[..., 2, 1] - skew_symmetric[..., 1, 2], |
|
|
skew_symmetric[..., 0, 2] - skew_symmetric[..., 2, 0], |
|
|
skew_symmetric[..., 1, 0] - skew_symmetric[..., 0, 1], |
|
|
), |
|
|
dim=-1, |
|
|
) |
|
|
rotvec = rotvec / 2.0 |
|
|
return rotvec |
|
|
|
|
|
|
|
|
def integrate_rotmat( |
|
|
rt: torch.Tensor, |
|
|
dr_dt: torch.Tensor, |
|
|
dt: float | torch.Tensor, |
|
|
body_frame: bool = True, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Integrate a rotation matrix `rt` by the derivative `dr_dt` over the time interval `dt`. |
|
|
Args: |
|
|
rt: Rotation matrix, shape [..., 3, 3] |
|
|
dr_dt: Derivative of the rotation matrix, shape [..., 3, 3] |
|
|
dt: Time interval to integrate over, scalar or a tensor of shape () or [..., 1] |
|
|
body_frame: If True, the integration is done in the body frame (post-multiply), |
|
|
otherwise in the inertial frame (pre-multiply). |
|
|
Returns: |
|
|
Integrated unit quaternion, shape [..., 4] |
|
|
""" |
|
|
assert rt.shape == dr_dt.shape, f"{rt.shape} != {dr_dt.shape}" |
|
|
assert is_rotmat(rt), f"{rt.shape} not a rotation matrix" |
|
|
is_3x3 = is_rotmat_3x3(rt) |
|
|
if not is_3x3: |
|
|
rt = rotmat_as_3x3(rt) |
|
|
dr_dt = rotmat_as_3x3(dr_dt) |
|
|
if isinstance(dt, torch.Tensor): |
|
|
assert dt.ndim in ( |
|
|
0, |
|
|
rt.ndim, |
|
|
rt.ndim - 1, |
|
|
), f"dt.ndim = {dt.ndim} | {rt.ndim} | {rt.ndim - 1}" |
|
|
if dt.ndim == rt.ndim: |
|
|
assert dt.shape[-2:] == (1, 1), dt.shape |
|
|
dt = dt.squeeze(-1) |
|
|
if body_frame: |
|
|
omega = skew_symmetric_to_rotvec(rotmat_inverse(rt) @ dr_dt) |
|
|
else: |
|
|
omega = skew_symmetric_to_rotvec(dr_dt @ rotmat_inverse(rt)) |
|
|
dr = roma.rotvec_to_rotmat(omega * dt) |
|
|
if body_frame: |
|
|
rt = rt @ dr |
|
|
else: |
|
|
rt = dr @ rt |
|
|
if not is_3x3: |
|
|
rt = rotmat_as_9(rt) |
|
|
return rt |
|
|
|
|
|
|
|
|
def integrate_rotation( |
|
|
rt: torch.Tensor, |
|
|
dr_dt: torch.Tensor, |
|
|
dt: float | torch.Tensor, |
|
|
body_frame: bool = True, |
|
|
half_cover: bool = True, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Integrate the rotation `rt` by the derivative `dr_dt` over the time interval `dt` on the SO(3) manifold. |
|
|
""" |
|
|
if is_quaternion(rt): |
|
|
return integrate_unitquat(rt, dr_dt, dt, body_frame=body_frame, half_cover=half_cover) |
|
|
if is_rotmat(rt): |
|
|
return integrate_rotmat(rt, dr_dt, dt, body_frame=body_frame) |
|
|
raise NotImplementedError(f"integrate_rotation not yet implemented for format {rt.shape}") |
|
|
|
|
|
|
|
|
class PiZeroFlowMatchingModule(ConfigurableModule): |
|
|
def __init__(self, config: PiZeroFlowMatchingModuleConfig, control_tokenizer: EmptyTokenizer): |
|
|
super().__init__(config) |
|
|
del control_tokenizer |
|
|
self.noised_control_proj = NoisedControlProjector(self.config.noised_control_proj_config) |
|
|
self.robot_state_proj = RobotStateProjector(self.config.robot_state_proj_config) |
|
|
self.control_decoder = PiZeroFlowMatchingDecoder(config=self.config.control_decoder_config) |
|
|
self.output_proj = make_mlp( |
|
|
[self.config.token_size, 3 + self.config.rotation_components + 1], |
|
|
activation=torch.nn.GELU, |
|
|
activate_final=False, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
vlm_input: RoboticsFlowInput, |
|
|
vlm_output: VLMOutput, |
|
|
cache: Optional[transformers.Cache] = None, |
|
|
) -> RoboticsOutput: |
|
|
robot_state_tokens = self.robot_state_proj(vlm_input) |
|
|
noised_tokens = self.noised_control_proj(vlm_input.flow_input) |
|
|
output_tokens = self.control_decoder( |
|
|
control_tokens=noised_tokens, |
|
|
robot_state_tokens=robot_state_tokens, |
|
|
llm_kv_tokens=vlm_output.llm_output.past_key_values, |
|
|
attn_mask=vlm_input.attn_mask, |
|
|
cache=cache, |
|
|
) |
|
|
contols = self.output_proj(output_tokens) |
|
|
(translation, rotation, gripper) = torch.split( |
|
|
contols, [3, self.config.rotation_components, 1], dim=-1 |
|
|
) |
|
|
return RoboticsOutput.make_empty().replace( |
|
|
translation=translation, rotation=rotation, gripper=gripper |
|
|
) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate( |
|
|
self, |
|
|
vlm_input: RoboticsFlowInput, |
|
|
vlm_output: VLMOutput, |
|
|
processor: PiZeroFlowMatchingProcessor, |
|
|
use_cache: bool = True, |
|
|
**kwargs, |
|
|
) -> RoboticsOutput: |
|
|
del kwargs |
|
|
(batch_size, vlm_seq_len) = vlm_input.input_ids.shape[:2] |
|
|
device = vlm_input.input_ids.device |
|
|
if use_cache: |
|
|
max_cache_len = ( |
|
|
vlm_seq_len |
|
|
+ processor.config.control_io_config.future_controls_sequence_length |
|
|
+ processor.config.control_io_config.past_scalars_sequence_length |
|
|
) |
|
|
cache = transformers.StaticCache( |
|
|
config=transformers.PretrainedConfig( |
|
|
head_dim=self.config.control_decoder_config.block_config.head_dim, |
|
|
num_key_value_heads=self.config.control_decoder_config.block_config.num_kv_heads, |
|
|
num_hidden_layers=self.config.control_decoder_config.num_blocks, |
|
|
), |
|
|
max_batch_size=batch_size, |
|
|
max_cache_len=max_cache_len, |
|
|
device=device, |
|
|
) |
|
|
else: |
|
|
cache = None |
|
|
flow_input: FlowInput = processor.sample_t0_input(batch_size=batch_size, device=device) |
|
|
step_size = 1 / processor.config.num_inference_steps |
|
|
translation = flow_input.translation_t0 |
|
|
rotation = flow_input.rotation_t0 |
|
|
gripper = flow_input.gripper_t0 |
|
|
vlm_input = vlm_input.replace( |
|
|
**{ |
|
|
"flow_input.timestep": flow_input.timestep, |
|
|
"flow_input.translation_t": translation, |
|
|
"flow_input.rotation_t": rotation, |
|
|
"flow_input.gripper_t": gripper, |
|
|
} |
|
|
) |
|
|
for _ in range(processor.config.num_inference_steps): |
|
|
model_output: RoboticsOutput = self(vlm_input, vlm_output, cache) |
|
|
translation = translation + step_size * model_output.translation |
|
|
rotation = integrate_rotation(rt=rotation, dr_dt=model_output.rotation, dt=step_size) |
|
|
gripper = gripper + step_size * model_output.gripper |
|
|
timestep = vlm_input.flow_input.timestep + step_size |
|
|
if processor.config.rotation_format == RotationFormat.QUATERNION: |
|
|
rotation = quaternion_half_cover(rotation) |
|
|
vlm_input = vlm_input.replace( |
|
|
**{ |
|
|
"flow_input.timestep": timestep, |
|
|
"flow_input.translation_t": translation, |
|
|
"flow_input.rotation_t": rotation, |
|
|
"flow_input.gripper_t": gripper, |
|
|
} |
|
|
) |
|
|
output = RoboticsOutput.make_empty().replace( |
|
|
translation=translation, rotation=rotation, gripper=gripper |
|
|
) |
|
|
return output |
|
|
|
|
|
@property |
|
|
def fsdp_wrap_modules(self) -> Set[torch.nn.Module]: |
|
|
return self.control_decoder.fsdp_wrap_modules | { |
|
|
self, |
|
|
self.robot_state_proj, |
|
|
self.noised_control_proj, |
|
|
self.output_proj, |
|
|
} |
|
|
|
|
|
|
|
|
CANONICAL_TO_BRIDGE_ROTATION = np.array( |
|
|
[ |
|
|
[1, 0, 0], |
|
|
[0, np.cos(np.pi), -np.sin(np.pi)], |
|
|
[0, np.sin(np.pi), np.cos(np.pi)], |
|
|
], |
|
|
dtype=np.float32, |
|
|
) |
|
|
|
|
|
|
|
|
class SPEAR1(ConfigurableModule, transformers.PreTrainedModel): |
|
|
config_class: transformers.PretrainedConfig = SPEAR1Config |
|
|
|
|
|
def __init__(self, config: SPEAR1Config): |
|
|
super().__init__(config) |
|
|
self.vlm = PaliGemmaVLM(config=self.config.vlm_config) |
|
|
self.processor = PiZeroFlowMatchingProcessor( |
|
|
config=self.config.processor_config, vlm_processor=self.vlm.processor |
|
|
) |
|
|
self.control_module = PiZeroFlowMatchingModule( |
|
|
config=self.config.control_module_config, |
|
|
control_tokenizer=self.processor.control_tokenizer, |
|
|
) |
|
|
self.generation_config = transformers.GenerationConfig() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
inputs: RoboticsInput, |
|
|
use_cache: Optional[bool] = True, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
) -> RoboticsOutput: |
|
|
del output_hidden_states |
|
|
vlm_output = self.vlm(inputs=inputs, use_cache=use_cache, output_hidden_states=True) |
|
|
control_output = self.control_module(vlm_input=inputs, vlm_output=vlm_output) |
|
|
output = control_output.replace(llm_output=vlm_output.llm_output) |
|
|
return output |
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate( |
|
|
self, |
|
|
inputs: RoboticsInput, |
|
|
use_cache: Optional[bool] = True, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
) -> RoboticsOutput: |
|
|
del output_hidden_states |
|
|
vlm_output = self.vlm( |
|
|
inputs=inputs, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=True, |
|
|
) |
|
|
control_output = self.control_module.generate( |
|
|
vlm_input=inputs, vlm_output=vlm_output, processor=self.processor |
|
|
) |
|
|
output = control_output.replace(llm_output=vlm_output.llm_output) |
|
|
return output |
|
|
|
|
|
def predict_action(self, inputs: Dict) -> Dict[str, np.ndarray]: |
|
|
images = inputs["images"] |
|
|
ee_translation = inputs["ee_translation"] |
|
|
ee_rotation = inputs["ee_rotation"] |
|
|
gripper = inputs["gripper"] |
|
|
|
|
|
num_resize_args = len(inspect.signature(self.processor.resize_image).parameters) |
|
|
|
|
|
for camera_name, camera_image in images.items(): |
|
|
|
|
|
|
|
|
if num_resize_args == 1: |
|
|
images[camera_name] = self.processor.resize_image(camera_image) |
|
|
elif num_resize_args == 2: |
|
|
images[camera_name] = self.processor.resize_image(camera_name, camera_image) |
|
|
else: |
|
|
raise ValueError(f"Unexpected number of arguments for resize_image: {num_resize_args}") |
|
|
|
|
|
images[camera_name] = [images[camera_name]] |
|
|
|
|
|
|
|
|
ee_translation = np.array(ee_translation, dtype=np.float32).reshape(1, 3) |
|
|
ee_rotation = np.array(ee_rotation, dtype=np.float32).reshape(1, 3, 3) @ CANONICAL_TO_BRIDGE_ROTATION |
|
|
gripper = np.array(gripper, dtype=np.float32).reshape(1, 1) |
|
|
joints = np.zeros((1, 7), dtype=np.float32) |
|
|
|
|
|
dataset_name = np.array([inputs["dataset_name"]]) |
|
|
|
|
|
chat = [f"{inputs['language_instruction']}", ""] |
|
|
|
|
|
model_input = self.processor.create_input( |
|
|
images=images, |
|
|
chat=chat, |
|
|
ee_pose_translation=ee_translation, |
|
|
ee_pose_rotation=ee_rotation, |
|
|
gripper=gripper, |
|
|
dataset_name=dataset_name, |
|
|
joints=joints, |
|
|
inference_mode=True, |
|
|
) |
|
|
|
|
|
model_input = model_input.apply( |
|
|
lambda x: x.unsqueeze(0).to("cuda") if isinstance(x, torch.Tensor) else x |
|
|
) |
|
|
|
|
|
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16): |
|
|
model_output = self.generate(model_input) |
|
|
|
|
|
control_plan = self.processor.policy_control_plan_from_model_output( |
|
|
model_output=model_output, |
|
|
dataset_name=dataset_name, |
|
|
valid_mask=torch.ones( |
|
|
model_output.gripper.shape[:2], dtype=torch.bool, device=model_output.gripper.device |
|
|
), |
|
|
) |
|
|
translation_m = control_plan.translation_m.to(dtype=torch.float32, device='cpu') |
|
|
rotation = control_plan.rotmat.to(dtype=torch.float32, device='cpu') |
|
|
gripper_prob = control_plan.gripper_prob.to(dtype=torch.float32, device='cpu') |
|
|
|
|
|
|
|
|
if self.processor.config.eef_control_frame: |
|
|
|
|
|
|
|
|
robot_base_rotmat = rotmat_as_3x3(model_input.ee_pose_rotation[:, -1:, ...]).cpu() |
|
|
translation_m = torch.matmul( |
|
|
robot_base_rotmat, translation_m.unsqueeze(-1) |
|
|
).squeeze(-1) |
|
|
rotation = rotmat_as_3x3( |
|
|
torch.matmul(robot_base_rotmat, rotmat_as_3x3(rotation)) |
|
|
) |
|
|
|
|
|
translation = translation_m |
|
|
rotation = rotmat_as_3x3(rotation) |
|
|
gripper = gripper_prob |
|
|
|
|
|
translation = translation.squeeze(0).numpy() |
|
|
rotation = rotation.squeeze(0).numpy() |
|
|
gripper = gripper.squeeze(0).numpy() |
|
|
|
|
|
rotation = CANONICAL_TO_BRIDGE_ROTATION @ rotation @ CANONICAL_TO_BRIDGE_ROTATION.T |
|
|
|
|
|
return { |
|
|
"translation": translation, |
|
|
"rotation": rotation, |
|
|
"gripper": gripper, |
|
|
} |
|
|
|
|
|
@property |
|
|
def fsdp_wrap_modules(self) -> Set[torch.nn.Module]: |
|
|
return ( |
|
|
{self.vlm, self.control_module} |
|
|
| self.vlm.fsdp_wrap_modules |
|
|
| self.control_module.fsdp_wrap_modules |
|
|
) |
|
|
|