atnikos commited on
Commit
777e3d5
·
1 Parent(s): 7d87cc1

fix demo changes

Browse files
Files changed (2) hide show
  1. geometry_utils.py +1 -1
  2. transform3d.py +915 -0
geometry_utils.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from transform3d import transform_body_pose, apply_rot_delta, remove_z_rot, get_z_rot, change_for
3
 
4
  def diffout2motion(diffout, normalizer):
5
 
 
1
  import torch
2
+ from transform3d import transform_body_pose, apply_rot_delta, get_z_rot, change_for
3
 
4
  def diffout2motion(diffout, normalizer):
5
 
transform3d.py ADDED
@@ -0,0 +1,915 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange
2
+ import numpy as np
3
+ import torch
4
+ from torch import Tensor
5
+ from roma import rotmat_to_rotvec, rotvec_to_rotmat
6
+ from torch.nn.functional import pad
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
8
+ # Check PYTORCH3D_LICENCE before use
9
+
10
+ import functools
11
+ from typing import Optional
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+ import torch
17
+ from einops import rearrange
18
+
19
+
20
+ def rotate_trajectory(traj, rotZ, inverse=False):
21
+ '''
22
+ Rotate the trajectory of a given body
23
+ '''
24
+ if inverse:
25
+ # transpose
26
+ rotZ = rearrange(rotZ, "... i j -> ... j i")
27
+
28
+ vel = torch.diff(traj, dim=-2)
29
+ # 0 for the first one => keep the dimentionality
30
+ vel = torch.cat((0 * vel[..., [0], :], vel), dim=-2)
31
+ vel_local = torch.einsum("...kj,...k->...j", rotZ[..., :2, :2], vel[..., :2])
32
+ # Integrate the trajectory
33
+ traj_local = torch.cumsum(vel_local, dim=-2)
34
+ # First frame should be the same as before
35
+ traj_local = traj_local - traj_local[..., [0], :] + traj[..., [0], :]
36
+ return traj_local
37
+
38
+
39
+ def rotate_trans(trans, rotZ, inverse=False):
40
+ '''
41
+ Rotate the translation of a given body
42
+ '''
43
+
44
+ traj = trans[..., :2]
45
+ transZ = trans[..., 2]
46
+ traj_local = rotate_trajectory(traj, rotZ, inverse=inverse)
47
+ trans_local = torch.cat((traj_local, transZ[..., None]), axis=-1)
48
+ return trans_local
49
+
50
+
51
+ def rotate_body_canonic(rots, trans, offset=0.0):
52
+
53
+ '''
54
+ Rotate the whole body
55
+ '''
56
+
57
+ # rots, trans = data.rots.clone(), data.trans.clone()
58
+ global_poses = rots[..., 0, :, :]
59
+ global_euler = matrix_to_euler_angles(global_poses, "ZYX")
60
+ anglesZ, anglesY, anglesX = torch.unbind(global_euler, -1)
61
+ rotZ = _axis_angle_rotation("Z", anglesZ)
62
+
63
+ diff_mat_rotZ = rotZ[..., 1:, :, :] @ rotZ.transpose(-1, -2)[..., :-1, :, :]
64
+ vel_anglesZ = matrix_to_axis_angle(diff_mat_rotZ)[..., 2]
65
+ # padding "same"
66
+ vel_anglesZ = torch.cat((vel_anglesZ[..., [0]], vel_anglesZ), dim=-1)
67
+ # canonicalizing here
68
+ new_anglesZ = torch.cumsum(vel_anglesZ, -1) + offset
69
+ new_rotZ = _axis_angle_rotation("Z", new_anglesZ)
70
+
71
+ new_global_euler = torch.stack((new_anglesZ, anglesY, anglesX), -1)
72
+ new_global_orient = euler_angles_to_matrix(new_global_euler, "ZYX")
73
+
74
+ rots[:, 0] = new_global_orient
75
+ trans = rotate_trans(trans, rotZ[0], inverse=False)
76
+ trans = rotate_trans(trans, new_rotZ[0], inverse=True)
77
+ # trans = rotate_trans(trans, rotZ[0], inverse=True)
78
+
79
+ # from sinc.transforms.smpl import RotTransDatastruct
80
+ # return RotTransDatastruct(rots=rots, trans=trans)
81
+ return rots, trans
82
+
83
+
84
+
85
+
86
+
87
+
88
+ """
89
+ The transformation matrices returned from the functions in this file assume
90
+ the points on which the transformation will be applied are column vectors.
91
+ i.e. the R matrix is structured as
92
+
93
+ R = [
94
+ [Rxx, Rxy, Rxz],
95
+ [Ryx, Ryy, Ryz],
96
+ [Rzx, Rzy, Rzz],
97
+ ] # (3, 3)
98
+
99
+ This matrix can be applied to column vectors by post multiplication
100
+ by the points e.g.
101
+
102
+ points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
103
+ transformed_points = R * points
104
+
105
+ To apply the same matrix to points which are row vectors, the R matrix
106
+ can be transposed and pre multiplied by the points:
107
+
108
+ e.g.
109
+ points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
110
+ transformed_points = points * R.transpose(1, 0)
111
+ """
112
+
113
+
114
+ # Added
115
+ def matrix_of_angles(cos, sin, inv=False, dim=2):
116
+ assert dim in [2, 3]
117
+ sin = -sin if inv else sin
118
+ if dim == 2:
119
+ row1 = torch.stack((cos, -sin), axis=-1)
120
+ row2 = torch.stack((sin, cos), axis=-1)
121
+ return torch.stack((row1, row2), axis=-2)
122
+ elif dim == 3:
123
+ row1 = torch.stack((cos, -sin, 0*cos), axis=-1)
124
+ row2 = torch.stack((sin, cos, 0*cos), axis=-1)
125
+ row3 = torch.stack((0*sin, 0*cos, 1+0*cos), axis=-1)
126
+ return torch.stack((row1, row2, row3),axis=-2)
127
+
128
+
129
+ def quaternion_to_matrix(quaternions):
130
+ """
131
+ Convert rotations given as quaternions to rotation matrices.
132
+
133
+ Args:
134
+ quaternions: quaternions with real part first,
135
+ as tensor of shape (..., 4).
136
+
137
+ Returns:
138
+ Rotation matrices as tensor of shape (..., 3, 3).
139
+ """
140
+ r, i, j, k = torch.unbind(quaternions, -1)
141
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
142
+
143
+ o = torch.stack(
144
+ (
145
+ 1 - two_s * (j * j + k * k),
146
+ two_s * (i * j - k * r),
147
+ two_s * (i * k + j * r),
148
+ two_s * (i * j + k * r),
149
+ 1 - two_s * (i * i + k * k),
150
+ two_s * (j * k - i * r),
151
+ two_s * (i * k - j * r),
152
+ two_s * (j * k + i * r),
153
+ 1 - two_s * (i * i + j * j),
154
+ ),
155
+ -1,
156
+ )
157
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
158
+
159
+
160
+ def _copysign(a, b):
161
+ """
162
+ Return a tensor where each element has the absolute value taken from the,
163
+ corresponding element of a, with sign taken from the corresponding
164
+ element of b. This is like the standard copysign floating-point operation,
165
+ but is not careful about negative 0 and NaN.
166
+
167
+ Args:
168
+ a: source tensor.
169
+ b: tensor whose signs will be used, of the same shape as a.
170
+
171
+ Returns:
172
+ Tensor of the same shape as a with the signs of b.
173
+ """
174
+ signs_differ = (a < 0) != (b < 0)
175
+ return torch.where(signs_differ, -a, a)
176
+
177
+
178
+ def _sqrt_positive_part(x):
179
+ """
180
+ Returns torch.sqrt(torch.max(0, x))
181
+ but with a zero subgradient where x is 0.
182
+ """
183
+ ret = torch.zeros_like(x)
184
+ positive_mask = x > 0
185
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
186
+ return ret
187
+
188
+
189
+ def matrix_to_quaternion(matrix):
190
+ """
191
+ Convert rotations given as rotation matrices to quaternions.
192
+
193
+ Args:
194
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
195
+
196
+ Returns:
197
+ quaternions with real part first, as tensor of shape (..., 4).
198
+ """
199
+ if isinstance(matrix, np.ndarray):
200
+ matrix = torch.from_numpy(matrix)
201
+ if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
202
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
203
+ m00 = matrix[..., 0, 0]
204
+ m11 = matrix[..., 1, 1]
205
+ m22 = matrix[..., 2, 2]
206
+ o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
207
+ x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
208
+ y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
209
+ z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
210
+ o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
211
+ o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
212
+ o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
213
+ return torch.stack((o0, o1, o2, o3), -1)
214
+
215
+
216
+ def _axis_angle_rotation(axis: str, angle):
217
+ """
218
+ Return the rotation matrices for one of the rotations about an axis
219
+ of which Euler angles describe, for each value of the angle given.
220
+
221
+ Args:
222
+ axis: Axis label "X" or "Y or "Z".
223
+ angle: any shape tensor of Euler angles in radians
224
+
225
+ Returns:
226
+ Rotation matrices as tensor of shape (..., 3, 3).
227
+ """
228
+
229
+ cos = torch.cos(angle)
230
+ sin = torch.sin(angle)
231
+ one = torch.ones_like(angle)
232
+ zero = torch.zeros_like(angle)
233
+
234
+ if axis == "X":
235
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
236
+ if axis == "Y":
237
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
238
+ if axis == "Z":
239
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
240
+
241
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
242
+
243
+
244
+ def euler_angles_to_matrix(euler_angles, convention: str):
245
+ """
246
+ Convert rotations given as Euler angles in radians to rotation matrices.
247
+
248
+ Args:
249
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
250
+ convention: Convention string of three uppercase letters from
251
+ {"X", "Y", and "Z"}.
252
+
253
+ Returns:
254
+ Rotation matrices as tensor of shape (..., 3, 3).
255
+ """
256
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
257
+ raise ValueError("Invalid input euler angles.")
258
+ if len(convention) != 3:
259
+ raise ValueError("Convention must have 3 letters.")
260
+ if convention[1] in (convention[0], convention[2]):
261
+ raise ValueError(f"Invalid convention {convention}.")
262
+ for letter in convention:
263
+ if letter not in ("X", "Y", "Z"):
264
+ raise ValueError(f"Invalid letter {letter} in convention string.")
265
+ matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
266
+ return functools.reduce(torch.matmul, matrices)
267
+
268
+
269
+ def _angle_from_tan(
270
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
271
+ ):
272
+ """
273
+ Extract the first or third Euler angle from the two members of
274
+ the matrix which are positive constant times its sine and cosine.
275
+
276
+ Args:
277
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
278
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
279
+ convention.
280
+ data: Rotation matrices as tensor of shape (..., 3, 3).
281
+ horizontal: Whether we are looking for the angle for the third axis,
282
+ which means the relevant entries are in the same row of the
283
+ rotation matrix. If not, they are in the same column.
284
+ tait_bryan: Whether the first and third axes in the convention differ.
285
+
286
+ Returns:
287
+ Euler Angles in radians for each matrix in data as a tensor
288
+ of shape (...).
289
+ """
290
+
291
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
292
+ if horizontal:
293
+ i2, i1 = i1, i2
294
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
295
+ if horizontal == even:
296
+ return torch.atan2(data[..., i1], data[..., i2])
297
+ if tait_bryan:
298
+ return torch.atan2(-data[..., i2], data[..., i1])
299
+ return torch.atan2(data[..., i2], -data[..., i1])
300
+
301
+
302
+ def _index_from_letter(letter: str):
303
+ if letter == "X":
304
+ return 0
305
+ if letter == "Y":
306
+ return 1
307
+ if letter == "Z":
308
+ return 2
309
+
310
+
311
+ def matrix_to_euler_angles(matrix, convention: str):
312
+ """
313
+ Convert rotations given as rotation matrices to Euler angles in radians.
314
+
315
+ Args:
316
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
317
+ convention: Convention string of three uppercase letters.
318
+
319
+ Returns:
320
+ Euler angles in radians as tensor of shape (..., 3).
321
+ """
322
+ if len(convention) != 3:
323
+ raise ValueError("Convention must have 3 letters.")
324
+ if convention[1] in (convention[0], convention[2]):
325
+ raise ValueError(f"Invalid convention {convention}.")
326
+ for letter in convention:
327
+ if letter not in ("X", "Y", "Z"):
328
+ raise ValueError(f"Invalid letter {letter} in convention string.")
329
+ if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
330
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
331
+ i0 = _index_from_letter(convention[0])
332
+ i2 = _index_from_letter(convention[2])
333
+ tait_bryan = i0 != i2
334
+ if tait_bryan:
335
+ central_angle = torch.asin(
336
+ matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
337
+ )
338
+ else:
339
+ central_angle = torch.acos(matrix[..., i0, i0])
340
+
341
+ o = (
342
+ _angle_from_tan(
343
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
344
+ ),
345
+ central_angle,
346
+ _angle_from_tan(
347
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
348
+ ),
349
+ )
350
+ return torch.stack(o, -1)
351
+
352
+
353
+ def random_quaternions(
354
+ n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
355
+ ):
356
+ """
357
+ Generate random quaternions representing rotations,
358
+ i.e. versors with nonnegative real part.
359
+
360
+ Args:
361
+ n: Number of quaternions in a batch to return.
362
+ dtype: Type to return.
363
+ device: Desired device of returned tensor. Default:
364
+ uses the current device for the default tensor type.
365
+ requires_grad: Whether the resulting tensor should have the gradient
366
+ flag set.
367
+
368
+ Returns:
369
+ Quaternions as tensor of shape (N, 4).
370
+ """
371
+ o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
372
+ s = (o * o).sum(1)
373
+ o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
374
+ return o
375
+
376
+
377
+ def random_rotations(
378
+ n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
379
+ ):
380
+ """
381
+ Generate random rotations as 3x3 rotation matrices.
382
+
383
+ Args:
384
+ n: Number of rotation matrices in a batch to return.
385
+ dtype: Type to return.
386
+ device: Device of returned tensor. Default: if None,
387
+ uses the current device for the default tensor type.
388
+ requires_grad: Whether the resulting tensor should have the gradient
389
+ flag set.
390
+
391
+ Returns:
392
+ Rotation matrices as tensor of shape (n, 3, 3).
393
+ """
394
+ quaternions = random_quaternions(
395
+ n, dtype=dtype, device=device, requires_grad=requires_grad
396
+ )
397
+ return quaternion_to_matrix(quaternions)
398
+
399
+
400
+ def random_rotation(
401
+ dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
402
+ ):
403
+ """
404
+ Generate a single random 3x3 rotation matrix.
405
+
406
+ Args:
407
+ dtype: Type to return
408
+ device: Device of returned tensor. Default: if None,
409
+ uses the current device for the default tensor type
410
+ requires_grad: Whether the resulting tensor should have the gradient
411
+ flag set
412
+
413
+ Returns:
414
+ Rotation matrix as tensor of shape (3, 3).
415
+ """
416
+ return random_rotations(1, dtype, device, requires_grad)[0]
417
+
418
+
419
+ def standardize_quaternion(quaternions):
420
+ """
421
+ Convert a unit quaternion to a standard form: one in which the real
422
+ part is non negative.
423
+
424
+ Args:
425
+ quaternions: Quaternions with real part first,
426
+ as tensor of shape (..., 4).
427
+
428
+ Returns:
429
+ Standardized quaternions as tensor of shape (..., 4).
430
+ """
431
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
432
+
433
+
434
+ def quaternion_raw_multiply(a, b):
435
+ """
436
+ Multiply two quaternions.
437
+ Usual torch rules for broadcasting apply.
438
+
439
+ Args:
440
+ a: Quaternions as tensor of shape (..., 4), real part first.
441
+ b: Quaternions as tensor of shape (..., 4), real part first.
442
+
443
+ Returns:
444
+ The product of a and b, a tensor of quaternions shape (..., 4).
445
+ """
446
+ aw, ax, ay, az = torch.unbind(a, -1)
447
+ bw, bx, by, bz = torch.unbind(b, -1)
448
+ ow = aw * bw - ax * bx - ay * by - az * bz
449
+ ox = aw * bx + ax * bw + ay * bz - az * by
450
+ oy = aw * by - ax * bz + ay * bw + az * bx
451
+ oz = aw * bz + ax * by - ay * bx + az * bw
452
+ return torch.stack((ow, ox, oy, oz), -1)
453
+
454
+
455
+ def quaternion_multiply(a, b):
456
+ """
457
+ Multiply two quaternions representing rotations, returning the quaternion
458
+ representing their composition, i.e. the versor with nonnegative real part.
459
+ Usual torch rules for broadcasting apply.
460
+
461
+ Args:
462
+ a: Quaternions as tensor of shape (..., 4), real part first.
463
+ b: Quaternions as tensor of shape (..., 4), real part first.
464
+
465
+ Returns:
466
+ The product of a and b, a tensor of quaternions of shape (..., 4).
467
+ """
468
+ ab = quaternion_raw_multiply(a, b)
469
+ return standardize_quaternion(ab)
470
+
471
+
472
+ def quaternion_invert(quaternion):
473
+ """
474
+ Given a quaternion representing rotation, get the quaternion representing
475
+ its inverse.
476
+
477
+ Args:
478
+ quaternion: Quaternions as tensor of shape (..., 4), with real part
479
+ first, which must be versors (unit quaternions).
480
+
481
+ Returns:
482
+ The inverse, a tensor of quaternions of shape (..., 4).
483
+ """
484
+
485
+ return quaternion * quaternion.new_tensor([1, -1, -1, -1])
486
+
487
+
488
+ def quaternion_apply(quaternion, point):
489
+ """
490
+ Apply the rotation given by a quaternion to a 3D point.
491
+ Usual torch rules for broadcasting apply.
492
+
493
+ Args:
494
+ quaternion: Tensor of quaternions, real part first, of shape (..., 4).
495
+ point: Tensor of 3D points of shape (..., 3).
496
+
497
+ Returns:
498
+ Tensor of rotated points of shape (..., 3).
499
+ """
500
+ if point.shape[-1] != 3:
501
+ raise ValueError(f"Points are not in 3D, f{point.shape}.")
502
+ real_parts = point.new_zeros(point.shape[:-1] + (1,))
503
+ point_as_quaternion = torch.cat((real_parts, point), -1)
504
+ out = quaternion_raw_multiply(
505
+ quaternion_raw_multiply(quaternion, point_as_quaternion),
506
+ quaternion_invert(quaternion),
507
+ )
508
+ return out[..., 1:]
509
+
510
+
511
+ def axis_angle_to_matrix(axis_angle):
512
+ """
513
+ Convert rotations given as axis/angle to rotation matrices.
514
+
515
+ Args:
516
+ axis_angle: Rotations given as a vector in axis angle form,
517
+ as a tensor of shape (..., 3), where the magnitude is
518
+ the angle turned anticlockwise in radians around the
519
+ vector's direction.
520
+
521
+ Returns:
522
+ Rotation matrices as tensor of shape (..., 3, 3).
523
+ """
524
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
525
+
526
+
527
+ def matrix_to_axis_angle(matrix):
528
+ """
529
+ Convert rotations given as rotation matrices to axis/angle.
530
+
531
+ Args:
532
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
533
+
534
+ Returns:
535
+ Rotations given as a vector in axis angle form, as a tensor
536
+ of shape (..., 3), where the magnitude is the angle
537
+ turned anticlockwise in radians around the vector's
538
+ direction.
539
+ """
540
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
541
+
542
+
543
+ def axis_angle_to_quaternion(axis_angle):
544
+ """
545
+ Convert rotations given as axis/angle to quaternions.
546
+
547
+ Args:
548
+ axis_angle: Rotations given as a vector in axis angle form,
549
+ as a tensor of shape (..., 3), where the magnitude is
550
+ the angle turned anticlockwise in radians around the
551
+ vector's direction.
552
+
553
+ Returns:
554
+ quaternions with real part first, as tensor of shape (..., 4).
555
+ """
556
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
557
+ half_angles = 0.5 * angles
558
+ eps = 1e-6
559
+ small_angles = angles.abs() < eps
560
+ sin_half_angles_over_angles = torch.empty_like(angles)
561
+ try:
562
+ sin_half_angles_over_angles[~small_angles] = (
563
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
564
+ )
565
+ except:
566
+ torch.save(axis_angle, f'before_convert_axis_angle.pt')
567
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
568
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
569
+ sin_half_angles_over_angles[small_angles] = (
570
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
571
+ )
572
+ quaternions = torch.cat(
573
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
574
+ )
575
+ return quaternions
576
+
577
+
578
+ def quaternion_to_axis_angle(quaternions):
579
+ """
580
+ Convert rotations given as quaternions to axis/angle.
581
+
582
+ Args:
583
+ quaternions: quaternions with real part first,
584
+ as tensor of shape (..., 4).
585
+
586
+ Returns:
587
+ Rotations given as a vector in axis angle form, as a tensor
588
+ of shape (..., 3), where the magnitude is the angle
589
+ turned anticlockwise in radians around the vector's
590
+ direction.
591
+ """
592
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
593
+ half_angles = torch.atan2(norms, quaternions[..., :1])
594
+ angles = 2 * half_angles
595
+ eps = 1e-6
596
+ small_angles = angles.abs() < eps
597
+ sin_half_angles_over_angles = torch.empty_like(angles)
598
+ sin_half_angles_over_angles[~small_angles] = (
599
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
600
+ )
601
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
602
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
603
+ sin_half_angles_over_angles[small_angles] = (
604
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
605
+ )
606
+ return quaternions[..., 1:] / sin_half_angles_over_angles
607
+
608
+
609
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
610
+ """
611
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
612
+ using Gram--Schmidt orthogonalisation per Section B of [1].
613
+ Args:
614
+ d6: 6D rotation representation, of size (*, 6)
615
+
616
+ Returns:
617
+ batch of rotation matrices of size (*, 3, 3)
618
+
619
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
620
+ On the Continuity of Rotation Representations in Neural Networks.
621
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
622
+ Retrieved from http://arxiv.org/abs/1812.07035
623
+ """
624
+
625
+ a1, a2 = d6[..., :3], d6[..., 3:]
626
+ b1 = F.normalize(a1, dim=-1)
627
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
628
+ b2 = F.normalize(b2, dim=-1)
629
+ b3 = torch.cross(b1, b2, dim=-1)
630
+ return torch.stack((b1, b2, b3), dim=-2)
631
+
632
+
633
+ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
634
+ """
635
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
636
+ by dropping the last row. Note that 6D representation is not unique.
637
+ Args:
638
+ matrix: batch of rotation matrices of size (*, 3, 3)
639
+
640
+ Returns:
641
+ 6D rotation representation, of size (*, 6)
642
+
643
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
644
+ On the Continuity of Rotation Representations in Neural Networks.
645
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
646
+ Retrieved from http://arxiv.org/abs/1812.07035
647
+ """
648
+ return matrix[..., :2, :].clone().reshape(*matrix.shape[:-2], 6)
649
+
650
+
651
+ def rotate_trajectory(traj, rotZ, inverse=False):
652
+ if inverse:
653
+ # transpose
654
+ rotZ = rearrange(rotZ, "... i j -> ... j i")
655
+
656
+ vel = torch.diff(traj, dim=-2)
657
+ # 0 for the first one => keep the dimentionality
658
+ vel = torch.cat((0 * vel[..., [0], :], vel), dim=-2)
659
+ vel_local = torch.einsum("...kj,...k->...j", rotZ[..., :2, :2], vel[..., :2])
660
+ # Integrate the trajectory
661
+ traj_local = torch.cumsum(vel_local, dim=-2)
662
+ # First frame should be the same as before
663
+ traj_local = traj_local - traj_local[..., [0], :] + traj[..., [0], :]
664
+ return traj_local
665
+
666
+
667
+ def rotate_trans(trans, rotZ, inverse=False):
668
+ traj = trans[..., :2]
669
+ transZ = trans[..., 2]
670
+ traj_local = rotate_trajectory(traj, rotZ, inverse=inverse)
671
+ trans_local = torch.cat((traj_local, transZ[..., None]), axis=-1)
672
+ return trans_local
673
+
674
+
675
+
676
+ def canonicalize_rotations(global_orient, trans, angle=torch.pi/4):
677
+ global_euler = matrix_to_euler_angles(global_orient, "ZYX")
678
+ anglesZ, anglesY, anglesX = torch.unbind(global_euler, -1)
679
+
680
+ rotZ = _axis_angle_rotation("Z", anglesZ)
681
+
682
+ # remove the current rotation
683
+ # make it local
684
+ local_trans = rotate_trans(trans, rotZ)
685
+
686
+ # For information:
687
+ # rotate_joints(joints, rotZ) == joints_local
688
+
689
+ diff_mat_rotZ = rotZ[..., 1:, :, :] @ rotZ.transpose(-1, -2)[..., :-1, :, :]
690
+
691
+ vel_anglesZ = matrix_to_axis_angle(diff_mat_rotZ)[..., 2]
692
+ # padding "same"
693
+ vel_anglesZ = torch.cat((vel_anglesZ[..., [0]], vel_anglesZ), dim=-1)
694
+
695
+ # Compute new rotation:
696
+ # canonicalized
697
+ anglesZ = torch.cumsum(vel_anglesZ, -1)
698
+ anglesZ += angle
699
+ rotZ = _axis_angle_rotation("Z", anglesZ)
700
+
701
+ new_trans = rotate_trans(local_trans, rotZ, inverse=True)
702
+
703
+ new_global_euler = torch.stack((anglesZ, anglesY, anglesX), -1)
704
+ new_global_orient = euler_angles_to_matrix(new_global_euler, "ZYX")
705
+
706
+ return new_global_orient, new_trans
707
+
708
+
709
+ def rotate_motion_canonical(rotations, translation, transl_zero=True):
710
+ """
711
+ Must be of shape S x (Jx3)
712
+ """
713
+ rots_motion = rotations
714
+ trans_motion = translation
715
+ datum_len = rotations.shape[0]
716
+ rots_motion_rotmat = transform_body_pose(rots_motion.reshape(datum_len,
717
+ -1, 3),
718
+ 'aa->rot')
719
+ orient_R_can, trans_can = canonicalize_rotations(rots_motion_rotmat[:,
720
+ 0],
721
+ trans_motion)
722
+ rots_motion_rotmat_can = rots_motion_rotmat
723
+ rots_motion_rotmat_can[:, 0] = orient_R_can
724
+
725
+ rots_motion_aa_can = transform_body_pose(rots_motion_rotmat_can,
726
+ 'rot->aa')
727
+ rots_motion_aa_can = rearrange(rots_motion_aa_can, 'F J d -> F (J d)',
728
+ d=3)
729
+ if transl_zero:
730
+ translation_can = trans_can - trans_can[0]
731
+ else:
732
+ translation_can = trans_can
733
+
734
+ return rots_motion_aa_can, translation_can
735
+
736
+ def transform_body_pose(pose, formats):
737
+ """
738
+ various angle transformations, transforms input to torch.Tensor
739
+ input:
740
+ - pose: pose tensor
741
+ - formats: string denoting the input-output angle format
742
+ """
743
+ if isinstance(pose, np.ndarray):
744
+ pose = torch.from_numpy(pose)
745
+ if formats == "6d->aa":
746
+ j = pose.shape[-1] / 6
747
+ pose = rearrange(pose, '... (j d) -> ... j d', d=6)
748
+ pose = pose.squeeze(-2) # in case of only one angle
749
+ pose = rotation_6d_to_matrix(pose)
750
+ pose = matrix_to_axis_angle(pose)
751
+ if j > 1:
752
+ pose = rearrange(pose, '... j d -> ... (j d)')
753
+ elif formats == "aa->6d":
754
+ j = pose.shape[-1] / 3
755
+ pose = rearrange(pose, '... (j c) -> ... j c', c=3)
756
+ pose = pose.squeeze(-2) # in case of only one angle
757
+ # axis-angle to rotation matrix & drop last row
758
+ pose = matrix_to_rotation_6d(axis_angle_to_matrix(pose))
759
+ if j > 1:
760
+ pose = rearrange(pose, '... j d -> ... (j d)')
761
+ elif formats == "aa->rot":
762
+ j = pose.shape[-1] / 3
763
+ pose = rearrange(pose, '... (j c) -> ... j c', c=3)
764
+ pose = pose.squeeze(-2) # in case of only one angle
765
+ # axis-angle to rotation matrix & drop last row
766
+ pose = torch.clamp(axis_angle_to_matrix(pose), min=-1.0, max=1.0)
767
+ elif formats == "6d->rot":
768
+ j = pose.shape[-1] / 6
769
+ pose = rearrange(pose, '... (j d) -> ... j d', d=6)
770
+ pose = pose.squeeze(-2) # in case of only one angle
771
+ pose = torch.clamp(rotation_6d_to_matrix(pose), min=-1.0, max=1.0)
772
+ elif formats == "rot->aa":
773
+ # pose = rearrange(pose, '... (j d1 d2) -> ... j d1 d2', d1=3, d2=3)
774
+ pose = matrix_to_axis_angle(pose)
775
+ elif formats == "rot->6d":
776
+ # pose = rearrange(pose, '... (j d1 d2) -> ... j d1 d2', d1=3, d2=3)
777
+ pose = matrix_to_rotation_6d(pose)
778
+ else:
779
+ raise ValueError(f"specified conversion format is invalid: {formats}")
780
+ return pose
781
+
782
+ def apply_rot_delta(rots, deltas, in_format="6d", out_format="6d"):
783
+ """
784
+ rots needs to have same dimentionality as delta
785
+ """
786
+ assert rots.shape == deltas.shape
787
+ if in_format == "aa":
788
+ j = rots.shape[-1] / 3
789
+ elif in_format == "6d":
790
+ j = rots.shape[-1] / 6
791
+ else:
792
+ raise ValueError(f"specified conversion format is unsupported: {in_format}")
793
+ rots = transform_body_pose(rots, f"{in_format}->rot")
794
+ deltas = transform_body_pose(deltas, f"{in_format}->rot")
795
+ new_rots = torch.einsum("...ij,...jk->...ik", rots, deltas) # Ri+1=Ri@delta
796
+ new_rots = transform_body_pose(new_rots, f"rot->{out_format}")
797
+ if j > 1:
798
+ new_rots = rearrange(new_rots, '... j d -> ... (j d)')
799
+ return new_rots
800
+
801
+ def rot_diff(rots1, rots2=None, in_format="6d", out_format="6d"):
802
+ """
803
+ dim 0 is considered to be the time dimention, this is where the shift will happen
804
+ """
805
+ self_diff = False
806
+ if in_format == "aa":
807
+ j = rots1.shape[-1] / 3
808
+ elif in_format == "6d":
809
+ j = rots1.shape[-1] / 6
810
+ else:
811
+ raise ValueError(f"specified conversion format is unsupported: {in_format}")
812
+ rots1 = transform_body_pose(rots1, f"{in_format}->rot")
813
+ if rots2 is not None:
814
+ rots2 = transform_body_pose(rots2, f"{in_format}->rot")
815
+ else:
816
+ self_diff = True
817
+ rots2 = rots1
818
+ rots1 = rots1.roll(1, 0)
819
+
820
+ rots_diff = torch.einsum("...ij,...ik->...jk", rots1, rots2) # Ri.T@R_i+1
821
+ if self_diff:
822
+ rots_diff[0, ..., :, :] = torch.eye(3, device=rots1.device)
823
+
824
+ rots_diff = transform_body_pose(rots_diff, f"rot->{out_format}")
825
+ if j > 1:
826
+ rots_diff = rearrange(rots_diff, '... j d -> ... (j d)')
827
+ return rots_diff
828
+
829
+ def change_for(p, R, T=0, forward=True):
830
+ """
831
+ Change frame of reference for vector p
832
+ p: vector in original coordinate frame
833
+ R: rotation matrix of new coordinate frame ([x, y, z] format)
834
+ T: translation of new coordinate frame
835
+ Let angle R by a.
836
+ forward: rotates the coordinate frame by -a (True) or rotate the point
837
+ by +a.
838
+ """
839
+ if forward: # R.T @ (p_global - pelvis_translation)
840
+ return torch.einsum('...di,...d->...i', R, p - T)
841
+ else: # R @ (p_global - pelvis_translation)
842
+ return torch.einsum('...di,...i->...d', R, p) + T
843
+
844
+ def get_z_rot(rot_, in_format="6d"):
845
+ rot = rot_.clone().detach()
846
+ rot = transform_body_pose(rot, f"{in_format}->rot")
847
+ euler_z = matrix_to_euler_angles(rot, "ZYX")
848
+ euler_z[..., 1:] = 0.0
849
+ z_rot = torch.clamp(
850
+ euler_angles_to_matrix(euler_z, "ZYX"),
851
+ min=-1.0, max=1.0) # add zero XY euler angles
852
+ return z_rot
853
+
854
+ def remove_z_rot(pose, in_format="6d", out_format="6d"):
855
+ """
856
+ zero-out the global orientation around Z axis
857
+ """
858
+ assert out_format == "6d"
859
+ if isinstance(pose, np.ndarray):
860
+ pose = torch.from_numpy(pose)
861
+ # transform to matrix
862
+ pose = transform_body_pose(pose, f"{in_format}->rot")
863
+ pose = matrix_to_euler_angles(pose, "ZYX")
864
+ pose[..., 0] = 0
865
+ pose = matrix_to_rotation_6d(torch.clamp(
866
+ euler_angles_to_matrix(pose, "ZYX"),
867
+ min=-1.0, max=1.0))
868
+ return pose
869
+
870
+ def local_to_global_orient(body_orient: Tensor, poses: Tensor, parents: list,
871
+ input_format='aa', output_format='aa'):
872
+ """
873
+ Modified from aitviewer
874
+ Convert relative joint angles to global by unrolling the kinematic chain.
875
+ This function is fully differentiable ;)
876
+ :param poses: A tensor of shape (N, N_JOINTS*d) defining the relative poses in angle-axis format.
877
+ :param parents: A list of parents for each joint j, i.e. parent[j] is the parent of joint j.
878
+ :param output_format: 'aa' for axis-angle or 'rotmat' for rotation matrices.
879
+ :param input_format: 'aa' or 'rotmat' ...
880
+ :return: The global joint angles as a tensor of shape (N, N_JOINTS*DOF).
881
+ """
882
+ assert output_format in ['aa', 'rotmat']
883
+ assert input_format in ['aa', 'rotmat']
884
+ dof = 3 if input_format == 'aa' else 9
885
+ n_joints = poses.shape[-1] // dof + 1
886
+ if input_format == 'aa':
887
+ body_orient = rotvec_to_rotmat(body_orient)
888
+ local_oris = rotvec_to_rotmat(rearrange(poses, '... (j d) -> ... j d', d=3))
889
+ local_oris = torch.cat((body_orient[..., None, :, :], local_oris), dim=-3)
890
+ else:
891
+ # this part has not been tested
892
+ local_oris = torch.cat((body_orient[..., None, :, :], local_oris), dim=-3)
893
+ global_oris_ = []
894
+
895
+ # Apply the chain rule starting from the pelvis
896
+ for j in range(n_joints):
897
+ if parents[j] < 0:
898
+ # root
899
+ global_oris_.append(local_oris[..., j, :, :])
900
+ else:
901
+ parent_rot = global_oris_[parents[j]]
902
+ local_rot = local_oris[..., j, :, :]
903
+ global_oris_.append(torch.einsum('...ij,...jk->...ik', parent_rot, local_rot))
904
+ # global_oris[..., j, :, :] = torch.bmm(parent_rot, local_rot)
905
+ global_oris = torch.stack(global_oris_, dim=1)
906
+ # global_oris: ... x J x 3 x 3
907
+ # account for the body's root orientation
908
+ # global_oris = torch.einsum('...ij,...jk->...ik', body_orient[..., None, :, :], global_oris)
909
+
910
+ if output_format == 'aa':
911
+ return rotmat_to_rotvec(global_oris)
912
+ # res = global_oris.reshape((-1, n_joints * 3))
913
+ else:
914
+ return global_oris
915
+ # return res