Upload modeling_pi0.py with huggingface_hub
Browse files- modeling_pi0.py +826 -0
modeling_pi0.py
ADDED
|
@@ -0,0 +1,826 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
π0: A Vision-Language-Action Flow Model for General Robot Control
|
| 19 |
+
|
| 20 |
+
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
|
| 21 |
+
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
| 22 |
+
|
| 23 |
+
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
| 24 |
+
|
| 25 |
+
Install pi0 extra dependencies:
|
| 26 |
+
```bash
|
| 27 |
+
pip install -e ".[pi0]"
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
|
| 31 |
+
```bash
|
| 32 |
+
python lerobot/scripts/train.py \
|
| 33 |
+
--policy.path=lerobot/pi0 \
|
| 34 |
+
--dataset.repo_id=danaaubakirova/koch_test
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
Example of finetuning the pi0 neural network with PaliGemma and expert Gemma
|
| 38 |
+
pretrained with VLM default parameters before pi0 finetuning:
|
| 39 |
+
```bash
|
| 40 |
+
python lerobot/scripts/train.py \
|
| 41 |
+
--policy.type=pi0 \
|
| 42 |
+
--dataset.repo_id=danaaubakirova/koch_test
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
Example of using the pi0 pretrained model outside LeRobot training framework:
|
| 46 |
+
```python
|
| 47 |
+
policy = Pi0Policy.from_pretrained("lerobot/pi0")
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
import math
|
| 53 |
+
from collections import deque
|
| 54 |
+
|
| 55 |
+
import torch
|
| 56 |
+
import torch.nn.functional as F # noqa: N812
|
| 57 |
+
from configuration_pi0 import PI0Config
|
| 58 |
+
from lerobot.common.constants import ACTION, OBS_ROBOT
|
| 59 |
+
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
| 60 |
+
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
| 61 |
+
from lerobot.common.utils.utils import get_safe_dtype
|
| 62 |
+
from paligemma_with_expert import (
|
| 63 |
+
PaliGemmaWithExpertConfig,
|
| 64 |
+
PaliGemmaWithExpertModel,
|
| 65 |
+
)
|
| 66 |
+
from torch import Tensor, nn
|
| 67 |
+
from transformers import AutoTokenizer
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def create_sinusoidal_pos_embedding(
|
| 71 |
+
time: torch.tensor,
|
| 72 |
+
dimension: int,
|
| 73 |
+
min_period: float,
|
| 74 |
+
max_period: float,
|
| 75 |
+
device="cpu",
|
| 76 |
+
) -> Tensor:
|
| 77 |
+
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
| 78 |
+
if dimension % 2 != 0:
|
| 79 |
+
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
| 80 |
+
|
| 81 |
+
if time.ndim != 1:
|
| 82 |
+
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
| 83 |
+
|
| 84 |
+
dtype = get_safe_dtype(torch.float64, device.type)
|
| 85 |
+
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
| 86 |
+
period = min_period * (max_period / min_period) ** fraction
|
| 87 |
+
|
| 88 |
+
# Compute the outer product
|
| 89 |
+
scaling_factor = 1.0 / period * 2 * math.pi
|
| 90 |
+
sin_input = scaling_factor[None, :] * time[:, None]
|
| 91 |
+
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
| 92 |
+
return pos_emb
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def sample_beta(alpha, beta, bsize, device):
|
| 96 |
+
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
|
| 97 |
+
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
|
| 98 |
+
return gamma1 / (gamma1 + gamma2)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def make_att_2d_masks(pad_masks, att_masks):
|
| 102 |
+
"""Copied from big_vision.
|
| 103 |
+
|
| 104 |
+
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
| 105 |
+
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
| 106 |
+
setup several types of attention, for example:
|
| 107 |
+
|
| 108 |
+
[[1 1 1 1 1 1]]: pure causal attention.
|
| 109 |
+
|
| 110 |
+
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
| 111 |
+
themselves and the last 3 tokens have a causal attention. The first
|
| 112 |
+
entry could also be a 1 without changing behaviour.
|
| 113 |
+
|
| 114 |
+
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
| 115 |
+
block can attend all previous blocks and all tokens on the same block.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
input_mask: bool[B, N] true if its part of the input, false if padding.
|
| 119 |
+
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
| 120 |
+
it and 0 where it shares the same attention mask as the previous token.
|
| 121 |
+
"""
|
| 122 |
+
if att_masks.ndim != 2:
|
| 123 |
+
raise ValueError(att_masks.ndim)
|
| 124 |
+
if pad_masks.ndim != 2:
|
| 125 |
+
raise ValueError(pad_masks.ndim)
|
| 126 |
+
|
| 127 |
+
cumsum = torch.cumsum(att_masks, dim=1)
|
| 128 |
+
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
| 129 |
+
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
| 130 |
+
att_2d_masks = att_2d_masks & pad_2d_masks
|
| 131 |
+
return att_2d_masks
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def resize_with_pad(img, width, height, pad_value=-1):
|
| 135 |
+
# assume no-op when width height fits already
|
| 136 |
+
if img.ndim != 4:
|
| 137 |
+
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
| 138 |
+
|
| 139 |
+
cur_height, cur_width = img.shape[2:]
|
| 140 |
+
|
| 141 |
+
ratio = max(cur_width / width, cur_height / height)
|
| 142 |
+
resized_height = int(cur_height / ratio)
|
| 143 |
+
resized_width = int(cur_width / ratio)
|
| 144 |
+
resized_img = F.interpolate(
|
| 145 |
+
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
pad_height = max(0, int(height - resized_height))
|
| 149 |
+
pad_width = max(0, int(width - resized_width))
|
| 150 |
+
|
| 151 |
+
# pad on left and top of image
|
| 152 |
+
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
| 153 |
+
return padded_img
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def pad_vector(vector, new_dim):
|
| 157 |
+
"""Can be (batch_size x sequence_length x features_dimension)
|
| 158 |
+
or (batch_size x features_dimension)
|
| 159 |
+
"""
|
| 160 |
+
if vector.shape[-1] == new_dim:
|
| 161 |
+
return vector
|
| 162 |
+
shape = list(vector.shape)
|
| 163 |
+
current_dim = shape[-1]
|
| 164 |
+
shape[-1] = new_dim
|
| 165 |
+
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
| 166 |
+
new_vector[..., :current_dim] = vector
|
| 167 |
+
return new_vector
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def normalize(x, min_val, max_val):
|
| 171 |
+
return (x - min_val) / (max_val - min_val)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def unnormalize(x, min_val, max_val):
|
| 175 |
+
return x * (max_val - min_val) + min_val
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def safe_arcsin(value):
|
| 179 |
+
# This ensures that the input stays within
|
| 180 |
+
# [−1,1] to avoid invalid values for arcsin
|
| 181 |
+
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def aloha_gripper_to_angular(value):
|
| 185 |
+
# Aloha transforms the gripper positions into a linear space. The following code
|
| 186 |
+
# reverses this transformation to be consistent with pi0 which is pretrained in
|
| 187 |
+
# angular space.
|
| 188 |
+
#
|
| 189 |
+
# These values are coming from the Aloha code:
|
| 190 |
+
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
| 191 |
+
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
| 192 |
+
|
| 193 |
+
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
| 194 |
+
def linear_to_radian(linear_position, arm_length, horn_radius):
|
| 195 |
+
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (
|
| 196 |
+
2 * horn_radius * linear_position
|
| 197 |
+
)
|
| 198 |
+
return safe_arcsin(value)
|
| 199 |
+
|
| 200 |
+
# The constants are taken from the Interbotix code.
|
| 201 |
+
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
| 202 |
+
|
| 203 |
+
# Normalize to [0, 1].
|
| 204 |
+
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
| 205 |
+
return normalize(value, min_val=0.4, max_val=1.5)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def aloha_gripper_from_angular(value):
|
| 209 |
+
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
| 210 |
+
# Note that the units are still angular but the range is different.
|
| 211 |
+
|
| 212 |
+
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
| 213 |
+
value = unnormalize(value, min_val=0.4, max_val=1.5)
|
| 214 |
+
|
| 215 |
+
# These values are coming from the Aloha code:
|
| 216 |
+
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
| 217 |
+
return normalize(value, min_val=-0.6213, max_val=1.4910)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def aloha_gripper_from_angular_inv(value):
|
| 221 |
+
# Directly inverts the gripper_from_angular function.
|
| 222 |
+
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
| 223 |
+
return normalize(value, min_val=0.4, max_val=1.5)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class PI0Policy(PreTrainedPolicy):
|
| 227 |
+
"""Wrapper class around PI0FlowMatching model to train and run inference within LeRobot."""
|
| 228 |
+
|
| 229 |
+
config_class = PI0Config
|
| 230 |
+
name = "pi0"
|
| 231 |
+
|
| 232 |
+
def __init__(
|
| 233 |
+
self,
|
| 234 |
+
config: PI0Config,
|
| 235 |
+
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
| 236 |
+
):
|
| 237 |
+
"""
|
| 238 |
+
Args:
|
| 239 |
+
config: Policy configuration class instance or None, in which case the default instantiation of
|
| 240 |
+
the configuration class is used.
|
| 241 |
+
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
| 242 |
+
that they will be passed with a call to `load_state_dict` before the policy is used.
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
super().__init__(config)
|
| 246 |
+
config.validate_features()
|
| 247 |
+
self.config = config
|
| 248 |
+
|
| 249 |
+
# TODO: input / output features / normalizer for mutiple datasets
|
| 250 |
+
self.normalize_inputs = Normalize(
|
| 251 |
+
config.input_features, config.normalization_mapping, dataset_stats
|
| 252 |
+
)
|
| 253 |
+
self.normalize_targets = Normalize(
|
| 254 |
+
config.output_features, config.normalization_mapping, dataset_stats
|
| 255 |
+
)
|
| 256 |
+
self.unnormalize_outputs = Unnormalize(
|
| 257 |
+
config.output_features, config.normalization_mapping, dataset_stats
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
# self.language_tokenizer = AutoTokenizer.from_pretrained("/cpfs01/shared/optimal/vla_next/pretrained/pi0", local_files_only=True)
|
| 261 |
+
self.language_tokenizer = None
|
| 262 |
+
self.model = PI0FlowMatching(config)
|
| 263 |
+
|
| 264 |
+
self.reset()
|
| 265 |
+
|
| 266 |
+
def reset(self):
|
| 267 |
+
"""This should be called whenever the environment is reset."""
|
| 268 |
+
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
| 269 |
+
|
| 270 |
+
def get_optim_params(self) -> dict:
|
| 271 |
+
return self.parameters()
|
| 272 |
+
|
| 273 |
+
@torch.no_grad
|
| 274 |
+
def select_action(
|
| 275 |
+
self, batch: dict[str, Tensor], noise: Tensor | None = None
|
| 276 |
+
) -> Tensor:
|
| 277 |
+
"""Select a single action given environment observations.
|
| 278 |
+
|
| 279 |
+
This method wraps `select_actions` in order to return one action at a time for execution in the
|
| 280 |
+
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
| 281 |
+
queue is empty.
|
| 282 |
+
"""
|
| 283 |
+
self.eval()
|
| 284 |
+
|
| 285 |
+
if self.config.adapt_to_pi_aloha:
|
| 286 |
+
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
| 287 |
+
|
| 288 |
+
batch = self.normalize_inputs(batch)
|
| 289 |
+
|
| 290 |
+
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
| 291 |
+
# querying the policy.
|
| 292 |
+
images, img_masks = self.prepare_images(batch)
|
| 293 |
+
state = self.prepare_state(batch)
|
| 294 |
+
lang_tokens, lang_masks = self.prepare_language(batch)
|
| 295 |
+
|
| 296 |
+
actions = self.model.sample_actions(
|
| 297 |
+
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# Unpad actions
|
| 301 |
+
original_action_dim = self.config.action_feature.shape[0]
|
| 302 |
+
actions = actions[:, :, :original_action_dim]
|
| 303 |
+
|
| 304 |
+
actions = self.unnormalize_outputs({"action": actions})["action"]
|
| 305 |
+
|
| 306 |
+
if self.config.adapt_to_pi_aloha:
|
| 307 |
+
actions = self._pi_aloha_encode_actions(actions)
|
| 308 |
+
return actions
|
| 309 |
+
|
| 310 |
+
def forward(
|
| 311 |
+
self, batch: dict[str, Tensor], noise=None, time=None
|
| 312 |
+
) -> tuple[Tensor, dict[str, Tensor]]:
|
| 313 |
+
"""Do a full training forward pass to compute the loss"""
|
| 314 |
+
if self.config.adapt_to_pi_aloha:
|
| 315 |
+
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
| 316 |
+
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
| 317 |
+
|
| 318 |
+
batch = self.normalize_inputs(batch)
|
| 319 |
+
batch = self.normalize_targets(batch)
|
| 320 |
+
|
| 321 |
+
images, img_masks = self.prepare_images(batch)
|
| 322 |
+
state = self.prepare_state(batch)
|
| 323 |
+
lang_tokens, lang_masks = self.prepare_language(batch)
|
| 324 |
+
actions = self.prepare_action(batch)
|
| 325 |
+
actions_is_pad = batch.get("action_is_pad")
|
| 326 |
+
|
| 327 |
+
loss_dict = {}
|
| 328 |
+
losses = self.model.forward(
|
| 329 |
+
images, img_masks, lang_tokens, lang_masks, state, actions, noise, time
|
| 330 |
+
)
|
| 331 |
+
# loss_dict["losses_after_forward"] = losses.detach().mean().item()
|
| 332 |
+
|
| 333 |
+
if actions_is_pad is not None:
|
| 334 |
+
in_episode_bound = ~actions_is_pad
|
| 335 |
+
losses = losses * in_episode_bound.unsqueeze(-1)
|
| 336 |
+
# loss_dict["losses_after_in_ep_bound"] = losses.detach().mean().item()
|
| 337 |
+
|
| 338 |
+
# Remove padding
|
| 339 |
+
losses = losses[:, :, : self.config.max_action_dim]
|
| 340 |
+
# loss_dict["losses_after_rm_padding"] = losses.detach().mean().item()
|
| 341 |
+
|
| 342 |
+
# For backward pass
|
| 343 |
+
loss = losses.mean()
|
| 344 |
+
# For logging
|
| 345 |
+
loss_dict["l2_loss"] = loss.item()
|
| 346 |
+
|
| 347 |
+
return loss, loss_dict
|
| 348 |
+
|
| 349 |
+
def prepare_images(self, batch):
|
| 350 |
+
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
| 351 |
+
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
|
| 352 |
+
"""
|
| 353 |
+
images = []
|
| 354 |
+
img_masks = []
|
| 355 |
+
|
| 356 |
+
present_img_keys = [key for key in self.config.image_features if key in batch]
|
| 357 |
+
missing_img_keys = [
|
| 358 |
+
key for key in self.config.image_features if key not in batch
|
| 359 |
+
]
|
| 360 |
+
|
| 361 |
+
if len(present_img_keys) == 0:
|
| 362 |
+
raise ValueError(
|
| 363 |
+
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
# Preprocess image features present in the batch
|
| 367 |
+
for key in present_img_keys:
|
| 368 |
+
img = batch[key]
|
| 369 |
+
|
| 370 |
+
if self.config.resize_imgs_with_padding is not None:
|
| 371 |
+
img = resize_with_pad(
|
| 372 |
+
img, *self.config.resize_imgs_with_padding, pad_value=0
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
| 376 |
+
img = img * 2.0 - 1.0
|
| 377 |
+
|
| 378 |
+
bsize = img.shape[0]
|
| 379 |
+
device = img.device
|
| 380 |
+
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
| 381 |
+
images.append(img)
|
| 382 |
+
img_masks.append(mask)
|
| 383 |
+
|
| 384 |
+
# Create image features not present in the batch
|
| 385 |
+
# as fully 0 padded images.
|
| 386 |
+
for num_empty_cameras in range(len(missing_img_keys)):
|
| 387 |
+
if num_empty_cameras >= self.config.empty_cameras:
|
| 388 |
+
break
|
| 389 |
+
img = torch.ones_like(img) * -1
|
| 390 |
+
mask = torch.zeros_like(mask)
|
| 391 |
+
images.append(img)
|
| 392 |
+
img_masks.append(mask)
|
| 393 |
+
|
| 394 |
+
return images, img_masks
|
| 395 |
+
|
| 396 |
+
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
| 397 |
+
"""Tokenize the text input"""
|
| 398 |
+
device = batch[OBS_ROBOT].device
|
| 399 |
+
tasks = batch["task"]
|
| 400 |
+
|
| 401 |
+
# PaliGemma prompt has to end with a new line
|
| 402 |
+
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
| 403 |
+
|
| 404 |
+
tokenized_prompt = self.language_tokenizer.__call__(
|
| 405 |
+
tasks,
|
| 406 |
+
padding="max_length",
|
| 407 |
+
padding_side="right",
|
| 408 |
+
max_length=self.config.tokenizer_max_length,
|
| 409 |
+
return_tensors="pt",
|
| 410 |
+
truncation=True,
|
| 411 |
+
)
|
| 412 |
+
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
| 413 |
+
lang_masks = tokenized_prompt["attention_mask"].to(
|
| 414 |
+
device=device, dtype=torch.bool
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
return lang_tokens, lang_masks
|
| 418 |
+
|
| 419 |
+
def _pi_aloha_decode_state(self, state):
|
| 420 |
+
# Flip the joints.
|
| 421 |
+
for motor_idx in [1, 2, 8, 9]:
|
| 422 |
+
state[:, motor_idx] *= -1
|
| 423 |
+
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
| 424 |
+
for motor_idx in [6, 13]:
|
| 425 |
+
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
| 426 |
+
return state
|
| 427 |
+
|
| 428 |
+
def _pi_aloha_encode_actions(self, actions):
|
| 429 |
+
# Flip the joints.
|
| 430 |
+
for motor_idx in [1, 2, 8, 9]:
|
| 431 |
+
actions[:, :, motor_idx] *= -1
|
| 432 |
+
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
| 433 |
+
for motor_idx in [6, 13]:
|
| 434 |
+
actions[:, :, motor_idx] = aloha_gripper_from_angular(
|
| 435 |
+
actions[:, :, motor_idx]
|
| 436 |
+
)
|
| 437 |
+
return actions
|
| 438 |
+
|
| 439 |
+
def _pi_aloha_encode_actions_inv(self, actions):
|
| 440 |
+
# Flip the joints again.
|
| 441 |
+
for motor_idx in [1, 2, 8, 9]:
|
| 442 |
+
actions[:, :, motor_idx] *= -1
|
| 443 |
+
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
| 444 |
+
for motor_idx in [6, 13]:
|
| 445 |
+
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(
|
| 446 |
+
actions[:, :, motor_idx]
|
| 447 |
+
)
|
| 448 |
+
return actions
|
| 449 |
+
|
| 450 |
+
def prepare_state(self, batch):
|
| 451 |
+
"""Pad state"""
|
| 452 |
+
state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim)
|
| 453 |
+
return state
|
| 454 |
+
|
| 455 |
+
def prepare_action(self, batch):
|
| 456 |
+
"""Pad action"""
|
| 457 |
+
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
| 458 |
+
return actions
|
| 459 |
+
|
| 460 |
+
def _save_pretrained(self, save_directory) -> None:
|
| 461 |
+
super()._save_pretrained(save_directory)
|
| 462 |
+
print(f"Saving the language tokenizer to {save_directory} ...")
|
| 463 |
+
self.language_tokenizer.save_pretrained(save_directory)
|
| 464 |
+
|
| 465 |
+
print(f"Copying config and model to {save_directory} ...")
|
| 466 |
+
import shutil
|
| 467 |
+
|
| 468 |
+
files = [
|
| 469 |
+
"pi0/configuration_pi0.py",
|
| 470 |
+
"pi0/flex_attention.py",
|
| 471 |
+
"pi0/modeling_pi0.py",
|
| 472 |
+
"pi0/paligemma_with_expert.py",
|
| 473 |
+
]
|
| 474 |
+
try:
|
| 475 |
+
for file in files:
|
| 476 |
+
shutil.copy(file, save_directory)
|
| 477 |
+
except Exception:
|
| 478 |
+
print("Failed to copy files to save_directory")
|
| 479 |
+
|
| 480 |
+
@classmethod
|
| 481 |
+
def from_pretrained(
|
| 482 |
+
cls,
|
| 483 |
+
pretrained_name_or_path,
|
| 484 |
+
**kwargs,
|
| 485 |
+
):
|
| 486 |
+
policy = super().from_pretrained(pretrained_name_or_path, **kwargs)
|
| 487 |
+
print(f"Loading the language tokenizer from {pretrained_name_or_path} ...")
|
| 488 |
+
policy.language_tokenizer = AutoTokenizer.from_pretrained(
|
| 489 |
+
pretrained_name_or_path
|
| 490 |
+
)
|
| 491 |
+
return policy
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
class PI0FlowMatching(nn.Module):
|
| 495 |
+
"""
|
| 496 |
+
π0: A Vision-Language-Action Flow Model for General Robot Control
|
| 497 |
+
|
| 498 |
+
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
|
| 499 |
+
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
| 500 |
+
|
| 501 |
+
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
| 502 |
+
┌──────────────────────────────┐
|
| 503 |
+
│ actions ──────────► noise
|
| 504 |
+
│ ▲ │ │
|
| 505 |
+
│ ┌┴─────┐ │ ┌┴─────┐
|
| 506 |
+
│ kv cache │Gemma │ │ │Gemma │
|
| 507 |
+
│ ┌──────────►│Expert│ │ │Expert│ 4
|
| 508 |
+
│ │ │ │ │ │ │
|
| 509 |
+
│ ┌┴─────▲───┐ │x 10 │ │ │x 10 │
|
| 510 |
+
│ │ │ └▲──▲──┘ │ └▲──▲─-┘
|
| 511 |
+
│ │PaliGemma │ │ │ │ │ │
|
| 512 |
+
│ │ │ │ robot state │ │ robot state
|
| 513 |
+
│ │ │ noise │ vision
|
| 514 |
+
│ └▲──▲──▲───┘ │
|
| 515 |
+
│ │ │ │
|
| 516 |
+
│ │ image(s) │
|
| 517 |
+
│ language tokens │
|
| 518 |
+
└──────────────────────────────┘
|
| 519 |
+
"""
|
| 520 |
+
|
| 521 |
+
def __init__(self, config):
|
| 522 |
+
super().__init__()
|
| 523 |
+
self.config = config
|
| 524 |
+
|
| 525 |
+
paligemma_with_export_config = PaliGemmaWithExpertConfig(
|
| 526 |
+
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
| 527 |
+
train_expert_only=self.config.train_expert_only,
|
| 528 |
+
attention_implementation=self.config.attention_implementation,
|
| 529 |
+
paligemma_config=self.config.paligemma_config,
|
| 530 |
+
gemma_expert_config=self.config.gemma_expert_config,
|
| 531 |
+
)
|
| 532 |
+
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
| 533 |
+
paligemma_with_export_config
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
# Projections are float32
|
| 537 |
+
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
|
| 538 |
+
self.action_in_proj = nn.Linear(
|
| 539 |
+
self.config.max_action_dim, self.config.proj_width
|
| 540 |
+
)
|
| 541 |
+
self.action_out_proj = nn.Linear(
|
| 542 |
+
self.config.proj_width, self.config.max_action_dim
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
self.action_time_mlp_in = nn.Linear(
|
| 546 |
+
self.config.proj_width * 2, self.config.proj_width
|
| 547 |
+
)
|
| 548 |
+
self.action_time_mlp_out = nn.Linear(
|
| 549 |
+
self.config.proj_width, self.config.proj_width
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
self.set_requires_grad()
|
| 553 |
+
|
| 554 |
+
def set_requires_grad(self):
|
| 555 |
+
for params in self.state_proj.parameters():
|
| 556 |
+
params.requires_grad = self.config.train_state_proj
|
| 557 |
+
|
| 558 |
+
def sample_noise(self, shape, device):
|
| 559 |
+
noise = torch.normal(
|
| 560 |
+
mean=0.0,
|
| 561 |
+
std=1.0,
|
| 562 |
+
size=shape,
|
| 563 |
+
dtype=torch.float32,
|
| 564 |
+
device=device,
|
| 565 |
+
)
|
| 566 |
+
return noise
|
| 567 |
+
|
| 568 |
+
def sample_time(self, bsize, device):
|
| 569 |
+
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
| 570 |
+
time = time_beta * 0.999 + 0.001
|
| 571 |
+
return time.to(dtype=torch.float32, device=device)
|
| 572 |
+
|
| 573 |
+
def embed_prefix(
|
| 574 |
+
self, images, img_masks, lang_tokens, lang_masks
|
| 575 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 576 |
+
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
| 577 |
+
for PaliGemma transformer processing.
|
| 578 |
+
"""
|
| 579 |
+
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
| 580 |
+
embs = []
|
| 581 |
+
pad_masks = []
|
| 582 |
+
att_masks = []
|
| 583 |
+
|
| 584 |
+
# TODO: remove for loop
|
| 585 |
+
for (
|
| 586 |
+
img,
|
| 587 |
+
img_mask,
|
| 588 |
+
) in zip(images, img_masks, strict=False):
|
| 589 |
+
img_emb = self.paligemma_with_expert.embed_image(img)
|
| 590 |
+
img_emb = img_emb.to(dtype=torch.bfloat16)
|
| 591 |
+
|
| 592 |
+
# Normalize image embeddings
|
| 593 |
+
img_emb_dim = img_emb.shape[-1]
|
| 594 |
+
img_emb = img_emb * torch.tensor(
|
| 595 |
+
img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
bsize, num_img_embs = img_emb.shape[:2]
|
| 599 |
+
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
| 600 |
+
|
| 601 |
+
embs.append(img_emb)
|
| 602 |
+
pad_masks.append(img_mask)
|
| 603 |
+
|
| 604 |
+
# Create attention masks so that image tokens attend to each other
|
| 605 |
+
att_masks += [0] * num_img_embs
|
| 606 |
+
|
| 607 |
+
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
| 608 |
+
|
| 609 |
+
# Normalize language embeddings
|
| 610 |
+
lang_emb_dim = lang_emb.shape[-1]
|
| 611 |
+
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
|
| 612 |
+
|
| 613 |
+
embs.append(lang_emb)
|
| 614 |
+
pad_masks.append(lang_masks)
|
| 615 |
+
|
| 616 |
+
# full attention between image and language inputs
|
| 617 |
+
num_lang_embs = lang_emb.shape[1]
|
| 618 |
+
att_masks += [0] * num_lang_embs
|
| 619 |
+
|
| 620 |
+
embs = torch.cat(embs, dim=1)
|
| 621 |
+
pad_masks = torch.cat(pad_masks, dim=1)
|
| 622 |
+
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
| 623 |
+
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
| 624 |
+
|
| 625 |
+
return embs, pad_masks, att_masks
|
| 626 |
+
|
| 627 |
+
def embed_suffix(self, state, noisy_actions, timestep):
|
| 628 |
+
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
| 629 |
+
embs = []
|
| 630 |
+
pad_masks = []
|
| 631 |
+
att_masks = []
|
| 632 |
+
|
| 633 |
+
# Embed state
|
| 634 |
+
state_emb = self.state_proj(state)
|
| 635 |
+
state_emb = state_emb.to(dtype=torch.bfloat16)
|
| 636 |
+
embs.append(state_emb[:, None, :])
|
| 637 |
+
bsize = state_emb.shape[0]
|
| 638 |
+
dtype = state_emb.dtype
|
| 639 |
+
device = state_emb.device
|
| 640 |
+
|
| 641 |
+
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
| 642 |
+
pad_masks.append(state_mask)
|
| 643 |
+
|
| 644 |
+
# Set attention masks so that image and language inputs do not attend to state or actions
|
| 645 |
+
att_masks += [1]
|
| 646 |
+
|
| 647 |
+
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
| 648 |
+
time_emb = create_sinusoidal_pos_embedding(
|
| 649 |
+
timestep,
|
| 650 |
+
self.config.proj_width,
|
| 651 |
+
min_period=4e-3,
|
| 652 |
+
max_period=4.0,
|
| 653 |
+
device=device,
|
| 654 |
+
)
|
| 655 |
+
time_emb = time_emb.type(dtype=dtype)
|
| 656 |
+
|
| 657 |
+
# Fuse timestep + action information using an MLP
|
| 658 |
+
action_emb = self.action_in_proj(noisy_actions)
|
| 659 |
+
|
| 660 |
+
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
| 661 |
+
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
| 662 |
+
|
| 663 |
+
action_time_emb = self.action_time_mlp_in(action_time_emb)
|
| 664 |
+
action_time_emb = F.silu(action_time_emb) # swish == silu
|
| 665 |
+
action_time_emb = self.action_time_mlp_out(action_time_emb)
|
| 666 |
+
|
| 667 |
+
# Add to input tokens
|
| 668 |
+
embs.append(action_time_emb)
|
| 669 |
+
|
| 670 |
+
bsize, action_time_dim = action_time_emb.shape[:2]
|
| 671 |
+
action_time_mask = torch.ones(
|
| 672 |
+
bsize, action_time_dim, dtype=torch.bool, device=device
|
| 673 |
+
)
|
| 674 |
+
pad_masks.append(action_time_mask)
|
| 675 |
+
|
| 676 |
+
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
| 677 |
+
att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
|
| 678 |
+
|
| 679 |
+
embs = torch.cat(embs, dim=1)
|
| 680 |
+
pad_masks = torch.cat(pad_masks, dim=1)
|
| 681 |
+
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
| 682 |
+
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
| 683 |
+
|
| 684 |
+
return embs, pad_masks, att_masks
|
| 685 |
+
|
| 686 |
+
def forward(
|
| 687 |
+
self,
|
| 688 |
+
images,
|
| 689 |
+
img_masks,
|
| 690 |
+
lang_tokens,
|
| 691 |
+
lang_masks,
|
| 692 |
+
state,
|
| 693 |
+
actions,
|
| 694 |
+
noise=None,
|
| 695 |
+
time=None,
|
| 696 |
+
) -> Tensor:
|
| 697 |
+
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
| 698 |
+
if noise is None:
|
| 699 |
+
noise = self.sample_noise(actions.shape, actions.device)
|
| 700 |
+
|
| 701 |
+
if time is None:
|
| 702 |
+
time = self.sample_time(actions.shape[0], actions.device)
|
| 703 |
+
time_expanded = time[:, None, None]
|
| 704 |
+
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
| 705 |
+
u_t = noise - actions
|
| 706 |
+
|
| 707 |
+
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
| 708 |
+
images, img_masks, lang_tokens, lang_masks
|
| 709 |
+
)
|
| 710 |
+
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
|
| 711 |
+
state, x_t, time
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
| 715 |
+
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
| 716 |
+
|
| 717 |
+
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
| 718 |
+
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
| 719 |
+
|
| 720 |
+
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
| 721 |
+
attention_mask=att_2d_masks,
|
| 722 |
+
position_ids=position_ids,
|
| 723 |
+
past_key_values=None,
|
| 724 |
+
inputs_embeds=[prefix_embs, suffix_embs],
|
| 725 |
+
use_cache=False,
|
| 726 |
+
fill_kv_cache=False,
|
| 727 |
+
)
|
| 728 |
+
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
| 729 |
+
# Original openpi code, upcast attention output
|
| 730 |
+
suffix_out = suffix_out.to(dtype=torch.float32)
|
| 731 |
+
v_t = self.action_out_proj(suffix_out)
|
| 732 |
+
|
| 733 |
+
losses = F.mse_loss(u_t, v_t, reduction="none")
|
| 734 |
+
return losses
|
| 735 |
+
|
| 736 |
+
def sample_actions(
|
| 737 |
+
self, images, img_masks, lang_tokens, lang_masks, state, noise=None
|
| 738 |
+
) -> Tensor:
|
| 739 |
+
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
| 740 |
+
bsize = state.shape[0]
|
| 741 |
+
device = state.device
|
| 742 |
+
|
| 743 |
+
if noise is None:
|
| 744 |
+
actions_shape = (
|
| 745 |
+
bsize,
|
| 746 |
+
self.config.n_action_steps,
|
| 747 |
+
self.config.max_action_dim,
|
| 748 |
+
)
|
| 749 |
+
noise = self.sample_noise(actions_shape, device)
|
| 750 |
+
|
| 751 |
+
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
| 752 |
+
images, img_masks, lang_tokens, lang_masks
|
| 753 |
+
)
|
| 754 |
+
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
| 755 |
+
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
| 756 |
+
|
| 757 |
+
# Compute image and language key value cache
|
| 758 |
+
_, past_key_values = self.paligemma_with_expert.forward(
|
| 759 |
+
attention_mask=prefix_att_2d_masks,
|
| 760 |
+
position_ids=prefix_position_ids,
|
| 761 |
+
past_key_values=None,
|
| 762 |
+
inputs_embeds=[prefix_embs, None],
|
| 763 |
+
use_cache=self.config.use_cache,
|
| 764 |
+
fill_kv_cache=True,
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
dt = -1.0 / self.config.num_steps
|
| 768 |
+
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
| 769 |
+
|
| 770 |
+
x_t = noise
|
| 771 |
+
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
| 772 |
+
while time >= -dt / 2:
|
| 773 |
+
expanded_time = time.expand(bsize)
|
| 774 |
+
v_t = self.denoise_step(
|
| 775 |
+
state,
|
| 776 |
+
prefix_pad_masks,
|
| 777 |
+
past_key_values,
|
| 778 |
+
x_t,
|
| 779 |
+
expanded_time,
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
# Euler step
|
| 783 |
+
x_t += dt * v_t
|
| 784 |
+
time += dt
|
| 785 |
+
return x_t
|
| 786 |
+
|
| 787 |
+
def denoise_step(
|
| 788 |
+
self,
|
| 789 |
+
state,
|
| 790 |
+
prefix_pad_masks,
|
| 791 |
+
past_key_values,
|
| 792 |
+
x_t,
|
| 793 |
+
timestep,
|
| 794 |
+
):
|
| 795 |
+
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
| 796 |
+
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
|
| 797 |
+
state, x_t, timestep
|
| 798 |
+
)
|
| 799 |
+
|
| 800 |
+
suffix_len = suffix_pad_masks.shape[1]
|
| 801 |
+
batch_size = prefix_pad_masks.shape[0]
|
| 802 |
+
prefix_len = prefix_pad_masks.shape[1]
|
| 803 |
+
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
|
| 804 |
+
batch_size, suffix_len, prefix_len
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
| 808 |
+
|
| 809 |
+
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
| 810 |
+
|
| 811 |
+
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
| 812 |
+
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
| 813 |
+
|
| 814 |
+
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
| 815 |
+
attention_mask=full_att_2d_masks,
|
| 816 |
+
position_ids=position_ids,
|
| 817 |
+
past_key_values=past_key_values,
|
| 818 |
+
inputs_embeds=[None, suffix_embs],
|
| 819 |
+
use_cache=self.config.use_cache,
|
| 820 |
+
fill_kv_cache=False,
|
| 821 |
+
)
|
| 822 |
+
suffix_out = outputs_embeds[1]
|
| 823 |
+
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
| 824 |
+
suffix_out = suffix_out.to(dtype=torch.float32)
|
| 825 |
+
v_t = self.action_out_proj(suffix_out)
|
| 826 |
+
return v_t
|