spear1-franka / modeling_spear.py
giu-alb's picture
Super-squash branch 'main' using huggingface_hub
a8bf2f3 verified
import collections
import functools
import inspect
import logging
import math
import warnings
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Protocol,
Set,
Tuple,
Type,
)
import einops
import numpy as np
import roma
import timm
import torch
import torch.distributed.fsdp
import torch.distributed.tensor
import transformers
from huggingface_hub import hf_hub_download
from .common_spear import (
Configurable,
DiffusionInput,
FlowInput,
LLMOutput,
RoboticsFlowInput,
RoboticsInput,
RoboticsOutput,
RotationFormat,
VLMOutput,
expand_dims,
is_quaternion,
is_rotmat,
is_rotmat_3x3,
is_rotmat_9,
quaternion_half_cover,
rotmat_as_3x3,
rotmat_as_9,
)
from .configuration_spear import (
FourierFeaturesConfig,
NoisedControlProjectorConfig,
PaliGemmaVLMConfig,
PiZeroFlowMatchingDecoderBlockConfig,
PiZeroFlowMatchingDecoderConfig,
PiZeroFlowMatchingModuleConfig,
RobotStateProjectorConfig,
RotaryPositionalEncodingConfig,
SPEAR1Config,
)
from .processing_spear import (
EmptyTokenizer,
PaliGemmaDepthProcessor,
PiZeroFlowMatchingProcessor,
)
class ConfigurableModule(torch.nn.Module, Configurable):
def __init__(self, config):
Configurable.__init__(self, config)
torch.nn.Module.__init__(self)
class GemmaRMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-06):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
@torch.no_grad()
def make_module_params_non_trainable(
module: torch.nn.Module, recursive: bool = True, dtype: Optional[torch.dtype] = None
):
"""
NOTE: dtype is applied only to module parameters, not buffers. This is different from the
default torch.nn.Module.to(dtype=dtype) behavior, which applies to both parameters and buffers
"""
for param in module.parameters(recurse=recursive):
param.requires_grad = False
if dtype is not None:
param.data = param.to(dtype=dtype)
@torch.no_grad()
def make_module_non_trainable(module: torch.nn.Module, dtype: Optional[torch.dtype] = None):
make_module_params_non_trainable(module, dtype=None)
if dtype is not None:
module.to(dtype=dtype)
module.eval()
class ResidualConvBlock(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int | None = None,
hidden_channels: int | None = None,
padding_mode: str = "replicate",
activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu",
norm: Literal["group_norm", "layer_norm"] = "group_norm",
):
super().__init__()
if out_channels is None:
out_channels = in_channels
if hidden_channels is None:
hidden_channels = in_channels
if activation == "relu":
activation_cls = functools.partial(torch.nn.ReLU, inplace=True)
elif activation == "leaky_relu":
activation_cls = functools.partial(torch.nn.LeakyReLU, negative_slope=0.2, inplace=True)
elif activation == "silu":
activation_cls = functools.partial(torch.nn.SiLU, inplace=True)
elif activation == "elu":
activation_cls = functools.partial(torch.nn.ELU, inplace=True)
else:
raise ValueError(f"Unsupported activation function: {activation}")
self.layers = torch.nn.Sequential(
torch.nn.GroupNorm(1, in_channels),
activation_cls(),
torch.nn.Conv2d(
in_channels,
hidden_channels,
kernel_size=3,
padding=1,
padding_mode=padding_mode,
),
torch.nn.GroupNorm(hidden_channels // 32 if norm == "group_norm" else 1, hidden_channels),
activation_cls(),
torch.nn.Conv2d(
hidden_channels,
out_channels,
kernel_size=3,
padding=1,
padding_mode=padding_mode,
),
)
self.skip_connection = (
torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
if in_channels != out_channels
else torch.nn.Identity()
)
def forward(self, x):
skip = self.skip_connection(x)
x = self.layers(x)
x = x + skip
return x
def normalized_view_plane_uv(
width: int,
height: int,
aspect_ratio: float | None = None,
dtype: torch.dtype = None,
device: torch.device = None,
) -> torch.Tensor:
"""
UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)
"""
if aspect_ratio is None:
aspect_ratio = width / height
span_x = aspect_ratio / (1 + aspect_ratio**2) ** 0.5
span_y = 1 / (1 + aspect_ratio**2) ** 0.5
u = torch.linspace(
-span_x * (width - 1) / width,
span_x * (width - 1) / width,
width,
dtype=dtype,
device=device,
)
v = torch.linspace(
-span_y * (height - 1) / height,
span_y * (height - 1) / height,
height,
dtype=dtype,
device=device,
)
(u, v) = torch.meshgrid(u, v, indexing="xy")
uv = torch.stack([u, v], dim=-1)
return uv
class Head(torch.nn.Module):
def __init__(
self,
num_features: int,
dim_in: int,
dim_out: List[int],
dim_proj: int = 512,
dim_upsample: List[int] = [256, 128, 128],
dim_times_res_block_hidden: int = 1,
num_res_blocks: int = 1,
res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm",
last_res_blocks: int = 0,
last_conv_channels: int = 32,
last_conv_size: int = 1,
):
super().__init__()
self.projects = torch.nn.ModuleList(
[
torch.nn.Conv2d(
in_channels=dim_in,
out_channels=dim_proj,
kernel_size=1,
stride=1,
padding=0,
)
for _ in range(num_features)
]
)
self.upsample_blocks = torch.nn.ModuleList(
[
torch.nn.Sequential(
self._make_upsampler(in_ch + 2, out_ch),
*(
ResidualConvBlock(
out_ch,
out_ch,
dim_times_res_block_hidden * out_ch,
activation="relu",
norm=res_block_norm,
)
for _ in range(num_res_blocks)
),
)
for (in_ch, out_ch) in zip([dim_proj] + dim_upsample[:-1], dim_upsample, strict=True)
]
)
self.output_block = torch.nn.ModuleList(
[
self._make_output_block(
dim_upsample[-1] + 2,
dim_out_,
dim_times_res_block_hidden,
last_res_blocks,
last_conv_channels,
last_conv_size,
res_block_norm,
)
for dim_out_ in dim_out
]
)
def _make_upsampler(self, in_channels: int, out_channels: int):
upsampler = torch.nn.Sequential(
torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
torch.nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
padding_mode="replicate",
),
)
upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1]
return upsampler
def _make_output_block(
self,
dim_in: int,
dim_out: int,
dim_times_res_block_hidden: int,
last_res_blocks: int,
last_conv_channels: int,
last_conv_size: int,
res_block_norm: Literal["group_norm", "layer_norm"],
):
return torch.nn.Sequential(
torch.nn.Conv2d(
dim_in,
last_conv_channels,
kernel_size=3,
stride=1,
padding=1,
padding_mode="replicate",
),
*(
ResidualConvBlock(
last_conv_channels,
last_conv_channels,
dim_times_res_block_hidden * last_conv_channels,
activation="relu",
norm=res_block_norm,
)
for _ in range(last_res_blocks)
),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(
last_conv_channels,
dim_out,
kernel_size=last_conv_size,
stride=1,
padding=last_conv_size // 2,
padding_mode="replicate",
),
)
def forward(self, hidden_states: List[torch.Tensor], image: torch.Tensor):
(img_h, img_w) = image.shape[-2:]
(patch_h, patch_w) = (img_h // 14, img_w // 14)
x = torch.stack(
[
proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous())
for (proj, feat) in zip(self.projects, hidden_states, strict=True)
],
dim=1,
).sum(dim=1)
for _, block in enumerate(self.upsample_blocks):
uv = normalized_view_plane_uv(
width=x.shape[-1],
height=x.shape[-2],
aspect_ratio=img_w / img_h,
dtype=x.dtype,
device=x.device,
)
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
x = torch.cat([x, uv], dim=1)
for layer in block:
x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
x = torch.nn.functional.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
uv = normalized_view_plane_uv(
width=x.shape[-1],
height=x.shape[-2],
aspect_ratio=img_w / img_h,
dtype=x.dtype,
device=x.device,
)
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
x = torch.cat([x, uv], dim=1)
if isinstance(self.output_block, torch.nn.ModuleList):
output = [
torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
for block in self.output_block
]
else:
output = torch.utils.checkpoint.checkpoint(self.output_block, x, use_reentrant=False)
return output
def _is_single_image_size(image_sizes: Dict[str, Dict[str, int]]) -> bool:
return (
len(image_sizes) == 1
or len(set(((image_size["height"], image_size["width"]) for image_size in image_sizes.values()))) == 1
)
class MoGe(torch.nn.Module):
"""
Implementation of MoGe taken from https://github.com/microsoft/MoGe/blob/main/moge/model/v1.py
Simplified and stripped down such that:
- It doesn't rely on MoGe codebase
- Uses timm for ViT backbone
- Only predicts points and mask
- Currently does NOT infer depth or intrinsics (but this could be added)
- Requires image to be resized to the expected resolution. Note that this resolution must be
in the range of num_tokens_range, (for square images, pixel sizes in the range ~[36*14, 50*14])
"""
def __init__(
self,
image_sizes: Dict[str, Dict[str, int]] = {},
backbone_id: str = "vit_large_patch14_dinov2.lvd142m",
intermediate_layers: int | List[int] = 4,
dim_proj: int = 512,
dim_upsample: List[int] = [256, 128, 64],
dim_times_res_block_hidden: int = 2,
num_res_blocks: int = 2,
remap_output: Literal[False, True, "linear", "sinh", "exp", "sinh_exp"] = "exp",
res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm",
num_tokens_range: Tuple[int, int] = [1275, 2551],
last_res_blocks: int = 0,
last_conv_channels: int = 32,
last_conv_size: int = 1,
mask_threshold: float = 0.5,
output_dtype: torch.dtype = torch.bfloat16,
):
"""
NOTE:
- All defaults were taken from the checkpoint config for 'Ruicheng/moge-vitl'
- output_dtype - by default was float32, changed to bfloat16 for model training
"""
super().__init__()
self.remap_output = remap_output
self.intermediate_layers = intermediate_layers
self.num_tokens_range = num_tokens_range
self.mask_threshold = mask_threshold
self.output_dtype = output_dtype
self.image_sizes = dict(image_sizes)
self.backbone: timm.models.vision_transformer.VisionTransformer = self._make_vit_backbone(backbone_id)
token_size: int = self.backbone.embed_dim
self.head = Head(
num_features=(
intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers)
),
dim_in=token_size,
dim_out=[3, 1],
dim_proj=dim_proj,
dim_upsample=dim_upsample,
dim_times_res_block_hidden=dim_times_res_block_hidden,
num_res_blocks=num_res_blocks,
res_block_norm=res_block_norm,
last_res_blocks=last_res_blocks,
last_conv_channels=last_conv_channels,
last_conv_size=last_conv_size,
)
def _make_vit_backbone(self, backbone_id: str) -> timm.models.vision_transformer.VisionTransformer:
if _is_single_image_size(self.image_sizes):
kwargs = {
"img_size": (
self.image_sizes["main"]["height"],
self.image_sizes["main"]["width"],
),
"dynamic_img_size": False,
}
else:
kwargs = {"img_size": (224, 224), "dynamic_img_size": True}
vit_backbone: timm.models.vision_transformer.VisionTransformer = timm.create_model(
backbone_id, pretrained=False, num_classes=0, **kwargs
)
vit_backbone.forward = functools.partial(
vit_backbone.forward_intermediates,
indices=4,
return_prefix_tokens=False,
norm=True,
stop_early=True,
output_fmt="NLC",
intermediates_only=True,
)
return vit_backbone
def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Args:
image: torch.Tensor of shape [B, 3, H, W] containing the preprocessed image, resized to the
size expected by the model.
Returns:
A dictionary containing:
- `points`: torch.Tensor of shape [B, 3, H, W] containing the predicted points.
- `mask`: torch.Tensor of shape [B, 1, H, W] containing the predicted mask.
"""
(height, width) = image.shape[-2:]
assert (height, width) in [
(image_size["height"], image_size["width"]) for image_size in self.image_sizes.values()
], f"{(height, width)} not in {self.image_sizes}"
features: List[torch.Tensor] = self.backbone(image)
output = self.head(features, image)
(points, mask) = output
with torch.autocast(
device_type=image.device.type,
dtype=torch.float32,
enabled=self.output_dtype == torch.float32,
):
points = torch.nn.functional.interpolate(
points,
(height, width),
mode="bilinear",
align_corners=False,
antialias=False,
)
mask = torch.nn.functional.interpolate(
mask,
(height, width),
mode="bilinear",
align_corners=False,
antialias=False,
)
points = self._remap_points(points, dim=1)
output = {"points": points, "mask": mask}
return output
def _remap_points(self, points: torch.Tensor, dim: int = 1) -> torch.Tensor:
if self.remap_output == "linear":
pass
elif self.remap_output == "sinh":
points = torch.sinh(points)
elif self.remap_output == "exp":
(xy, z) = points.split([2, 1], dim=dim)
z = torch.exp(z)
points = torch.cat([xy * z, z], dim=dim)
elif self.remap_output == "sinh_exp":
(xy, z) = points.split([2, 1], dim=dim)
points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=dim)
else:
raise ValueError(f"Invalid remap output type: {self.remap_output}")
return points
class DepthBackboneConfig(transformers.PretrainedConfig):
def __init__(
self,
hf_hub_repo: str = "",
hf_filename: str = "",
image_sizes: Dict[str, Dict[str, int]] = {},
**kwargs,
):
super().__init__(**kwargs)
self.hf_hub_repo = hf_hub_repo
self.hf_filename = hf_filename
self.image_sizes = dict(image_sizes)
class PaliGemma3DConfig(transformers.models.paligemma.PaliGemmaConfig):
sub_configs = {
"text_config": transformers.AutoConfig,
"vision_config": transformers.AutoConfig,
"depth_config": DepthBackboneConfig,
}
def __init__(
self,
depth_config={},
depth_only: bool = False,
mask_prob: float = 0.0,
projection: str = "",
depth_layers: int = 4,
**kwargs,
):
super().__init__(**kwargs)
if isinstance(depth_config, dict):
self.depth_config = DepthBackboneConfig(**depth_config)
else:
self.depth_config = depth_config
self.mask_prob = mask_prob
self.depth_only = depth_only
self.projection = projection
self.depth_layers = depth_layers
@property
def is_single_image_size(self) -> bool:
return (
len(self.depth_config.image_sizes) == 1
or len(
set(
(
(image_size["height"], image_size["width"])
for image_size in self.depth_config.image_sizes.values()
)
)
)
== 1
)
@property
def camera_names(self) -> List[str]:
return list(self.depth_config.image_sizes.keys())
class NeRFPositionalEmbedding(torch.nn.Module):
def __init__(self, n_frequencies: int, log_scale: bool = True, scale: float = 1.0):
"""
Args:
n_frequencies: Dimension size, same as L parameter in the NeRF paper
scale: Scale factor for the frequencies. To match the formula from the paper
[sin(2^k * pi * x), cos(2^k * pi * x)], set scale to math.pi. In practice, the paper
authors don't multiply by pi and use scale=1.0.
See https://github.com/bmild/nerf/issues/12
"""
super().__init__()
self.n_frequencies = n_frequencies
if log_scale:
freq = 2 ** torch.arange(self.n_frequencies, dtype=torch.float32) * scale
else:
freq = torch.linspace(1, 2 ** (self.n_frequencies - 1), self.n_frequencies) * scale
self.register_buffer("freq", freq)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
Maps features from dimensionality N to dimensionality N*(2*L + 1), i.e. `2L + 1` output
features are produced for each input feature.
Embeds x to (x, sin(2^k x), cos(2^k x), ...). NOTE: `x` is also in the output. This is
different from the equation in the paper, but matches the actual author implementation.
See https://github.com/bmild/nerf/issues/12
Args:
inputs: torch.Tensor of shape [B, ..., N]; input values to be transformed
Returns: torch.Tensor of shape [B, ..., N*(2*L + 1) = embedding_dim], encoded input values
"""
freq = expand_dims(self.freq, ndim=inputs.ndim + 1, order=[-1, 1])
spectrum = freq * inputs.unsqueeze(-1)
sin = torch.sin(spectrum)
cos = torch.cos(spectrum)
encoding = torch.stack([sin, cos], dim=-1)
encoding = encoding.view(*inputs.shape[:-1], -1)
encoding = torch.cat([inputs, encoding], dim=-1)
return encoding
def make_mlp(
layer_sizes: List[int],
activation: str | Type[torch.nn.Module],
norm: str | Type[torch.nn.Module] | None = torch.nn.LayerNorm,
activate_final: bool = False,
bias: bool = True,
) -> torch.nn.Sequential:
"""
Args:
layer_sizes: List of layer sizes. The first value is the number of input features and the last
value is the number of output features
activation: str or the class of the activation. If str, it should be the exact name of
the activation module under torch.nn, e.g. 'ReLU', 'SiLU', 'GeLU'. Use 'Identity' if
no activation wanted
norm: type of normalization. Same type as `activation`. Ex: `torch.nn.LayerNorm`, 'LayerNorm', etc
"""
if len(layer_sizes) == 0:
return torch.nn.Identity()
assert len(layer_sizes) > 1, "Need to provide input and output layer sizes at least"
if isinstance(activation, str):
TorchActivation: Type[torch.nn.Module] = getattr(torch.nn, activation)
else:
TorchActivation: Type[torch.nn.Module] = activation
assert issubclass(TorchActivation, torch.nn.Module), TorchActivation
if isinstance(norm, str):
TorchNorm: Type[torch.nn.Module] = getattr(torch.nn, norm)
elif norm is None:
TorchNorm: Type[torch.nn.Module] = torch.nn.Identity
else:
TorchNorm: Type[torch.nn.Module] = norm
assert issubclass(TorchNorm, torch.nn.Module), TorchNorm
def make_norm_act(modules: dict[str, torch.nn.Module], empty: bool):
return {} if empty else modules
module = torch.nn.Sequential(
*[
torch.nn.Sequential(
collections.OrderedDict(
{
"linear": torch.nn.Linear(in_features, out_features, bias=bias),
**make_norm_act(
{"norm": TorchNorm(out_features), "act": TorchActivation()},
empty=i == len(layer_sizes) - 2 and not activate_final,
),
}
)
)
for (i, (in_features, out_features)) in enumerate(
zip(layer_sizes[:-1], layer_sizes[1:], strict=True)
)
]
)
return module
class PaliGemma3D(transformers.models.paligemma.PaliGemmaForConditionalGeneration):
"""
Transformers-like implementation of PaliGemma with additional depth encoder
"""
config_class = PaliGemma3DConfig
def __init__(self, config: PaliGemma3DConfig):
super().__init__(config)
assert self.config.projection in ["features_add"]
if self.config.projection in ["features_add"]:
self.depth_tower = self._make_depth_encoder()
self.depth_projector = torch.nn.Linear(
in_features=self.depth_tower.embed_dim * self.config.depth_layers,
out_features=self.config.text_config.hidden_size,
)
self.generator: torch.Generator = torch.Generator()
if self.config.depth_only:
make_module_non_trainable(self.vision_tower)
make_module_non_trainable(self.siglip_projector)
self.projector = None
else:
raise ValueError(f"Projection type `{self.config.projection}` not supported!")
@property
def projectors(self) -> Set[torch.nn.Module]:
modules = set(
(
module
for module in [
self.projector,
self.depth_projector,
self.siglip_projector,
]
if module is not None
)
)
if isinstance(self.depth_tower, MoGe):
modules = modules | {
module
for module in self.depth_tower.modules()
if isinstance(module, (ResidualConvBlock, Head))
}
return modules
@property
def siglip_projector(self) -> torch.nn.Linear:
return self.multi_modal_projector.linear
def _make_depth_encoder(self) -> torch.nn.Module:
if self.config.is_single_image_size:
kwargs = {
"img_size": (
self.config.depth_config.image_sizes["main"]["height"],
self.config.depth_config.image_sizes["main"]["width"],
),
"dynamic_img_size": False,
}
else:
kwargs = {
"img_size": (
self.config.depth_config.image_sizes["main"]["height"],
self.config.depth_config.image_sizes["main"]["width"],
),
"dynamic_img_size": True,
}
model: timm.models.vision_transformer.VisionTransformer = timm.create_model(
"vit_large_patch14_dinov2.lvd142m",
pretrained=False,
num_classes=0,
**kwargs,
)
model.forward = functools.partial(
model.forward_intermediates,
indices=self.config.depth_layers,
return_prefix_tokens=False,
norm=True,
stop_early=True,
output_fmt="NLC",
intermediates_only=True,
)
return model
def _load_depth_model_state_dict(self, depth_model: torch.nn.Module):
logging.info(
f"Loading depth model from {self.config.depth_config.hf_hub_repo}/{self.config.depth_config.hf_filename}"
)
state_dict = torch.load(
hf_hub_download(
repo_id=self.config.depth_config.hf_hub_repo,
filename=self.config.depth_config.hf_filename,
),
map_location="cpu",
mmap=True,
weights_only=False,
)
if self.config.projection in ["spatial_add", "spatial_concat"]:
pos_embed_state_dict = {"pos_embed": state_dict["backbone.pos_embed"]}
pos_embed_state_dict = timm.models.vision_transformer.checkpoint_filter_fn(
pos_embed_state_dict, depth_model.backbone
)
state_dict["backbone.pos_embed"] = pos_embed_state_dict["pos_embed"]
else:
state_dict = timm.models.vision_transformer.checkpoint_filter_fn(state_dict, depth_model)
depth_model.load_state_dict(state_dict)
def get_image_features(self, pixel_values: dict[str, torch.Tensor]) -> torch.Tensor:
if self.config.projection == "features_add":
images_forward = self._get_image_features_add
elif self.config.projection in ["spatial_add", "spatial_concat"]:
images_forward = self._get_image_features_spatial
else:
raise ValueError(f"Project type `{self.config.projection}` not supported!")
camera_names = self.config.camera_names
if self.config.is_single_image_size:
inputs = {
"siglip": einops.rearrange(
torch.stack(
[pixel_values[f"{camera_name}.siglip"] for camera_name in camera_names],
dim=1,
),
"B N C H W -> (B N) C H W",
),
"depth": einops.rearrange(
torch.stack(
[pixel_values[f"{camera_name}.depth"] for camera_name in camera_names],
dim=1,
),
"B N C H W -> (B N) C H W",
),
}
image_tokens = images_forward(inputs)
else:
camera_tokens: List[torch.Tensor] = [
images_forward(
{
"siglip": pixel_values[f"{camera_name}.siglip"],
"depth": pixel_values[f"{camera_name}.depth"],
}
)
for camera_name in camera_names
]
image_tokens = torch.cat(camera_tokens, dim=-2)
return image_tokens
def _get_image_features_add(self, pixel_values: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values: images of shape `[B * num_images, C, H, W]`.
Keys in the dict correspond to the specific vision encoder to use
Returns:
image_features of shape [B * num_images, num_tokens, token_size = 2048]
"""
siglip_input = pixel_values["siglip"]
depth_input = pixel_values["depth"]
siglip_output: ViTOutput = self.vision_tower(siglip_input)
siglip_features = siglip_output.last_hidden_state
depth_output: list[torch.Tensor] = self.depth_tower(depth_input)
depth_features = torch.cat(depth_output, dim=-1) if len(depth_output) > 1 else depth_output[0]
siglip_features = self.siglip_projector(siglip_features)
depth_features = self.depth_projector(depth_features)
if self.config.depth_only:
image_features = depth_features
elif self.training and torch.bernoulli(torch.tensor(self.config.mask_prob), generator=self.generator):
num_tokens = depth_features.shape[1]
device = depth_features.device
ones = torch.ones((depth_features.shape[0], num_tokens), device=device)
indices = torch.multinomial(ones / num_tokens, num_samples=num_tokens // 2)
mask = ones.to(dtype=torch.bool).scatter_(dim=-1, index=indices, value=0)
mask = mask.unsqueeze(-1)
image_features = siglip_features * mask + depth_features * ~mask
else:
image_features = (siglip_features + depth_features) / 2
image_features = image_features / self.config.text_config.hidden_size**0.5
return image_features
def _get_image_features_spatial(self, pixel_values: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values: images of shape `[B * num_images, C, H, W]`.
Keys in the dict correspond to the specific vision encoder to use
Returns:
image_features of shape [B * num_images, num_tokens, token_size = 2048]
"""
siglip_input = pixel_values["siglip"]
depth_input = pixel_values["depth"]
siglip_output: ViTOutput = self.vision_tower(siglip_input)
siglip_features = siglip_output.last_hidden_state
depth_output: dict[str, torch.Tensor] = self.depth_tower(depth_input)
points = depth_output["points"]
mask = depth_output["mask"]
mask_binary = mask > self.depth_tower.mask_threshold
points = torch.where(mask_binary, points, 0)
points_embed: torch.Tensor = self.depth_projector(points)
if self.config.projection == "spatial_concat":
features = torch.cat([siglip_features, points_embed], dim=-1)
image_features = self.projector(features)
else:
features = siglip_features + points_embed
image_features = self.siglip_projector(features)
image_features = image_features / self.config.text_config.hidden_size**0.5
return image_features
@classmethod
def from_pretrained(cls, *args, **kwargs) -> "PaliGemma3D":
model = super().from_pretrained(*args, **kwargs)
model._load_depth_model_state_dict(model.depth_tower)
return model
class RobotStateProjector(ConfigurableModule):
"""Pack robot state and project to a single token per timestep"""
def __init__(self, config: RobotStateProjectorConfig):
super().__init__(config)
if self.config.fourier:
raise NotImplementedError("Fourier robot state projector is not implemented yet")
self.robot_state_tokens_proj = make_mlp(
layer_sizes=self.config.layers,
activation=self.config.activation,
norm=torch.nn.LayerNorm,
)
def forward(self, inputs: RoboticsInput) -> Optional[torch.Tensor]:
"""
Returns:
torch.Tensor of shape [B, num_past_steps, token_size] or None (if mode == 'none')
"""
if self.config.mode == "ee_pose":
robot_state = torch.cat([inputs.ee_pose_translation, inputs.ee_pose_rotation], dim=-1)
elif self.config.mode == "ee_pose_gripper":
robot_state = torch.cat(
[inputs.ee_pose_translation, inputs.ee_pose_rotation, inputs.gripper],
dim=-1,
)
elif self.config.mode == "ee_pose_joints":
robot_state = torch.cat(
[inputs.ee_pose_translation, inputs.ee_pose_rotation, inputs.joints],
dim=-1,
)
elif self.config.mode == "joints":
robot_state = inputs.joints
elif self.config.mode == "all":
robot_state = torch.cat(
[
inputs.ee_pose_translation,
inputs.ee_pose_rotation,
inputs.gripper,
inputs.joints,
],
dim=-1,
)
elif self.config.mode == "none":
robot_state = torch.tensor([], device=inputs.ee_pose_translation.device).view(
inputs.ee_pose_translation.shape[0],
0,
self.config.layers[0] if len(self.config.layers) > 0 else 0,
)
else:
raise NotImplementedError(f"Unknown image tokens mode {self.config.mode}")
output = self.robot_state_tokens_proj(robot_state)
return output
class FourierFeatures(ConfigurableModule):
def __init__(self, config: FourierFeaturesConfig):
super().__init__(config)
if self.config.learnable_features:
self.linear = torch.nn.Linear(
in_features=1, out_features=self.config.num_features // 2, bias=False
)
else:
half_dim = self.config.num_features // 2
freqs = torch.log(torch.tensor(self.config.max_period)) / (half_dim - 1)
freqs = torch.exp(-freqs * torch.arange(half_dim))
self.register_buffer("freqs", freqs)
self.layers: torch.nn.Sequential = make_mlp(
self.config.layers,
activation=self.config.activation,
norm=self.config.norm,
activate_final=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Compute Fourier features and project them via MLP
Args:
x: Input tensor of shape [..., 1]
Returns:
torch.Tensor: Fourier features of shape [..., num_features] or [..., layers[-1]]
"""
assert x.shape[-1] == 1 and x.ndim > 1, x.shape
if self.config.learnable_features:
frequencies = 2 * math.pi * self.linear(x)
else:
frequencies = x * expand_dims(self.freqs, x.ndim, [-1, 1])
output = torch.cat([torch.cos(frequencies), torch.sin(frequencies)], dim=-1)
output = self.layers(output)
return output
class NoisedControlProjector(ConfigurableModule):
"""Pack noised control (translation, rotation, gripper) and project to a single token per timestep"""
def __init__(self, config: NoisedControlProjectorConfig):
super().__init__(config)
self.input_projector = torch.nn.Linear(
in_features=self.config.layers[0],
out_features=self.config.layers[1] // 2,
bias=False,
)
self.time_embed = FourierFeatures(self.config.time_embed)
self.layers = make_mlp(
self.config.layers[1:],
activation=self.config.activation,
norm=self.config.norm,
activate_final=False,
bias=False,
)
def forward(self, inputs: FlowInput | DiffusionInput) -> Optional[torch.Tensor]:
"""
Returns:
torch.Tensor of shape [B, num_control_timesteps, token_size]
"""
noised_controls = torch.cat([inputs.translation_t, inputs.rotation_t, inputs.gripper_t], dim=-1)
noised_controls = self.input_projector(noised_controls)
timestep = self.time_embed(inputs.timestep)
timestep = timestep.expand(-1, noised_controls.shape[1], -1)
features = torch.cat([timestep, noised_controls], dim=-1)
output = self.layers(features)
return output
def unmask_unattended(attn_mask: torch.Tensor, mask_value: Optional[float] = None) -> torch.Tensor:
"""
Copy-pased from `transformers.modeling_attn_mask_utils.AttentionMaskConverter._unmask_unattended`
Attend to all tokens in fully-masked rows. This is required by F.scaled_dot_product_attention
memory-efficient attention path. Otherwise, results are NaN
Details: https://github.com/pytorch/pytorch/issues/110213
Args:
attn_mask: [B, 1 | num_heads, query_seq_len, kv_seq_len] or [B, query_seq_len, kv_seq_len], float dtype
mask_value: The value inside `attn_mask` that corresponds to masked elements
Returns:
For example, if `attn_mask` is (e.g. here left-padding case)
```
[
[[
[0, 0, 0],
[0, 0, 0],
[0, 0, 1]
]],
[[
[1, 0, 0],
[1, 1, 0],
[1, 1, 1]
]],
[[
[0, 0, 0],
[0, 1, 0],
[0, 1, 1]
]]
]
```
then the modified `attn_mask` will be
```
[
[[
[1, 1, 1], <-- modified
[1, 1, 1], <-- modified
[0, 0, 1]
]],
[[
[1, 0, 0],
[1, 1, 0],
[1, 1, 1]
]],
[[
[1, 1, 1], <-- modified
[0, 1, 0],
[0, 1, 1]
]]
]
```
"""
assert attn_mask.dtype.is_floating_point, attn_mask.dtype
if mask_value is None:
mask_value = torch.finfo(attn_mask.dtype).min
return attn_mask * ~torch.all(attn_mask == mask_value, dim=-1, keepdim=True)
@torch.no_grad()
def attn_mask_to_float(attn_mask: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Convert a 4D mask of type bool to `dtype`. If the attn_mask isn't 4D or isn't bool, raise error
"""
assert attn_mask.ndim == 4, attn_mask.shape
assert attn_mask.dtype == torch.bool, attn_mask.dtype
if dtype is None:
dtype = torch.get_autocast_dtype(attn_mask.device.type)
mask_value = torch.finfo(dtype).min
attn_mask = torch.zeros(attn_mask.shape, dtype=dtype, device=attn_mask.device).masked_fill(
~attn_mask, mask_value
)
attn_mask = unmask_unattended(attn_mask, mask_value)
return attn_mask
@torch.no_grad()
def make_4d_float_attn_mask(
attn_mask: Optional[torch.Tensor],
query_seq_length: int,
kv_seq_length: int,
dtype: torch.dtype,
device: torch.device,
batch_size: int,
) -> torch.Tensor:
"""
Creates a 4D mask of shape [B | 1, 1, query_length, kv_seq_length] from a 2D mask of shape [B, kv_seq_length].
If the input `attn_mask` is already 4D: if dtype=torch.bool, convert to dtype, else do nothing
If the input is None, output is a full bi-directional attn_mask
Args:
attn_mask: A 2D attention mask of shape [B, kv_seq_length] or [B, 1, query_length, kv_seq_length]
and dtype bool. False values indicate masked out positions
query_seq_length: The query sequence length (L)
kv_seq_length: The key-value sequence length (S). When `transformers.StaticCache` is used, this should
equal the cache size to account for zero-padding the part of the cache that is not yet filled.
dtype: Output dtype
device: Output device
batch_size: Batch size
Returns:
torch.Tensor of shape [B | 1, 1, query_length, kv_seq_length] (i.e. [B | 1, 1, L, S]).
Contains zero at unmasked positions and `torch.finfo(dtype).min` at masked positions
"""
if attn_mask is not None and attn_mask.ndim == 4:
if attn_mask.dtype == torch.bool:
attn_mask = attn_mask_to_float(attn_mask, dtype=dtype)
elif attn_mask.dtype != dtype:
raise TypeError(f"Expected attn_mask.dtype={dtype}, but got {attn_mask.dtype}")
return attn_mask
mask_value = torch.finfo(dtype).min
output_mask = torch.zeros([batch_size, 1, query_seq_length, kv_seq_length], dtype=dtype, device=device)
if attn_mask is not None:
assert attn_mask.dtype == torch.bool, f"Unsupported dtype {attn_mask.dtype}"
mask_length = attn_mask.shape[-1]
if mask_length != kv_seq_length:
raise NotImplementedError(f"{mask_length} != {kv_seq_length} not properly supported yet")
inverted_mask = ~attn_mask.view(batch_size, 1, 1, mask_length)
output_mask[..., :mask_length] = output_mask[..., :mask_length].masked_fill(inverted_mask, mask_value)
return output_mask
class VLMInput(Protocol):
input_ids: torch.Tensor
attn_mask: torch.Tensor
images: Dict[str, torch.Tensor]
multimodal_indices: torch.Tensor
unimodal_indices: torch.Tensor
@property
def inputs_embeds(self) -> Optional[torch.Tensor]:
return None
@property
def past_key_values(self) -> Optional[List[torch.Tensor]]:
return None
def zero_out_param_pretrained_grad(
param: torch.nn.Parameter | torch.distributed.tensor.DTensor,
module: torch.nn.Module,
) -> None:
"""Zero out the gradients of pretrained embeddings"""
module.mask = module.mask.to(param.device)
if isinstance(param, torch.distributed.tensor.DTensor) and not isinstance(
module.mask, torch.distributed.tensor.DTensor
):
module.mask = torch.distributed.tensor.distribute_tensor(
module.mask, device_mesh=param.device_mesh, placements=param.placements
)
mask = module.mask
if type(param) is torch.distributed.tensor.DTensor and type(mask) is torch.distributed.tensor.DTensor:
assert param.grad.shape == mask.shape, f"{param.grad.shape} != {mask.shape}"
param.grad._local_tensor = torch.where(mask._local_tensor, param.grad._local_tensor, 0)
elif type(param) in (torch.Tensor, torch.nn.Parameter) and type(mask) is torch.Tensor:
assert param.grad.shape == mask.shape, f"{param.grad.shape} != {mask.shape}"
param.grad = torch.where(mask, param.grad, 0)
elif type(param) in (torch.Tensor, torch.nn.Parameter) and type(mask) is torch.distributed.tensor.DTensor:
mask = mask.full_tensor()
assert param.grad.shape == mask.shape, f"{param.grad.shape} != {mask.shape}"
param.grad = torch.where(mask, param.grad, 0)
else:
raise ValueError(f"Unsupported parameter type: {type(param)} and mask type: {type(mask)}")
class PaliGemmaVLM(ConfigurableModule):
"""Wraps PaliGemma to make compatible with VLM API"""
def __init__(self, config: PaliGemmaVLMConfig):
super().__init__(config)
if self.config.with_depth:
config = PaliGemma3DConfig.from_pretrained(
self.config.model_id, **self.config.paligemma_3d_config_dict
)
self.model = PaliGemma3D.from_pretrained(
self.config.model_id,
config=config,
attn_implementation=self.config.attn_implementation,
)
else:
self.model = transformers.AutoModelForVision2Seq.from_pretrained(
self.config.model_id,
attn_implementation=self.config.attn_implementation,
)
hf_processor = transformers.AutoProcessor.from_pretrained(self.config.model_id)
self.processor = PaliGemmaDepthProcessor(
config=self.config.processor_config,
hf_processor=hf_processor,
depth_tokens=self.config.depth_tokens,
)
self._resize_siglip_image_input()
self._maybe_override_get_image_features()
if self.config.depth_tokens > 0:
self._resize_llm_token_embeddings(hf_processor.tokenizer)
if not self.config.lm_head:
self.model.language_model.lm_head = torch.nn.Identity()
self.model.train(True)
for decoder in self.model.language_model.model.layers:
decoder.self_attn.is_causal = False
def forward(
self,
inputs: VLMInput,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> VLMOutput:
del kwargs
self._maybe_register_zero_out_grad_hooks()
cache = transformers.DynamicCache()
if inputs.attn_mask.ndim == 4:
attn_mask = attn_mask_to_float(inputs.attn_mask)
else:
attn_mask = inputs.attn_mask
images = {
encoder_camera_name: camera_images.view(-1, *camera_images.shape[2:])
for (encoder_camera_name, camera_images) in inputs.images.items()
}
llm_output: transformers.models.paligemma.modeling_paligemma.PaliGemmaCausalLMOutputWithPast = (
self.model(
input_ids=inputs.input_ids,
pixel_values=images,
attention_mask=attn_mask,
use_cache=use_cache,
past_key_values=cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
)
indices = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device=inputs.input_ids.device)
image_indices = indices[inputs.input_ids[0] == self.processor.hf_processor.image_token_id]
text_indices = indices[inputs.input_ids[0] != self.processor.hf_processor.image_token_id]
output = VLMOutput(
llm_output=LLMOutput.from_transformers(
input_ids=inputs.input_ids,
llm_output=llm_output,
text_indices=text_indices,
image_indices=image_indices,
),
vit_tokens=llm_output.image_hidden_states,
attn_mask=inputs.attn_mask,
)
return output
@property
def fsdp_wrap_modules(self) -> Set[torch.nn.Module]:
transformer_modules = {
module
for module in self.modules()
if isinstance(
module,
(
transformers.models.siglip.modeling_siglip.SiglipEncoderLayer,
transformers.models.siglip.modeling_siglip.SiglipVisionTransformer,
timm.models.vision_transformer.Block,
timm.models.vision_transformer.VisionTransformer,
transformers.models.gemma.modeling_gemma.GemmaDecoderLayer,
),
)
or module
in (
self.model.language_model.model.embed_tokens,
self.model.language_model.model.norm,
)
}
if self.config.with_depth:
projectors = self.model.projectors
else:
projectors = {self.model.multi_modal_projector}
return projectors | transformer_modules
def _resize_siglip_image_input(self) -> None:
"""
Enables resizing SigLIP positional embeddings to a new image size.
"""
num_image_tokens: int = self.config.processor_config.num_image_tokens["main"]
image_size: Dict[str, int] = self.config.processor_config.image_sizes["main"].as_json()
siglip_embeddings: transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings = (
self.model.vision_tower.vision_model.embeddings
)
embedding_weight: torch.Tensor = timm.layers.pos_embed.resample_abs_pos_embed(
posemb=siglip_embeddings.position_embedding.weight.unsqueeze(0),
new_size=(image_size["height"] // 14, image_size["width"] // 14),
old_size=(
siglip_embeddings.image_size // 14,
siglip_embeddings.image_size // 14,
),
num_prefix_tokens=0,
interpolation="bicubic",
antialias=True,
verbose=False,
).squeeze(0)
with torch.no_grad():
siglip_embeddings.position_embedding.weight.data = embedding_weight
siglip_embeddings.num_patches = siglip_embeddings.num_positions = num_image_tokens
siglip_embeddings.image_size = dict(image_size)
siglip_embeddings.register_buffer(
"position_ids",
torch.arange(siglip_embeddings.num_positions).expand((1, -1)),
persistent=False,
)
def interpolate_pos_encoding(embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
del embeddings
new_size = (height // 14, width // 14)
old_size = (image_size["height"] // 14, image_size["width"] // 14)
if old_size == new_size:
patch_pos_embedding = siglip_embeddings.position_embedding.weight
else:
patch_pos_embedding: torch.Tensor = timm.layers.pos_embed.resample_abs_pos_embed(
posemb=siglip_embeddings.position_embedding.weight.unsqueeze(0),
new_size=new_size,
old_size=(
image_size["height"] // 14,
image_size["width"] // 14,
),
num_prefix_tokens=0,
interpolation="bicubic",
antialias=True,
verbose=False,
).squeeze(0)
return patch_pos_embedding
siglip_embeddings.interpolate_pos_encoding = interpolate_pos_encoding
if not self.processor.config.is_single_image_size:
self.model.vision_tower.forward = functools.partial(
self.model.vision_tower.forward, interpolate_pos_encoding=True
)
def _maybe_override_get_image_features(self) -> None:
"""
Override PaliGemmaForConditionalGeneration.get_image_features() from transformers
such that it can handle multiple cameras with different resolutions.
"""
if self.config.with_depth:
return
images_forward = self.model.get_image_features
camera_names: List[str] = self.processor.config.camera_names
def get_image_features(pixel_values: torch.Tensor) -> torch.Tensor:
if self.processor.config.is_single_image_size:
inputs = einops.rearrange(
torch.stack(
[pixel_values[f"{camera_name}.siglip"] for camera_name in camera_names],
dim=1,
),
"B N C H W -> (B N) C H W",
)
image_tokens = images_forward(inputs)
else:
camera_tokens: List[torch.Tensor] = [
images_forward(pixel_values[f"{camera_name}.siglip"]) for camera_name in camera_names
]
image_tokens = torch.cat(camera_tokens, dim=-2)
return image_tokens
self.model.get_image_features = get_image_features
def _resize_llm_token_embeddings(self, tokenizer: transformers.PreTrainedTokenizer) -> None:
assert self.config.depth_tokens > 0, self.config.depth_tokens
tokenizer.add_tokens([f"<dist{i:04d}>" for i in range(self.config.depth_tokens)])
total_num_tokens = len(tokenizer)
vocab_size = tokenizer.vocab_size
llm = self.model.language_model
(_, hidden_size) = llm.lm_head.weight.shape
self.model.resize_token_embeddings(
total_num_tokens,
pad_to_multiple_of=64,
mean_resizing=self.config.mean_resizing,
)
if self.config.train_only_depth_tokens:
assert len(self.model.language_model._tied_weights_keys) > 0
(weight_size, hidden_size) = llm.lm_head.weight.shape
mask = torch.cat(
[
torch.zeros([vocab_size, hidden_size], dtype=torch.bool),
torch.ones([weight_size - vocab_size, hidden_size], dtype=torch.bool),
],
dim=0,
)
self.mask = mask
self.embed_handle = None
self.lm_head_handle = None
self._maybe_register_zero_out_grad_hooks()
def _maybe_register_zero_out_grad_hooks(self) -> None:
"""
Register hooks to zero out the gradients of pretrained embeddings and LM head.
Skips registering the hooks if they already exist. This runs at every step as wrapping
in FSDP removes any hooks that were registered on the *parameters* of the original module
and the only way to run this reliably is to check if the hooks exist.
"""
if not self.config.train_only_depth_tokens:
return
llm = self.model.language_model
if (
self.embed_handle is None
or llm.model.embed_tokens.weight._post_accumulate_grad_hooks is None
or self.embed_handle.id not in llm.model.embed_tokens.weight._post_accumulate_grad_hooks
):
self.embed_handle = llm.model.embed_tokens.weight.register_post_accumulate_grad_hook(
functools.partial(zero_out_param_pretrained_grad, module=self)
)
if llm.model.embed_tokens.weight is not llm.lm_head.weight and (
self.lm_head_handle is None
or llm.lm_head.weight._post_accumulate_grad_hooks is None
or self.lm_head_handle.id not in llm.lm_head.weight._post_accumulate_grad_hooks
):
self.lm_head_handle = llm.lm_head.weight.register_post_accumulate_grad_hook(
functools.partial(zero_out_param_pretrained_grad, module=self)
)
def make_position_indices(
position_indices: Optional[torch.Tensor],
seq_length: int,
device: torch.device,
max_seq_length: Optional[int],
) -> torch.Tensor:
if position_indices is not None:
position_indices = position_indices.to(dtype=torch.int64)
else:
position_indices = torch.arange(seq_length, dtype=torch.int64, device=device).view(1, -1)
if not torch.max(position_indices) < max_seq_length:
raise IndexError(
f"position_indices={position_indices} contains index out of bounds of num_embeddings={max_seq_length}"
)
return position_indices
class RotaryPositionalEncoding(ConfigurableModule):
"""
Rotary Positional Embeddings (RoPE) from https://arxiv.org/abs/2104.09864
Reference implementations:
- https://github.com/meta-llama/llama/blob/main/llama/model.py#L80
- transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding
- transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
If cached=True, we cache the embeddings for each position up to `num_embeddings`
"""
def __init__(self, config: RotaryPositionalEncodingConfig):
super().__init__(config)
inv_freq = 1.0 / self.config.base ** (
torch.arange(0, self.config.embedding_dim, 2, dtype=torch.float32) / self.config.embedding_dim
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._build_cache()
def _build_cache(self) -> None:
if not self.config.cached:
return
position_indices = torch.arange(self.config.num_embeddings, dtype=torch.float32)
indices_inv_freq = torch.einsum("i, j -> ij", position_indices, self.inv_freq)
sin = torch.sin(indices_inv_freq)
cos = torch.cos(indices_inv_freq)
self.register_buffer("sin_cache", sin, persistent=False)
self.register_buffer("cos_cache", cos, persistent=False)
def forward(
self,
tokens: torch.Tensor,
position_indices: Optional[torch.Tensor] = None,
apply: bool = True,
) -> torch.Tensor:
"""
Args:
tokens: torch.Tensor of shape [B, ..., S, head_dim], where `...` might be any number of dims
position_indices: torch.Tensor of shape [B | 1, S]. The indices of tokens within the sequence
apply: If True, apply the positional embedding on tokens and return the result
Returns:
torch.Tensor of the same shape as `tokens` with positional embedding applied on tokens
"""
assert apply, f"{self.__class__} does not support applying embeddings externally"
position_indices = make_position_indices(
position_indices,
seq_length=tokens.shape[-2],
device=tokens.device,
max_seq_length=self.config.num_embeddings,
)
if self.config.cached:
sin = self.sin_cache[position_indices]
cos = self.cos_cache[position_indices]
sin = torch.cat([sin, sin], dim=-1)
cos = torch.cat([cos, cos], dim=-1)
else:
inv_freq = self.inv_freq.view(1, -1, 1).to(dtype=torch.float32)
position_indices = position_indices.to(dtype=torch.float32).unsqueeze(1)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="In CPU autocast, but the target dtype is not supported. Disabling autocast.",
)
with torch.autocast(device_type=tokens.device.type, dtype=torch.float32):
freqs = (inv_freq @ position_indices).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
(sin, cos) = (torch.sin(emb), torch.cos(emb))
(sin, cos) = (sin.to(dtype=tokens.dtype), cos.to(dtype=tokens.dtype))
sin = expand_dims(sin, tokens.ndim, order=[1, -1, 1, 1])
cos = expand_dims(cos, tokens.ndim, order=[1, -1, 1, 1])
tokens = tokens * cos + self._rotate_invert_half(tokens) * sin
return tokens
@staticmethod
def _rotate_invert_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
EAGER_ATTN = "eager"
SDPA_ATTN = "sdpa"
FLASH_ATTN = "flash_attention_2"
def is_full_attn(attn_mask: Optional[torch.Tensor]) -> bool:
"""
Return True if attn_mask doesn't contain any masked out positions, False otherwise
"""
if attn_mask is None:
return True
if attn_mask.dtype == torch.bool:
return torch.all(attn_mask == 1).item()
if attn_mask.dtype.is_floating_point:
return torch.all(attn_mask == 0).item()
raise TypeError(f"Unrecognized dtype {attn_mask.dtype}")
@torch.no_grad()
def make_attn_mask_causal(attn_mask: torch.Tensor, cache_position: torch.Tensor) -> torch.Tensor:
"""
Args:
attn_mask: 4D tensor of shape [B | 1, 1, query_seq_len, kv_seq_len] (i.e. [B | 1, 1, L, S]) of float
dtype (NOT bool!). Masked positions contain the value `torch.finfo(dtype).min`
cache_position: torch.Tensor of type torch.int64 and shape [query_seq_len]. Contained values
are index positions of the query tokens in the sequence. During training, this would usually
be torch.arange(query_seq_len), but during generate, this would usually be a tensor sequence
with 1 element indicating the position of the token currently being generated
Returns:
torch.Tensor of the same shape as attn_mask. Contains zero at unmasked positions and
`torch.finfo(dtype).min` at masked positions
"""
if attn_mask.dtype.is_floating_point:
mask_value = torch.finfo(attn_mask.dtype).min
elif attn_mask.dtype == torch.bool:
mask_value = 0
else:
raise TypeError(f"Unsupported mask type {attn_mask.dtype}")
(_, _, query_seq_length, kv_seq_length) = attn_mask.shape
causal_mask = torch.ones(attn_mask.shape, dtype=torch.bool, device=attn_mask.device)
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask = causal_mask * (
torch.arange(kv_seq_length, device=cache_position.device).view(1, -1) > cache_position.view(-1, 1)
).view(*[1] * (causal_mask.ndim - 2), query_seq_length, kv_seq_length)
causal_attn_mask = attn_mask.masked_fill_(causal_mask, mask_value)
return causal_attn_mask
def update_attn_mask(
attn_mask: Optional[torch.Tensor],
attn_implementation: str,
query_seq_length: int,
kv_seq_length: int,
cache_position: Optional[torch.Tensor],
cache: Optional[transformers.Cache],
batch_size: int,
causal: bool,
dtype: torch.dtype,
device: torch.device,
output_attentions: bool = False,
) -> Optional[torch.Tensor]:
"""
Update attn_mask such that it's compatible with the attention implementation.
Meant to be used with barrel.components.nn.layers.attention.MultiheadAttention and its derivatives
Args:
attn_mask: dtype torch.bool, torch.float32, torch.float16 or torch.bfloat16 and shape one of:
- [B, kv_seq_length] (i.e. [B, S])
- [B, 1, query_seq_length, kv_seq_length] (i.e. [B, 1, L, S])
- [1, 1, query_seq_length, kv_seq_length] (i.e. [L, S])
If bool, False values indicate masked positions.
If float, must contain only 0.0 and torch.finfo(dtype).min
If attn_mask is None, full-bidirectional attention is assumed. The output might be None or
a tensor. Refer to the return value documentation
attn_implementation: One of [FLASH_ATTN, SDPA_ATTN, EAGER_ATTN]
query_seq_length: The query sequence length (L)
kv_seq_length: The key-value sequence length (S)
cache_position: dtype torch.int64, shape [query_seq_len]. Used only when causal=True.
Contained values are index positions of the query tokens in the sequence. During training,
this would usually be torch.arange(query_seq_len), but during generate, this would usually be
a tensor sequence with 1 element indicating the position of the token currently being generated.
If None, default `cache_positions` are autocomputed from `query_seq_length` and cache size
cache: Optional cache. Usually not None when running generate at inference.
batch_size: Batch size of the generated attention mask
causal: If True, make the attn_mask causal -> all non-causal positions are masked out, regardless
of their attn_mask values. When using flash attention or SDPA and `causal == False`, make sure
to pass `causal` to the attention operation, in case this function delegates causal masking
dtype: dtype of the output attention mask. Must be the dtype of the attn computation
device: device of the output attention mask
output_attentions: If True, the attention operation is required to output attention weights
Returns:
- `None` in either of these cases:
- `attn_mask` doesn't contain any masked out positions and causal=False
- `attn_implementation in [FLASH_ATTN, SDPA_ATTN]` and `attn_mask` doesn't contain any
masked out positions. If causal=True, we instead rely on the causal argument to
flash attention or `torch.nn.functional.scaled_dot_product_attention`. This happens
only if the cache is empty and cache_position is None
- `attn_mask` if `attn_implementation == FLASH_ATTN` and `attn_mask` can't be ignored TODO(FLASH)
- torch.Tensor of shape [B, 1, query_length, kv_seq_length] (i.e. [B, 1, L, S]) and type `dtype`.
Contains zero at unmasked positions and `torch.finfo(dtype).min` at masked positions.
"""
assert attn_implementation in [FLASH_ATTN, SDPA_ATTN, EAGER_ATTN]
assert dtype.is_floating_point, dtype
if torch.jit.is_tracing() or torch.jit.is_scripting() or torch.compiler.is_compiling():
raise NotImplementedError("Complete correctness not confirmed yet")
if isinstance(cache, transformers.StaticCache):
if attn_mask is not None and attn_mask.shape[-1] != cache.get_max_cache_shape():
raise NotImplementedError("Complete correctness not confirmed yet")
full_attn = is_full_attn(attn_mask)
past_seen_tokens = cache.get_seq_length() if cache is not None else 0
if full_attn and not causal:
return None
if (
full_attn
and causal
and attn_implementation in [SDPA_ATTN, FLASH_ATTN]
and past_seen_tokens == 0
and cache_position is None
):
return None
past_seen_tokens = cache.get_seq_length() if cache is not None else 0
static_cache = isinstance(cache, transformers.StaticCache)
if static_cache and kv_seq_length < cache.get_max_cache_shape():
kv_seq_length = cache.get_max_cache_shape()
elif attn_mask is not None:
assert kv_seq_length == attn_mask.shape[-1], f"{kv_seq_length}, {attn_mask.shape}"
output_mask = make_4d_float_attn_mask(
attn_mask=attn_mask,
query_seq_length=query_seq_length,
kv_seq_length=kv_seq_length,
dtype=dtype,
device=device,
batch_size=batch_size,
)
if causal:
cache_position = (
torch.arange(past_seen_tokens, past_seen_tokens + query_seq_length, device=device)
if cache_position is None
else cache_position
)
output_mask = make_attn_mask_causal(output_mask, cache_position)
if (
attn_implementation == SDPA_ATTN
and attn_mask is not None
and attn_mask.device.type == "cuda"
and not output_attentions
):
output_mask = unmask_unattended(output_mask, mask_value=torch.finfo(dtype).min)
return output_mask
def expand_kv_heads(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
The equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). Convert hidden_states from
[batch, num_kv_heads, seqlen, head_dim] -> [batch, num_attention_heads, seqlen, head_dim]
"""
(batch, num_kv_heads, slen, head_dim) = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim)
class MultiheadAttention(torch.nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper
Different implementation from torch.nn.MultiheadAttention to support:
- Easy switch between EAGER_ATTN, SDPA_ATTN and FLASH_ATTN
- Number of key-value heads different from query heads
- Key-value cache during forward, in the same way as transformers. Useful for generation or
cross-attention to projected keys and values
- Ability to apply positional encodings to key and value after input linear projection
- Different linear projection output size
Adapted from transformers.models.llama.modeling_llama.LlamaAttention
"""
def __init__(
self,
in_features: int,
num_heads: int,
head_dim: Optional[int] = None,
out_features: Optional[int] = None,
key_features: Optional[int] = None,
value_features: Optional[int] = None,
num_kv_heads: Optional[int] = None,
bias: bool = False,
dropout: float = 0.0,
cache_layer: Optional[int] = None,
query_position_embed: Optional[Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]] = None,
key_position_embed: Optional[Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]] = None,
):
"""
Args:
in_features: Input dimension for query linear projection.
num_heads: Number of heads for query
head_dim: Head dimension. If None, defaults to `in_features // num_heads`
out_features: Output dimension for the output linear layer. If None, defaults to `in_features`
key_features: Input dimension for key linear projection. If None, defaults to `in_features`
value_features: Input dimension for value linear projection. If None, defaults to `in_features`
num_kv_heads: Number of heads for keys and values. If None, defaults to `num_heads`
cache_layer: Index of the layer in the cache. Needed only when `cache` is passed to
the `forward()` call, usually during generation or when the projected keys and values need
to be cached during training. Can be omitted when `cache_layer` is passed to `forward`
position_embed: Callable that takes as input linearly projected query and key and a tuple of
positional embeddings and returns query and key with positional embeddings applied. Note
these embeddings are applied after linear projection. If you want to apply embeddings before
the linear projection, do so before calling the forward method and use the default value
for `position_embed`, which is a simple pass-through. Note you can also pass torch.nn.Module
key_position_embed: Callable that takes as input linearly projected key and optional positional
index in the sequence and returns key with positional embeddings applied.
positional embeddings and returns query and key with positional embeddings applied. Note
these embeddings are applied after linear projection. If you want to apply embeddings before
the linear projection, do so before calling the forward method and use the default value
for `position_embed`, which is a simple pass-through. Note you can also pass torch.nn.Module
"""
super().__init__()
self.in_features = in_features
self.key_features = key_features or in_features
self.value_features = value_features or in_features
self.bias = bias
self.out_features = out_features or in_features
self.num_heads = num_heads
self.head_dim = head_dim or in_features // num_heads
self.num_kv_heads = num_kv_heads or num_heads
self.dropout = dropout
self.query_position_embed = query_position_embed
self.key_position_embed = key_position_embed
self.cache_layer = cache_layer
self.q_proj = torch.nn.Linear(self.in_features, self.num_heads * self.head_dim, bias=self.bias)
self.k_proj = torch.nn.Linear(self.key_features, self.num_kv_heads * self.head_dim, bias=self.bias)
self.v_proj = torch.nn.Linear(self.value_features, self.num_kv_heads * self.head_dim, bias=self.bias)
self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.out_features, bias=self.bias)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
query_position_indices: Optional[torch.Tensor] = None,
key_position_indices: Optional[torch.Tensor] = None,
cache: Optional[transformers.Cache] = None,
cache_layer: Optional[int] = None,
output_attentions: bool = False,
cache_kwargs: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
query: Query embedding of shape [B, L, in_features]
key: Key embedding of shape [B, S, key_features]
value: Value embedding of shape [B, S, value_features]
attn_mask: dtype torch.bool or same dtype as query/key/value and shape one of:
- [B, S]
- [B | 1, 1 | num_heads, L, S]
If bool, False values indicate masked positions (opposite of torch.nn.MultiheadAttention)
If float, must contain only 0.0 and torch.finfo(dtype).min
If attn_mask is None, full-bidirectional attention or causal attention is used depdening
on the value of `is_causal`.
is_causal: If True, apply additional causal masking to `attn_mask`
query_position_indices: torch.Tensor of shape [1 | B, L] containing the indices of the `query`
tokens within the entire sequence. Passed through to query_position_embed. If None and `cache`
is not None, indices are autogenerated [0, 1, ..., L] and offset by `cache_size`
key_position_indices: Same as `query_position_indices`, but applied to key
cache: transformers.Cache containing cached key-value pairs. The linearly projected
`key` and `value` passed to this function get added to the cache and concatenated after the
key-value pairs in the cache and then attention is computed on the concatenated sequence.
This is most commonly used at inference when generating auto-regressively or when one needs
to cross attend to the keys and values outside this module forward pass.
cache_layer: Index of the layer in the cache. Needed only when `cache` is passed to
the `forward()` call, usually during generation or when the projected keys and values need
to be cached during training. Can be omitted when `cache_layer` was passed to `__init__`
output_attentions: If True, output also the attention weights. Otherwise output None.
Note that only the eager implementation of MultiheadAttention supports this.
cache_kwargs: kwargs directly passed to `cache.update()`
Returns:
Tuple with entries:
- Attention block output: torch.Tensor of shape [B, L, out_features]
- Optional attention weights if `output_attentions=True`, shape [B, num_heads, L, S]
"""
batch_size = query.shape[0]
query_states = self.q_proj(query)
key_states = self.k_proj(key)
value_states = self.v_proj(value)
query_states = query_states.view(
batch_size, query_states.shape[1], self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
batch_size, key_states.shape[1], self.num_kv_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
batch_size, value_states.shape[1], self.num_kv_heads, self.head_dim
).transpose(1, 2)
(query_states, key_states) = self._maybe_apply_positional_embeddings(
query_states=query_states,
key_states=key_states,
query_position_indices=query_position_indices,
key_position_indices=key_position_indices,
cache=cache,
)
(key_states, value_states) = self._maybe_update_cache(
key_states,
value_states,
cache_layer=cache_layer,
cache=cache,
cache_kwargs=cache_kwargs,
)
key_states = expand_kv_heads(key_states, self.num_heads // self.num_kv_heads)
value_states = expand_kv_heads(value_states, self.num_heads // self.num_kv_heads)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_mask = update_attn_mask(
attn_mask,
attn_implementation=EAGER_ATTN,
query_seq_length=query_states.shape[2],
kv_seq_length=value_states.shape[2],
cache_position=query_position_indices,
cache=cache,
batch_size=batch_size,
causal=is_causal,
dtype=query_states.dtype,
device=query_states.device,
output_attentions=output_attentions,
)
if attn_mask is not None:
attn_mask = attn_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + attn_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
shape = (batch_size, self.num_heads, query.shape[1], self.head_dim)
assert attn_output.shape == shape, f"{attn_output.shape} != {shape}"
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
def _maybe_apply_positional_embeddings(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
query_position_indices: Optional[torch.Tensor],
key_position_indices: Optional[torch.Tensor],
cache: Optional[transformers.Cache],
) -> Tuple[torch.Tensor, torch.Tensor]:
device = query_states.device
if self.query_position_embed is not None:
if query_position_indices is None and cache is not None:
query_position_indices = (
torch.arange(query_states.shape[-2], dtype=torch.int64, device=device).view(1, -1)
+ cache.get_seq_length()
)
query_states = self.query_position_embed(query_states, position_indices=query_position_indices)
if self.key_position_embed is not None:
if key_position_indices is None and cache is not None:
key_position_indices = (
torch.arange(key_states.shape[-2], dtype=torch.int64, device=device).view(1, -1)
+ cache.get_seq_length()
)
key_states = self.key_position_embed(key_states, position_indices=key_position_indices)
return query_states, key_states
def _maybe_update_cache(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_layer: Optional[int],
cache: Optional[transformers.Cache],
cache_kwargs: Dict[str, Any],
) -> Tuple[torch.Tensor, torch.Tensor]:
if cache is not None:
if cache_layer is None and self.cache_layer is None:
raise RuntimeError("When cache != None, cache_layer has to be set")
cache_layer = cache_layer if cache_layer is not None else self.cache_layer
(key_states, value_states) = cache.update(key_states, value_states, cache_layer, cache_kwargs)
return key_states, value_states
class MultiheadFlashAttention2(MultiheadAttention):
"""
MultiheadAttention implemented using flash attention module. Inherits `MultiheadAttention` as the weights
of the module stay untouched. The only change is on the forward pass where we call flash attention.
"""
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
query_position_indices: Optional[torch.Tensor] = None,
key_position_indices: Optional[torch.Tensor] = None,
cache: Optional[transformers.Cache] = None,
cache_layer: Optional[int] = None,
output_attentions: bool = False,
cache_kwargs: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
query: Query embedding of shape [B, L, in_features]
key: Key embedding of shape [B, S, key_features]
value: Value embedding of shape [B, S, value_features]
attn_mask: dtype torch.bool and shape [B, S].
If bool, False values indicate masked positions (opposite of torch.nn.MultiheadAttention)
If attn_mask is None, full-bidirectional attention or causal attention is used depdening
on the value of `is_causal`.
NOTE: Doesn't support 4D attn_mask, unlike MultiheadAttention
is_causal: If True, apply additional causal masking to `attn_mask`
query_position_indices: torch.Tensor of shape [1 | B, L] containing the indices of the `query`
tokens within the entire sequence. Passed through to query_position_embed. If None and `cache`
is not None, indices are autogenerated [0, 1, ..., L] and offset by `cache_size`
key_position_indices: Same as `query_position_indices`, but applied to key
cache: transformers.Cache containing cached key-value pairs. The linearly projected
`key` and `value` passed to this function get added to the cache and concatenated after the
key-value pairs in the cache and then attention is computed on the concatenated sequence.
This is most commonly used at inference when generating auto-regressively or when one needs
to cross attend to the keys and values outside this module forward pass.
cache_layer: Index of the layer in the cache. Needed only when `cache` is passed to
the `forward()` call, usually during generation or when the projected keys and values need
to be cached during training. Can be omitted when `cache_layer` was passed to `__init__`
output_attentions: If True, output also the attention weights. Otherwise output None.
Note that only the eager implementation of MultiheadAttention supports this.
cache_kwargs: kwargs directly passed to `cache.update()`
Returns:
Tuple with entries:
- Attention block output: torch.Tensor of shape [B, L, out_features]
- Optional attention weights if `output_attentions=True`, shape [B, num_heads, L, S]
"""
if isinstance(cache, transformers.StaticCache):
raise ValueError(
"transformers.StaticCache not compatible with flash attention. Use `sdpa` instead (for now)."
)
assert output_attentions is False, f"{self.__class__} doesn't support output_attentions=True"
batch_size = query.shape[0]
query_states = self.q_proj(query)
key_states = self.k_proj(key)
value_states = self.v_proj(value)
query_states = query_states.view(
batch_size, query_states.shape[1], self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
batch_size, key_states.shape[1], self.num_kv_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
batch_size, value_states.shape[1], self.num_kv_heads, self.head_dim
).transpose(1, 2)
(query_states, key_states) = self._maybe_apply_positional_embeddings(
query_states=query_states,
key_states=key_states,
query_position_indices=query_position_indices,
key_position_indices=key_position_indices,
cache=cache,
)
(key_states, value_states) = self._maybe_update_cache(
key_states,
value_states,
cache_layer=cache_layer,
cache=cache,
cache_kwargs=cache_kwargs,
)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attn_mask = update_attn_mask(
attn_mask,
attn_implementation=FLASH_ATTN,
query_seq_length=query_states.shape[2],
kv_seq_length=value_states.shape[2],
cache_position=query_position_indices,
cache=cache,
batch_size=batch_size,
causal=is_causal,
dtype=query_states.dtype,
device=query_states.device,
output_attentions=output_attentions,
)
raise NotImplementedError("Correctness not yet confirmed")
attn_output = transformers.modeling_flash_attention_utils._flash_attention_forward(
query_states=query_states,
key_states=key_states,
value_states=value_states,
attention_mask=attn_mask,
query_length=query.shape[1],
position_ids=None,
dropout=self.dropout if self.training else 0.0,
sliding_window=None,
use_top_left_mask=False,
is_causal=is_causal,
deterministic=True,
)
size = (batch_size, self.num_heads, query.shape[1], self.head_dim)
if attn_output.size() != size:
raise ValueError(f"`attn_output` should be of size {size}, but is {attn_output.size()}")
shape = (batch_size, self.num_heads, query.shape[1], self.head_dim)
assert attn_output.shape == shape, f"{attn_output.shape} != {shape}"
attn_output = attn_output.reshape(batch_size, query.shape[1], -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, None
class MultiheadSdpaAttention(MultiheadAttention):
"""
MultiheadAttention SDPA attention. Inherits `MultiheadAttention` as the weights of the module stay untouched.
The only change is on the forward pass where we call `torch.nn.functional.scaled_dot_product_attention`
"""
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
query_position_indices: Optional[torch.Tensor] = None,
key_position_indices: Optional[torch.Tensor] = None,
cache: Optional[transformers.Cache] = None,
cache_layer: Optional[int] = None,
output_attentions: bool = False,
cache_kwargs: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
query: Query embedding of shape [B, L, in_features]
key: Key embedding of shape [B, S, key_features]
value: Value embedding of shape [B, S, value_features]
attn_mask: dtype torch.bool or same dtype as query/key/value and shape one of:
- [B, S]
- [B | 1, 1 | num_heads, L, S]
If bool, False values indicate masked positions (opposite of torch.nn.MultiheadAttention)
If float, must contain only 0.0 and torch.finfo(dtype).min
If attn_mask is None, full-bidirectional attention or causal attention is used depdening
on the value of `is_causal`.
is_causal: If True, apply additional causal masking to `attn_mask`
query_position_indices: torch.Tensor of shape [1 | B, L] containing the indices of the `query`
tokens within the entire sequence. Passed through to query_position_embed. If None and `cache`
is not None, indices are autogenerated [0, 1, ..., L] and offset by `cache_size`
key_position_indices: Same as `query_position_indices`, but applied to key
cache: transformers.Cache containing cached key-value pairs. The linearly projected
`key` and `value` passed to this function get added to the cache and concatenated after the
key-value pairs in the cache and then attention is computed on the concatenated sequence.
This is most commonly used at inference when generating auto-regressively or when one needs
to cross attend to the keys and values outside this module forward pass.
cache_layer: Index of the layer in the cache. Needed only when `cache` is passed to
the `forward()` call, usually during generation or when the projected keys and values need
to be cached during training. Can be omitted when `cache_layer` was passed to `__init__`
output_attentions: If True, output also the attention weights. Otherwise output None.
Note that only the eager implementation of MultiheadAttention supports this.
cache_kwargs: kwargs directly passed to `cache.update()`
Returns:
Tuple with entries:
- Attention block output: torch.Tensor of shape [B, L, out_features]
- Optional attention weights if `output_attentions=True`, shape [B, num_heads, L, S]
"""
assert output_attentions is False, f"{self.__class__} doesn't support output_attentions=True"
batch_size = query.shape[0]
query_states = self.q_proj(query)
key_states = self.k_proj(key)
value_states = self.v_proj(value)
query_states = query_states.view(
batch_size, query_states.shape[1], self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
batch_size, key_states.shape[1], self.num_kv_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
batch_size, value_states.shape[1], self.num_kv_heads, self.head_dim
).transpose(1, 2)
(query_states, key_states) = self._maybe_apply_positional_embeddings(
query_states=query_states,
key_states=key_states,
query_position_indices=query_position_indices,
key_position_indices=key_position_indices,
cache=cache,
)
(key_states, value_states) = self._maybe_update_cache(
key_states,
value_states,
cache_layer=cache_layer,
cache=cache,
cache_kwargs=cache_kwargs,
)
key_states = expand_kv_heads(key_states, self.num_heads // self.num_kv_heads)
value_states = expand_kv_heads(value_states, self.num_heads // self.num_kv_heads)
attn_mask = update_attn_mask(
attn_mask,
attn_implementation=SDPA_ATTN,
query_seq_length=query_states.shape[2],
kv_seq_length=value_states.shape[2],
cache_position=query_position_indices,
cache=cache,
batch_size=batch_size,
causal=is_causal,
dtype=query_states.dtype,
device=query_states.device,
output_attentions=output_attentions,
)
if attn_mask is not None:
attn_mask = attn_mask[:, :, :, : key_states.shape[-2]]
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attn_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
shape = (batch_size, self.num_heads, query.shape[1], self.head_dim)
assert attn_output.shape == shape, f"{attn_output.shape} != {shape}"
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, query.shape[1], self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
return attn_output, None
ATTN_TYPES = {
EAGER_ATTN: MultiheadAttention,
SDPA_ATTN: MultiheadSdpaAttention,
FLASH_ATTN: MultiheadFlashAttention2,
}
def make_activation(activation: str | Type[torch.nn.Module], **kwargs) -> torch.nn.Module:
if isinstance(activation, str):
TorchActivation: Type[torch.nn.Module] = getattr(torch.nn, activation)
else:
TorchActivation: Type[torch.nn.Module] = activation
assert issubclass(TorchActivation, torch.nn.Module), TorchActivation
return TorchActivation(**kwargs)
class PiZeroMLP(torch.nn.Module):
def __init__(self, feature_size: int, hidden_size: int, activation: str):
super().__init__()
self.gate_proj = torch.nn.Linear(feature_size, hidden_size, bias=False)
self.up_proj = torch.nn.Linear(feature_size, hidden_size, bias=False)
self.down_proj = torch.nn.Linear(hidden_size, feature_size, bias=False)
self.activation = make_activation(activation, approximate="tanh")
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
class PiZeroFlowMatchingDecoderBlock(ConfigurableModule):
def __init__(self, config: PiZeroFlowMatchingDecoderBlockConfig, **attn_kwargs):
super().__init__(config)
self.norm_in = GemmaRMSNorm(self.config.feature_size, eps=1e-06)
self.self_attn = ATTN_TYPES[self.config.attn_implementation](
in_features=self.config.feature_size,
num_heads=self.config.num_heads,
head_dim=self.config.head_dim,
num_kv_heads=self.config.num_kv_heads,
**attn_kwargs,
)
self.mlp = PiZeroMLP(
feature_size=self.config.feature_size,
hidden_size=self.config.hidden_size,
activation=self.config.activation,
)
self.norm_out = GemmaRMSNorm(self.config.feature_size, eps=1e-06)
def forward(
self,
query: torch.Tensor,
attn_mask: torch.Tensor,
cache: transformers.Cache,
attn_kwargs: Dict[str, Any],
) -> torch.Tensor:
"""
Args:
query: torch.Tensor of shape [B, L, token_size]. The query seqence in the order:
[noised query tokens, condition token, robot state tokens]
timestep: torch.Tensor of shape [B, 1, token_size]. Timestep token
attn_mask: torch.Tensor of shape [B, 1, L, L+S] and dtype torch.bool, where S is the VLM
sequence length
cache: Cache that contains only the VLM tokens during training and VLM + past query tokens
during generation
num_noised_tokens: Number of noised tokens in `query`
num_condition_tokens: Number of condition tokens in `query`
Returns:
torch.Tensor of same shape as query [B, L, token_size]
"""
residual = x = query
x = self.norm_in(x)
(x, _) = self.self_attn(
query=x,
key=x,
value=x,
attn_mask=attn_mask,
is_causal=False,
cache=cache,
**attn_kwargs,
)
x = residual + x
residual = x
x = self.norm_out(x)
x = self.mlp(x)
x = residual + x
return x
class PiZeroFlowMatchingDecoder(ConfigurableModule):
"""PiZero Flow Matching control decoder"""
def __init__(self, config: PiZeroFlowMatchingDecoderConfig):
super().__init__(config)
query_position_embed = RotaryPositionalEncoding(config=self.config.block_config.position_embed_config)
key_position_embed = RotaryPositionalEncoding(config=self.config.block_config.position_embed_config)
self.blocks = torch.nn.ModuleList(
[
PiZeroFlowMatchingDecoderBlock(
self.config.block_config,
query_position_embed=query_position_embed,
key_position_embed=key_position_embed,
cache_layer=i,
)
for i in range(self.config.num_blocks)
]
)
self.norm = GemmaRMSNorm(self.config.block_config.feature_size, eps=1e-06)
def forward(
self,
control_tokens: torch.Tensor,
robot_state_tokens: torch.Tensor,
llm_kv_tokens: List[Tuple[torch.Tensor, torch.Tensor]],
attn_mask: Optional[torch.Tensor],
cache: Optional[transformers.Cache] = None,
) -> torch.Tensor:
"""
Args:
control_tokens: torch.Tensor of shape [B, N, token_size], contains sequence of controls
robot_state_tokens: torch.Tensor of shape [B, num_state_tokens, token_size]
llm_kv_tokens: List of linearly projected key-value pairs from LLM, right before attention
operation. Each tensor is of the shape [B, num_kv_heads, S, head_dim]
attn_mask: One of
- shape [B, S], dtype torch.bool -> padding attention mask for LLM tokens
- shape [B, 1, L, S], dtype torch.bool -> full attention mask for LLM tokens
Returns:
torch.Tensor, shape [B, N, token_size]
"""
assert (
len(llm_kv_tokens) == self.config.num_blocks
), f"{len(llm_kv_tokens)} != {self.config.num_blocks}"
is_step_zero = cache.get_seq_length() == 0 if cache is not None else True
vlm_seq_len = attn_mask.shape[-1]
device = attn_mask.device
if cache is None:
cache = transformers.DynamicCache()
if is_step_zero:
position_indices = torch.arange(vlm_seq_len, dtype=torch.int64, device=device)
for block_index, kv_tokens in enumerate(llm_kv_tokens):
(key_states, value_states) = kv_tokens
cache.update(
key_states,
value_states,
block_index,
cache_kwargs={"cache_position": position_indices},
)
num_control_tokens = control_tokens.shape[1]
num_robot_state_tokens = robot_state_tokens.shape[1]
attn_mask = self._build_attn_mask(
num_control_tokens=num_control_tokens,
num_robot_state_tokens=num_robot_state_tokens,
attn_mask=attn_mask,
)
if is_step_zero:
tokens = torch.cat([robot_state_tokens, control_tokens], axis=1)
query_position_indices = key_position_indices = vlm_seq_len + torch.arange(
tokens.shape[1], dtype=torch.int64, device=device
).view(1, -1)
else:
tokens = control_tokens
attn_mask = attn_mask[:, :, -control_tokens.shape[1] :]
query_position_indices = key_position_indices = (
vlm_seq_len
+ num_robot_state_tokens
+ torch.arange(tokens.shape[1], dtype=torch.int64, device=device).view(1, -1)
)
for block in self.blocks:
tokens = block(
query=tokens,
attn_mask=attn_mask,
cache=cache,
attn_kwargs={
"query_position_indices": query_position_indices,
"key_position_indices": key_position_indices,
"cache_kwargs": {"cache_position": key_position_indices.view(-1)},
},
)
if is_step_zero:
(_, control_tokens) = torch.split(tokens, [num_robot_state_tokens, num_control_tokens], dim=1)
else:
control_tokens = tokens
control_tokens = self.norm(control_tokens)
return control_tokens
@torch.no_grad()
def _build_attn_mask(
self,
num_control_tokens: int,
num_robot_state_tokens: int,
attn_mask: torch.Tensor,
) -> torch.Tensor:
"""
Expand `attn_mask` (which is effectively a padding mask) to 4D such that:
- robot state tokens and control tokens can't attend to padding tokens
- robot state tokens can't attend to control tokens
Note: We can't keep the mask in 2D as it doesn't allow masking of padding tokens from the
VLM sequence. Furthermore, in a 2D mask you can't disable attention from robot state tokens
to control tokens
"""
assert attn_mask.dtype == torch.bool, attn_mask.dtype
assert attn_mask.ndim in [2, 4], attn_mask.shape
device = attn_mask.device
batch_size = attn_mask.shape[0]
query_seq_len = num_robot_state_tokens + num_control_tokens
vlm_seq_len = attn_mask.shape[-1]
kv_seq_len = query_seq_len + vlm_seq_len
cross_attn_mask = torch.ones(
[batch_size, 1, query_seq_len, kv_seq_len], dtype=torch.bool, device=device
)
if attn_mask.ndim == 2:
attn_mask = attn_mask.view(batch_size, 1, 1, vlm_seq_len)
else:
attn_mask = torch.any(attn_mask, dim=-2, keepdims=True)
cross_attn_mask[..., :vlm_seq_len] = attn_mask
robot_state_query_indices = torch.arange(
num_robot_state_tokens, dtype=torch.int64, device=device
).view(-1, 1)
control_key_indices = (
torch.arange(num_control_tokens, dtype=torch.int64, device=device).view(-1, 1)
+ vlm_seq_len
+ num_robot_state_tokens
)
cross_attn_mask[:, :, robot_state_query_indices, control_key_indices] = 0
return cross_attn_mask
@property
def fsdp_wrap_modules(self) -> Set[torch.nn.Module]:
return {module for module in self.modules() if isinstance(module, type(self.blocks[0]))} | {self.norm}
def integrate_unitquat(
qt: torch.Tensor,
dq_dt: torch.Tensor,
dt: float | torch.Tensor,
body_frame: bool = True,
half_cover: bool = True,
) -> torch.Tensor:
"""
Integrate a unit quaternion `qt` by the derivative `dq_dt` over the time interval `dt`.
Args:
qt: Unit quaternion, shape [..., 4]
dq_dt: Derivative of the unit quaternion, shape [..., 4]
dt: Time interval to integrate over, scalar or a tensor of shape () or [..., 1]
half_cover: If True, the result is guaranteed to lie in the half space
body_frame: If True, the integration is done in the body frame (post-multiply),
otherwise in the inertial frame (pre-multiply).
Returns:
Integrated unit quaternion, shape [..., 4]
"""
assert qt.shape == dq_dt.shape, f"{qt.shape} != {dq_dt.shape}"
assert is_quaternion(qt), f"{qt.shape} not a quaternion"
if isinstance(dt, torch.Tensor):
assert dt.ndim in (0, qt.ndim), f"dt.ndim = {dt.ndim} | {qt.ndim}"
if body_frame:
omega_q = 2.0 * roma.quat_product(roma.quat_conjugation(qt), dq_dt)
else:
omega_q = 2.0 * roma.quat_product(dq_dt, roma.quat_conjugation(qt))
omega = omega_q[..., :-1]
dq = roma.rotvec_to_unitquat(omega * dt)
if body_frame:
qt = roma.quat_product(qt, dq)
else:
qt = roma.quat_product(dq, qt)
if half_cover:
qt = quaternion_half_cover(qt)
return qt
def rotmat_inverse(rotation: torch.Tensor) -> torch.Tensor:
assert is_rotmat(rotation), f"Expected a rotation matrix, but got shape {rotation.shape}"
rotmat = rotmat_as_3x3(rotation)
rotmat = rotmat.transpose(-1, -2)
if is_rotmat_9(rotation):
rotmat = rotmat_as_9(rotmat)
return rotmat
def skew_symmetric_to_rotvec(skew_symmetric: torch.Tensor) -> torch.Tensor:
"""
Convert a skew-symmetric matrix to a rotation vector in a differentiable way
[
[ 0, -z, y],
[ z, 0, -x],
[-y, x, 0],
]
Args:
skew_symmetric: Skew-symmetric matrix of shape [..., 3, 3]
Returns:
torch.Tensor of shape [..., 3]
"""
assert is_rotmat(skew_symmetric), skew_symmetric.shape
rotvec = torch.stack(
(
skew_symmetric[..., 2, 1] - skew_symmetric[..., 1, 2],
skew_symmetric[..., 0, 2] - skew_symmetric[..., 2, 0],
skew_symmetric[..., 1, 0] - skew_symmetric[..., 0, 1],
),
dim=-1,
)
rotvec = rotvec / 2.0
return rotvec
def integrate_rotmat(
rt: torch.Tensor,
dr_dt: torch.Tensor,
dt: float | torch.Tensor,
body_frame: bool = True,
) -> torch.Tensor:
"""
Integrate a rotation matrix `rt` by the derivative `dr_dt` over the time interval `dt`.
Args:
rt: Rotation matrix, shape [..., 3, 3]
dr_dt: Derivative of the rotation matrix, shape [..., 3, 3]
dt: Time interval to integrate over, scalar or a tensor of shape () or [..., 1]
body_frame: If True, the integration is done in the body frame (post-multiply),
otherwise in the inertial frame (pre-multiply).
Returns:
Integrated unit quaternion, shape [..., 4]
"""
assert rt.shape == dr_dt.shape, f"{rt.shape} != {dr_dt.shape}"
assert is_rotmat(rt), f"{rt.shape} not a rotation matrix"
is_3x3 = is_rotmat_3x3(rt)
if not is_3x3:
rt = rotmat_as_3x3(rt)
dr_dt = rotmat_as_3x3(dr_dt)
if isinstance(dt, torch.Tensor):
assert dt.ndim in (
0,
rt.ndim,
rt.ndim - 1,
), f"dt.ndim = {dt.ndim} | {rt.ndim} | {rt.ndim - 1}"
if dt.ndim == rt.ndim:
assert dt.shape[-2:] == (1, 1), dt.shape
dt = dt.squeeze(-1)
if body_frame:
omega = skew_symmetric_to_rotvec(rotmat_inverse(rt) @ dr_dt)
else:
omega = skew_symmetric_to_rotvec(dr_dt @ rotmat_inverse(rt))
dr = roma.rotvec_to_rotmat(omega * dt)
if body_frame:
rt = rt @ dr
else:
rt = dr @ rt
if not is_3x3:
rt = rotmat_as_9(rt)
return rt
def integrate_rotation(
rt: torch.Tensor,
dr_dt: torch.Tensor,
dt: float | torch.Tensor,
body_frame: bool = True,
half_cover: bool = True,
) -> torch.Tensor:
"""
Integrate the rotation `rt` by the derivative `dr_dt` over the time interval `dt` on the SO(3) manifold.
"""
if is_quaternion(rt):
return integrate_unitquat(rt, dr_dt, dt, body_frame=body_frame, half_cover=half_cover)
if is_rotmat(rt):
return integrate_rotmat(rt, dr_dt, dt, body_frame=body_frame)
raise NotImplementedError(f"integrate_rotation not yet implemented for format {rt.shape}")
class PiZeroFlowMatchingModule(ConfigurableModule):
def __init__(self, config: PiZeroFlowMatchingModuleConfig, control_tokenizer: EmptyTokenizer):
super().__init__(config)
del control_tokenizer
self.noised_control_proj = NoisedControlProjector(self.config.noised_control_proj_config)
self.robot_state_proj = RobotStateProjector(self.config.robot_state_proj_config)
self.control_decoder = PiZeroFlowMatchingDecoder(config=self.config.control_decoder_config)
self.output_proj = make_mlp(
[self.config.token_size, 3 + self.config.rotation_components + 1],
activation=torch.nn.GELU,
activate_final=False,
)
def forward(
self,
vlm_input: RoboticsFlowInput,
vlm_output: VLMOutput,
cache: Optional[transformers.Cache] = None,
) -> RoboticsOutput:
robot_state_tokens = self.robot_state_proj(vlm_input)
noised_tokens = self.noised_control_proj(vlm_input.flow_input)
output_tokens = self.control_decoder(
control_tokens=noised_tokens,
robot_state_tokens=robot_state_tokens,
llm_kv_tokens=vlm_output.llm_output.past_key_values,
attn_mask=vlm_input.attn_mask,
cache=cache,
)
contols = self.output_proj(output_tokens)
(translation, rotation, gripper) = torch.split(
contols, [3, self.config.rotation_components, 1], dim=-1
)
return RoboticsOutput.make_empty().replace(
translation=translation, rotation=rotation, gripper=gripper
)
@torch.inference_mode()
def generate(
self,
vlm_input: RoboticsFlowInput,
vlm_output: VLMOutput,
processor: PiZeroFlowMatchingProcessor,
use_cache: bool = True,
**kwargs,
) -> RoboticsOutput:
del kwargs
(batch_size, vlm_seq_len) = vlm_input.input_ids.shape[:2]
device = vlm_input.input_ids.device
if use_cache:
max_cache_len = (
vlm_seq_len
+ processor.config.control_io_config.future_controls_sequence_length
+ processor.config.control_io_config.past_scalars_sequence_length
)
cache = transformers.StaticCache(
config=transformers.PretrainedConfig(
head_dim=self.config.control_decoder_config.block_config.head_dim,
num_key_value_heads=self.config.control_decoder_config.block_config.num_kv_heads,
num_hidden_layers=self.config.control_decoder_config.num_blocks,
),
max_batch_size=batch_size,
max_cache_len=max_cache_len,
device=device,
)
else:
cache = None
flow_input: FlowInput = processor.sample_t0_input(batch_size=batch_size, device=device)
step_size = 1 / processor.config.num_inference_steps
translation = flow_input.translation_t0
rotation = flow_input.rotation_t0
gripper = flow_input.gripper_t0
vlm_input = vlm_input.replace(
**{
"flow_input.timestep": flow_input.timestep,
"flow_input.translation_t": translation,
"flow_input.rotation_t": rotation,
"flow_input.gripper_t": gripper,
}
)
for _ in range(processor.config.num_inference_steps):
model_output: RoboticsOutput = self(vlm_input, vlm_output, cache)
translation = translation + step_size * model_output.translation
rotation = integrate_rotation(rt=rotation, dr_dt=model_output.rotation, dt=step_size)
gripper = gripper + step_size * model_output.gripper
timestep = vlm_input.flow_input.timestep + step_size
if processor.config.rotation_format == RotationFormat.QUATERNION:
rotation = quaternion_half_cover(rotation)
vlm_input = vlm_input.replace(
**{
"flow_input.timestep": timestep,
"flow_input.translation_t": translation,
"flow_input.rotation_t": rotation,
"flow_input.gripper_t": gripper,
}
)
output = RoboticsOutput.make_empty().replace(
translation=translation, rotation=rotation, gripper=gripper
)
return output
@property
def fsdp_wrap_modules(self) -> Set[torch.nn.Module]:
return self.control_decoder.fsdp_wrap_modules | {
self,
self.robot_state_proj,
self.noised_control_proj,
self.output_proj,
}
CANONICAL_TO_BRIDGE_ROTATION = np.array(
[
[1, 0, 0],
[0, np.cos(np.pi), -np.sin(np.pi)],
[0, np.sin(np.pi), np.cos(np.pi)],
],
dtype=np.float32,
)
class SPEAR1(ConfigurableModule, transformers.PreTrainedModel):
config_class: transformers.PretrainedConfig = SPEAR1Config
def __init__(self, config: SPEAR1Config):
super().__init__(config)
self.vlm = PaliGemmaVLM(config=self.config.vlm_config)
self.processor = PiZeroFlowMatchingProcessor(
config=self.config.processor_config, vlm_processor=self.vlm.processor
)
self.control_module = PiZeroFlowMatchingModule(
config=self.config.control_module_config,
control_tokenizer=self.processor.control_tokenizer,
)
self.generation_config = transformers.GenerationConfig()
def forward(
self,
inputs: RoboticsInput,
use_cache: Optional[bool] = True,
output_hidden_states: Optional[bool] = None,
) -> RoboticsOutput:
del output_hidden_states
vlm_output = self.vlm(inputs=inputs, use_cache=use_cache, output_hidden_states=True)
control_output = self.control_module(vlm_input=inputs, vlm_output=vlm_output)
output = control_output.replace(llm_output=vlm_output.llm_output)
return output
@torch.inference_mode()
def generate(
self,
inputs: RoboticsInput,
use_cache: Optional[bool] = True,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> RoboticsOutput:
del output_hidden_states
vlm_output = self.vlm(
inputs=inputs,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=True,
)
control_output = self.control_module.generate(
vlm_input=inputs, vlm_output=vlm_output, processor=self.processor
)
output = control_output.replace(llm_output=vlm_output.llm_output)
return output
def predict_action(self, inputs: Dict) -> Dict[str, np.ndarray]:
images = inputs["images"]
ee_translation = inputs["ee_translation"]
ee_rotation = inputs["ee_rotation"]
gripper = inputs["gripper"]
num_resize_args = len(inspect.signature(self.processor.resize_image).parameters)
# Resize images using the processor's resize_image method
for camera_name, camera_image in images.items():
# Handle the different signatures resize_image - old one used to take only the image,
# new one also takes the camera name
if num_resize_args == 1:
images[camera_name] = self.processor.resize_image(camera_image)
elif num_resize_args == 2:
images[camera_name] = self.processor.resize_image(camera_name, camera_image)
else:
raise ValueError(f"Unexpected number of arguments for resize_image: {num_resize_args}")
# add batch dimension and wrap into list to match processor expected format
images[camera_name] = [images[camera_name]]
# add batch dimensions to state obs
ee_translation = np.array(ee_translation, dtype=np.float32).reshape(1, 3)
ee_rotation = np.array(ee_rotation, dtype=np.float32).reshape(1, 3, 3) @ CANONICAL_TO_BRIDGE_ROTATION
gripper = np.array(gripper, dtype=np.float32).reshape(1, 1)
joints = np.zeros((1, 7), dtype=np.float32)
dataset_name = np.array([inputs["dataset_name"]])
chat = [f"{inputs['language_instruction']}", ""]
model_input = self.processor.create_input(
images=images,
chat=chat,
ee_pose_translation=ee_translation,
ee_pose_rotation=ee_rotation,
gripper=gripper,
dataset_name=dataset_name,
joints=joints,
inference_mode=True,
)
model_input = model_input.apply(
lambda x: x.unsqueeze(0).to("cuda") if isinstance(x, torch.Tensor) else x
)
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
model_output = self.generate(model_input)
control_plan = self.processor.policy_control_plan_from_model_output(
model_output=model_output,
dataset_name=dataset_name,
valid_mask=torch.ones(
model_output.gripper.shape[:2], dtype=torch.bool, device=model_output.gripper.device
),
)
translation_m = control_plan.translation_m.to(dtype=torch.float32, device='cpu')
rotation = control_plan.rotmat.to(dtype=torch.float32, device='cpu')
gripper_prob = control_plan.gripper_prob.to(dtype=torch.float32, device='cpu')
# Convert controls back to robot base frame
if self.processor.config.eef_control_frame:
# Get the robot base rotation matrix R_BE - the same as the robot EEF pose.
# R_BE - converts from end-effector frame E to robot base frame B
robot_base_rotmat = rotmat_as_3x3(model_input.ee_pose_rotation[:, -1:, ...]).cpu() # [B, 1, 3, 3]
translation_m = torch.matmul( # [B, num_future_control_steps, 3]
robot_base_rotmat, translation_m.unsqueeze(-1)
).squeeze(-1)
rotation = rotmat_as_3x3( # [B, num_future_control_steps, 3, 3]
torch.matmul(robot_base_rotmat, rotmat_as_3x3(rotation))
)
translation = translation_m # [B, num_future_control_steps, 3]
rotation = rotmat_as_3x3(rotation) # [B, num_future_control_steps, 3, 3]
gripper = gripper_prob # [B, num_future_control_steps, 1]
translation = translation.squeeze(0).numpy()
rotation = rotation.squeeze(0).numpy()
gripper = gripper.squeeze(0).numpy()
rotation = CANONICAL_TO_BRIDGE_ROTATION @ rotation @ CANONICAL_TO_BRIDGE_ROTATION.T
return {
"translation": translation,
"rotation": rotation,
"gripper": gripper,
}
@property
def fsdp_wrap_modules(self) -> Set[torch.nn.Module]:
return (
{self.vlm, self.control_module}
| self.vlm.fsdp_wrap_modules
| self.control_module.fsdp_wrap_modules
)