spear1-franka / common_spear.py
giu-alb's picture
Super-squash branch 'main' using huggingface_hub
a8bf2f3 verified
import collections.abc
import dataclasses
import enum
import inspect
import types
from collections.abc import Mapping as MappingABC
from functools import cached_property
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import torch
import transformers
class StrEnum(str, enum.Enum):
"""
A minimal drop-in replacement for backports.strenum.StrEnum
"""
def __str__(self):
return str(self.value)
def __new__(cls, value):
# Create new instance that properly handles string initialization
if isinstance(value, str):
obj = str.__new__(cls, value)
obj._value_ = value
return obj
return super().__new__(cls, value)
@classmethod
def _missing_(cls, value):
# Enhanced lookup by string value with better error handling
if isinstance(value, str):
for member in cls:
if member.value == value:
return member
# Return None to let enum handle the KeyError
return None
def __eq__(self, other):
# Allow comparison with string values
if isinstance(other, str):
return self.value == other
return super().__eq__(other)
def __hash__(self):
# Ensure consistent hashing
return hash(self.value)
class _cached_classproperty:
def __init__(self, func):
self.func = func
self._values = {}
def __get__(self, obj, klass):
if klass not in self._values.keys():
self._values[klass] = self.func.__get__(obj, klass)()
return self._values[klass]
def cached_classproperty(func):
if not isinstance(func, (classmethod, staticmethod)):
func = classmethod(func)
return _cached_classproperty(func)
@dataclasses.dataclass
class Dataclass:
def __post_init__(self):
pass
@classmethod
def make_empty(cls) -> "Dataclass":
return cls(
**{
k: (v.make_empty() if inspect.isclass(v) and issubclass(v, Dataclass) else None)
for (k, v) in cls.types.items()
}
)
@cached_classproperty
def fields(cls) -> Tuple[dataclasses.Field, ...]:
"""Returns a sorted list of the Field objects"""
return tuple(sorted(dataclasses.fields(cls), key=lambda x: x.name))
@cached_classproperty
def types(cls) -> Dict[str, type]:
return {f.name: f.type for f in cls.fields}
def as_json(self, recursive: bool = True) -> dict:
return {k: v.as_json() if isinstance(v, Dataclass) and recursive else v for (k, v) in self.items()}
@classmethod
def keys(cls) -> List[str]:
return [field.name for field in cls.fields]
def values(self):
return [getattr(self, field.name) for field in self.fields]
def items(self, recursive: bool = False):
for key, value in zip(self.keys(), self.values(), strict=True):
if recursive and isinstance(value, Dataclass):
for subkey, subvalue in value.items(recursive=True):
yield (f"{key}.{subkey}", subvalue)
else:
yield (key, value)
def replace(self, **kwargs):
"""
Return a new instance of Dataclass with the kwargs overwritten.
"""
kwargs = maybe_chained_keys_to_nested_dict(kwargs)
data = self.as_json(recursive=False)
for key, value in kwargs.items():
value_type = self.types.get(key, None)
if value_type is None:
raise KeyError(f"Dataclass {self.__class__} does not have a field {key}")
value_type = get_maybe_optional_type(value_type)
if inspect.isclass(value_type) and issubclass(value_type, Dataclass):
if isinstance(value, dict):
data[key] = data[key].replace(**value)
else:
data[key] = value
else:
data[key] = value
return self.__class__(**data)
def apply(self, fcn: Callable, recursive: bool = True, skip_nones: bool = False) -> "Dataclass":
def fcn_wrapper(value: Any) -> Any:
if value is None and skip_nones:
return None
if isinstance(value, dict) and recursive:
return type(value)(**{k: fcn(v) for (k, v) in value.items()})
if isinstance(value, (list, tuple)) and recursive:
return type(value)([fcn(v) for v in value])
if isinstance(value, Dataclass) and recursive:
return value.apply(fcn, recursive=True, skip_nones=skip_nones)
return fcn(value)
return self.__class__(**{key: fcn_wrapper(value) for (key, value) in self.items()})
def __getitem__(self, index) -> "Dataclass":
def extract(obj):
if obj is None:
return None
if isinstance(obj, torch.Tensor):
return obj[index]
raise ValueError(f"Cannot slice {obj.__class__.__name__} object")
return self.apply(extract)
class Config:
def __init__(self, **kwargs):
self._apply_defaults()
self._set_attributes(**kwargs)
super().__init__()
self.__post_init__()
def _apply_defaults(self):
"""
Initializes all annotated fields with defaults or sensible instances.
"""
annotations = getattr(self, "__annotations__", {})
for key, type_hint in annotations.items():
# Skip if already set via class-level value or __init__ kwarg
if hasattr(self, key):
continue
# Case 1: class variable has a default (declared at class level)
if key in self.__class__.__dict__:
setattr(self, key, getattr(self.__class__, key))
continue
# Case 2: if the type is another Config subclass, default-construct it
if inspect.isclass(type_hint) and issubclass(type_hint, Config):
setattr(self, key, type_hint())
continue
# Case 3: fallback None (or empty dict for mappings)
if hasattr(type_hint, "__origin__") and type_hint.__origin__ in (
dict,
Dict,
MappingABC,
):
setattr(self, key, {})
else:
setattr(self, key, None)
def _set_attributes(self, **kwargs):
subconfig_types = self._subconfig_types
for key, value in kwargs.items():
if key in subconfig_types:
if not isinstance(value, Mapping):
raise ValueError(
f"{self.__class__.__name__}.{key} expects dict-like object for nested config, but got: {value}"
)
setattr(self, key, subconfig_types[key](**value))
else:
setattr(self, key, value)
def keys(self) -> List[str]:
"""Get all annotated keys including those from parent classes."""
all_keys = {}
# Walk through MRO in reverse to respect inheritance order
for cls in reversed(self.__class__.__mro__):
if cls is object:
continue
all_keys.update(getattr(cls, "__annotations__", {}))
return list(all_keys.keys())
def items(self) -> Iterable[Tuple[str, Any]]:
for key in self.keys():
yield (key, getattr(self, key))
@cached_classproperty
def _subconfig_types(cls) -> dict[str, Type]:
keys = {
key: value
for (key, value) in cls.__annotations__.items()
if inspect.isclass(value) and issubclass(value, Config)
}
for base in cls.__bases__:
if not issubclass(base, Config):
continue
keys = {**keys, **base._subconfig_types}
return keys
def __post_init__(self):
pass
def as_json(self) -> dict:
data = {}
for key, value in self.items():
if isinstance(value, Config):
data[key] = value.as_json()
elif (
isinstance(value, collections.abc.Sequence)
and len(value) > 0
and isinstance(value[0], Config)
):
data[key] = [v.as_json() for v in value]
elif (
isinstance(value, collections.abc.Mapping)
and len(value) > 0
and isinstance(next(iter(value.values())), Config)
):
data[key] = {k: v.as_json() for k, v in value.items()}
else:
data[key] = value
return data
class HFConfigMixin(transformers.PretrainedConfig):
"""
Bridge between your Config system and HF PretrainedConfig.
Usage:
class SPEAR1Config(HFConfigMixin, Config):
model_type = "spear1"
processor_config: PaliGemmaProcessorConfig
...
"""
def __init__(self, **kwargs):
# Let HF's machinery initialize its own attributes / defaults first.
# PretrainedConfig.__init__ will set things like `model_type`,
# `_name_or_path`, `architectures`, and keep a `kwargs`->dict of extra items.
super().__init__(**kwargs)
# Now initialize your Config behavior: set defaults and construct nested configs.
# We call Config.__init__ explicitly because HFConfigMixin inherits from PretrainedConfig,
# and the user's concrete class will use multiple-inheritance with Config.
# (This approach mirrors the earlier MRO design: class Concrete(HFConfigMixin, Config).)
# We pass kwargs again so nested configs get overridden by user kwargs.
# Note: Config.__init__ itself calls super().__init__() — but because we are calling
# Config.__init__ directly (not via super()) the MRO won't re-call PretrainedConfig.__init__ here.
# (I.e., we are deliberately calling the concrete base initializer.)
Config.__init__(self, **kwargs) # type: ignore[name-defined]
def to_dict(self) -> Dict[str, Any]:
"""
Merge HF PretrainedConfig serialization and Config.as_json().
Strategy:
1. Take HF dict (super().to_dict()) so HF metadata/defaults are present.
2. Take our nested config dict (Config.as_json(self)).
3. Update the HF dict with our nested config dict so annotated fields
(nested configs, lists/dicts that should be recursively serialized)
take precedence.
"""
# HF's representation (contains model_type, etc.). This is trusted HF serialization.
hf = super().to_dict()
# Our nested config representation (recursively serializes Config objects).
# Do not call self.to_dict() because that would recurse back here.
cfg_json = Config.as_json(self) # type: ignore[name-defined]
# Merge: prefer cfg_json values for keys present in our config (so nested configs
# are represented as dicts rather than raw objects or omitted).
merged: Dict[str, Any] = dict(hf)
merged.update(cfg_json)
return merged
@classmethod
def from_dict(
cls: Type["HFConfigMixin"],
config_dict: Dict[str, Any],
**kwargs,
) -> "HFConfigMixin":
"""
Construct by delegating to the class constructor — that will instantiate nested configs.
This is simple and consistent with PretrainedConfig.from_dict/from_pretrained behavior.
"""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
instance = cls(**config_dict)
if return_unused_kwargs:
# Return tuple of (instance, unused_kwargs) if requested
# Since we consume everything in __init__, unused is typically empty
return instance, {}
return instance
class Configurable:
def __init__(self, config: Config):
self._config = config
@property
def config(self) -> Config:
return self._config
class RotationFormat(StrEnum):
"""Determines how rotations will be encoded in the loaded batch"""
EULER = "euler"
QUATERNION = "quaternion"
ROTMAT = "rotmat"
class ResizeMode(StrEnum):
"""
Different modes for resizing images.
"""
MATCH_WIDTH = "match_width"
MATCH_HEIGHT = "match_height"
MATCH_MAX = "match_max"
NAIVE = "naive"
SMART = "smart"
PAD = "pad"
CROP = "crop"
class Normalization(StrEnum):
"""Action normalization types"""
NONE = "none"
BOUNDS = "bounds"
BOUNDS_Q99 = "bounds_q99"
MEAN = "mean"
def expand_dims(tensor: torch.Tensor, ndim: int, order: Sequence[int]) -> torch.Tensor:
"""
Expand the dimensions of `tensor` to `ndim` such that all new dimensions have size of 1
Args:
tensor: torch.Tensor of any shape
ndim: Number of output dimensions. Must be >= `tensor.ndim`
order: Sequence of size `tensor.ndim + 1`. Contains only values of 1 and a single value of -1,
indicating where the new `ndim - tensor.ndim` dimensions will be inserted
Returns:
torch.Tensor with dimensions `ndim`, a view of `tensor`
Ex:
expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[1, -1, 1, 1]).shape -> [2, 1, 1, 3, 4]
expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[-1, 1, 1, 1]).shape -> [1, 1, 2, 3, 4]
expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[1, 1, 1, -1]).shape -> [2, 3, 4, 1, 1]
"""
assert tensor.ndim <= ndim, f"{tensor.ndim} > {ndim}; shape={tensor.shape}"
assert len(order) == tensor.ndim + 1, f"{len(order)} != {tensor.ndim + 1}; shape={tensor.shape}"
order = list(order)
assert order.count(-1) == 1, "Order must have exactly one value of -1"
assert order.count(1) == len(order) - 1, "Order must have exactly len(order) - 1 values of 1"
if tensor.ndim == ndim:
return tensor
insert_index = order.index(-1)
view = list(tensor.shape[:insert_index]) + [1] * (ndim - tensor.ndim) + list(tensor.shape[insert_index:])
tensor = tensor.view(view)
return tensor
def merge_dicts_recursive(dict_1: Dict[str, Any], dict_2: Dict[str, Any]) -> Dict[str, Any]:
"""
Merges dict_1 with dict_2 recursively.
Handles clashing keys:
1. If both values are dicts, merges them recursively
2. If any value is not a dict, raises ValueError
"""
merged = dict(dict_1)
for key, value in dict_2.items():
if key in merged:
if not type(merged[key]) is type(value) is dict:
raise ValueError(f"Multiple values provided for key {key}: {merged[key]} and {value}")
merged[key] = merge_dicts_recursive(merged[key], value)
else:
merged[key] = value
return merged
def maybe_chained_keys_to_nested_dict(data: Dict[str, Any]) -> Dict[str, Any]:
"""Converts a dict with keys of the form "key1.key2.key3" to a nested dict"""
unpacked_data: Dict[str, Any] = {}
for key, value in data.items():
if "." not in key:
unpacked_data = merge_dicts_recursive(unpacked_data, {key: value})
else:
(mainkey, subkey) = key.split(".", maxsplit=1)
nested_value = maybe_chained_keys_to_nested_dict({subkey: value})
unpacked_data = merge_dicts_recursive(unpacked_data, {mainkey: nested_value})
return unpacked_data
def annotation_is_union(type_value: Type) -> bool:
return getattr(type_value, "__origin__", None) is Union or type(type_value) is types.UnionType
def annotation_is_optional(type_value: Type) -> bool:
if annotation_is_union(type_value):
union_args = set(type_value.__args__)
if len(union_args) == 2 and type(None) in union_args:
return True
return False
def get_maybe_optional_type(type_value: Type[Optional[Any]]) -> Type[Any]:
if annotation_is_optional(type_value):
type_args = type_value.__args__
if type_args[1] is type(None):
return type_args[0]
return type_args[1]
return type_value
@dataclasses.dataclass
class RoboticsTarget(Dataclass):
control_tokens_ids: Optional[torch.Tensor]
text_tokens_ids: Optional[torch.Tensor]
translation: torch.Tensor
rotation: torch.Tensor
gripper: torch.Tensor
valid_mask: torch.Tensor
@dataclasses.dataclass
class RoboticsControlPlan(Dataclass):
translation_m: torch.Tensor
rotmat: torch.Tensor
gripper_prob: torch.Tensor
valid_mask: torch.Tensor
def __post_init__(self):
super().__post_init__()
assert self.translation_m.ndim == 3, self.translation_m.shape
assert self.rotmat.ndim == 3, self.rotmat.shape
assert self.gripper_prob.ndim == 3, self.gripper_prob.shape
@dataclasses.dataclass
class RoboticsInput(Dataclass):
images: Dict[str, torch.Tensor]
input_ids: torch.Tensor
attn_mask: torch.Tensor
ee_pose_translation: torch.Tensor
ee_pose_rotation: torch.Tensor
gripper: torch.Tensor
joints: torch.Tensor
control_tokens_ids: Optional[torch.Tensor]
@property
def inputs_embeds(self) -> Optional[torch.Tensor]:
return None
@property
def past_key_values(self) -> Optional[List[torch.Tensor]]:
return None
@cached_property
def multimodal_indices(self) -> torch.Tensor:
"""
Returns a torch.Tensor containing only the indices of the batch examples which are multimodal.
Return shape is [B]
"""
return torch.arange(self.input_ids.shape[0], dtype=torch.int64, device=self.input_ids.device)
@cached_property
def unimodal_indices(self) -> torch.Tensor:
"""
Returns a torch.Tensor containing only the indices of the batch examples which are unimodal.
Return shape is [B]
"""
return torch.tensor([], dtype=torch.int64, device=self.input_ids.device)
@dataclasses.dataclass
class FlowInput(Dataclass):
timestep: torch.Tensor
translation_t: torch.Tensor
rotation_t: torch.Tensor
gripper_t: torch.Tensor
translation_t0: torch.Tensor
rotation_t0: torch.Tensor
gripper_t0: torch.Tensor
@dataclasses.dataclass
class RoboticsFlowInput(RoboticsInput):
"""Input to the entire Robotics VLM"""
flow_input: FlowInput
@dataclasses.dataclass
class DiffusionInput(Dataclass):
timestep: torch.Tensor
noised_translation: torch.Tensor
noised_rotation: torch.Tensor
noised_gripper: torch.Tensor
@dataclasses.dataclass
class LLMOutput(Dataclass):
"""Fork of transformers.modeling_outputs.CausalLMOutputWithPast"""
input_ids: torch.Tensor
logits: Optional[torch.Tensor]
output_ids: Optional[torch.Tensor]
loss: Optional[torch.Tensor]
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]
hidden_states: List[torch.Tensor]
text_indices: torch.Tensor
image_indices: torch.Tensor
@classmethod
def from_transformers(
cls,
input_ids: torch.Tensor,
llm_output: transformers.modeling_outputs.CausalLMOutputWithPast,
text_indices: Optional[torch.Tensor],
image_indices: Optional[torch.Tensor],
) -> "LLMOutput":
return LLMOutput(
input_ids=input_ids,
logits=llm_output.logits,
output_ids=None,
loss=llm_output.loss,
past_key_values=(
list(llm_output.past_key_values) if llm_output.past_key_values is not None else []
),
hidden_states=(list(llm_output.hidden_states) if llm_output.hidden_states is not None else []),
text_indices=text_indices,
image_indices=image_indices,
)
def compress(self) -> "LLMOutput":
"""
Compress the data contained in the class so it can be moved between CPU and GPU or concatenated
much faster:
- hidden_states - huge tensors; take a lot of CPU time to move across devices or concat
- past_key_values - huge tensors; take a lot of CPU time to move across devices or concat
- logits - huge last dimension; takes a lot of CPU time to move across devices or concat
"""
replace: Dict[str, Any] = {
"hidden_states": [],
"past_key_values": [],
"loss": None,
"input_ids": None,
}
if self.logits is not None:
replace["logits"] = None
if self.output_ids is None or self.output_ids.shape[1] != self.text_indices.shape[0]:
replace["output_ids"] = (
torch.index_select(self.logits, dim=1, index=self.text_indices)
.argmax(dim=-1)
.to(dtype=torch.int64)
)
return self.replace(**replace)
@dataclasses.dataclass
class RoboticsOutput(Dataclass):
translation: Optional[torch.Tensor]
rotation: Optional[torch.Tensor]
gripper: Optional[torch.Tensor]
token_logits: Optional[torch.Tensor]
token_ids: Optional[torch.Tensor]
llm_output: LLMOutput
def compress(self) -> "RoboticsOutput":
"""
Compress output and drop unnecessary components to speed up transfer GPU <-> CPU.
Note that LLM logits can be extremely expensive since their size is [B, S, vocab_size], which
can reach millions or billions of values for large vocab_size
"""
replace: Dict[str, Any] = {
"llm_output": self.llm_output.compress(),
"token_logits": None,
}
if self.token_logits is not None and self.token_ids is None:
replace["token_ids"] = torch.argmax(self.token_logits, dim=-1)
return self.replace(**replace)
@dataclasses.dataclass
class VLMOutput(Dataclass):
llm_output: LLMOutput
vit_tokens: Optional[torch.Tensor]
attn_mask: torch.Tensor
def compress(self) -> "VLMOutput":
"""
Compress output and drop unnecessary components to speed up transfer GPU <-> CPU.
Note that LLM logits can be extremely expensive since their size is [B, S, vocab_size], which
can reach millions or billions of values for large vocab_size
"""
return self.replace(llm_output=self.llm_output.compress())
def is_quaternion(quaternion: torch.Tensor) -> bool:
return quaternion.shape[-1] == 4
def quaternion_half_cover(quaternion: torch.Tensor) -> torch.Tensor:
"""
Flip quaternions so they cover only a half the space. If the q_w is negative, flip the quaternion.
If q_w is 0, then choose such that the first non-zero component is positive. Note that geometrically,
this doesn't correspond to a single hemisphere of the unit sphere. Follows
https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.as_quat.html#scipy.spatial.transform.Rotation.as_quat
"""
assert is_quaternion(quaternion), quaternion.shape
with torch.no_grad():
is_zero = quaternion == 0
flip_condition = (
(quaternion[..., -1:] < 0)
| is_zero[..., -1:] & (quaternion[..., 0:1] < 0)
| is_zero[..., -1:] & is_zero[..., 0:1] & (quaternion[..., 1:2] < 0)
| is_zero[..., -1:] & is_zero[..., 0:1] & is_zero[..., 1:2] & (quaternion[..., 2:3] < 0)
)
quaternion = torch.where(flip_condition, -quaternion, quaternion)
return quaternion
def is_rotmat_3x3(rotmat: torch.Tensor) -> bool:
return rotmat.shape[-2:] == torch.Size([3, 3])
def is_rotmat_9(rotmat: torch.Tensor) -> bool:
return rotmat.shape[-1] == 9
def rotmat_as_9(rotmat: torch.Tensor) -> torch.Tensor:
"""Convert any rotmat input to [..., 9] shape"""
if is_rotmat_9(rotmat):
return rotmat
if is_rotmat_3x3(rotmat):
return rotmat.reshape(*rotmat.shape[:-2], 9)
raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix")
def is_rotmat(rotmat: torch.Tensor) -> bool:
"""
Checks if the tensor shape matches that of a rotmat. However, it's not guaranteed the data is a
valid rotmat. `is_orthonormal_rotmat` performs this additional check.
NOTE: This might incorrectly return True if the underlying data is euler angles and accidentally
`rotmat.shape[-2:] == [3, 3]`. This would happen very rarely, but use with caution
"""
return is_rotmat_3x3(rotmat) or is_rotmat_9(rotmat)
def rotmat_as_3x3(rotmat: torch.Tensor) -> torch.Tensor:
"""Convert any rotmat input to [..., 3, 3] shape"""
if rotmat.shape[-1] == 9:
return rotmat.reshape(*rotmat.shape[:-1], 3, 3)
if rotmat.shape[-2:] == torch.Size([3, 3]):
return rotmat
raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix")