from adam.core.spatial_math import ArrayLike, SpatialMath, ArrayLikeFactory
import array_api_compat as aac
from typing import Any, Callable
from dataclasses import dataclass
from types import ModuleType
from typing import Any, Optional
import array_api_compat as aac
@dataclass(frozen=True)
[docs]
class ArraySpec:
[docs]
xp: ModuleType # array API namespace (compat-wrapped if needed)
[docs]
dtype: Optional[Any] # xp.float32, torch.float32, jnp.float32, etc.
[docs]
device: Optional[Any] # xp device object (torch device, jax device, "cpu", ...)
[docs]
def spec_from_reference(ref: Any) -> ArraySpec:
# Force compat namespace when available (useful for PyTorch/JAX).
# JAX doesn't have an array-api-compat wrapper, so use use_compat=False for JAX
try:
xp = aac.array_namespace(ref, use_compat=True)
except ValueError as e:
if "JAX does not have an array-api-compat wrapper" in str(e):
xp = aac.array_namespace(ref, use_compat=False)
else:
raise
dtype = getattr(ref, "dtype", None)
# aac.device(x) provides spec-like device, including a CPU device for NumPy.
try:
device = aac.device(ref)
except Exception:
device = getattr(ref, "device", None)
return ArraySpec(xp=xp, dtype=dtype, device=device)
[docs]
def xp_getter(*xs: Any):
return aac.array_namespace(*xs) # use_compat=True?
@dataclass
[docs]
class ArrayAPILike(ArrayLike):
"""Generic Array-API-style wrapper used by NumPy/JAX/Torch backends."""
[docs]
def __getitem__(self, idx) -> "ArrayAPILike":
return self.__class__(self.array[idx])
@property
[docs]
def shape(self):
return self.array.shape
[docs]
def reshape(self, *args):
xp = xp_getter(self.array)
return xp.reshape(self.array, *args)
@property
[docs]
def T(self) -> "ArrayAPILike":
if getattr(self.array, "ndim", 0) == 0:
return self.__class__(self.array)
xp = xp_getter(self.array)
return self.__class__(
xp.swapaxes(self.array, 0, -1) # if self.array.ndim != 0 else self.array
)
[docs]
def __matmul__(self, other):
xp = xp_getter(self.array, other.array)
return self.__class__(xp.matmul(self.array, other.array))
[docs]
def __rmatmul__(self, other) -> "ArrayAPILike":
xp = xp_getter(self.array, other.array)
return self.__class__(xp.matmul(other.array, self.array))
[docs]
def __mul__(self, other) -> "ArrayAPILike":
xp = xp_getter(self.array, other.array)
return self.__class__(xp.multiply(self.array, other.array))
[docs]
def __rmul__(self, other) -> "ArrayAPILike":
xp = xp_getter(self.array, other.array)
return self.__class__(xp.multiply(self.array, other.array))
[docs]
def __truediv__(self, other) -> "ArrayAPILike":
xp = xp_getter(self.array, other.array)
return self.__class__(xp.divide(self.array, other.array))
[docs]
def __add__(self, other) -> "ArrayAPILike":
xp = xp_getter(self.array, other.array)
return self.__class__(xp.add(self.array, other.array))
[docs]
def __radd__(self, other) -> "ArrayAPILike":
xp = xp_getter(self.array, other.array)
return self.__class__(other.array + self.array)
[docs]
def __sub__(self, other) -> "ArrayAPILike":
xp = xp_getter(self.array, other.array)
return self.__class__(self.array - other.array)
[docs]
def __rsub__(self, other) -> "ArrayAPILike":
xp = xp_getter(self.array, other.array)
return self.__class__(xp.squeeze(other.array) - xp.squeeze(self.array))
[docs]
def __neg__(self) -> "ArrayAPILike":
return self.__class__(-self.array)
@property
[docs]
def ndim(self):
return self.array.ndim
[docs]
class ArrayAPIFactory(ArrayLikeFactory):
"""
Generic factory. Give it (a) a Like class and (b) an xp namespace
(array_api_compat.* if available; otherwise the library module).
"""
def __init__(self, like_cls, xp, *, dtype=None, device=None):
[docs]
def zeros(self, *shape) -> ArrayAPILike:
# Handle tuple concatenation like H.shape[:-2] + (1, 4)
if len(shape) == 1 and isinstance(shape[0], (tuple, list)):
final_shape = shape[0]
else:
final_shape = shape
x = self._xp.zeros(final_shape, dtype=self._dtype, device=self._device)
return self._like(x)
[docs]
def eye(self, *shape) -> ArrayAPILike:
if len(shape) == 1 and isinstance(shape[0], (tuple, list)):
x = shape[0][-1]
batch = shape[0][:-1]
else:
batch = shape[:-1]
x = shape[-1]
return self._like(
self._xp.eye(x, dtype=self._dtype, device=self._device)
if batch is None
else self._xp.broadcast_to(
self._xp.eye(x, dtype=self._dtype, device=self._device), batch + (x, x)
)
)
[docs]
def asarray(self, x) -> ArrayAPILike:
# preserve the gradient if x is a torch tensor (check if it has "requires_grad_" attribute)
# it could be moved to the torch-like class, but maybe here is more visible
if getattr(x, "requires_grad_", False):
return self._like(x.to(device=self._device, dtype=self._dtype))
return self._like(self._xp.asarray(x, dtype=self._dtype, device=self._device))
[docs]
def zeros_like(self, x: ArrayAPILike) -> ArrayAPILike:
return self._like(self._xp.zeros_like(x.array, dtype=x.array.dtype))
[docs]
def ones_like(self, x: ArrayAPILike) -> ArrayAPILike:
return self._like(self._xp.ones_like(x.array, dtype=x.array.dtype))
[docs]
def tile(self, x: ArrayAPILike, reps: tuple) -> ArrayAPILike:
return self._like(self._xp.tile(x.array, reps))
[docs]
class ArrayAPISpatialMath(SpatialMath):
"""A drop-in SpatialMath that implements sin/cos/outer/concat/skew with the Array API.
Works for NumPy, PyTorch, and JAX; CasADi should keep its own subclass.
"""
def __init__(self, factory, xp_getter: Callable[..., Any] = xp_getter):
super().__init__(factory)
[docs]
self._xp_getter = xp_getter
[docs]
def _xp(self, *xs: Any):
return self._xp_getter(*xs)
[docs]
def sin(self, x):
xp = self._xp(x.array)
x = x.array
return self.factory.asarray(xp.sin(x))
[docs]
def cos(self, x):
xp = self._xp(x.array)
x = x.array
return self.factory.asarray(xp.cos(x))
[docs]
def skew(self, x):
xp = self._xp(x.array)
a = x.array
# if x is batched (shape (B, 3, 1)), remove the last dimension
if a.ndim >= 2 and a.shape[-1] == 1:
a = a[..., 0]
x0, x1, x2 = a[..., 0], a[..., 1], a[..., 2]
z = x0 * 0
row0 = xp.stack([z, -x2, x1], axis=-1)
row1 = xp.stack([x2, z, -x0], axis=-1)
row2 = xp.stack([-x1, x0, z], axis=-1)
return self.factory.asarray(xp.stack([row0, row1, row2], axis=-2)) # (...,3,3)
[docs]
def outer(self, x, y):
xp = self._xp(x.array, y.array)
a = x.array
b = y.array
# normalize to (...,3)
if a.ndim >= 2 and a.shape[-2] == 3 and a.shape[-1] == 1:
a = a[..., :, 0]
if b.ndim >= 2 and b.shape[-2] == 3 and b.shape[-1] == 1:
b = b[..., :, 0]
# (...,3,1) @ (...,1,3) -> (...,3,3)
A = a[..., :, None]
B = b[..., None, :]
return self.factory.asarray(xp.matmul(A, B))
[docs]
def vertcat(self, *x):
xp = self._xp(*[xi.array for xi in x])
return self.factory.asarray(xp.vstack([xi.array for xi in x]))
[docs]
def horzcat(self, *x):
xp = self._xp(*[xi.array for xi in x])
return self.factory.asarray(xp.hstack([xi.array for xi in x]))
[docs]
def stack(self, x, axis=0):
xp = self._xp(x[0].array)
return self.factory.asarray(xp.stack([xi.array for xi in x], axis=axis))
[docs]
def concatenate(self, x, axis=0):
xp = self._xp(x[0].array)
return self.factory.asarray(xp.concatenate([xi.array for xi in x], axis=axis))
[docs]
def swapaxes(self, x: ArrayAPILike, axis1: int, axis2: int) -> ArrayAPILike:
xp = self._xp(x.array)
return self.factory.asarray(xp.swapaxes(x.array, axis1, axis2))
[docs]
def expand_dims(self, x: ArrayAPILike, axis: int) -> ArrayAPILike:
xp = self._xp(x.array)
return self.factory.asarray(xp.expand_dims(x.array, axis=axis))
[docs]
def transpose(self, x: ArrayAPILike, dims: tuple) -> ArrayAPILike:
xp = self._xp(x.array)
return self.factory.asarray(xp.permute_dims(x.array, dims))
[docs]
def inv(self, x: ArrayAPILike) -> ArrayAPILike:
xp = self._xp(x.array)
return self.factory.asarray(xp.linalg.inv(x.array))
[docs]
def mtimes(self, A: ArrayAPILike, B: ArrayAPILike) -> ArrayAPILike:
xp = self._xp(A.array, B.array)
return self.factory.asarray(xp.matmul(A.array, B.array))
[docs]
def solve(self, A: ArrayAPILike, B: ArrayAPILike) -> ArrayAPILike:
xp = self._xp(A.array, B.array)
return self.factory.asarray(xp.linalg.solve(A.array, B.array))