spear1-franka / configuration_spear.py
giu-alb's picture
Super-squash branch 'main' using huggingface_hub
a8bf2f3 verified
import collections
import collections.abc
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from .common_spear import (
Config,
HFConfigMixin,
Normalization,
ResizeMode,
RotationFormat,
)
class InputSequencingConfig(Config):
"""
past_frames_sequence_length: number of past images needed in a single robot state
past_scalars_sequence_length: number of past scalar state data, e.g. actions, poses, etc,
needed in a single robot state
past_frames_stride_sec: sampling rate, determines how far apart in time each point in the sequence
is. If None, ignored and takes the default data collection frequency from the dataset
past_scalars_stride_sec: similar to past_frames_stride_sec
sequence_frames: number of temporally-sequential points in a single example in the batch
sequence_frames_stride_sec: sampling rate
Understanding sequence_frames:
TODO: sequences are possibly useful in some rare cases, maybe sequence modeling problems,
but yet to be confirmed. Keeping for now, but could be removed if proved unnecessary
- past_scalars_sequence_length, past_frames_sequence_length, future_controls_sequence_length,
future_frames_sequence_length are hyperparameters refering to a SINGLE dataset example / 'state'.
It is assumed that `past_scalars_sequence_length` and `past_frames_sequence_length` are the min
number of observations that comprise a single 'state'
- sequence_frames is a hyperparameter refering to the entire learning process. It controls the size
of the sequence dimension in the batch. It's treated similarly to the batch dimension, with the
difference that points in the sequence dimensions are temporally aligned. Unlike `past_*`
attributes, in supervised learning a label is loaded for every point in the sequence dimension
and the loss usually computed over the entire sequence dimension.
"""
past_scalars_sequence_length: int = 1
past_frames_sequence_length: int = 1
past_scalars_stride_sec: Optional[float] = None
past_frames_stride_sec: Optional[float] = None
sequence_frames: int = 1
sequence_frames_stride_sec: Optional[float] = None
def __post_init__(self):
super().__post_init__()
assert self.past_scalars_sequence_length >= 1, self.past_scalars_sequence_length
assert self.past_frames_sequence_length >= 1, self.past_frames_sequence_length
assert self.sequence_frames >= 1, self.sequence_frames
if self.past_frames_stride_sec is not None:
assert self.past_frames_stride_sec >= 0.0, self.past_frames_stride_sec
if self.past_scalars_stride_sec is not None:
assert self.past_scalars_stride_sec >= 0.0, self.past_scalars_stride_sec
if self.sequence_frames_stride_sec is not None:
assert self.sequence_frames_stride_sec >= 0.0, self.sequence_frames_stride_sec
def assert_same_past(self) -> None:
assert (
self.past_frames_stride_sec == self.past_scalars_stride_sec
), f"{self.past_frames_stride_sec} != {self.past_scalars_stride_sec}"
assert (
self.past_frames_sequence_length == self.past_scalars_sequence_length
), f"{self.past_frames_sequence_length} != {self.past_scalars_sequence_length}"
class OutputSequencingConfig(Config):
"""
future_controls_sequence_length: number of control steps in the future the model predicts
future_frames_sequence_length: number of future frames the model predicts
(only relevant for neural networks that learn some sort of a world model)
future_controls_sequence_stride_sec / future_frames_sequence_stride_sec: sampling rate
that determines how far apart in time each point in the sequence is. If None,
ignored and takes the default data collection frequency from the dataset
future_control_offset_sec: time interval between the last observation and the first
point at which control is predicted. Serves as a 'causality hyperparameter', allowing
for predicting controls slightly further into the future in environments with dynamics
where the observed effects of an action appear slightly later
"""
future_controls_sequence_length: int = 1
future_controls_sequence_stride_sec: Optional[float] = None
future_frames_sequence_length: int = 1
future_frames_sequence_stride_sec: Optional[float] = None
future_control_offset_sec: float = 0.0
def __post_init__(self):
super().__post_init__()
assert self.future_controls_sequence_length >= 1, self.future_controls_sequence_length
assert self.future_frames_sequence_length >= 1, self.future_frames_sequence_length
assert self.future_control_offset_sec >= 0.0, self.future_control_offset_sec
if self.future_controls_sequence_stride_sec is not None:
assert self.future_controls_sequence_stride_sec >= 0.0, self.future_controls_sequence_stride_sec
if self.future_frames_sequence_stride_sec is not None:
assert self.future_frames_sequence_stride_sec >= 0.0, self.future_frames_sequence_stride_sec
class ControlDataIOConfig(InputSequencingConfig, OutputSequencingConfig):
pass
class ControlTokenizerConfig(Config):
pass
class EmptyTokenizerConfig(ControlTokenizerConfig):
pass
class VLAMProcessorConfig(Config):
control_io_config: ControlDataIOConfig = ControlDataIOConfig()
obs_translation_norm: Normalization | Dict[str, Tuple[float, float, float]] = Normalization.NONE
obs_rotation_norm: Normalization = Normalization.NONE
translation_norm: Normalization | Dict[str, Tuple[float, float, float]] = Normalization.NONE
rotation_norm: Normalization = Normalization.NONE
joints_norm: Dict[str, Tuple[float, ...]] = {
"low": (-np.pi,) * 7,
"high": (np.pi,) * 7,
}
rotation_format: RotationFormat = RotationFormat.QUATERNION
eef_control_frame: bool = False
delta_controls: bool = False
image_resize: ResizeMode = ResizeMode.SMART
control_tokenizer_config: EmptyTokenizerConfig = EmptyTokenizerConfig()
control_stats_path: str = "barrel/pipes/vlams/types/control_stats.yaml"
observation_stats_path: str = "barrel/pipes/vlams/types/observation_stats.yaml"
def __post_init__(self):
super().__post_init__()
if isinstance(self.translation_norm, collections.abc.Mapping):
assert all((len(value) == 3 for value in self.translation_norm.values())), self.translation_norm
assert set(self.translation_norm.keys()) in (
{"low", "high"},
{"mean", "std"},
), self.translation_norm
assert isinstance(self.joints_norm, collections.abc.Mapping), type(self.joints_norm)
assert all((len(value) == 7 for value in self.joints_norm.values())), self.joints_norm
assert set(self.joints_norm.keys()) in (
{"low", "high"},
{"mean", "std"},
), self.joints_norm
class RegressionProcessorConfig(VLAMProcessorConfig):
pass
class PiZeroFlowProcessorConfig(RegressionProcessorConfig):
num_inference_steps: int
r0_distribution: str = "uniform"
timestep_distribution: str
distribution_hyperparams: Dict[str, Any] = {}
sig_min: float = 0.001
def __post_init__(self):
super().__post_init__()
assert self.r0_distribution in ["normal", "uniform"]
class VLMConfig(Config):
pass
class VLMProcessorConfig(Config):
pass
class ImageSizeConfig(Config):
width: int
height: int
def to_dict(self):
return {"width": self.width, "height": self.height}
class PaliGemmaProcessorConfig(Config):
image_token: str = "<image>"
image_sizes: Dict[str, ImageSizeConfig] = {"main": ImageSizeConfig(width=224, height=224)}
max_language_tokens: int = 75
def __post_init__(self):
super().__post_init__()
self.image_sizes = {
camera_name: (
ImageSizeConfig(**camera_image_size)
if not isinstance(camera_image_size, ImageSizeConfig)
else camera_image_size
)
for camera_name, camera_image_size in self.image_sizes.items()
}
for camera_name, camera_image_size in self.image_sizes.items():
assert camera_image_size.height % 14 == 0, f"{camera_name}: {camera_image_size}"
assert camera_image_size.width % 14 == 0, f"{camera_name}: {camera_image_size}"
@property
def num_image_tokens(self) -> Dict[str, int]:
return {
camera_name: camera_image_size.height // 14 * (camera_image_size.width // 14)
for (camera_name, camera_image_size) in self.image_sizes.items()
}
@property
def is_single_image_size(self) -> bool:
return (
len(self.image_sizes) == 1
or len(set(((image_size.height, image_size.width) for image_size in self.image_sizes.values())))
== 1
)
@property
def camera_names(self) -> List[str]:
return list(self.image_sizes.keys())
def to_dict(self) -> Dict[str, Any]:
base_dict = {
"image_token": self.image_token,
"max_language_tokens": self.max_language_tokens,
}
base_dict["image_sizes"] = {
camera_name: camera_image_size.to_dict()
for camera_name, camera_image_size in self.image_sizes.items()
}
return base_dict
class PaliGemmaVLMConfig(Config):
model_id: str = "google/paligemma-3b-mix-224"
attn_implementation: str = "flash_attention_2"
processor_config: PaliGemmaProcessorConfig
lm_head: bool = False
paligemma_3d_config: Dict[str, Any] = {}
depth_tokens: int = 0
train_only_depth_tokens: bool = False
mean_resizing: bool = False
def __post_init__(self):
super().__post_init__()
if self.train_only_depth_tokens:
assert self.depth_tokens > 0, self.depth_tokens
if self.paligemma_3d_config.get("mask_prob", 0.0) != 0.0:
raise NotImplementedError(
f"Masking is deprecated, but got mask_prob={self.paligemma_3d_config['mask_prob']}"
)
@property
def paligemma_3d_config_dict(self) -> Dict[str, Any]:
if len(self.paligemma_3d_config) == 0:
return {}
config = dict(self.paligemma_3d_config)
config["depth_config"] = dict(config["depth_config"])
config["depth_config"]["image_sizes"] = {
camera_name: camera_image_size.to_dict()
for camera_name, camera_image_size in self.processor_config.image_sizes.items()
}
return config
@property
def with_depth(self) -> bool:
return len(self.paligemma_3d_config) > 0
class FourierFeaturesConfig(Config):
num_features: int = 256
learnable_features: bool = False
max_period: float = 10000.0
layers: List[int] = [256, 512, 256]
activation: str = "SiLU"
norm: Optional[str] = None
class NoisedControlProjectorConfig(Config):
time_embed: FourierFeaturesConfig
layers: List[int] = []
activation: str = "SiLU"
norm: Optional[str] = None
class RobotStateProjectorConfig(Config):
layers: List[int] = []
mode: str = "none"
activation: str = "GELU"
fourier: bool = False
def __post_init__(self):
super().__post_init__()
assert self.mode in [
"ee_pose",
"ee_pose_gripper",
"ee_pose_joints",
"joints",
"all",
"none",
], self.mode
class RotaryPositionalEncodingConfig(Config):
num_embeddings: int
embedding_dim: int
base: int = 10000
cached: bool = True
class PiZeroFlowMatchingDecoderBlockConfig(Config):
feature_size: int
head_dim: int = 128
num_heads: int = 32
num_kv_heads: int = 1
hidden_size: int
activation: str = "GELU"
norm: str = "RMSNorm"
dropout: float = 0.0
attn_implementation: str = "sdpa"
position_embed_config: RotaryPositionalEncodingConfig
class PiZeroFlowMatchingDecoderConfig(Config):
num_blocks: int
block_config: PiZeroFlowMatchingDecoderBlockConfig
class PiZeroFlowMatchingModuleConfig(Config):
token_size: int = 1024
noised_control_proj_config: NoisedControlProjectorConfig
robot_state_proj_config: RobotStateProjectorConfig
control_decoder_config: PiZeroFlowMatchingDecoderConfig
rotation_components: int = 3
class SPEAR1Config(HFConfigMixin, Config):
model_type: str = "spear1"
processor_config: PiZeroFlowProcessorConfig
vlm_config: PaliGemmaVLMConfig
control_module_config: PiZeroFlowMatchingModuleConfig
def __init__(self, **kwargs):
if "auto_map" not in kwargs:
kwargs["auto_map"] = {
"AutoConfig": "configuration_spear.SPEAR1Config",
"AutoModel": "modeling_spear.SPEAR1",
}
super().__init__(**kwargs)