|
|
import collections.abc |
|
|
import dataclasses |
|
|
import enum |
|
|
import inspect |
|
|
import types |
|
|
from collections.abc import Mapping as MappingABC |
|
|
from functools import cached_property |
|
|
from typing import ( |
|
|
Any, |
|
|
Callable, |
|
|
Dict, |
|
|
Iterable, |
|
|
List, |
|
|
Mapping, |
|
|
Optional, |
|
|
Sequence, |
|
|
Tuple, |
|
|
Type, |
|
|
Union, |
|
|
) |
|
|
|
|
|
import torch |
|
|
import transformers |
|
|
|
|
|
|
|
|
class StrEnum(str, enum.Enum): |
|
|
""" |
|
|
A minimal drop-in replacement for backports.strenum.StrEnum |
|
|
""" |
|
|
|
|
|
def __str__(self): |
|
|
return str(self.value) |
|
|
|
|
|
def __new__(cls, value): |
|
|
|
|
|
if isinstance(value, str): |
|
|
obj = str.__new__(cls, value) |
|
|
obj._value_ = value |
|
|
return obj |
|
|
return super().__new__(cls, value) |
|
|
|
|
|
@classmethod |
|
|
def _missing_(cls, value): |
|
|
|
|
|
if isinstance(value, str): |
|
|
for member in cls: |
|
|
if member.value == value: |
|
|
return member |
|
|
|
|
|
return None |
|
|
|
|
|
def __eq__(self, other): |
|
|
|
|
|
if isinstance(other, str): |
|
|
return self.value == other |
|
|
return super().__eq__(other) |
|
|
|
|
|
def __hash__(self): |
|
|
|
|
|
return hash(self.value) |
|
|
|
|
|
|
|
|
class _cached_classproperty: |
|
|
def __init__(self, func): |
|
|
self.func = func |
|
|
self._values = {} |
|
|
|
|
|
def __get__(self, obj, klass): |
|
|
if klass not in self._values.keys(): |
|
|
self._values[klass] = self.func.__get__(obj, klass)() |
|
|
return self._values[klass] |
|
|
|
|
|
|
|
|
def cached_classproperty(func): |
|
|
if not isinstance(func, (classmethod, staticmethod)): |
|
|
func = classmethod(func) |
|
|
return _cached_classproperty(func) |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class Dataclass: |
|
|
def __post_init__(self): |
|
|
pass |
|
|
|
|
|
@classmethod |
|
|
def make_empty(cls) -> "Dataclass": |
|
|
return cls( |
|
|
**{ |
|
|
k: (v.make_empty() if inspect.isclass(v) and issubclass(v, Dataclass) else None) |
|
|
for (k, v) in cls.types.items() |
|
|
} |
|
|
) |
|
|
|
|
|
@cached_classproperty |
|
|
def fields(cls) -> Tuple[dataclasses.Field, ...]: |
|
|
"""Returns a sorted list of the Field objects""" |
|
|
return tuple(sorted(dataclasses.fields(cls), key=lambda x: x.name)) |
|
|
|
|
|
@cached_classproperty |
|
|
def types(cls) -> Dict[str, type]: |
|
|
return {f.name: f.type for f in cls.fields} |
|
|
|
|
|
def as_json(self, recursive: bool = True) -> dict: |
|
|
return {k: v.as_json() if isinstance(v, Dataclass) and recursive else v for (k, v) in self.items()} |
|
|
|
|
|
@classmethod |
|
|
def keys(cls) -> List[str]: |
|
|
return [field.name for field in cls.fields] |
|
|
|
|
|
def values(self): |
|
|
return [getattr(self, field.name) for field in self.fields] |
|
|
|
|
|
def items(self, recursive: bool = False): |
|
|
for key, value in zip(self.keys(), self.values(), strict=True): |
|
|
if recursive and isinstance(value, Dataclass): |
|
|
for subkey, subvalue in value.items(recursive=True): |
|
|
yield (f"{key}.{subkey}", subvalue) |
|
|
else: |
|
|
yield (key, value) |
|
|
|
|
|
def replace(self, **kwargs): |
|
|
""" |
|
|
Return a new instance of Dataclass with the kwargs overwritten. |
|
|
""" |
|
|
kwargs = maybe_chained_keys_to_nested_dict(kwargs) |
|
|
data = self.as_json(recursive=False) |
|
|
for key, value in kwargs.items(): |
|
|
value_type = self.types.get(key, None) |
|
|
if value_type is None: |
|
|
raise KeyError(f"Dataclass {self.__class__} does not have a field {key}") |
|
|
value_type = get_maybe_optional_type(value_type) |
|
|
if inspect.isclass(value_type) and issubclass(value_type, Dataclass): |
|
|
if isinstance(value, dict): |
|
|
data[key] = data[key].replace(**value) |
|
|
else: |
|
|
data[key] = value |
|
|
else: |
|
|
data[key] = value |
|
|
return self.__class__(**data) |
|
|
|
|
|
def apply(self, fcn: Callable, recursive: bool = True, skip_nones: bool = False) -> "Dataclass": |
|
|
def fcn_wrapper(value: Any) -> Any: |
|
|
if value is None and skip_nones: |
|
|
return None |
|
|
if isinstance(value, dict) and recursive: |
|
|
return type(value)(**{k: fcn(v) for (k, v) in value.items()}) |
|
|
if isinstance(value, (list, tuple)) and recursive: |
|
|
return type(value)([fcn(v) for v in value]) |
|
|
if isinstance(value, Dataclass) and recursive: |
|
|
return value.apply(fcn, recursive=True, skip_nones=skip_nones) |
|
|
return fcn(value) |
|
|
|
|
|
return self.__class__(**{key: fcn_wrapper(value) for (key, value) in self.items()}) |
|
|
|
|
|
def __getitem__(self, index) -> "Dataclass": |
|
|
def extract(obj): |
|
|
if obj is None: |
|
|
return None |
|
|
if isinstance(obj, torch.Tensor): |
|
|
return obj[index] |
|
|
raise ValueError(f"Cannot slice {obj.__class__.__name__} object") |
|
|
|
|
|
return self.apply(extract) |
|
|
|
|
|
|
|
|
class Config: |
|
|
def __init__(self, **kwargs): |
|
|
self._apply_defaults() |
|
|
self._set_attributes(**kwargs) |
|
|
super().__init__() |
|
|
self.__post_init__() |
|
|
|
|
|
def _apply_defaults(self): |
|
|
""" |
|
|
Initializes all annotated fields with defaults or sensible instances. |
|
|
""" |
|
|
annotations = getattr(self, "__annotations__", {}) |
|
|
for key, type_hint in annotations.items(): |
|
|
|
|
|
if hasattr(self, key): |
|
|
continue |
|
|
|
|
|
|
|
|
if key in self.__class__.__dict__: |
|
|
setattr(self, key, getattr(self.__class__, key)) |
|
|
continue |
|
|
|
|
|
|
|
|
if inspect.isclass(type_hint) and issubclass(type_hint, Config): |
|
|
setattr(self, key, type_hint()) |
|
|
continue |
|
|
|
|
|
|
|
|
if hasattr(type_hint, "__origin__") and type_hint.__origin__ in ( |
|
|
dict, |
|
|
Dict, |
|
|
MappingABC, |
|
|
): |
|
|
setattr(self, key, {}) |
|
|
else: |
|
|
setattr(self, key, None) |
|
|
|
|
|
def _set_attributes(self, **kwargs): |
|
|
subconfig_types = self._subconfig_types |
|
|
for key, value in kwargs.items(): |
|
|
if key in subconfig_types: |
|
|
if not isinstance(value, Mapping): |
|
|
raise ValueError( |
|
|
f"{self.__class__.__name__}.{key} expects dict-like object for nested config, but got: {value}" |
|
|
) |
|
|
setattr(self, key, subconfig_types[key](**value)) |
|
|
else: |
|
|
setattr(self, key, value) |
|
|
|
|
|
def keys(self) -> List[str]: |
|
|
"""Get all annotated keys including those from parent classes.""" |
|
|
all_keys = {} |
|
|
|
|
|
for cls in reversed(self.__class__.__mro__): |
|
|
if cls is object: |
|
|
continue |
|
|
all_keys.update(getattr(cls, "__annotations__", {})) |
|
|
return list(all_keys.keys()) |
|
|
|
|
|
def items(self) -> Iterable[Tuple[str, Any]]: |
|
|
for key in self.keys(): |
|
|
yield (key, getattr(self, key)) |
|
|
|
|
|
@cached_classproperty |
|
|
def _subconfig_types(cls) -> dict[str, Type]: |
|
|
keys = { |
|
|
key: value |
|
|
for (key, value) in cls.__annotations__.items() |
|
|
if inspect.isclass(value) and issubclass(value, Config) |
|
|
} |
|
|
for base in cls.__bases__: |
|
|
if not issubclass(base, Config): |
|
|
continue |
|
|
keys = {**keys, **base._subconfig_types} |
|
|
return keys |
|
|
|
|
|
def __post_init__(self): |
|
|
pass |
|
|
|
|
|
def as_json(self) -> dict: |
|
|
data = {} |
|
|
for key, value in self.items(): |
|
|
if isinstance(value, Config): |
|
|
data[key] = value.as_json() |
|
|
elif ( |
|
|
isinstance(value, collections.abc.Sequence) |
|
|
and len(value) > 0 |
|
|
and isinstance(value[0], Config) |
|
|
): |
|
|
data[key] = [v.as_json() for v in value] |
|
|
elif ( |
|
|
isinstance(value, collections.abc.Mapping) |
|
|
and len(value) > 0 |
|
|
and isinstance(next(iter(value.values())), Config) |
|
|
): |
|
|
data[key] = {k: v.as_json() for k, v in value.items()} |
|
|
else: |
|
|
data[key] = value |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
class HFConfigMixin(transformers.PretrainedConfig): |
|
|
""" |
|
|
Bridge between your Config system and HF PretrainedConfig. |
|
|
|
|
|
Usage: |
|
|
class SPEAR1Config(HFConfigMixin, Config): |
|
|
model_type = "spear1" |
|
|
processor_config: PaliGemmaProcessorConfig |
|
|
... |
|
|
""" |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
|
|
|
|
|
|
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Config.__init__(self, **kwargs) |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Merge HF PretrainedConfig serialization and Config.as_json(). |
|
|
|
|
|
Strategy: |
|
|
1. Take HF dict (super().to_dict()) so HF metadata/defaults are present. |
|
|
2. Take our nested config dict (Config.as_json(self)). |
|
|
3. Update the HF dict with our nested config dict so annotated fields |
|
|
(nested configs, lists/dicts that should be recursively serialized) |
|
|
take precedence. |
|
|
""" |
|
|
|
|
|
hf = super().to_dict() |
|
|
|
|
|
|
|
|
|
|
|
cfg_json = Config.as_json(self) |
|
|
|
|
|
|
|
|
|
|
|
merged: Dict[str, Any] = dict(hf) |
|
|
merged.update(cfg_json) |
|
|
return merged |
|
|
|
|
|
@classmethod |
|
|
def from_dict( |
|
|
cls: Type["HFConfigMixin"], |
|
|
config_dict: Dict[str, Any], |
|
|
**kwargs, |
|
|
) -> "HFConfigMixin": |
|
|
""" |
|
|
Construct by delegating to the class constructor — that will instantiate nested configs. |
|
|
This is simple and consistent with PretrainedConfig.from_dict/from_pretrained behavior. |
|
|
""" |
|
|
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) |
|
|
|
|
|
instance = cls(**config_dict) |
|
|
|
|
|
if return_unused_kwargs: |
|
|
|
|
|
|
|
|
return instance, {} |
|
|
return instance |
|
|
|
|
|
|
|
|
class Configurable: |
|
|
def __init__(self, config: Config): |
|
|
self._config = config |
|
|
|
|
|
@property |
|
|
def config(self) -> Config: |
|
|
return self._config |
|
|
|
|
|
|
|
|
class RotationFormat(StrEnum): |
|
|
"""Determines how rotations will be encoded in the loaded batch""" |
|
|
|
|
|
EULER = "euler" |
|
|
QUATERNION = "quaternion" |
|
|
ROTMAT = "rotmat" |
|
|
|
|
|
|
|
|
class ResizeMode(StrEnum): |
|
|
""" |
|
|
Different modes for resizing images. |
|
|
""" |
|
|
|
|
|
MATCH_WIDTH = "match_width" |
|
|
MATCH_HEIGHT = "match_height" |
|
|
MATCH_MAX = "match_max" |
|
|
NAIVE = "naive" |
|
|
SMART = "smart" |
|
|
PAD = "pad" |
|
|
CROP = "crop" |
|
|
|
|
|
|
|
|
class Normalization(StrEnum): |
|
|
"""Action normalization types""" |
|
|
|
|
|
NONE = "none" |
|
|
BOUNDS = "bounds" |
|
|
BOUNDS_Q99 = "bounds_q99" |
|
|
MEAN = "mean" |
|
|
|
|
|
|
|
|
def expand_dims(tensor: torch.Tensor, ndim: int, order: Sequence[int]) -> torch.Tensor: |
|
|
""" |
|
|
Expand the dimensions of `tensor` to `ndim` such that all new dimensions have size of 1 |
|
|
Args: |
|
|
tensor: torch.Tensor of any shape |
|
|
ndim: Number of output dimensions. Must be >= `tensor.ndim` |
|
|
order: Sequence of size `tensor.ndim + 1`. Contains only values of 1 and a single value of -1, |
|
|
indicating where the new `ndim - tensor.ndim` dimensions will be inserted |
|
|
Returns: |
|
|
torch.Tensor with dimensions `ndim`, a view of `tensor` |
|
|
|
|
|
Ex: |
|
|
expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[1, -1, 1, 1]).shape -> [2, 1, 1, 3, 4] |
|
|
expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[-1, 1, 1, 1]).shape -> [1, 1, 2, 3, 4] |
|
|
expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[1, 1, 1, -1]).shape -> [2, 3, 4, 1, 1] |
|
|
""" |
|
|
assert tensor.ndim <= ndim, f"{tensor.ndim} > {ndim}; shape={tensor.shape}" |
|
|
assert len(order) == tensor.ndim + 1, f"{len(order)} != {tensor.ndim + 1}; shape={tensor.shape}" |
|
|
order = list(order) |
|
|
assert order.count(-1) == 1, "Order must have exactly one value of -1" |
|
|
assert order.count(1) == len(order) - 1, "Order must have exactly len(order) - 1 values of 1" |
|
|
if tensor.ndim == ndim: |
|
|
return tensor |
|
|
insert_index = order.index(-1) |
|
|
view = list(tensor.shape[:insert_index]) + [1] * (ndim - tensor.ndim) + list(tensor.shape[insert_index:]) |
|
|
tensor = tensor.view(view) |
|
|
return tensor |
|
|
|
|
|
|
|
|
def merge_dicts_recursive(dict_1: Dict[str, Any], dict_2: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Merges dict_1 with dict_2 recursively. |
|
|
Handles clashing keys: |
|
|
1. If both values are dicts, merges them recursively |
|
|
2. If any value is not a dict, raises ValueError |
|
|
""" |
|
|
merged = dict(dict_1) |
|
|
for key, value in dict_2.items(): |
|
|
if key in merged: |
|
|
if not type(merged[key]) is type(value) is dict: |
|
|
raise ValueError(f"Multiple values provided for key {key}: {merged[key]} and {value}") |
|
|
merged[key] = merge_dicts_recursive(merged[key], value) |
|
|
else: |
|
|
merged[key] = value |
|
|
return merged |
|
|
|
|
|
|
|
|
def maybe_chained_keys_to_nested_dict(data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Converts a dict with keys of the form "key1.key2.key3" to a nested dict""" |
|
|
unpacked_data: Dict[str, Any] = {} |
|
|
for key, value in data.items(): |
|
|
if "." not in key: |
|
|
unpacked_data = merge_dicts_recursive(unpacked_data, {key: value}) |
|
|
else: |
|
|
(mainkey, subkey) = key.split(".", maxsplit=1) |
|
|
nested_value = maybe_chained_keys_to_nested_dict({subkey: value}) |
|
|
unpacked_data = merge_dicts_recursive(unpacked_data, {mainkey: nested_value}) |
|
|
return unpacked_data |
|
|
|
|
|
|
|
|
def annotation_is_union(type_value: Type) -> bool: |
|
|
return getattr(type_value, "__origin__", None) is Union or type(type_value) is types.UnionType |
|
|
|
|
|
|
|
|
def annotation_is_optional(type_value: Type) -> bool: |
|
|
if annotation_is_union(type_value): |
|
|
union_args = set(type_value.__args__) |
|
|
if len(union_args) == 2 and type(None) in union_args: |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def get_maybe_optional_type(type_value: Type[Optional[Any]]) -> Type[Any]: |
|
|
if annotation_is_optional(type_value): |
|
|
type_args = type_value.__args__ |
|
|
if type_args[1] is type(None): |
|
|
return type_args[0] |
|
|
return type_args[1] |
|
|
return type_value |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class RoboticsTarget(Dataclass): |
|
|
control_tokens_ids: Optional[torch.Tensor] |
|
|
text_tokens_ids: Optional[torch.Tensor] |
|
|
translation: torch.Tensor |
|
|
rotation: torch.Tensor |
|
|
gripper: torch.Tensor |
|
|
valid_mask: torch.Tensor |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class RoboticsControlPlan(Dataclass): |
|
|
translation_m: torch.Tensor |
|
|
rotmat: torch.Tensor |
|
|
gripper_prob: torch.Tensor |
|
|
valid_mask: torch.Tensor |
|
|
|
|
|
def __post_init__(self): |
|
|
super().__post_init__() |
|
|
assert self.translation_m.ndim == 3, self.translation_m.shape |
|
|
assert self.rotmat.ndim == 3, self.rotmat.shape |
|
|
assert self.gripper_prob.ndim == 3, self.gripper_prob.shape |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class RoboticsInput(Dataclass): |
|
|
images: Dict[str, torch.Tensor] |
|
|
input_ids: torch.Tensor |
|
|
attn_mask: torch.Tensor |
|
|
ee_pose_translation: torch.Tensor |
|
|
ee_pose_rotation: torch.Tensor |
|
|
gripper: torch.Tensor |
|
|
joints: torch.Tensor |
|
|
control_tokens_ids: Optional[torch.Tensor] |
|
|
|
|
|
@property |
|
|
def inputs_embeds(self) -> Optional[torch.Tensor]: |
|
|
return None |
|
|
|
|
|
@property |
|
|
def past_key_values(self) -> Optional[List[torch.Tensor]]: |
|
|
return None |
|
|
|
|
|
@cached_property |
|
|
def multimodal_indices(self) -> torch.Tensor: |
|
|
""" |
|
|
Returns a torch.Tensor containing only the indices of the batch examples which are multimodal. |
|
|
Return shape is [B] |
|
|
""" |
|
|
return torch.arange(self.input_ids.shape[0], dtype=torch.int64, device=self.input_ids.device) |
|
|
|
|
|
@cached_property |
|
|
def unimodal_indices(self) -> torch.Tensor: |
|
|
""" |
|
|
Returns a torch.Tensor containing only the indices of the batch examples which are unimodal. |
|
|
Return shape is [B] |
|
|
""" |
|
|
return torch.tensor([], dtype=torch.int64, device=self.input_ids.device) |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class FlowInput(Dataclass): |
|
|
timestep: torch.Tensor |
|
|
translation_t: torch.Tensor |
|
|
rotation_t: torch.Tensor |
|
|
gripper_t: torch.Tensor |
|
|
translation_t0: torch.Tensor |
|
|
rotation_t0: torch.Tensor |
|
|
gripper_t0: torch.Tensor |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class RoboticsFlowInput(RoboticsInput): |
|
|
"""Input to the entire Robotics VLM""" |
|
|
|
|
|
flow_input: FlowInput |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class DiffusionInput(Dataclass): |
|
|
timestep: torch.Tensor |
|
|
noised_translation: torch.Tensor |
|
|
noised_rotation: torch.Tensor |
|
|
noised_gripper: torch.Tensor |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class LLMOutput(Dataclass): |
|
|
"""Fork of transformers.modeling_outputs.CausalLMOutputWithPast""" |
|
|
|
|
|
input_ids: torch.Tensor |
|
|
logits: Optional[torch.Tensor] |
|
|
output_ids: Optional[torch.Tensor] |
|
|
loss: Optional[torch.Tensor] |
|
|
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]] |
|
|
hidden_states: List[torch.Tensor] |
|
|
text_indices: torch.Tensor |
|
|
image_indices: torch.Tensor |
|
|
|
|
|
@classmethod |
|
|
def from_transformers( |
|
|
cls, |
|
|
input_ids: torch.Tensor, |
|
|
llm_output: transformers.modeling_outputs.CausalLMOutputWithPast, |
|
|
text_indices: Optional[torch.Tensor], |
|
|
image_indices: Optional[torch.Tensor], |
|
|
) -> "LLMOutput": |
|
|
return LLMOutput( |
|
|
input_ids=input_ids, |
|
|
logits=llm_output.logits, |
|
|
output_ids=None, |
|
|
loss=llm_output.loss, |
|
|
past_key_values=( |
|
|
list(llm_output.past_key_values) if llm_output.past_key_values is not None else [] |
|
|
), |
|
|
hidden_states=(list(llm_output.hidden_states) if llm_output.hidden_states is not None else []), |
|
|
text_indices=text_indices, |
|
|
image_indices=image_indices, |
|
|
) |
|
|
|
|
|
def compress(self) -> "LLMOutput": |
|
|
""" |
|
|
Compress the data contained in the class so it can be moved between CPU and GPU or concatenated |
|
|
much faster: |
|
|
- hidden_states - huge tensors; take a lot of CPU time to move across devices or concat |
|
|
- past_key_values - huge tensors; take a lot of CPU time to move across devices or concat |
|
|
- logits - huge last dimension; takes a lot of CPU time to move across devices or concat |
|
|
""" |
|
|
replace: Dict[str, Any] = { |
|
|
"hidden_states": [], |
|
|
"past_key_values": [], |
|
|
"loss": None, |
|
|
"input_ids": None, |
|
|
} |
|
|
if self.logits is not None: |
|
|
replace["logits"] = None |
|
|
if self.output_ids is None or self.output_ids.shape[1] != self.text_indices.shape[0]: |
|
|
replace["output_ids"] = ( |
|
|
torch.index_select(self.logits, dim=1, index=self.text_indices) |
|
|
.argmax(dim=-1) |
|
|
.to(dtype=torch.int64) |
|
|
) |
|
|
return self.replace(**replace) |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class RoboticsOutput(Dataclass): |
|
|
translation: Optional[torch.Tensor] |
|
|
rotation: Optional[torch.Tensor] |
|
|
gripper: Optional[torch.Tensor] |
|
|
token_logits: Optional[torch.Tensor] |
|
|
token_ids: Optional[torch.Tensor] |
|
|
llm_output: LLMOutput |
|
|
|
|
|
def compress(self) -> "RoboticsOutput": |
|
|
""" |
|
|
Compress output and drop unnecessary components to speed up transfer GPU <-> CPU. |
|
|
Note that LLM logits can be extremely expensive since their size is [B, S, vocab_size], which |
|
|
can reach millions or billions of values for large vocab_size |
|
|
""" |
|
|
replace: Dict[str, Any] = { |
|
|
"llm_output": self.llm_output.compress(), |
|
|
"token_logits": None, |
|
|
} |
|
|
if self.token_logits is not None and self.token_ids is None: |
|
|
replace["token_ids"] = torch.argmax(self.token_logits, dim=-1) |
|
|
return self.replace(**replace) |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class VLMOutput(Dataclass): |
|
|
llm_output: LLMOutput |
|
|
vit_tokens: Optional[torch.Tensor] |
|
|
attn_mask: torch.Tensor |
|
|
|
|
|
def compress(self) -> "VLMOutput": |
|
|
""" |
|
|
Compress output and drop unnecessary components to speed up transfer GPU <-> CPU. |
|
|
Note that LLM logits can be extremely expensive since their size is [B, S, vocab_size], which |
|
|
can reach millions or billions of values for large vocab_size |
|
|
""" |
|
|
return self.replace(llm_output=self.llm_output.compress()) |
|
|
|
|
|
|
|
|
def is_quaternion(quaternion: torch.Tensor) -> bool: |
|
|
return quaternion.shape[-1] == 4 |
|
|
|
|
|
|
|
|
def quaternion_half_cover(quaternion: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Flip quaternions so they cover only a half the space. If the q_w is negative, flip the quaternion. |
|
|
If q_w is 0, then choose such that the first non-zero component is positive. Note that geometrically, |
|
|
this doesn't correspond to a single hemisphere of the unit sphere. Follows |
|
|
https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.as_quat.html#scipy.spatial.transform.Rotation.as_quat |
|
|
""" |
|
|
assert is_quaternion(quaternion), quaternion.shape |
|
|
with torch.no_grad(): |
|
|
is_zero = quaternion == 0 |
|
|
flip_condition = ( |
|
|
(quaternion[..., -1:] < 0) |
|
|
| is_zero[..., -1:] & (quaternion[..., 0:1] < 0) |
|
|
| is_zero[..., -1:] & is_zero[..., 0:1] & (quaternion[..., 1:2] < 0) |
|
|
| is_zero[..., -1:] & is_zero[..., 0:1] & is_zero[..., 1:2] & (quaternion[..., 2:3] < 0) |
|
|
) |
|
|
quaternion = torch.where(flip_condition, -quaternion, quaternion) |
|
|
return quaternion |
|
|
|
|
|
|
|
|
def is_rotmat_3x3(rotmat: torch.Tensor) -> bool: |
|
|
return rotmat.shape[-2:] == torch.Size([3, 3]) |
|
|
|
|
|
|
|
|
def is_rotmat_9(rotmat: torch.Tensor) -> bool: |
|
|
return rotmat.shape[-1] == 9 |
|
|
|
|
|
|
|
|
def rotmat_as_9(rotmat: torch.Tensor) -> torch.Tensor: |
|
|
"""Convert any rotmat input to [..., 9] shape""" |
|
|
if is_rotmat_9(rotmat): |
|
|
return rotmat |
|
|
if is_rotmat_3x3(rotmat): |
|
|
return rotmat.reshape(*rotmat.shape[:-2], 9) |
|
|
raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix") |
|
|
|
|
|
|
|
|
def is_rotmat(rotmat: torch.Tensor) -> bool: |
|
|
""" |
|
|
Checks if the tensor shape matches that of a rotmat. However, it's not guaranteed the data is a |
|
|
valid rotmat. `is_orthonormal_rotmat` performs this additional check. |
|
|
NOTE: This might incorrectly return True if the underlying data is euler angles and accidentally |
|
|
`rotmat.shape[-2:] == [3, 3]`. This would happen very rarely, but use with caution |
|
|
""" |
|
|
return is_rotmat_3x3(rotmat) or is_rotmat_9(rotmat) |
|
|
|
|
|
|
|
|
def rotmat_as_3x3(rotmat: torch.Tensor) -> torch.Tensor: |
|
|
"""Convert any rotmat input to [..., 3, 3] shape""" |
|
|
if rotmat.shape[-1] == 9: |
|
|
return rotmat.reshape(*rotmat.shape[:-1], 3, 3) |
|
|
if rotmat.shape[-2:] == torch.Size([3, 3]): |
|
|
return rotmat |
|
|
raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix") |
|
|
|