import collections import functools import inspect import logging import math import warnings from typing import ( Any, Callable, Dict, List, Literal, Optional, Protocol, Set, Tuple, Type, ) import einops import numpy as np import roma import timm import torch import torch.distributed.fsdp import torch.distributed.tensor import transformers from huggingface_hub import hf_hub_download from .common_spear import ( Configurable, DiffusionInput, FlowInput, LLMOutput, RoboticsFlowInput, RoboticsInput, RoboticsOutput, RotationFormat, VLMOutput, expand_dims, is_quaternion, is_rotmat, is_rotmat_3x3, is_rotmat_9, quaternion_half_cover, rotmat_as_3x3, rotmat_as_9, ) from .configuration_spear import ( FourierFeaturesConfig, NoisedControlProjectorConfig, PaliGemmaVLMConfig, PiZeroFlowMatchingDecoderBlockConfig, PiZeroFlowMatchingDecoderConfig, PiZeroFlowMatchingModuleConfig, RobotStateProjectorConfig, RotaryPositionalEncodingConfig, SPEAR1Config, ) from .processing_spear import ( EmptyTokenizer, PaliGemmaDepthProcessor, PiZeroFlowMatchingProcessor, ) class ConfigurableModule(torch.nn.Module, Configurable): def __init__(self, config): Configurable.__init__(self, config) torch.nn.Module.__init__(self) class GemmaRMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-06): super().__init__() self.eps = eps self.weight = torch.nn.Parameter(torch.zeros(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()) output = output * (1.0 + self.weight.float()) return output.type_as(x) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" @torch.no_grad() def make_module_params_non_trainable( module: torch.nn.Module, recursive: bool = True, dtype: Optional[torch.dtype] = None ): """ NOTE: dtype is applied only to module parameters, not buffers. This is different from the default torch.nn.Module.to(dtype=dtype) behavior, which applies to both parameters and buffers """ for param in module.parameters(recurse=recursive): param.requires_grad = False if dtype is not None: param.data = param.to(dtype=dtype) @torch.no_grad() def make_module_non_trainable(module: torch.nn.Module, dtype: Optional[torch.dtype] = None): make_module_params_non_trainable(module, dtype=None) if dtype is not None: module.to(dtype=dtype) module.eval() class ResidualConvBlock(torch.nn.Module): def __init__( self, in_channels: int, out_channels: int | None = None, hidden_channels: int | None = None, padding_mode: str = "replicate", activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu", norm: Literal["group_norm", "layer_norm"] = "group_norm", ): super().__init__() if out_channels is None: out_channels = in_channels if hidden_channels is None: hidden_channels = in_channels if activation == "relu": activation_cls = functools.partial(torch.nn.ReLU, inplace=True) elif activation == "leaky_relu": activation_cls = functools.partial(torch.nn.LeakyReLU, negative_slope=0.2, inplace=True) elif activation == "silu": activation_cls = functools.partial(torch.nn.SiLU, inplace=True) elif activation == "elu": activation_cls = functools.partial(torch.nn.ELU, inplace=True) else: raise ValueError(f"Unsupported activation function: {activation}") self.layers = torch.nn.Sequential( torch.nn.GroupNorm(1, in_channels), activation_cls(), torch.nn.Conv2d( in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode, ), torch.nn.GroupNorm(hidden_channels // 32 if norm == "group_norm" else 1, hidden_channels), activation_cls(), torch.nn.Conv2d( hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode, ), ) self.skip_connection = ( torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else torch.nn.Identity() ) def forward(self, x): skip = self.skip_connection(x) x = self.layers(x) x = x + skip return x def normalized_view_plane_uv( width: int, height: int, aspect_ratio: float | None = None, dtype: torch.dtype = None, device: torch.device = None, ) -> torch.Tensor: """ UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal) """ if aspect_ratio is None: aspect_ratio = width / height span_x = aspect_ratio / (1 + aspect_ratio**2) ** 0.5 span_y = 1 / (1 + aspect_ratio**2) ** 0.5 u = torch.linspace( -span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device, ) v = torch.linspace( -span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device, ) (u, v) = torch.meshgrid(u, v, indexing="xy") uv = torch.stack([u, v], dim=-1) return uv class Head(torch.nn.Module): def __init__( self, num_features: int, dim_in: int, dim_out: List[int], dim_proj: int = 512, dim_upsample: List[int] = [256, 128, 128], dim_times_res_block_hidden: int = 1, num_res_blocks: int = 1, res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm", last_res_blocks: int = 0, last_conv_channels: int = 32, last_conv_size: int = 1, ): super().__init__() self.projects = torch.nn.ModuleList( [ torch.nn.Conv2d( in_channels=dim_in, out_channels=dim_proj, kernel_size=1, stride=1, padding=0, ) for _ in range(num_features) ] ) self.upsample_blocks = torch.nn.ModuleList( [ torch.nn.Sequential( self._make_upsampler(in_ch + 2, out_ch), *( ResidualConvBlock( out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm, ) for _ in range(num_res_blocks) ), ) for (in_ch, out_ch) in zip([dim_proj] + dim_upsample[:-1], dim_upsample, strict=True) ] ) self.output_block = torch.nn.ModuleList( [ self._make_output_block( dim_upsample[-1] + 2, dim_out_, dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm, ) for dim_out_ in dim_out ] ) def _make_upsampler(self, in_channels: int, out_channels: int): upsampler = torch.nn.Sequential( torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), torch.nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode="replicate", ), ) upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1] return upsampler def _make_output_block( self, dim_in: int, dim_out: int, dim_times_res_block_hidden: int, last_res_blocks: int, last_conv_channels: int, last_conv_size: int, res_block_norm: Literal["group_norm", "layer_norm"], ): return torch.nn.Sequential( torch.nn.Conv2d( dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode="replicate", ), *( ResidualConvBlock( last_conv_channels, last_conv_channels, dim_times_res_block_hidden * last_conv_channels, activation="relu", norm=res_block_norm, ) for _ in range(last_res_blocks) ), torch.nn.ReLU(inplace=True), torch.nn.Conv2d( last_conv_channels, dim_out, kernel_size=last_conv_size, stride=1, padding=last_conv_size // 2, padding_mode="replicate", ), ) def forward(self, hidden_states: List[torch.Tensor], image: torch.Tensor): (img_h, img_w) = image.shape[-2:] (patch_h, patch_w) = (img_h // 14, img_w // 14) x = torch.stack( [ proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous()) for (proj, feat) in zip(self.projects, hidden_states, strict=True) ], dim=1, ).sum(dim=1) for _, block in enumerate(self.upsample_blocks): uv = normalized_view_plane_uv( width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device, ) uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) x = torch.cat([x, uv], dim=1) for layer in block: x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False) x = torch.nn.functional.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False) uv = normalized_view_plane_uv( width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device, ) uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) x = torch.cat([x, uv], dim=1) if isinstance(self.output_block, torch.nn.ModuleList): output = [ torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) for block in self.output_block ] else: output = torch.utils.checkpoint.checkpoint(self.output_block, x, use_reentrant=False) return output def _is_single_image_size(image_sizes: Dict[str, Dict[str, int]]) -> bool: return ( len(image_sizes) == 1 or len(set(((image_size["height"], image_size["width"]) for image_size in image_sizes.values()))) == 1 ) class MoGe(torch.nn.Module): """ Implementation of MoGe taken from https://github.com/microsoft/MoGe/blob/main/moge/model/v1.py Simplified and stripped down such that: - It doesn't rely on MoGe codebase - Uses timm for ViT backbone - Only predicts points and mask - Currently does NOT infer depth or intrinsics (but this could be added) - Requires image to be resized to the expected resolution. Note that this resolution must be in the range of num_tokens_range, (for square images, pixel sizes in the range ~[36*14, 50*14]) """ def __init__( self, image_sizes: Dict[str, Dict[str, int]] = {}, backbone_id: str = "vit_large_patch14_dinov2.lvd142m", intermediate_layers: int | List[int] = 4, dim_proj: int = 512, dim_upsample: List[int] = [256, 128, 64], dim_times_res_block_hidden: int = 2, num_res_blocks: int = 2, remap_output: Literal[False, True, "linear", "sinh", "exp", "sinh_exp"] = "exp", res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm", num_tokens_range: Tuple[int, int] = [1275, 2551], last_res_blocks: int = 0, last_conv_channels: int = 32, last_conv_size: int = 1, mask_threshold: float = 0.5, output_dtype: torch.dtype = torch.bfloat16, ): """ NOTE: - All defaults were taken from the checkpoint config for 'Ruicheng/moge-vitl' - output_dtype - by default was float32, changed to bfloat16 for model training """ super().__init__() self.remap_output = remap_output self.intermediate_layers = intermediate_layers self.num_tokens_range = num_tokens_range self.mask_threshold = mask_threshold self.output_dtype = output_dtype self.image_sizes = dict(image_sizes) self.backbone: timm.models.vision_transformer.VisionTransformer = self._make_vit_backbone(backbone_id) token_size: int = self.backbone.embed_dim self.head = Head( num_features=( intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers) ), dim_in=token_size, dim_out=[3, 1], dim_proj=dim_proj, dim_upsample=dim_upsample, dim_times_res_block_hidden=dim_times_res_block_hidden, num_res_blocks=num_res_blocks, res_block_norm=res_block_norm, last_res_blocks=last_res_blocks, last_conv_channels=last_conv_channels, last_conv_size=last_conv_size, ) def _make_vit_backbone(self, backbone_id: str) -> timm.models.vision_transformer.VisionTransformer: if _is_single_image_size(self.image_sizes): kwargs = { "img_size": ( self.image_sizes["main"]["height"], self.image_sizes["main"]["width"], ), "dynamic_img_size": False, } else: kwargs = {"img_size": (224, 224), "dynamic_img_size": True} vit_backbone: timm.models.vision_transformer.VisionTransformer = timm.create_model( backbone_id, pretrained=False, num_classes=0, **kwargs ) vit_backbone.forward = functools.partial( vit_backbone.forward_intermediates, indices=4, return_prefix_tokens=False, norm=True, stop_early=True, output_fmt="NLC", intermediates_only=True, ) return vit_backbone def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]: """ Args: image: torch.Tensor of shape [B, 3, H, W] containing the preprocessed image, resized to the size expected by the model. Returns: A dictionary containing: - `points`: torch.Tensor of shape [B, 3, H, W] containing the predicted points. - `mask`: torch.Tensor of shape [B, 1, H, W] containing the predicted mask. """ (height, width) = image.shape[-2:] assert (height, width) in [ (image_size["height"], image_size["width"]) for image_size in self.image_sizes.values() ], f"{(height, width)} not in {self.image_sizes}" features: List[torch.Tensor] = self.backbone(image) output = self.head(features, image) (points, mask) = output with torch.autocast( device_type=image.device.type, dtype=torch.float32, enabled=self.output_dtype == torch.float32, ): points = torch.nn.functional.interpolate( points, (height, width), mode="bilinear", align_corners=False, antialias=False, ) mask = torch.nn.functional.interpolate( mask, (height, width), mode="bilinear", align_corners=False, antialias=False, ) points = self._remap_points(points, dim=1) output = {"points": points, "mask": mask} return output def _remap_points(self, points: torch.Tensor, dim: int = 1) -> torch.Tensor: if self.remap_output == "linear": pass elif self.remap_output == "sinh": points = torch.sinh(points) elif self.remap_output == "exp": (xy, z) = points.split([2, 1], dim=dim) z = torch.exp(z) points = torch.cat([xy * z, z], dim=dim) elif self.remap_output == "sinh_exp": (xy, z) = points.split([2, 1], dim=dim) points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=dim) else: raise ValueError(f"Invalid remap output type: {self.remap_output}") return points class DepthBackboneConfig(transformers.PretrainedConfig): def __init__( self, hf_hub_repo: str = "", hf_filename: str = "", image_sizes: Dict[str, Dict[str, int]] = {}, **kwargs, ): super().__init__(**kwargs) self.hf_hub_repo = hf_hub_repo self.hf_filename = hf_filename self.image_sizes = dict(image_sizes) class PaliGemma3DConfig(transformers.models.paligemma.PaliGemmaConfig): sub_configs = { "text_config": transformers.AutoConfig, "vision_config": transformers.AutoConfig, "depth_config": DepthBackboneConfig, } def __init__( self, depth_config={}, depth_only: bool = False, mask_prob: float = 0.0, projection: str = "", depth_layers: int = 4, **kwargs, ): super().__init__(**kwargs) if isinstance(depth_config, dict): self.depth_config = DepthBackboneConfig(**depth_config) else: self.depth_config = depth_config self.mask_prob = mask_prob self.depth_only = depth_only self.projection = projection self.depth_layers = depth_layers @property def is_single_image_size(self) -> bool: return ( len(self.depth_config.image_sizes) == 1 or len( set( ( (image_size["height"], image_size["width"]) for image_size in self.depth_config.image_sizes.values() ) ) ) == 1 ) @property def camera_names(self) -> List[str]: return list(self.depth_config.image_sizes.keys()) class NeRFPositionalEmbedding(torch.nn.Module): def __init__(self, n_frequencies: int, log_scale: bool = True, scale: float = 1.0): """ Args: n_frequencies: Dimension size, same as L parameter in the NeRF paper scale: Scale factor for the frequencies. To match the formula from the paper [sin(2^k * pi * x), cos(2^k * pi * x)], set scale to math.pi. In practice, the paper authors don't multiply by pi and use scale=1.0. See https://github.com/bmild/nerf/issues/12 """ super().__init__() self.n_frequencies = n_frequencies if log_scale: freq = 2 ** torch.arange(self.n_frequencies, dtype=torch.float32) * scale else: freq = torch.linspace(1, 2 ** (self.n_frequencies - 1), self.n_frequencies) * scale self.register_buffer("freq", freq) def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ Maps features from dimensionality N to dimensionality N*(2*L + 1), i.e. `2L + 1` output features are produced for each input feature. Embeds x to (x, sin(2^k x), cos(2^k x), ...). NOTE: `x` is also in the output. This is different from the equation in the paper, but matches the actual author implementation. See https://github.com/bmild/nerf/issues/12 Args: inputs: torch.Tensor of shape [B, ..., N]; input values to be transformed Returns: torch.Tensor of shape [B, ..., N*(2*L + 1) = embedding_dim], encoded input values """ freq = expand_dims(self.freq, ndim=inputs.ndim + 1, order=[-1, 1]) spectrum = freq * inputs.unsqueeze(-1) sin = torch.sin(spectrum) cos = torch.cos(spectrum) encoding = torch.stack([sin, cos], dim=-1) encoding = encoding.view(*inputs.shape[:-1], -1) encoding = torch.cat([inputs, encoding], dim=-1) return encoding def make_mlp( layer_sizes: List[int], activation: str | Type[torch.nn.Module], norm: str | Type[torch.nn.Module] | None = torch.nn.LayerNorm, activate_final: bool = False, bias: bool = True, ) -> torch.nn.Sequential: """ Args: layer_sizes: List of layer sizes. The first value is the number of input features and the last value is the number of output features activation: str or the class of the activation. If str, it should be the exact name of the activation module under torch.nn, e.g. 'ReLU', 'SiLU', 'GeLU'. Use 'Identity' if no activation wanted norm: type of normalization. Same type as `activation`. Ex: `torch.nn.LayerNorm`, 'LayerNorm', etc """ if len(layer_sizes) == 0: return torch.nn.Identity() assert len(layer_sizes) > 1, "Need to provide input and output layer sizes at least" if isinstance(activation, str): TorchActivation: Type[torch.nn.Module] = getattr(torch.nn, activation) else: TorchActivation: Type[torch.nn.Module] = activation assert issubclass(TorchActivation, torch.nn.Module), TorchActivation if isinstance(norm, str): TorchNorm: Type[torch.nn.Module] = getattr(torch.nn, norm) elif norm is None: TorchNorm: Type[torch.nn.Module] = torch.nn.Identity else: TorchNorm: Type[torch.nn.Module] = norm assert issubclass(TorchNorm, torch.nn.Module), TorchNorm def make_norm_act(modules: dict[str, torch.nn.Module], empty: bool): return {} if empty else modules module = torch.nn.Sequential( *[ torch.nn.Sequential( collections.OrderedDict( { "linear": torch.nn.Linear(in_features, out_features, bias=bias), **make_norm_act( {"norm": TorchNorm(out_features), "act": TorchActivation()}, empty=i == len(layer_sizes) - 2 and not activate_final, ), } ) ) for (i, (in_features, out_features)) in enumerate( zip(layer_sizes[:-1], layer_sizes[1:], strict=True) ) ] ) return module class PaliGemma3D(transformers.models.paligemma.PaliGemmaForConditionalGeneration): """ Transformers-like implementation of PaliGemma with additional depth encoder """ config_class = PaliGemma3DConfig def __init__(self, config: PaliGemma3DConfig): super().__init__(config) assert self.config.projection in ["features_add"] if self.config.projection in ["features_add"]: self.depth_tower = self._make_depth_encoder() self.depth_projector = torch.nn.Linear( in_features=self.depth_tower.embed_dim * self.config.depth_layers, out_features=self.config.text_config.hidden_size, ) self.generator: torch.Generator = torch.Generator() if self.config.depth_only: make_module_non_trainable(self.vision_tower) make_module_non_trainable(self.siglip_projector) self.projector = None else: raise ValueError(f"Projection type `{self.config.projection}` not supported!") @property def projectors(self) -> Set[torch.nn.Module]: modules = set( ( module for module in [ self.projector, self.depth_projector, self.siglip_projector, ] if module is not None ) ) if isinstance(self.depth_tower, MoGe): modules = modules | { module for module in self.depth_tower.modules() if isinstance(module, (ResidualConvBlock, Head)) } return modules @property def siglip_projector(self) -> torch.nn.Linear: return self.multi_modal_projector.linear def _make_depth_encoder(self) -> torch.nn.Module: if self.config.is_single_image_size: kwargs = { "img_size": ( self.config.depth_config.image_sizes["main"]["height"], self.config.depth_config.image_sizes["main"]["width"], ), "dynamic_img_size": False, } else: kwargs = { "img_size": ( self.config.depth_config.image_sizes["main"]["height"], self.config.depth_config.image_sizes["main"]["width"], ), "dynamic_img_size": True, } model: timm.models.vision_transformer.VisionTransformer = timm.create_model( "vit_large_patch14_dinov2.lvd142m", pretrained=False, num_classes=0, **kwargs, ) model.forward = functools.partial( model.forward_intermediates, indices=self.config.depth_layers, return_prefix_tokens=False, norm=True, stop_early=True, output_fmt="NLC", intermediates_only=True, ) return model def _load_depth_model_state_dict(self, depth_model: torch.nn.Module): logging.info( f"Loading depth model from {self.config.depth_config.hf_hub_repo}/{self.config.depth_config.hf_filename}" ) state_dict = torch.load( hf_hub_download( repo_id=self.config.depth_config.hf_hub_repo, filename=self.config.depth_config.hf_filename, ), map_location="cpu", mmap=True, weights_only=False, ) if self.config.projection in ["spatial_add", "spatial_concat"]: pos_embed_state_dict = {"pos_embed": state_dict["backbone.pos_embed"]} pos_embed_state_dict = timm.models.vision_transformer.checkpoint_filter_fn( pos_embed_state_dict, depth_model.backbone ) state_dict["backbone.pos_embed"] = pos_embed_state_dict["pos_embed"] else: state_dict = timm.models.vision_transformer.checkpoint_filter_fn(state_dict, depth_model) depth_model.load_state_dict(state_dict) def get_image_features(self, pixel_values: dict[str, torch.Tensor]) -> torch.Tensor: if self.config.projection == "features_add": images_forward = self._get_image_features_add elif self.config.projection in ["spatial_add", "spatial_concat"]: images_forward = self._get_image_features_spatial else: raise ValueError(f"Project type `{self.config.projection}` not supported!") camera_names = self.config.camera_names if self.config.is_single_image_size: inputs = { "siglip": einops.rearrange( torch.stack( [pixel_values[f"{camera_name}.siglip"] for camera_name in camera_names], dim=1, ), "B N C H W -> (B N) C H W", ), "depth": einops.rearrange( torch.stack( [pixel_values[f"{camera_name}.depth"] for camera_name in camera_names], dim=1, ), "B N C H W -> (B N) C H W", ), } image_tokens = images_forward(inputs) else: camera_tokens: List[torch.Tensor] = [ images_forward( { "siglip": pixel_values[f"{camera_name}.siglip"], "depth": pixel_values[f"{camera_name}.depth"], } ) for camera_name in camera_names ] image_tokens = torch.cat(camera_tokens, dim=-2) return image_tokens def _get_image_features_add(self, pixel_values: dict[str, torch.Tensor]) -> torch.Tensor: """ Obtains image last hidden states from the vision tower and apply multimodal projection. Args: pixel_values: images of shape `[B * num_images, C, H, W]`. Keys in the dict correspond to the specific vision encoder to use Returns: image_features of shape [B * num_images, num_tokens, token_size = 2048] """ siglip_input = pixel_values["siglip"] depth_input = pixel_values["depth"] siglip_output: ViTOutput = self.vision_tower(siglip_input) siglip_features = siglip_output.last_hidden_state depth_output: list[torch.Tensor] = self.depth_tower(depth_input) depth_features = torch.cat(depth_output, dim=-1) if len(depth_output) > 1 else depth_output[0] siglip_features = self.siglip_projector(siglip_features) depth_features = self.depth_projector(depth_features) if self.config.depth_only: image_features = depth_features elif self.training and torch.bernoulli(torch.tensor(self.config.mask_prob), generator=self.generator): num_tokens = depth_features.shape[1] device = depth_features.device ones = torch.ones((depth_features.shape[0], num_tokens), device=device) indices = torch.multinomial(ones / num_tokens, num_samples=num_tokens // 2) mask = ones.to(dtype=torch.bool).scatter_(dim=-1, index=indices, value=0) mask = mask.unsqueeze(-1) image_features = siglip_features * mask + depth_features * ~mask else: image_features = (siglip_features + depth_features) / 2 image_features = image_features / self.config.text_config.hidden_size**0.5 return image_features def _get_image_features_spatial(self, pixel_values: dict[str, torch.Tensor]) -> torch.Tensor: """ Obtains image last hidden states from the vision tower and apply multimodal projection. Args: pixel_values: images of shape `[B * num_images, C, H, W]`. Keys in the dict correspond to the specific vision encoder to use Returns: image_features of shape [B * num_images, num_tokens, token_size = 2048] """ siglip_input = pixel_values["siglip"] depth_input = pixel_values["depth"] siglip_output: ViTOutput = self.vision_tower(siglip_input) siglip_features = siglip_output.last_hidden_state depth_output: dict[str, torch.Tensor] = self.depth_tower(depth_input) points = depth_output["points"] mask = depth_output["mask"] mask_binary = mask > self.depth_tower.mask_threshold points = torch.where(mask_binary, points, 0) points_embed: torch.Tensor = self.depth_projector(points) if self.config.projection == "spatial_concat": features = torch.cat([siglip_features, points_embed], dim=-1) image_features = self.projector(features) else: features = siglip_features + points_embed image_features = self.siglip_projector(features) image_features = image_features / self.config.text_config.hidden_size**0.5 return image_features @classmethod def from_pretrained(cls, *args, **kwargs) -> "PaliGemma3D": model = super().from_pretrained(*args, **kwargs) model._load_depth_model_state_dict(model.depth_tower) return model class RobotStateProjector(ConfigurableModule): """Pack robot state and project to a single token per timestep""" def __init__(self, config: RobotStateProjectorConfig): super().__init__(config) if self.config.fourier: raise NotImplementedError("Fourier robot state projector is not implemented yet") self.robot_state_tokens_proj = make_mlp( layer_sizes=self.config.layers, activation=self.config.activation, norm=torch.nn.LayerNorm, ) def forward(self, inputs: RoboticsInput) -> Optional[torch.Tensor]: """ Returns: torch.Tensor of shape [B, num_past_steps, token_size] or None (if mode == 'none') """ if self.config.mode == "ee_pose": robot_state = torch.cat([inputs.ee_pose_translation, inputs.ee_pose_rotation], dim=-1) elif self.config.mode == "ee_pose_gripper": robot_state = torch.cat( [inputs.ee_pose_translation, inputs.ee_pose_rotation, inputs.gripper], dim=-1, ) elif self.config.mode == "ee_pose_joints": robot_state = torch.cat( [inputs.ee_pose_translation, inputs.ee_pose_rotation, inputs.joints], dim=-1, ) elif self.config.mode == "joints": robot_state = inputs.joints elif self.config.mode == "all": robot_state = torch.cat( [ inputs.ee_pose_translation, inputs.ee_pose_rotation, inputs.gripper, inputs.joints, ], dim=-1, ) elif self.config.mode == "none": robot_state = torch.tensor([], device=inputs.ee_pose_translation.device).view( inputs.ee_pose_translation.shape[0], 0, self.config.layers[0] if len(self.config.layers) > 0 else 0, ) else: raise NotImplementedError(f"Unknown image tokens mode {self.config.mode}") output = self.robot_state_tokens_proj(robot_state) return output class FourierFeatures(ConfigurableModule): def __init__(self, config: FourierFeaturesConfig): super().__init__(config) if self.config.learnable_features: self.linear = torch.nn.Linear( in_features=1, out_features=self.config.num_features // 2, bias=False ) else: half_dim = self.config.num_features // 2 freqs = torch.log(torch.tensor(self.config.max_period)) / (half_dim - 1) freqs = torch.exp(-freqs * torch.arange(half_dim)) self.register_buffer("freqs", freqs) self.layers: torch.nn.Sequential = make_mlp( self.config.layers, activation=self.config.activation, norm=self.config.norm, activate_final=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Compute Fourier features and project them via MLP Args: x: Input tensor of shape [..., 1] Returns: torch.Tensor: Fourier features of shape [..., num_features] or [..., layers[-1]] """ assert x.shape[-1] == 1 and x.ndim > 1, x.shape if self.config.learnable_features: frequencies = 2 * math.pi * self.linear(x) else: frequencies = x * expand_dims(self.freqs, x.ndim, [-1, 1]) output = torch.cat([torch.cos(frequencies), torch.sin(frequencies)], dim=-1) output = self.layers(output) return output class NoisedControlProjector(ConfigurableModule): """Pack noised control (translation, rotation, gripper) and project to a single token per timestep""" def __init__(self, config: NoisedControlProjectorConfig): super().__init__(config) self.input_projector = torch.nn.Linear( in_features=self.config.layers[0], out_features=self.config.layers[1] // 2, bias=False, ) self.time_embed = FourierFeatures(self.config.time_embed) self.layers = make_mlp( self.config.layers[1:], activation=self.config.activation, norm=self.config.norm, activate_final=False, bias=False, ) def forward(self, inputs: FlowInput | DiffusionInput) -> Optional[torch.Tensor]: """ Returns: torch.Tensor of shape [B, num_control_timesteps, token_size] """ noised_controls = torch.cat([inputs.translation_t, inputs.rotation_t, inputs.gripper_t], dim=-1) noised_controls = self.input_projector(noised_controls) timestep = self.time_embed(inputs.timestep) timestep = timestep.expand(-1, noised_controls.shape[1], -1) features = torch.cat([timestep, noised_controls], dim=-1) output = self.layers(features) return output def unmask_unattended(attn_mask: torch.Tensor, mask_value: Optional[float] = None) -> torch.Tensor: """ Copy-pased from `transformers.modeling_attn_mask_utils.AttentionMaskConverter._unmask_unattended` Attend to all tokens in fully-masked rows. This is required by F.scaled_dot_product_attention memory-efficient attention path. Otherwise, results are NaN Details: https://github.com/pytorch/pytorch/issues/110213 Args: attn_mask: [B, 1 | num_heads, query_seq_len, kv_seq_len] or [B, query_seq_len, kv_seq_len], float dtype mask_value: The value inside `attn_mask` that corresponds to masked elements Returns: For example, if `attn_mask` is (e.g. here left-padding case) ``` [ [[ [0, 0, 0], [0, 0, 0], [0, 0, 1] ]], [[ [1, 0, 0], [1, 1, 0], [1, 1, 1] ]], [[ [0, 0, 0], [0, 1, 0], [0, 1, 1] ]] ] ``` then the modified `attn_mask` will be ``` [ [[ [1, 1, 1], <-- modified [1, 1, 1], <-- modified [0, 0, 1] ]], [[ [1, 0, 0], [1, 1, 0], [1, 1, 1] ]], [[ [1, 1, 1], <-- modified [0, 1, 0], [0, 1, 1] ]] ] ``` """ assert attn_mask.dtype.is_floating_point, attn_mask.dtype if mask_value is None: mask_value = torch.finfo(attn_mask.dtype).min return attn_mask * ~torch.all(attn_mask == mask_value, dim=-1, keepdim=True) @torch.no_grad() def attn_mask_to_float(attn_mask: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ Convert a 4D mask of type bool to `dtype`. If the attn_mask isn't 4D or isn't bool, raise error """ assert attn_mask.ndim == 4, attn_mask.shape assert attn_mask.dtype == torch.bool, attn_mask.dtype if dtype is None: dtype = torch.get_autocast_dtype(attn_mask.device.type) mask_value = torch.finfo(dtype).min attn_mask = torch.zeros(attn_mask.shape, dtype=dtype, device=attn_mask.device).masked_fill( ~attn_mask, mask_value ) attn_mask = unmask_unattended(attn_mask, mask_value) return attn_mask @torch.no_grad() def make_4d_float_attn_mask( attn_mask: Optional[torch.Tensor], query_seq_length: int, kv_seq_length: int, dtype: torch.dtype, device: torch.device, batch_size: int, ) -> torch.Tensor: """ Creates a 4D mask of shape [B | 1, 1, query_length, kv_seq_length] from a 2D mask of shape [B, kv_seq_length]. If the input `attn_mask` is already 4D: if dtype=torch.bool, convert to dtype, else do nothing If the input is None, output is a full bi-directional attn_mask Args: attn_mask: A 2D attention mask of shape [B, kv_seq_length] or [B, 1, query_length, kv_seq_length] and dtype bool. False values indicate masked out positions query_seq_length: The query sequence length (L) kv_seq_length: The key-value sequence length (S). When `transformers.StaticCache` is used, this should equal the cache size to account for zero-padding the part of the cache that is not yet filled. dtype: Output dtype device: Output device batch_size: Batch size Returns: torch.Tensor of shape [B | 1, 1, query_length, kv_seq_length] (i.e. [B | 1, 1, L, S]). Contains zero at unmasked positions and `torch.finfo(dtype).min` at masked positions """ if attn_mask is not None and attn_mask.ndim == 4: if attn_mask.dtype == torch.bool: attn_mask = attn_mask_to_float(attn_mask, dtype=dtype) elif attn_mask.dtype != dtype: raise TypeError(f"Expected attn_mask.dtype={dtype}, but got {attn_mask.dtype}") return attn_mask mask_value = torch.finfo(dtype).min output_mask = torch.zeros([batch_size, 1, query_seq_length, kv_seq_length], dtype=dtype, device=device) if attn_mask is not None: assert attn_mask.dtype == torch.bool, f"Unsupported dtype {attn_mask.dtype}" mask_length = attn_mask.shape[-1] if mask_length != kv_seq_length: raise NotImplementedError(f"{mask_length} != {kv_seq_length} not properly supported yet") inverted_mask = ~attn_mask.view(batch_size, 1, 1, mask_length) output_mask[..., :mask_length] = output_mask[..., :mask_length].masked_fill(inverted_mask, mask_value) return output_mask class VLMInput(Protocol): input_ids: torch.Tensor attn_mask: torch.Tensor images: Dict[str, torch.Tensor] multimodal_indices: torch.Tensor unimodal_indices: torch.Tensor @property def inputs_embeds(self) -> Optional[torch.Tensor]: return None @property def past_key_values(self) -> Optional[List[torch.Tensor]]: return None def zero_out_param_pretrained_grad( param: torch.nn.Parameter | torch.distributed.tensor.DTensor, module: torch.nn.Module, ) -> None: """Zero out the gradients of pretrained embeddings""" module.mask = module.mask.to(param.device) if isinstance(param, torch.distributed.tensor.DTensor) and not isinstance( module.mask, torch.distributed.tensor.DTensor ): module.mask = torch.distributed.tensor.distribute_tensor( module.mask, device_mesh=param.device_mesh, placements=param.placements ) mask = module.mask if type(param) is torch.distributed.tensor.DTensor and type(mask) is torch.distributed.tensor.DTensor: assert param.grad.shape == mask.shape, f"{param.grad.shape} != {mask.shape}" param.grad._local_tensor = torch.where(mask._local_tensor, param.grad._local_tensor, 0) elif type(param) in (torch.Tensor, torch.nn.Parameter) and type(mask) is torch.Tensor: assert param.grad.shape == mask.shape, f"{param.grad.shape} != {mask.shape}" param.grad = torch.where(mask, param.grad, 0) elif type(param) in (torch.Tensor, torch.nn.Parameter) and type(mask) is torch.distributed.tensor.DTensor: mask = mask.full_tensor() assert param.grad.shape == mask.shape, f"{param.grad.shape} != {mask.shape}" param.grad = torch.where(mask, param.grad, 0) else: raise ValueError(f"Unsupported parameter type: {type(param)} and mask type: {type(mask)}") class PaliGemmaVLM(ConfigurableModule): """Wraps PaliGemma to make compatible with VLM API""" def __init__(self, config: PaliGemmaVLMConfig): super().__init__(config) if self.config.with_depth: config = PaliGemma3DConfig.from_pretrained( self.config.model_id, **self.config.paligemma_3d_config_dict ) self.model = PaliGemma3D.from_pretrained( self.config.model_id, config=config, attn_implementation=self.config.attn_implementation, ) else: self.model = transformers.AutoModelForVision2Seq.from_pretrained( self.config.model_id, attn_implementation=self.config.attn_implementation, ) hf_processor = transformers.AutoProcessor.from_pretrained(self.config.model_id) self.processor = PaliGemmaDepthProcessor( config=self.config.processor_config, hf_processor=hf_processor, depth_tokens=self.config.depth_tokens, ) self._resize_siglip_image_input() self._maybe_override_get_image_features() if self.config.depth_tokens > 0: self._resize_llm_token_embeddings(hf_processor.tokenizer) if not self.config.lm_head: self.model.language_model.lm_head = torch.nn.Identity() self.model.train(True) for decoder in self.model.language_model.model.layers: decoder.self_attn.is_causal = False def forward( self, inputs: VLMInput, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, **kwargs, ) -> VLMOutput: del kwargs self._maybe_register_zero_out_grad_hooks() cache = transformers.DynamicCache() if inputs.attn_mask.ndim == 4: attn_mask = attn_mask_to_float(inputs.attn_mask) else: attn_mask = inputs.attn_mask images = { encoder_camera_name: camera_images.view(-1, *camera_images.shape[2:]) for (encoder_camera_name, camera_images) in inputs.images.items() } llm_output: transformers.models.paligemma.modeling_paligemma.PaliGemmaCausalLMOutputWithPast = ( self.model( input_ids=inputs.input_ids, pixel_values=images, attention_mask=attn_mask, use_cache=use_cache, past_key_values=cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, ) ) indices = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device=inputs.input_ids.device) image_indices = indices[inputs.input_ids[0] == self.processor.hf_processor.image_token_id] text_indices = indices[inputs.input_ids[0] != self.processor.hf_processor.image_token_id] output = VLMOutput( llm_output=LLMOutput.from_transformers( input_ids=inputs.input_ids, llm_output=llm_output, text_indices=text_indices, image_indices=image_indices, ), vit_tokens=llm_output.image_hidden_states, attn_mask=inputs.attn_mask, ) return output @property def fsdp_wrap_modules(self) -> Set[torch.nn.Module]: transformer_modules = { module for module in self.modules() if isinstance( module, ( transformers.models.siglip.modeling_siglip.SiglipEncoderLayer, transformers.models.siglip.modeling_siglip.SiglipVisionTransformer, timm.models.vision_transformer.Block, timm.models.vision_transformer.VisionTransformer, transformers.models.gemma.modeling_gemma.GemmaDecoderLayer, ), ) or module in ( self.model.language_model.model.embed_tokens, self.model.language_model.model.norm, ) } if self.config.with_depth: projectors = self.model.projectors else: projectors = {self.model.multi_modal_projector} return projectors | transformer_modules def _resize_siglip_image_input(self) -> None: """ Enables resizing SigLIP positional embeddings to a new image size. """ num_image_tokens: int = self.config.processor_config.num_image_tokens["main"] image_size: Dict[str, int] = self.config.processor_config.image_sizes["main"].as_json() siglip_embeddings: transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings = ( self.model.vision_tower.vision_model.embeddings ) embedding_weight: torch.Tensor = timm.layers.pos_embed.resample_abs_pos_embed( posemb=siglip_embeddings.position_embedding.weight.unsqueeze(0), new_size=(image_size["height"] // 14, image_size["width"] // 14), old_size=( siglip_embeddings.image_size // 14, siglip_embeddings.image_size // 14, ), num_prefix_tokens=0, interpolation="bicubic", antialias=True, verbose=False, ).squeeze(0) with torch.no_grad(): siglip_embeddings.position_embedding.weight.data = embedding_weight siglip_embeddings.num_patches = siglip_embeddings.num_positions = num_image_tokens siglip_embeddings.image_size = dict(image_size) siglip_embeddings.register_buffer( "position_ids", torch.arange(siglip_embeddings.num_positions).expand((1, -1)), persistent=False, ) def interpolate_pos_encoding(embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: del embeddings new_size = (height // 14, width // 14) old_size = (image_size["height"] // 14, image_size["width"] // 14) if old_size == new_size: patch_pos_embedding = siglip_embeddings.position_embedding.weight else: patch_pos_embedding: torch.Tensor = timm.layers.pos_embed.resample_abs_pos_embed( posemb=siglip_embeddings.position_embedding.weight.unsqueeze(0), new_size=new_size, old_size=( image_size["height"] // 14, image_size["width"] // 14, ), num_prefix_tokens=0, interpolation="bicubic", antialias=True, verbose=False, ).squeeze(0) return patch_pos_embedding siglip_embeddings.interpolate_pos_encoding = interpolate_pos_encoding if not self.processor.config.is_single_image_size: self.model.vision_tower.forward = functools.partial( self.model.vision_tower.forward, interpolate_pos_encoding=True ) def _maybe_override_get_image_features(self) -> None: """ Override PaliGemmaForConditionalGeneration.get_image_features() from transformers such that it can handle multiple cameras with different resolutions. """ if self.config.with_depth: return images_forward = self.model.get_image_features camera_names: List[str] = self.processor.config.camera_names def get_image_features(pixel_values: torch.Tensor) -> torch.Tensor: if self.processor.config.is_single_image_size: inputs = einops.rearrange( torch.stack( [pixel_values[f"{camera_name}.siglip"] for camera_name in camera_names], dim=1, ), "B N C H W -> (B N) C H W", ) image_tokens = images_forward(inputs) else: camera_tokens: List[torch.Tensor] = [ images_forward(pixel_values[f"{camera_name}.siglip"]) for camera_name in camera_names ] image_tokens = torch.cat(camera_tokens, dim=-2) return image_tokens self.model.get_image_features = get_image_features def _resize_llm_token_embeddings(self, tokenizer: transformers.PreTrainedTokenizer) -> None: assert self.config.depth_tokens > 0, self.config.depth_tokens tokenizer.add_tokens([f"" for i in range(self.config.depth_tokens)]) total_num_tokens = len(tokenizer) vocab_size = tokenizer.vocab_size llm = self.model.language_model (_, hidden_size) = llm.lm_head.weight.shape self.model.resize_token_embeddings( total_num_tokens, pad_to_multiple_of=64, mean_resizing=self.config.mean_resizing, ) if self.config.train_only_depth_tokens: assert len(self.model.language_model._tied_weights_keys) > 0 (weight_size, hidden_size) = llm.lm_head.weight.shape mask = torch.cat( [ torch.zeros([vocab_size, hidden_size], dtype=torch.bool), torch.ones([weight_size - vocab_size, hidden_size], dtype=torch.bool), ], dim=0, ) self.mask = mask self.embed_handle = None self.lm_head_handle = None self._maybe_register_zero_out_grad_hooks() def _maybe_register_zero_out_grad_hooks(self) -> None: """ Register hooks to zero out the gradients of pretrained embeddings and LM head. Skips registering the hooks if they already exist. This runs at every step as wrapping in FSDP removes any hooks that were registered on the *parameters* of the original module and the only way to run this reliably is to check if the hooks exist. """ if not self.config.train_only_depth_tokens: return llm = self.model.language_model if ( self.embed_handle is None or llm.model.embed_tokens.weight._post_accumulate_grad_hooks is None or self.embed_handle.id not in llm.model.embed_tokens.weight._post_accumulate_grad_hooks ): self.embed_handle = llm.model.embed_tokens.weight.register_post_accumulate_grad_hook( functools.partial(zero_out_param_pretrained_grad, module=self) ) if llm.model.embed_tokens.weight is not llm.lm_head.weight and ( self.lm_head_handle is None or llm.lm_head.weight._post_accumulate_grad_hooks is None or self.lm_head_handle.id not in llm.lm_head.weight._post_accumulate_grad_hooks ): self.lm_head_handle = llm.lm_head.weight.register_post_accumulate_grad_hook( functools.partial(zero_out_param_pretrained_grad, module=self) ) def make_position_indices( position_indices: Optional[torch.Tensor], seq_length: int, device: torch.device, max_seq_length: Optional[int], ) -> torch.Tensor: if position_indices is not None: position_indices = position_indices.to(dtype=torch.int64) else: position_indices = torch.arange(seq_length, dtype=torch.int64, device=device).view(1, -1) if not torch.max(position_indices) < max_seq_length: raise IndexError( f"position_indices={position_indices} contains index out of bounds of num_embeddings={max_seq_length}" ) return position_indices class RotaryPositionalEncoding(ConfigurableModule): """ Rotary Positional Embeddings (RoPE) from https://arxiv.org/abs/2104.09864 Reference implementations: - https://github.com/meta-llama/llama/blob/main/llama/model.py#L80 - transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding - transformers.models.llama.modeling_llama.LlamaRotaryEmbedding If cached=True, we cache the embeddings for each position up to `num_embeddings` """ def __init__(self, config: RotaryPositionalEncodingConfig): super().__init__(config) inv_freq = 1.0 / self.config.base ** ( torch.arange(0, self.config.embedding_dim, 2, dtype=torch.float32) / self.config.embedding_dim ) self.register_buffer("inv_freq", inv_freq, persistent=False) self._build_cache() def _build_cache(self) -> None: if not self.config.cached: return position_indices = torch.arange(self.config.num_embeddings, dtype=torch.float32) indices_inv_freq = torch.einsum("i, j -> ij", position_indices, self.inv_freq) sin = torch.sin(indices_inv_freq) cos = torch.cos(indices_inv_freq) self.register_buffer("sin_cache", sin, persistent=False) self.register_buffer("cos_cache", cos, persistent=False) def forward( self, tokens: torch.Tensor, position_indices: Optional[torch.Tensor] = None, apply: bool = True, ) -> torch.Tensor: """ Args: tokens: torch.Tensor of shape [B, ..., S, head_dim], where `...` might be any number of dims position_indices: torch.Tensor of shape [B | 1, S]. The indices of tokens within the sequence apply: If True, apply the positional embedding on tokens and return the result Returns: torch.Tensor of the same shape as `tokens` with positional embedding applied on tokens """ assert apply, f"{self.__class__} does not support applying embeddings externally" position_indices = make_position_indices( position_indices, seq_length=tokens.shape[-2], device=tokens.device, max_seq_length=self.config.num_embeddings, ) if self.config.cached: sin = self.sin_cache[position_indices] cos = self.cos_cache[position_indices] sin = torch.cat([sin, sin], dim=-1) cos = torch.cat([cos, cos], dim=-1) else: inv_freq = self.inv_freq.view(1, -1, 1).to(dtype=torch.float32) position_indices = position_indices.to(dtype=torch.float32).unsqueeze(1) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="In CPU autocast, but the target dtype is not supported. Disabling autocast.", ) with torch.autocast(device_type=tokens.device.type, dtype=torch.float32): freqs = (inv_freq @ position_indices).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) (sin, cos) = (torch.sin(emb), torch.cos(emb)) (sin, cos) = (sin.to(dtype=tokens.dtype), cos.to(dtype=tokens.dtype)) sin = expand_dims(sin, tokens.ndim, order=[1, -1, 1, 1]) cos = expand_dims(cos, tokens.ndim, order=[1, -1, 1, 1]) tokens = tokens * cos + self._rotate_invert_half(tokens) * sin return tokens @staticmethod def _rotate_invert_half(x: torch.Tensor) -> torch.Tensor: x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) EAGER_ATTN = "eager" SDPA_ATTN = "sdpa" FLASH_ATTN = "flash_attention_2" def is_full_attn(attn_mask: Optional[torch.Tensor]) -> bool: """ Return True if attn_mask doesn't contain any masked out positions, False otherwise """ if attn_mask is None: return True if attn_mask.dtype == torch.bool: return torch.all(attn_mask == 1).item() if attn_mask.dtype.is_floating_point: return torch.all(attn_mask == 0).item() raise TypeError(f"Unrecognized dtype {attn_mask.dtype}") @torch.no_grad() def make_attn_mask_causal(attn_mask: torch.Tensor, cache_position: torch.Tensor) -> torch.Tensor: """ Args: attn_mask: 4D tensor of shape [B | 1, 1, query_seq_len, kv_seq_len] (i.e. [B | 1, 1, L, S]) of float dtype (NOT bool!). Masked positions contain the value `torch.finfo(dtype).min` cache_position: torch.Tensor of type torch.int64 and shape [query_seq_len]. Contained values are index positions of the query tokens in the sequence. During training, this would usually be torch.arange(query_seq_len), but during generate, this would usually be a tensor sequence with 1 element indicating the position of the token currently being generated Returns: torch.Tensor of the same shape as attn_mask. Contains zero at unmasked positions and `torch.finfo(dtype).min` at masked positions """ if attn_mask.dtype.is_floating_point: mask_value = torch.finfo(attn_mask.dtype).min elif attn_mask.dtype == torch.bool: mask_value = 0 else: raise TypeError(f"Unsupported mask type {attn_mask.dtype}") (_, _, query_seq_length, kv_seq_length) = attn_mask.shape causal_mask = torch.ones(attn_mask.shape, dtype=torch.bool, device=attn_mask.device) causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask = causal_mask * ( torch.arange(kv_seq_length, device=cache_position.device).view(1, -1) > cache_position.view(-1, 1) ).view(*[1] * (causal_mask.ndim - 2), query_seq_length, kv_seq_length) causal_attn_mask = attn_mask.masked_fill_(causal_mask, mask_value) return causal_attn_mask def update_attn_mask( attn_mask: Optional[torch.Tensor], attn_implementation: str, query_seq_length: int, kv_seq_length: int, cache_position: Optional[torch.Tensor], cache: Optional[transformers.Cache], batch_size: int, causal: bool, dtype: torch.dtype, device: torch.device, output_attentions: bool = False, ) -> Optional[torch.Tensor]: """ Update attn_mask such that it's compatible with the attention implementation. Meant to be used with barrel.components.nn.layers.attention.MultiheadAttention and its derivatives Args: attn_mask: dtype torch.bool, torch.float32, torch.float16 or torch.bfloat16 and shape one of: - [B, kv_seq_length] (i.e. [B, S]) - [B, 1, query_seq_length, kv_seq_length] (i.e. [B, 1, L, S]) - [1, 1, query_seq_length, kv_seq_length] (i.e. [L, S]) If bool, False values indicate masked positions. If float, must contain only 0.0 and torch.finfo(dtype).min If attn_mask is None, full-bidirectional attention is assumed. The output might be None or a tensor. Refer to the return value documentation attn_implementation: One of [FLASH_ATTN, SDPA_ATTN, EAGER_ATTN] query_seq_length: The query sequence length (L) kv_seq_length: The key-value sequence length (S) cache_position: dtype torch.int64, shape [query_seq_len]. Used only when causal=True. Contained values are index positions of the query tokens in the sequence. During training, this would usually be torch.arange(query_seq_len), but during generate, this would usually be a tensor sequence with 1 element indicating the position of the token currently being generated. If None, default `cache_positions` are autocomputed from `query_seq_length` and cache size cache: Optional cache. Usually not None when running generate at inference. batch_size: Batch size of the generated attention mask causal: If True, make the attn_mask causal -> all non-causal positions are masked out, regardless of their attn_mask values. When using flash attention or SDPA and `causal == False`, make sure to pass `causal` to the attention operation, in case this function delegates causal masking dtype: dtype of the output attention mask. Must be the dtype of the attn computation device: device of the output attention mask output_attentions: If True, the attention operation is required to output attention weights Returns: - `None` in either of these cases: - `attn_mask` doesn't contain any masked out positions and causal=False - `attn_implementation in [FLASH_ATTN, SDPA_ATTN]` and `attn_mask` doesn't contain any masked out positions. If causal=True, we instead rely on the causal argument to flash attention or `torch.nn.functional.scaled_dot_product_attention`. This happens only if the cache is empty and cache_position is None - `attn_mask` if `attn_implementation == FLASH_ATTN` and `attn_mask` can't be ignored TODO(FLASH) - torch.Tensor of shape [B, 1, query_length, kv_seq_length] (i.e. [B, 1, L, S]) and type `dtype`. Contains zero at unmasked positions and `torch.finfo(dtype).min` at masked positions. """ assert attn_implementation in [FLASH_ATTN, SDPA_ATTN, EAGER_ATTN] assert dtype.is_floating_point, dtype if torch.jit.is_tracing() or torch.jit.is_scripting() or torch.compiler.is_compiling(): raise NotImplementedError("Complete correctness not confirmed yet") if isinstance(cache, transformers.StaticCache): if attn_mask is not None and attn_mask.shape[-1] != cache.get_max_cache_shape(): raise NotImplementedError("Complete correctness not confirmed yet") full_attn = is_full_attn(attn_mask) past_seen_tokens = cache.get_seq_length() if cache is not None else 0 if full_attn and not causal: return None if ( full_attn and causal and attn_implementation in [SDPA_ATTN, FLASH_ATTN] and past_seen_tokens == 0 and cache_position is None ): return None past_seen_tokens = cache.get_seq_length() if cache is not None else 0 static_cache = isinstance(cache, transformers.StaticCache) if static_cache and kv_seq_length < cache.get_max_cache_shape(): kv_seq_length = cache.get_max_cache_shape() elif attn_mask is not None: assert kv_seq_length == attn_mask.shape[-1], f"{kv_seq_length}, {attn_mask.shape}" output_mask = make_4d_float_attn_mask( attn_mask=attn_mask, query_seq_length=query_seq_length, kv_seq_length=kv_seq_length, dtype=dtype, device=device, batch_size=batch_size, ) if causal: cache_position = ( torch.arange(past_seen_tokens, past_seen_tokens + query_seq_length, device=device) if cache_position is None else cache_position ) output_mask = make_attn_mask_causal(output_mask, cache_position) if ( attn_implementation == SDPA_ATTN and attn_mask is not None and attn_mask.device.type == "cuda" and not output_attentions ): output_mask = unmask_unattended(output_mask, mask_value=torch.finfo(dtype).min) return output_mask def expand_kv_heads(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ The equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). Convert hidden_states from [batch, num_kv_heads, seqlen, head_dim] -> [batch, num_attention_heads, seqlen, head_dim] """ (batch, num_kv_heads, slen, head_dim) = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim) class MultiheadAttention(torch.nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper Different implementation from torch.nn.MultiheadAttention to support: - Easy switch between EAGER_ATTN, SDPA_ATTN and FLASH_ATTN - Number of key-value heads different from query heads - Key-value cache during forward, in the same way as transformers. Useful for generation or cross-attention to projected keys and values - Ability to apply positional encodings to key and value after input linear projection - Different linear projection output size Adapted from transformers.models.llama.modeling_llama.LlamaAttention """ def __init__( self, in_features: int, num_heads: int, head_dim: Optional[int] = None, out_features: Optional[int] = None, key_features: Optional[int] = None, value_features: Optional[int] = None, num_kv_heads: Optional[int] = None, bias: bool = False, dropout: float = 0.0, cache_layer: Optional[int] = None, query_position_embed: Optional[Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]] = None, key_position_embed: Optional[Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]] = None, ): """ Args: in_features: Input dimension for query linear projection. num_heads: Number of heads for query head_dim: Head dimension. If None, defaults to `in_features // num_heads` out_features: Output dimension for the output linear layer. If None, defaults to `in_features` key_features: Input dimension for key linear projection. If None, defaults to `in_features` value_features: Input dimension for value linear projection. If None, defaults to `in_features` num_kv_heads: Number of heads for keys and values. If None, defaults to `num_heads` cache_layer: Index of the layer in the cache. Needed only when `cache` is passed to the `forward()` call, usually during generation or when the projected keys and values need to be cached during training. Can be omitted when `cache_layer` is passed to `forward` position_embed: Callable that takes as input linearly projected query and key and a tuple of positional embeddings and returns query and key with positional embeddings applied. Note these embeddings are applied after linear projection. If you want to apply embeddings before the linear projection, do so before calling the forward method and use the default value for `position_embed`, which is a simple pass-through. Note you can also pass torch.nn.Module key_position_embed: Callable that takes as input linearly projected key and optional positional index in the sequence and returns key with positional embeddings applied. positional embeddings and returns query and key with positional embeddings applied. Note these embeddings are applied after linear projection. If you want to apply embeddings before the linear projection, do so before calling the forward method and use the default value for `position_embed`, which is a simple pass-through. Note you can also pass torch.nn.Module """ super().__init__() self.in_features = in_features self.key_features = key_features or in_features self.value_features = value_features or in_features self.bias = bias self.out_features = out_features or in_features self.num_heads = num_heads self.head_dim = head_dim or in_features // num_heads self.num_kv_heads = num_kv_heads or num_heads self.dropout = dropout self.query_position_embed = query_position_embed self.key_position_embed = key_position_embed self.cache_layer = cache_layer self.q_proj = torch.nn.Linear(self.in_features, self.num_heads * self.head_dim, bias=self.bias) self.k_proj = torch.nn.Linear(self.key_features, self.num_kv_heads * self.head_dim, bias=self.bias) self.v_proj = torch.nn.Linear(self.value_features, self.num_kv_heads * self.head_dim, bias=self.bias) self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.out_features, bias=self.bias) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, query_position_indices: Optional[torch.Tensor] = None, key_position_indices: Optional[torch.Tensor] = None, cache: Optional[transformers.Cache] = None, cache_layer: Optional[int] = None, output_attentions: bool = False, cache_kwargs: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Args: query: Query embedding of shape [B, L, in_features] key: Key embedding of shape [B, S, key_features] value: Value embedding of shape [B, S, value_features] attn_mask: dtype torch.bool or same dtype as query/key/value and shape one of: - [B, S] - [B | 1, 1 | num_heads, L, S] If bool, False values indicate masked positions (opposite of torch.nn.MultiheadAttention) If float, must contain only 0.0 and torch.finfo(dtype).min If attn_mask is None, full-bidirectional attention or causal attention is used depdening on the value of `is_causal`. is_causal: If True, apply additional causal masking to `attn_mask` query_position_indices: torch.Tensor of shape [1 | B, L] containing the indices of the `query` tokens within the entire sequence. Passed through to query_position_embed. If None and `cache` is not None, indices are autogenerated [0, 1, ..., L] and offset by `cache_size` key_position_indices: Same as `query_position_indices`, but applied to key cache: transformers.Cache containing cached key-value pairs. The linearly projected `key` and `value` passed to this function get added to the cache and concatenated after the key-value pairs in the cache and then attention is computed on the concatenated sequence. This is most commonly used at inference when generating auto-regressively or when one needs to cross attend to the keys and values outside this module forward pass. cache_layer: Index of the layer in the cache. Needed only when `cache` is passed to the `forward()` call, usually during generation or when the projected keys and values need to be cached during training. Can be omitted when `cache_layer` was passed to `__init__` output_attentions: If True, output also the attention weights. Otherwise output None. Note that only the eager implementation of MultiheadAttention supports this. cache_kwargs: kwargs directly passed to `cache.update()` Returns: Tuple with entries: - Attention block output: torch.Tensor of shape [B, L, out_features] - Optional attention weights if `output_attentions=True`, shape [B, num_heads, L, S] """ batch_size = query.shape[0] query_states = self.q_proj(query) key_states = self.k_proj(key) value_states = self.v_proj(value) query_states = query_states.view( batch_size, query_states.shape[1], self.num_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( batch_size, key_states.shape[1], self.num_kv_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( batch_size, value_states.shape[1], self.num_kv_heads, self.head_dim ).transpose(1, 2) (query_states, key_states) = self._maybe_apply_positional_embeddings( query_states=query_states, key_states=key_states, query_position_indices=query_position_indices, key_position_indices=key_position_indices, cache=cache, ) (key_states, value_states) = self._maybe_update_cache( key_states, value_states, cache_layer=cache_layer, cache=cache, cache_kwargs=cache_kwargs, ) key_states = expand_kv_heads(key_states, self.num_heads // self.num_kv_heads) value_states = expand_kv_heads(value_states, self.num_heads // self.num_kv_heads) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_mask = update_attn_mask( attn_mask, attn_implementation=EAGER_ATTN, query_seq_length=query_states.shape[2], kv_seq_length=value_states.shape[2], cache_position=query_position_indices, cache=cache, batch_size=batch_size, causal=is_causal, dtype=query_states.dtype, device=query_states.device, output_attentions=output_attentions, ) if attn_mask is not None: attn_mask = attn_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + attn_mask attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( query_states.dtype ) attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) shape = (batch_size, self.num_heads, query.shape[1], self.head_dim) assert attn_output.shape == shape, f"{attn_output.shape} != {shape}" attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights def _maybe_apply_positional_embeddings( self, query_states: torch.Tensor, key_states: torch.Tensor, query_position_indices: Optional[torch.Tensor], key_position_indices: Optional[torch.Tensor], cache: Optional[transformers.Cache], ) -> Tuple[torch.Tensor, torch.Tensor]: device = query_states.device if self.query_position_embed is not None: if query_position_indices is None and cache is not None: query_position_indices = ( torch.arange(query_states.shape[-2], dtype=torch.int64, device=device).view(1, -1) + cache.get_seq_length() ) query_states = self.query_position_embed(query_states, position_indices=query_position_indices) if self.key_position_embed is not None: if key_position_indices is None and cache is not None: key_position_indices = ( torch.arange(key_states.shape[-2], dtype=torch.int64, device=device).view(1, -1) + cache.get_seq_length() ) key_states = self.key_position_embed(key_states, position_indices=key_position_indices) return query_states, key_states def _maybe_update_cache( self, key_states: torch.Tensor, value_states: torch.Tensor, cache_layer: Optional[int], cache: Optional[transformers.Cache], cache_kwargs: Dict[str, Any], ) -> Tuple[torch.Tensor, torch.Tensor]: if cache is not None: if cache_layer is None and self.cache_layer is None: raise RuntimeError("When cache != None, cache_layer has to be set") cache_layer = cache_layer if cache_layer is not None else self.cache_layer (key_states, value_states) = cache.update(key_states, value_states, cache_layer, cache_kwargs) return key_states, value_states class MultiheadFlashAttention2(MultiheadAttention): """ MultiheadAttention implemented using flash attention module. Inherits `MultiheadAttention` as the weights of the module stay untouched. The only change is on the forward pass where we call flash attention. """ def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, query_position_indices: Optional[torch.Tensor] = None, key_position_indices: Optional[torch.Tensor] = None, cache: Optional[transformers.Cache] = None, cache_layer: Optional[int] = None, output_attentions: bool = False, cache_kwargs: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Args: query: Query embedding of shape [B, L, in_features] key: Key embedding of shape [B, S, key_features] value: Value embedding of shape [B, S, value_features] attn_mask: dtype torch.bool and shape [B, S]. If bool, False values indicate masked positions (opposite of torch.nn.MultiheadAttention) If attn_mask is None, full-bidirectional attention or causal attention is used depdening on the value of `is_causal`. NOTE: Doesn't support 4D attn_mask, unlike MultiheadAttention is_causal: If True, apply additional causal masking to `attn_mask` query_position_indices: torch.Tensor of shape [1 | B, L] containing the indices of the `query` tokens within the entire sequence. Passed through to query_position_embed. If None and `cache` is not None, indices are autogenerated [0, 1, ..., L] and offset by `cache_size` key_position_indices: Same as `query_position_indices`, but applied to key cache: transformers.Cache containing cached key-value pairs. The linearly projected `key` and `value` passed to this function get added to the cache and concatenated after the key-value pairs in the cache and then attention is computed on the concatenated sequence. This is most commonly used at inference when generating auto-regressively or when one needs to cross attend to the keys and values outside this module forward pass. cache_layer: Index of the layer in the cache. Needed only when `cache` is passed to the `forward()` call, usually during generation or when the projected keys and values need to be cached during training. Can be omitted when `cache_layer` was passed to `__init__` output_attentions: If True, output also the attention weights. Otherwise output None. Note that only the eager implementation of MultiheadAttention supports this. cache_kwargs: kwargs directly passed to `cache.update()` Returns: Tuple with entries: - Attention block output: torch.Tensor of shape [B, L, out_features] - Optional attention weights if `output_attentions=True`, shape [B, num_heads, L, S] """ if isinstance(cache, transformers.StaticCache): raise ValueError( "transformers.StaticCache not compatible with flash attention. Use `sdpa` instead (for now)." ) assert output_attentions is False, f"{self.__class__} doesn't support output_attentions=True" batch_size = query.shape[0] query_states = self.q_proj(query) key_states = self.k_proj(key) value_states = self.v_proj(value) query_states = query_states.view( batch_size, query_states.shape[1], self.num_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( batch_size, key_states.shape[1], self.num_kv_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( batch_size, value_states.shape[1], self.num_kv_heads, self.head_dim ).transpose(1, 2) (query_states, key_states) = self._maybe_apply_positional_embeddings( query_states=query_states, key_states=key_states, query_position_indices=query_position_indices, key_position_indices=key_position_indices, cache=cache, ) (key_states, value_states) = self._maybe_update_cache( key_states, value_states, cache_layer=cache_layer, cache=cache, cache_kwargs=cache_kwargs, ) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) attn_mask = update_attn_mask( attn_mask, attn_implementation=FLASH_ATTN, query_seq_length=query_states.shape[2], kv_seq_length=value_states.shape[2], cache_position=query_position_indices, cache=cache, batch_size=batch_size, causal=is_causal, dtype=query_states.dtype, device=query_states.device, output_attentions=output_attentions, ) raise NotImplementedError("Correctness not yet confirmed") attn_output = transformers.modeling_flash_attention_utils._flash_attention_forward( query_states=query_states, key_states=key_states, value_states=value_states, attention_mask=attn_mask, query_length=query.shape[1], position_ids=None, dropout=self.dropout if self.training else 0.0, sliding_window=None, use_top_left_mask=False, is_causal=is_causal, deterministic=True, ) size = (batch_size, self.num_heads, query.shape[1], self.head_dim) if attn_output.size() != size: raise ValueError(f"`attn_output` should be of size {size}, but is {attn_output.size()}") shape = (batch_size, self.num_heads, query.shape[1], self.head_dim) assert attn_output.shape == shape, f"{attn_output.shape} != {shape}" attn_output = attn_output.reshape(batch_size, query.shape[1], -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, None class MultiheadSdpaAttention(MultiheadAttention): """ MultiheadAttention SDPA attention. Inherits `MultiheadAttention` as the weights of the module stay untouched. The only change is on the forward pass where we call `torch.nn.functional.scaled_dot_product_attention` """ def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, query_position_indices: Optional[torch.Tensor] = None, key_position_indices: Optional[torch.Tensor] = None, cache: Optional[transformers.Cache] = None, cache_layer: Optional[int] = None, output_attentions: bool = False, cache_kwargs: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Args: query: Query embedding of shape [B, L, in_features] key: Key embedding of shape [B, S, key_features] value: Value embedding of shape [B, S, value_features] attn_mask: dtype torch.bool or same dtype as query/key/value and shape one of: - [B, S] - [B | 1, 1 | num_heads, L, S] If bool, False values indicate masked positions (opposite of torch.nn.MultiheadAttention) If float, must contain only 0.0 and torch.finfo(dtype).min If attn_mask is None, full-bidirectional attention or causal attention is used depdening on the value of `is_causal`. is_causal: If True, apply additional causal masking to `attn_mask` query_position_indices: torch.Tensor of shape [1 | B, L] containing the indices of the `query` tokens within the entire sequence. Passed through to query_position_embed. If None and `cache` is not None, indices are autogenerated [0, 1, ..., L] and offset by `cache_size` key_position_indices: Same as `query_position_indices`, but applied to key cache: transformers.Cache containing cached key-value pairs. The linearly projected `key` and `value` passed to this function get added to the cache and concatenated after the key-value pairs in the cache and then attention is computed on the concatenated sequence. This is most commonly used at inference when generating auto-regressively or when one needs to cross attend to the keys and values outside this module forward pass. cache_layer: Index of the layer in the cache. Needed only when `cache` is passed to the `forward()` call, usually during generation or when the projected keys and values need to be cached during training. Can be omitted when `cache_layer` was passed to `__init__` output_attentions: If True, output also the attention weights. Otherwise output None. Note that only the eager implementation of MultiheadAttention supports this. cache_kwargs: kwargs directly passed to `cache.update()` Returns: Tuple with entries: - Attention block output: torch.Tensor of shape [B, L, out_features] - Optional attention weights if `output_attentions=True`, shape [B, num_heads, L, S] """ assert output_attentions is False, f"{self.__class__} doesn't support output_attentions=True" batch_size = query.shape[0] query_states = self.q_proj(query) key_states = self.k_proj(key) value_states = self.v_proj(value) query_states = query_states.view( batch_size, query_states.shape[1], self.num_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( batch_size, key_states.shape[1], self.num_kv_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( batch_size, value_states.shape[1], self.num_kv_heads, self.head_dim ).transpose(1, 2) (query_states, key_states) = self._maybe_apply_positional_embeddings( query_states=query_states, key_states=key_states, query_position_indices=query_position_indices, key_position_indices=key_position_indices, cache=cache, ) (key_states, value_states) = self._maybe_update_cache( key_states, value_states, cache_layer=cache_layer, cache=cache, cache_kwargs=cache_kwargs, ) key_states = expand_kv_heads(key_states, self.num_heads // self.num_kv_heads) value_states = expand_kv_heads(value_states, self.num_heads // self.num_kv_heads) attn_mask = update_attn_mask( attn_mask, attn_implementation=SDPA_ATTN, query_seq_length=query_states.shape[2], kv_seq_length=value_states.shape[2], cache_position=query_position_indices, cache=cache, batch_size=batch_size, causal=is_causal, dtype=query_states.dtype, device=query_states.device, output_attentions=output_attentions, ) if attn_mask is not None: attn_mask = attn_mask[:, :, :, : key_states.shape[-2]] attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) shape = (batch_size, self.num_heads, query.shape[1], self.head_dim) assert attn_output.shape == shape, f"{attn_output.shape} != {shape}" attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch_size, query.shape[1], self.num_heads * self.head_dim) attn_output = self.o_proj(attn_output) return attn_output, None ATTN_TYPES = { EAGER_ATTN: MultiheadAttention, SDPA_ATTN: MultiheadSdpaAttention, FLASH_ATTN: MultiheadFlashAttention2, } def make_activation(activation: str | Type[torch.nn.Module], **kwargs) -> torch.nn.Module: if isinstance(activation, str): TorchActivation: Type[torch.nn.Module] = getattr(torch.nn, activation) else: TorchActivation: Type[torch.nn.Module] = activation assert issubclass(TorchActivation, torch.nn.Module), TorchActivation return TorchActivation(**kwargs) class PiZeroMLP(torch.nn.Module): def __init__(self, feature_size: int, hidden_size: int, activation: str): super().__init__() self.gate_proj = torch.nn.Linear(feature_size, hidden_size, bias=False) self.up_proj = torch.nn.Linear(feature_size, hidden_size, bias=False) self.down_proj = torch.nn.Linear(hidden_size, feature_size, bias=False) self.activation = make_activation(activation, approximate="tanh") def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x)) class PiZeroFlowMatchingDecoderBlock(ConfigurableModule): def __init__(self, config: PiZeroFlowMatchingDecoderBlockConfig, **attn_kwargs): super().__init__(config) self.norm_in = GemmaRMSNorm(self.config.feature_size, eps=1e-06) self.self_attn = ATTN_TYPES[self.config.attn_implementation]( in_features=self.config.feature_size, num_heads=self.config.num_heads, head_dim=self.config.head_dim, num_kv_heads=self.config.num_kv_heads, **attn_kwargs, ) self.mlp = PiZeroMLP( feature_size=self.config.feature_size, hidden_size=self.config.hidden_size, activation=self.config.activation, ) self.norm_out = GemmaRMSNorm(self.config.feature_size, eps=1e-06) def forward( self, query: torch.Tensor, attn_mask: torch.Tensor, cache: transformers.Cache, attn_kwargs: Dict[str, Any], ) -> torch.Tensor: """ Args: query: torch.Tensor of shape [B, L, token_size]. The query seqence in the order: [noised query tokens, condition token, robot state tokens] timestep: torch.Tensor of shape [B, 1, token_size]. Timestep token attn_mask: torch.Tensor of shape [B, 1, L, L+S] and dtype torch.bool, where S is the VLM sequence length cache: Cache that contains only the VLM tokens during training and VLM + past query tokens during generation num_noised_tokens: Number of noised tokens in `query` num_condition_tokens: Number of condition tokens in `query` Returns: torch.Tensor of same shape as query [B, L, token_size] """ residual = x = query x = self.norm_in(x) (x, _) = self.self_attn( query=x, key=x, value=x, attn_mask=attn_mask, is_causal=False, cache=cache, **attn_kwargs, ) x = residual + x residual = x x = self.norm_out(x) x = self.mlp(x) x = residual + x return x class PiZeroFlowMatchingDecoder(ConfigurableModule): """PiZero Flow Matching control decoder""" def __init__(self, config: PiZeroFlowMatchingDecoderConfig): super().__init__(config) query_position_embed = RotaryPositionalEncoding(config=self.config.block_config.position_embed_config) key_position_embed = RotaryPositionalEncoding(config=self.config.block_config.position_embed_config) self.blocks = torch.nn.ModuleList( [ PiZeroFlowMatchingDecoderBlock( self.config.block_config, query_position_embed=query_position_embed, key_position_embed=key_position_embed, cache_layer=i, ) for i in range(self.config.num_blocks) ] ) self.norm = GemmaRMSNorm(self.config.block_config.feature_size, eps=1e-06) def forward( self, control_tokens: torch.Tensor, robot_state_tokens: torch.Tensor, llm_kv_tokens: List[Tuple[torch.Tensor, torch.Tensor]], attn_mask: Optional[torch.Tensor], cache: Optional[transformers.Cache] = None, ) -> torch.Tensor: """ Args: control_tokens: torch.Tensor of shape [B, N, token_size], contains sequence of controls robot_state_tokens: torch.Tensor of shape [B, num_state_tokens, token_size] llm_kv_tokens: List of linearly projected key-value pairs from LLM, right before attention operation. Each tensor is of the shape [B, num_kv_heads, S, head_dim] attn_mask: One of - shape [B, S], dtype torch.bool -> padding attention mask for LLM tokens - shape [B, 1, L, S], dtype torch.bool -> full attention mask for LLM tokens Returns: torch.Tensor, shape [B, N, token_size] """ assert ( len(llm_kv_tokens) == self.config.num_blocks ), f"{len(llm_kv_tokens)} != {self.config.num_blocks}" is_step_zero = cache.get_seq_length() == 0 if cache is not None else True vlm_seq_len = attn_mask.shape[-1] device = attn_mask.device if cache is None: cache = transformers.DynamicCache() if is_step_zero: position_indices = torch.arange(vlm_seq_len, dtype=torch.int64, device=device) for block_index, kv_tokens in enumerate(llm_kv_tokens): (key_states, value_states) = kv_tokens cache.update( key_states, value_states, block_index, cache_kwargs={"cache_position": position_indices}, ) num_control_tokens = control_tokens.shape[1] num_robot_state_tokens = robot_state_tokens.shape[1] attn_mask = self._build_attn_mask( num_control_tokens=num_control_tokens, num_robot_state_tokens=num_robot_state_tokens, attn_mask=attn_mask, ) if is_step_zero: tokens = torch.cat([robot_state_tokens, control_tokens], axis=1) query_position_indices = key_position_indices = vlm_seq_len + torch.arange( tokens.shape[1], dtype=torch.int64, device=device ).view(1, -1) else: tokens = control_tokens attn_mask = attn_mask[:, :, -control_tokens.shape[1] :] query_position_indices = key_position_indices = ( vlm_seq_len + num_robot_state_tokens + torch.arange(tokens.shape[1], dtype=torch.int64, device=device).view(1, -1) ) for block in self.blocks: tokens = block( query=tokens, attn_mask=attn_mask, cache=cache, attn_kwargs={ "query_position_indices": query_position_indices, "key_position_indices": key_position_indices, "cache_kwargs": {"cache_position": key_position_indices.view(-1)}, }, ) if is_step_zero: (_, control_tokens) = torch.split(tokens, [num_robot_state_tokens, num_control_tokens], dim=1) else: control_tokens = tokens control_tokens = self.norm(control_tokens) return control_tokens @torch.no_grad() def _build_attn_mask( self, num_control_tokens: int, num_robot_state_tokens: int, attn_mask: torch.Tensor, ) -> torch.Tensor: """ Expand `attn_mask` (which is effectively a padding mask) to 4D such that: - robot state tokens and control tokens can't attend to padding tokens - robot state tokens can't attend to control tokens Note: We can't keep the mask in 2D as it doesn't allow masking of padding tokens from the VLM sequence. Furthermore, in a 2D mask you can't disable attention from robot state tokens to control tokens """ assert attn_mask.dtype == torch.bool, attn_mask.dtype assert attn_mask.ndim in [2, 4], attn_mask.shape device = attn_mask.device batch_size = attn_mask.shape[0] query_seq_len = num_robot_state_tokens + num_control_tokens vlm_seq_len = attn_mask.shape[-1] kv_seq_len = query_seq_len + vlm_seq_len cross_attn_mask = torch.ones( [batch_size, 1, query_seq_len, kv_seq_len], dtype=torch.bool, device=device ) if attn_mask.ndim == 2: attn_mask = attn_mask.view(batch_size, 1, 1, vlm_seq_len) else: attn_mask = torch.any(attn_mask, dim=-2, keepdims=True) cross_attn_mask[..., :vlm_seq_len] = attn_mask robot_state_query_indices = torch.arange( num_robot_state_tokens, dtype=torch.int64, device=device ).view(-1, 1) control_key_indices = ( torch.arange(num_control_tokens, dtype=torch.int64, device=device).view(-1, 1) + vlm_seq_len + num_robot_state_tokens ) cross_attn_mask[:, :, robot_state_query_indices, control_key_indices] = 0 return cross_attn_mask @property def fsdp_wrap_modules(self) -> Set[torch.nn.Module]: return {module for module in self.modules() if isinstance(module, type(self.blocks[0]))} | {self.norm} def integrate_unitquat( qt: torch.Tensor, dq_dt: torch.Tensor, dt: float | torch.Tensor, body_frame: bool = True, half_cover: bool = True, ) -> torch.Tensor: """ Integrate a unit quaternion `qt` by the derivative `dq_dt` over the time interval `dt`. Args: qt: Unit quaternion, shape [..., 4] dq_dt: Derivative of the unit quaternion, shape [..., 4] dt: Time interval to integrate over, scalar or a tensor of shape () or [..., 1] half_cover: If True, the result is guaranteed to lie in the half space body_frame: If True, the integration is done in the body frame (post-multiply), otherwise in the inertial frame (pre-multiply). Returns: Integrated unit quaternion, shape [..., 4] """ assert qt.shape == dq_dt.shape, f"{qt.shape} != {dq_dt.shape}" assert is_quaternion(qt), f"{qt.shape} not a quaternion" if isinstance(dt, torch.Tensor): assert dt.ndim in (0, qt.ndim), f"dt.ndim = {dt.ndim} | {qt.ndim}" if body_frame: omega_q = 2.0 * roma.quat_product(roma.quat_conjugation(qt), dq_dt) else: omega_q = 2.0 * roma.quat_product(dq_dt, roma.quat_conjugation(qt)) omega = omega_q[..., :-1] dq = roma.rotvec_to_unitquat(omega * dt) if body_frame: qt = roma.quat_product(qt, dq) else: qt = roma.quat_product(dq, qt) if half_cover: qt = quaternion_half_cover(qt) return qt def rotmat_inverse(rotation: torch.Tensor) -> torch.Tensor: assert is_rotmat(rotation), f"Expected a rotation matrix, but got shape {rotation.shape}" rotmat = rotmat_as_3x3(rotation) rotmat = rotmat.transpose(-1, -2) if is_rotmat_9(rotation): rotmat = rotmat_as_9(rotmat) return rotmat def skew_symmetric_to_rotvec(skew_symmetric: torch.Tensor) -> torch.Tensor: """ Convert a skew-symmetric matrix to a rotation vector in a differentiable way [ [ 0, -z, y], [ z, 0, -x], [-y, x, 0], ] Args: skew_symmetric: Skew-symmetric matrix of shape [..., 3, 3] Returns: torch.Tensor of shape [..., 3] """ assert is_rotmat(skew_symmetric), skew_symmetric.shape rotvec = torch.stack( ( skew_symmetric[..., 2, 1] - skew_symmetric[..., 1, 2], skew_symmetric[..., 0, 2] - skew_symmetric[..., 2, 0], skew_symmetric[..., 1, 0] - skew_symmetric[..., 0, 1], ), dim=-1, ) rotvec = rotvec / 2.0 return rotvec def integrate_rotmat( rt: torch.Tensor, dr_dt: torch.Tensor, dt: float | torch.Tensor, body_frame: bool = True, ) -> torch.Tensor: """ Integrate a rotation matrix `rt` by the derivative `dr_dt` over the time interval `dt`. Args: rt: Rotation matrix, shape [..., 3, 3] dr_dt: Derivative of the rotation matrix, shape [..., 3, 3] dt: Time interval to integrate over, scalar or a tensor of shape () or [..., 1] body_frame: If True, the integration is done in the body frame (post-multiply), otherwise in the inertial frame (pre-multiply). Returns: Integrated unit quaternion, shape [..., 4] """ assert rt.shape == dr_dt.shape, f"{rt.shape} != {dr_dt.shape}" assert is_rotmat(rt), f"{rt.shape} not a rotation matrix" is_3x3 = is_rotmat_3x3(rt) if not is_3x3: rt = rotmat_as_3x3(rt) dr_dt = rotmat_as_3x3(dr_dt) if isinstance(dt, torch.Tensor): assert dt.ndim in ( 0, rt.ndim, rt.ndim - 1, ), f"dt.ndim = {dt.ndim} | {rt.ndim} | {rt.ndim - 1}" if dt.ndim == rt.ndim: assert dt.shape[-2:] == (1, 1), dt.shape dt = dt.squeeze(-1) if body_frame: omega = skew_symmetric_to_rotvec(rotmat_inverse(rt) @ dr_dt) else: omega = skew_symmetric_to_rotvec(dr_dt @ rotmat_inverse(rt)) dr = roma.rotvec_to_rotmat(omega * dt) if body_frame: rt = rt @ dr else: rt = dr @ rt if not is_3x3: rt = rotmat_as_9(rt) return rt def integrate_rotation( rt: torch.Tensor, dr_dt: torch.Tensor, dt: float | torch.Tensor, body_frame: bool = True, half_cover: bool = True, ) -> torch.Tensor: """ Integrate the rotation `rt` by the derivative `dr_dt` over the time interval `dt` on the SO(3) manifold. """ if is_quaternion(rt): return integrate_unitquat(rt, dr_dt, dt, body_frame=body_frame, half_cover=half_cover) if is_rotmat(rt): return integrate_rotmat(rt, dr_dt, dt, body_frame=body_frame) raise NotImplementedError(f"integrate_rotation not yet implemented for format {rt.shape}") class PiZeroFlowMatchingModule(ConfigurableModule): def __init__(self, config: PiZeroFlowMatchingModuleConfig, control_tokenizer: EmptyTokenizer): super().__init__(config) del control_tokenizer self.noised_control_proj = NoisedControlProjector(self.config.noised_control_proj_config) self.robot_state_proj = RobotStateProjector(self.config.robot_state_proj_config) self.control_decoder = PiZeroFlowMatchingDecoder(config=self.config.control_decoder_config) self.output_proj = make_mlp( [self.config.token_size, 3 + self.config.rotation_components + 1], activation=torch.nn.GELU, activate_final=False, ) def forward( self, vlm_input: RoboticsFlowInput, vlm_output: VLMOutput, cache: Optional[transformers.Cache] = None, ) -> RoboticsOutput: robot_state_tokens = self.robot_state_proj(vlm_input) noised_tokens = self.noised_control_proj(vlm_input.flow_input) output_tokens = self.control_decoder( control_tokens=noised_tokens, robot_state_tokens=robot_state_tokens, llm_kv_tokens=vlm_output.llm_output.past_key_values, attn_mask=vlm_input.attn_mask, cache=cache, ) contols = self.output_proj(output_tokens) (translation, rotation, gripper) = torch.split( contols, [3, self.config.rotation_components, 1], dim=-1 ) return RoboticsOutput.make_empty().replace( translation=translation, rotation=rotation, gripper=gripper ) @torch.inference_mode() def generate( self, vlm_input: RoboticsFlowInput, vlm_output: VLMOutput, processor: PiZeroFlowMatchingProcessor, use_cache: bool = True, **kwargs, ) -> RoboticsOutput: del kwargs (batch_size, vlm_seq_len) = vlm_input.input_ids.shape[:2] device = vlm_input.input_ids.device if use_cache: max_cache_len = ( vlm_seq_len + processor.config.control_io_config.future_controls_sequence_length + processor.config.control_io_config.past_scalars_sequence_length ) cache = transformers.StaticCache( config=transformers.PretrainedConfig( head_dim=self.config.control_decoder_config.block_config.head_dim, num_key_value_heads=self.config.control_decoder_config.block_config.num_kv_heads, num_hidden_layers=self.config.control_decoder_config.num_blocks, ), max_batch_size=batch_size, max_cache_len=max_cache_len, device=device, ) else: cache = None flow_input: FlowInput = processor.sample_t0_input(batch_size=batch_size, device=device) step_size = 1 / processor.config.num_inference_steps translation = flow_input.translation_t0 rotation = flow_input.rotation_t0 gripper = flow_input.gripper_t0 vlm_input = vlm_input.replace( **{ "flow_input.timestep": flow_input.timestep, "flow_input.translation_t": translation, "flow_input.rotation_t": rotation, "flow_input.gripper_t": gripper, } ) for _ in range(processor.config.num_inference_steps): model_output: RoboticsOutput = self(vlm_input, vlm_output, cache) translation = translation + step_size * model_output.translation rotation = integrate_rotation(rt=rotation, dr_dt=model_output.rotation, dt=step_size) gripper = gripper + step_size * model_output.gripper timestep = vlm_input.flow_input.timestep + step_size if processor.config.rotation_format == RotationFormat.QUATERNION: rotation = quaternion_half_cover(rotation) vlm_input = vlm_input.replace( **{ "flow_input.timestep": timestep, "flow_input.translation_t": translation, "flow_input.rotation_t": rotation, "flow_input.gripper_t": gripper, } ) output = RoboticsOutput.make_empty().replace( translation=translation, rotation=rotation, gripper=gripper ) return output @property def fsdp_wrap_modules(self) -> Set[torch.nn.Module]: return self.control_decoder.fsdp_wrap_modules | { self, self.robot_state_proj, self.noised_control_proj, self.output_proj, } CANONICAL_TO_BRIDGE_ROTATION = np.array( [ [1, 0, 0], [0, np.cos(np.pi), -np.sin(np.pi)], [0, np.sin(np.pi), np.cos(np.pi)], ], dtype=np.float32, ) class SPEAR1(ConfigurableModule, transformers.PreTrainedModel): config_class: transformers.PretrainedConfig = SPEAR1Config def __init__(self, config: SPEAR1Config): super().__init__(config) self.vlm = PaliGemmaVLM(config=self.config.vlm_config) self.processor = PiZeroFlowMatchingProcessor( config=self.config.processor_config, vlm_processor=self.vlm.processor ) self.control_module = PiZeroFlowMatchingModule( config=self.config.control_module_config, control_tokenizer=self.processor.control_tokenizer, ) self.generation_config = transformers.GenerationConfig() def forward( self, inputs: RoboticsInput, use_cache: Optional[bool] = True, output_hidden_states: Optional[bool] = None, ) -> RoboticsOutput: del output_hidden_states vlm_output = self.vlm(inputs=inputs, use_cache=use_cache, output_hidden_states=True) control_output = self.control_module(vlm_input=inputs, vlm_output=vlm_output) output = control_output.replace(llm_output=vlm_output.llm_output) return output @torch.inference_mode() def generate( self, inputs: RoboticsInput, use_cache: Optional[bool] = True, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> RoboticsOutput: del output_hidden_states vlm_output = self.vlm( inputs=inputs, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=True, ) control_output = self.control_module.generate( vlm_input=inputs, vlm_output=vlm_output, processor=self.processor ) output = control_output.replace(llm_output=vlm_output.llm_output) return output def predict_action(self, inputs: Dict) -> Dict[str, np.ndarray]: images = inputs["images"] ee_translation = inputs["ee_translation"] ee_rotation = inputs["ee_rotation"] gripper = inputs["gripper"] num_resize_args = len(inspect.signature(self.processor.resize_image).parameters) # Resize images using the processor's resize_image method for camera_name, camera_image in images.items(): # Handle the different signatures resize_image - old one used to take only the image, # new one also takes the camera name if num_resize_args == 1: images[camera_name] = self.processor.resize_image(camera_image) elif num_resize_args == 2: images[camera_name] = self.processor.resize_image(camera_name, camera_image) else: raise ValueError(f"Unexpected number of arguments for resize_image: {num_resize_args}") # add batch dimension and wrap into list to match processor expected format images[camera_name] = [images[camera_name]] # add batch dimensions to state obs ee_translation = np.array(ee_translation, dtype=np.float32).reshape(1, 3) ee_rotation = np.array(ee_rotation, dtype=np.float32).reshape(1, 3, 3) @ CANONICAL_TO_BRIDGE_ROTATION gripper = np.array(gripper, dtype=np.float32).reshape(1, 1) joints = np.zeros((1, 7), dtype=np.float32) dataset_name = np.array([inputs["dataset_name"]]) chat = [f"{inputs['language_instruction']}", ""] model_input = self.processor.create_input( images=images, chat=chat, ee_pose_translation=ee_translation, ee_pose_rotation=ee_rotation, gripper=gripper, dataset_name=dataset_name, joints=joints, inference_mode=True, ) model_input = model_input.apply( lambda x: x.unsqueeze(0).to("cuda") if isinstance(x, torch.Tensor) else x ) with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16): model_output = self.generate(model_input) control_plan = self.processor.policy_control_plan_from_model_output( model_output=model_output, dataset_name=dataset_name, valid_mask=torch.ones( model_output.gripper.shape[:2], dtype=torch.bool, device=model_output.gripper.device ), ) translation_m = control_plan.translation_m.to(dtype=torch.float32, device='cpu') rotation = control_plan.rotmat.to(dtype=torch.float32, device='cpu') gripper_prob = control_plan.gripper_prob.to(dtype=torch.float32, device='cpu') # Convert controls back to robot base frame if self.processor.config.eef_control_frame: # Get the robot base rotation matrix R_BE - the same as the robot EEF pose. # R_BE - converts from end-effector frame E to robot base frame B robot_base_rotmat = rotmat_as_3x3(model_input.ee_pose_rotation[:, -1:, ...]).cpu() # [B, 1, 3, 3] translation_m = torch.matmul( # [B, num_future_control_steps, 3] robot_base_rotmat, translation_m.unsqueeze(-1) ).squeeze(-1) rotation = rotmat_as_3x3( # [B, num_future_control_steps, 3, 3] torch.matmul(robot_base_rotmat, rotmat_as_3x3(rotation)) ) translation = translation_m # [B, num_future_control_steps, 3] rotation = rotmat_as_3x3(rotation) # [B, num_future_control_steps, 3, 3] gripper = gripper_prob # [B, num_future_control_steps, 1] translation = translation.squeeze(0).numpy() rotation = rotation.squeeze(0).numpy() gripper = gripper.squeeze(0).numpy() rotation = CANONICAL_TO_BRIDGE_ROTATION @ rotation @ CANONICAL_TO_BRIDGE_ROTATION.T return { "translation": translation, "rotation": rotation, "gripper": gripper, } @property def fsdp_wrap_modules(self) -> Set[torch.nn.Module]: return ( {self.vlm, self.control_module} | self.vlm.fsdp_wrap_modules | self.control_module.fsdp_wrap_modules )