|
|
import collections |
|
|
import collections.abc |
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from .common_spear import ( |
|
|
Config, |
|
|
HFConfigMixin, |
|
|
Normalization, |
|
|
ResizeMode, |
|
|
RotationFormat, |
|
|
) |
|
|
|
|
|
|
|
|
class InputSequencingConfig(Config): |
|
|
""" |
|
|
past_frames_sequence_length: number of past images needed in a single robot state |
|
|
past_scalars_sequence_length: number of past scalar state data, e.g. actions, poses, etc, |
|
|
needed in a single robot state |
|
|
past_frames_stride_sec: sampling rate, determines how far apart in time each point in the sequence |
|
|
is. If None, ignored and takes the default data collection frequency from the dataset |
|
|
past_scalars_stride_sec: similar to past_frames_stride_sec |
|
|
|
|
|
sequence_frames: number of temporally-sequential points in a single example in the batch |
|
|
sequence_frames_stride_sec: sampling rate |
|
|
|
|
|
Understanding sequence_frames: |
|
|
TODO: sequences are possibly useful in some rare cases, maybe sequence modeling problems, |
|
|
but yet to be confirmed. Keeping for now, but could be removed if proved unnecessary |
|
|
|
|
|
- past_scalars_sequence_length, past_frames_sequence_length, future_controls_sequence_length, |
|
|
future_frames_sequence_length are hyperparameters refering to a SINGLE dataset example / 'state'. |
|
|
It is assumed that `past_scalars_sequence_length` and `past_frames_sequence_length` are the min |
|
|
number of observations that comprise a single 'state' |
|
|
- sequence_frames is a hyperparameter refering to the entire learning process. It controls the size |
|
|
of the sequence dimension in the batch. It's treated similarly to the batch dimension, with the |
|
|
difference that points in the sequence dimensions are temporally aligned. Unlike `past_*` |
|
|
attributes, in supervised learning a label is loaded for every point in the sequence dimension |
|
|
and the loss usually computed over the entire sequence dimension. |
|
|
""" |
|
|
|
|
|
past_scalars_sequence_length: int = 1 |
|
|
past_frames_sequence_length: int = 1 |
|
|
past_scalars_stride_sec: Optional[float] = None |
|
|
past_frames_stride_sec: Optional[float] = None |
|
|
sequence_frames: int = 1 |
|
|
sequence_frames_stride_sec: Optional[float] = None |
|
|
|
|
|
def __post_init__(self): |
|
|
super().__post_init__() |
|
|
assert self.past_scalars_sequence_length >= 1, self.past_scalars_sequence_length |
|
|
assert self.past_frames_sequence_length >= 1, self.past_frames_sequence_length |
|
|
assert self.sequence_frames >= 1, self.sequence_frames |
|
|
if self.past_frames_stride_sec is not None: |
|
|
assert self.past_frames_stride_sec >= 0.0, self.past_frames_stride_sec |
|
|
if self.past_scalars_stride_sec is not None: |
|
|
assert self.past_scalars_stride_sec >= 0.0, self.past_scalars_stride_sec |
|
|
if self.sequence_frames_stride_sec is not None: |
|
|
assert self.sequence_frames_stride_sec >= 0.0, self.sequence_frames_stride_sec |
|
|
|
|
|
def assert_same_past(self) -> None: |
|
|
assert ( |
|
|
self.past_frames_stride_sec == self.past_scalars_stride_sec |
|
|
), f"{self.past_frames_stride_sec} != {self.past_scalars_stride_sec}" |
|
|
assert ( |
|
|
self.past_frames_sequence_length == self.past_scalars_sequence_length |
|
|
), f"{self.past_frames_sequence_length} != {self.past_scalars_sequence_length}" |
|
|
|
|
|
|
|
|
class OutputSequencingConfig(Config): |
|
|
""" |
|
|
future_controls_sequence_length: number of control steps in the future the model predicts |
|
|
future_frames_sequence_length: number of future frames the model predicts |
|
|
(only relevant for neural networks that learn some sort of a world model) |
|
|
|
|
|
future_controls_sequence_stride_sec / future_frames_sequence_stride_sec: sampling rate |
|
|
that determines how far apart in time each point in the sequence is. If None, |
|
|
ignored and takes the default data collection frequency from the dataset |
|
|
|
|
|
future_control_offset_sec: time interval between the last observation and the first |
|
|
point at which control is predicted. Serves as a 'causality hyperparameter', allowing |
|
|
for predicting controls slightly further into the future in environments with dynamics |
|
|
where the observed effects of an action appear slightly later |
|
|
""" |
|
|
|
|
|
future_controls_sequence_length: int = 1 |
|
|
future_controls_sequence_stride_sec: Optional[float] = None |
|
|
future_frames_sequence_length: int = 1 |
|
|
future_frames_sequence_stride_sec: Optional[float] = None |
|
|
future_control_offset_sec: float = 0.0 |
|
|
|
|
|
def __post_init__(self): |
|
|
super().__post_init__() |
|
|
assert self.future_controls_sequence_length >= 1, self.future_controls_sequence_length |
|
|
assert self.future_frames_sequence_length >= 1, self.future_frames_sequence_length |
|
|
assert self.future_control_offset_sec >= 0.0, self.future_control_offset_sec |
|
|
if self.future_controls_sequence_stride_sec is not None: |
|
|
assert self.future_controls_sequence_stride_sec >= 0.0, self.future_controls_sequence_stride_sec |
|
|
if self.future_frames_sequence_stride_sec is not None: |
|
|
assert self.future_frames_sequence_stride_sec >= 0.0, self.future_frames_sequence_stride_sec |
|
|
|
|
|
|
|
|
class ControlDataIOConfig(InputSequencingConfig, OutputSequencingConfig): |
|
|
pass |
|
|
|
|
|
|
|
|
class ControlTokenizerConfig(Config): |
|
|
pass |
|
|
|
|
|
|
|
|
class EmptyTokenizerConfig(ControlTokenizerConfig): |
|
|
pass |
|
|
|
|
|
|
|
|
class VLAMProcessorConfig(Config): |
|
|
control_io_config: ControlDataIOConfig = ControlDataIOConfig() |
|
|
obs_translation_norm: Normalization | Dict[str, Tuple[float, float, float]] = Normalization.NONE |
|
|
obs_rotation_norm: Normalization = Normalization.NONE |
|
|
translation_norm: Normalization | Dict[str, Tuple[float, float, float]] = Normalization.NONE |
|
|
rotation_norm: Normalization = Normalization.NONE |
|
|
joints_norm: Dict[str, Tuple[float, ...]] = { |
|
|
"low": (-np.pi,) * 7, |
|
|
"high": (np.pi,) * 7, |
|
|
} |
|
|
rotation_format: RotationFormat = RotationFormat.QUATERNION |
|
|
eef_control_frame: bool = False |
|
|
delta_controls: bool = False |
|
|
image_resize: ResizeMode = ResizeMode.SMART |
|
|
control_tokenizer_config: EmptyTokenizerConfig = EmptyTokenizerConfig() |
|
|
control_stats_path: str = "barrel/pipes/vlams/types/control_stats.yaml" |
|
|
observation_stats_path: str = "barrel/pipes/vlams/types/observation_stats.yaml" |
|
|
|
|
|
def __post_init__(self): |
|
|
super().__post_init__() |
|
|
if isinstance(self.translation_norm, collections.abc.Mapping): |
|
|
assert all((len(value) == 3 for value in self.translation_norm.values())), self.translation_norm |
|
|
assert set(self.translation_norm.keys()) in ( |
|
|
{"low", "high"}, |
|
|
{"mean", "std"}, |
|
|
), self.translation_norm |
|
|
assert isinstance(self.joints_norm, collections.abc.Mapping), type(self.joints_norm) |
|
|
assert all((len(value) == 7 for value in self.joints_norm.values())), self.joints_norm |
|
|
assert set(self.joints_norm.keys()) in ( |
|
|
{"low", "high"}, |
|
|
{"mean", "std"}, |
|
|
), self.joints_norm |
|
|
|
|
|
|
|
|
class RegressionProcessorConfig(VLAMProcessorConfig): |
|
|
pass |
|
|
|
|
|
|
|
|
class PiZeroFlowProcessorConfig(RegressionProcessorConfig): |
|
|
num_inference_steps: int |
|
|
r0_distribution: str = "uniform" |
|
|
timestep_distribution: str |
|
|
distribution_hyperparams: Dict[str, Any] = {} |
|
|
sig_min: float = 0.001 |
|
|
|
|
|
def __post_init__(self): |
|
|
super().__post_init__() |
|
|
assert self.r0_distribution in ["normal", "uniform"] |
|
|
|
|
|
|
|
|
class VLMConfig(Config): |
|
|
pass |
|
|
|
|
|
|
|
|
class VLMProcessorConfig(Config): |
|
|
pass |
|
|
|
|
|
|
|
|
class ImageSizeConfig(Config): |
|
|
width: int |
|
|
height: int |
|
|
|
|
|
def to_dict(self): |
|
|
return {"width": self.width, "height": self.height} |
|
|
|
|
|
|
|
|
class PaliGemmaProcessorConfig(Config): |
|
|
image_token: str = "<image>" |
|
|
image_sizes: Dict[str, ImageSizeConfig] = {"main": ImageSizeConfig(width=224, height=224)} |
|
|
max_language_tokens: int = 75 |
|
|
|
|
|
def __post_init__(self): |
|
|
super().__post_init__() |
|
|
self.image_sizes = { |
|
|
camera_name: ( |
|
|
ImageSizeConfig(**camera_image_size) |
|
|
if not isinstance(camera_image_size, ImageSizeConfig) |
|
|
else camera_image_size |
|
|
) |
|
|
for camera_name, camera_image_size in self.image_sizes.items() |
|
|
} |
|
|
for camera_name, camera_image_size in self.image_sizes.items(): |
|
|
assert camera_image_size.height % 14 == 0, f"{camera_name}: {camera_image_size}" |
|
|
assert camera_image_size.width % 14 == 0, f"{camera_name}: {camera_image_size}" |
|
|
|
|
|
@property |
|
|
def num_image_tokens(self) -> Dict[str, int]: |
|
|
return { |
|
|
camera_name: camera_image_size.height // 14 * (camera_image_size.width // 14) |
|
|
for (camera_name, camera_image_size) in self.image_sizes.items() |
|
|
} |
|
|
|
|
|
@property |
|
|
def is_single_image_size(self) -> bool: |
|
|
return ( |
|
|
len(self.image_sizes) == 1 |
|
|
or len(set(((image_size.height, image_size.width) for image_size in self.image_sizes.values()))) |
|
|
== 1 |
|
|
) |
|
|
|
|
|
@property |
|
|
def camera_names(self) -> List[str]: |
|
|
return list(self.image_sizes.keys()) |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
base_dict = { |
|
|
"image_token": self.image_token, |
|
|
"max_language_tokens": self.max_language_tokens, |
|
|
} |
|
|
base_dict["image_sizes"] = { |
|
|
camera_name: camera_image_size.to_dict() |
|
|
for camera_name, camera_image_size in self.image_sizes.items() |
|
|
} |
|
|
return base_dict |
|
|
|
|
|
|
|
|
class PaliGemmaVLMConfig(Config): |
|
|
model_id: str = "google/paligemma-3b-mix-224" |
|
|
attn_implementation: str = "flash_attention_2" |
|
|
processor_config: PaliGemmaProcessorConfig |
|
|
lm_head: bool = False |
|
|
paligemma_3d_config: Dict[str, Any] = {} |
|
|
depth_tokens: int = 0 |
|
|
train_only_depth_tokens: bool = False |
|
|
mean_resizing: bool = False |
|
|
|
|
|
def __post_init__(self): |
|
|
super().__post_init__() |
|
|
if self.train_only_depth_tokens: |
|
|
assert self.depth_tokens > 0, self.depth_tokens |
|
|
if self.paligemma_3d_config.get("mask_prob", 0.0) != 0.0: |
|
|
raise NotImplementedError( |
|
|
f"Masking is deprecated, but got mask_prob={self.paligemma_3d_config['mask_prob']}" |
|
|
) |
|
|
|
|
|
@property |
|
|
def paligemma_3d_config_dict(self) -> Dict[str, Any]: |
|
|
if len(self.paligemma_3d_config) == 0: |
|
|
return {} |
|
|
config = dict(self.paligemma_3d_config) |
|
|
config["depth_config"] = dict(config["depth_config"]) |
|
|
config["depth_config"]["image_sizes"] = { |
|
|
camera_name: camera_image_size.to_dict() |
|
|
for camera_name, camera_image_size in self.processor_config.image_sizes.items() |
|
|
} |
|
|
return config |
|
|
|
|
|
@property |
|
|
def with_depth(self) -> bool: |
|
|
return len(self.paligemma_3d_config) > 0 |
|
|
|
|
|
|
|
|
class FourierFeaturesConfig(Config): |
|
|
num_features: int = 256 |
|
|
learnable_features: bool = False |
|
|
max_period: float = 10000.0 |
|
|
layers: List[int] = [256, 512, 256] |
|
|
activation: str = "SiLU" |
|
|
norm: Optional[str] = None |
|
|
|
|
|
|
|
|
class NoisedControlProjectorConfig(Config): |
|
|
time_embed: FourierFeaturesConfig |
|
|
layers: List[int] = [] |
|
|
activation: str = "SiLU" |
|
|
norm: Optional[str] = None |
|
|
|
|
|
|
|
|
class RobotStateProjectorConfig(Config): |
|
|
layers: List[int] = [] |
|
|
mode: str = "none" |
|
|
activation: str = "GELU" |
|
|
fourier: bool = False |
|
|
|
|
|
def __post_init__(self): |
|
|
super().__post_init__() |
|
|
assert self.mode in [ |
|
|
"ee_pose", |
|
|
"ee_pose_gripper", |
|
|
"ee_pose_joints", |
|
|
"joints", |
|
|
"all", |
|
|
"none", |
|
|
], self.mode |
|
|
|
|
|
|
|
|
class RotaryPositionalEncodingConfig(Config): |
|
|
num_embeddings: int |
|
|
embedding_dim: int |
|
|
base: int = 10000 |
|
|
cached: bool = True |
|
|
|
|
|
|
|
|
class PiZeroFlowMatchingDecoderBlockConfig(Config): |
|
|
feature_size: int |
|
|
head_dim: int = 128 |
|
|
num_heads: int = 32 |
|
|
num_kv_heads: int = 1 |
|
|
hidden_size: int |
|
|
activation: str = "GELU" |
|
|
norm: str = "RMSNorm" |
|
|
dropout: float = 0.0 |
|
|
attn_implementation: str = "sdpa" |
|
|
position_embed_config: RotaryPositionalEncodingConfig |
|
|
|
|
|
|
|
|
class PiZeroFlowMatchingDecoderConfig(Config): |
|
|
num_blocks: int |
|
|
block_config: PiZeroFlowMatchingDecoderBlockConfig |
|
|
|
|
|
|
|
|
class PiZeroFlowMatchingModuleConfig(Config): |
|
|
token_size: int = 1024 |
|
|
noised_control_proj_config: NoisedControlProjectorConfig |
|
|
robot_state_proj_config: RobotStateProjectorConfig |
|
|
control_decoder_config: PiZeroFlowMatchingDecoderConfig |
|
|
rotation_components: int = 3 |
|
|
|
|
|
|
|
|
class SPEAR1Config(HFConfigMixin, Config): |
|
|
model_type: str = "spear1" |
|
|
processor_config: PiZeroFlowProcessorConfig |
|
|
vlm_config: PaliGemmaVLMConfig |
|
|
control_module_config: PiZeroFlowMatchingModuleConfig |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
if "auto_map" not in kwargs: |
|
|
kwargs["auto_map"] = { |
|
|
"AutoConfig": "configuration_spear.SPEAR1Config", |
|
|
"AutoModel": "modeling_spear.SPEAR1", |
|
|
} |
|
|
super().__init__(**kwargs) |
|
|
|