# Copyright (C) Istituto Italiano di Tecnologia (IIT). All rights reserved.
from dataclasses import dataclass
from typing import Union, Sequence
import casadi as cs
import numpy.typing as npt
from adam.core.spatial_math import (
ArrayLike,
ArrayLikeFactory,
SpatialMath as _SpatialMath,
)
@dataclass
[docs]
class CasadiLike(ArrayLike):
"""Wrapper class for CasADi SX/DM with ArrayLike ops."""
[docs]
array: Union[cs.SX, cs.DM]
[docs]
def __matmul__(self, other: "CasadiLike") -> "CasadiLike":
return CasadiLike(cs.mtimes(self.array, other.array))
[docs]
def __rmatmul__(self, other: "CasadiLike") -> "CasadiLike":
return CasadiLike(cs.mtimes(other.array, self.array))
[docs]
def __mul__(self, other: "CasadiLike") -> "CasadiLike":
return CasadiLike(self.array * other.array)
[docs]
def __rmul__(self, other: "CasadiLike") -> "CasadiLike":
return CasadiLike(other.array * self.array)
[docs]
def __truediv__(self, other: "CasadiLike") -> "CasadiLike":
return CasadiLike(self.array / other.array)
[docs]
def __add__(self, other: "CasadiLike") -> "CasadiLike":
a, b = self.array, other.array
sa, sb = a.shape, b.shape
# Scalars always broadcast in CasADi
if sa == sb or (sa == (1, 1)) or (sb == (1, 1)):
return CasadiLike(a + b)
# If one is a vector and the other is the same vector transposed, align to self
# column (n,1) + row (1,n) is undefined for elementwise add; we only allow when shapes match after T
if sa == (sb[1], sb[0]) and (1 in sb):
return CasadiLike(a + b.T)
# If both are vectors with same length but different orientation, align to self
if sa[1] == 1 and sb[0] == 1 and sa[0] == sb[1]: # self col, other row
return CasadiLike(a + b.T)
if sa[0] == 1 and sb[1] == 1 and sa[1] == sb[0]: # self row, other col
return CasadiLike(a + b.T)
raise ValueError(f"Shape mismatch for add: {sa} + {sb}")
[docs]
def __radd__(self, other: "CasadiLike") -> "CasadiLike":
return CasadiLike(other.array).__add__(self)
[docs]
def __sub__(self, other: "CasadiLike") -> "CasadiLike":
return CasadiLike(self.array - other.array)
[docs]
def __rsub__(self, other: "CasadiLike") -> "CasadiLike":
return CasadiLike(other.array - self.array)
[docs]
def __neg__(self) -> "CasadiLike":
return CasadiLike(-self.array)
# --- indexing / shape / transpose ---
[docs]
def __getitem__(self, idx) -> "CasadiLike":
# CasADi is 2-D; strip ellipsis/None and keep the remaining 2 indices max.
if idx is Ellipsis:
return self
if isinstance(idx, tuple):
# remove Ellipsis
idx = tuple(i for i in idx if i is not Ellipsis)
# remove None (newaxis); CasADi doesn't support >2D, so just ignore
idx = tuple(i for i in idx if i is not None)
if not idx:
return self
if len(idx) == 1:
idx = idx[0]
elif len(idx) > 2:
# Keep only last two indices (row, col)
idx = idx[-2:]
return CasadiLike(self.array[idx])
@property
[docs]
def shape(self) -> tuple[int, ...]:
return self.array.shape
@property
[docs]
def ndim(self) -> int:
return len(self.array.shape)
@property
[docs]
def T(self) -> "CasadiLike":
return CasadiLike(self.array.T)
[docs]
class CasadiLikeFactory(ArrayLikeFactory):
"""ArrayLikeFactory for CasADi. Drops batch dims (>2) since CasADi is 2-D only."""
def __init__(self, xp: Union[cs.SX, cs.DM, None] = None):
[docs]
self._xp = cs.SX if xp is None else xp
# else:
# super().__init__(CasadiLike, xp)
[docs]
def zeros(self, *x: npt.ArrayLike) -> CasadiLike:
# Accept zeros((..batch.., r, c)) or zeros(r, c)
if len(x) == 1 and isinstance(x[0], (tuple, list)):
shp = tuple(x[0])
else:
shp = tuple(x)
if len(shp) > 2:
shp = shp[-2:]
if len(shp) == 0:
shp = (1,)
return CasadiLike(self._xp.zeros(*shp))
[docs]
def eye(self, x: npt.ArrayLike) -> CasadiLike:
# Accept eye(n) or eye((..batch.., n))
n = x[-1] if isinstance(x, (tuple, list)) else x
return CasadiLike(self._xp.eye(int(n)))
[docs]
def asarray(self, x) -> CasadiLike:
"""
Convert input to a CasadiLike array.
This method handles various input types and converts them to CasadiLike objects
using appropriate CasADi operations for concatenation and array construction.
Args:
x: Input to convert. Can be:
- Empty list: Returns empty CasadiLike array
- List of CasADi objects (cs.SX, cs.DM): Horizontally concatenated
- List of lists/tuples: Creates 2D array with vertical and horizontal concatenation
- Numbers or lists of numbers: Direct conversion to CasadiLike
Returns:
CasadiLike: A CasadiLike object wrapping the converted input.
Examples:
- Empty list [] -> CasadiLike with empty array
- [sx1, sx2] -> CasadiLike with horizontally concatenated SX objects
- [[1, 2], [3, 4]] -> CasadiLike with 2x2 matrix
- 5 or [1, 2, 3] -> CasadiLike with direct conversion
"""
# Handle empty list case
if isinstance(x, list):
if not x:
return CasadiLike(self._xp([]))
# List contains CasADi objects - concatenate horizontally
if any(isinstance(it, (cs.SX, cs.DM)) for it in x):
return CasadiLike(self._xp(cs.vertcat(*x)))
# List of lists/tuples - create 2D array with vertical and horizontal concatenation
if all(isinstance(it, (list, tuple)) for it in x):
return CasadiLike(
self._xp(
cs.vertcat(
*[cs.horzcat(*[self._xp(e) for e in it]) for it in x]
)
)
)
# Direct conversion for numbers or lists of numbers
return CasadiLike(self._xp(x))
[docs]
def zeros_like(self, x: CasadiLike) -> CasadiLike:
r, c = x.array.shape if len(x.array.shape) == 2 else (x.array.numel(), 1)
return CasadiLike(self._xp.zeros(r, c))
[docs]
def ones_like(self, x: CasadiLike) -> CasadiLike:
r, c = x.array.shape if len(x.array.shape) == 2 else (x.array.numel(), 1)
return CasadiLike(self._xp.ones(r, c))
[docs]
def tile(self, x: CasadiLike, reps: tuple) -> CasadiLike:
# No batching in CasADi: return input unchanged.
return x
[docs]
class SpatialMath(_SpatialMath):
"""CasADi backend for SpatialMath. Keeps the same high-level API."""
def __init__(self, spec=None):
super().__init__(CasadiLikeFactory(spec))
@staticmethod
[docs]
def sin(x: CasadiLike) -> CasadiLike:
return CasadiLike(cs.sin(x.array))
@staticmethod
[docs]
def cos(x: CasadiLike) -> CasadiLike:
return CasadiLike(cs.cos(x.array))
@staticmethod
[docs]
def skew(x: Union[CasadiLike, npt.ArrayLike]) -> CasadiLike:
a = x.array if isinstance(x, CasadiLike) else x
# Expect 3-vector; if it's a row, transpose; if scalar/empty, raise.
if isinstance(a, (cs.SX, cs.DM)) and a.is_empty():
raise ValueError("skew received empty array")
return CasadiLike(cs.skew(a))
@staticmethod
[docs]
def outer(x: CasadiLike, y: CasadiLike) -> CasadiLike:
return CasadiLike(cs.np.outer(x.array, y.array))
@staticmethod
[docs]
def vertcat(*x: CasadiLike) -> CasadiLike:
return CasadiLike(cs.vertcat(*[xi.array for xi in x]))
@staticmethod
[docs]
def horzcat(*x: CasadiLike) -> CasadiLike:
return CasadiLike(cs.horzcat(*[xi.array for xi in x]))
@staticmethod
[docs]
def stack(x: Sequence[CasadiLike], axis: int = 0) -> CasadiLike:
arrs = [xi.array for xi in x]
if axis in {-2, 0}:
return CasadiLike(cs.vertcat(*arrs))
if axis in {-1, 1}:
return CasadiLike(cs.horzcat(*arrs))
raise NotImplementedError(f"CasADi stack not implemented for axis={axis}")
@staticmethod
[docs]
def concatenate(x: Sequence[CasadiLike], axis: int = 0) -> CasadiLike:
"""
Concatenate a sequence of CasadiLike objects along a specified axis.
This function provides flexible concatenation behavior with special handling
for common use cases in CasADi operations.
Args:
x (Sequence[CasadiLike]): Sequence of CasadiLike objects to concatenate
axis (int, optional): Axis along which to concatenate. Defaults to 0.
- 0 or -2: Vertical concatenation (stack rows)
- 1 or -1: Horizontal concatenation (stack columns)
Returns:
CasadiLike: The concatenated result
Raises:
NotImplementedError: If axis is not in {-2, -1, 0, 1}
Special Cases:
- When axis=-1 and exactly 2 column vectors are provided, they are
vertically stacked to create a longer column vector
- For horizontal concatenation (axis=1 or -1), if arrays don't have
matching row dimensions, the function attempts to reshape them:
* 1D arrays are reshaped to column vectors
* Row vectors (1xn) are transposed to column vectors
* Other shapes are kept as-is and stacked vertically
Note:
The function uses CasADi's vertcat and horzcat functions internally
for the actual concatenation operations.
"""
arrs = [xi.array for xi in x]
# Friendly special-case: if axis == -1 and we have two column vectors, build a longer column
if axis == -1 and len(arrs) == 2:
a, b = arrs
if (
len(a.shape) == 2
and a.shape[1] == 1
and len(b.shape) == 2
and b.shape[1] == 1
):
return CasadiLike(cs.vertcat(a, b))
if axis in {-2, 0}: # vertical stack
return CasadiLike(cs.vertcat(*arrs))
if axis in {-1, 1}:
if all(arr.shape[0] == arrs[0].shape[0] for arr in arrs):
return CasadiLike(cs.horzcat(*arrs))
# Reshape to columns and stack vertically
cols = []
for A in arrs:
if len(A.shape) == 1:
cols.append(A.reshape((-1, 1)))
elif len(A.shape) == 2 and A.shape[1] != 1 and A.shape[0] == 1:
cols.append(A.T)
else:
cols.append(A)
return CasadiLike(cs.vertcat(*cols))
raise NotImplementedError(f"CasADi concatenate not implemented for axis={axis}")
@staticmethod
[docs]
def swapaxes(x: CasadiLike, axis1: int, axis2: int) -> CasadiLike:
# Only last-two or (0,1) swaps are meaningful in 2-D CasADi -> transpose.
if (axis1, axis2) in {(-1, -2), (-2, -1), (0, 1), (1, 0)}:
return CasadiLike(x.array.T)
raise NotImplementedError(
f"CasADi swapaxes not implemented for {axis1=}, {axis2=}"
)
[docs]
def tile(self, x: CasadiLike, reps: tuple) -> CasadiLike:
# matching ArrayLike API (no-op for CasADi)
return x
[docs]
def transpose(self, x: CasadiLike, dims: tuple) -> CasadiLike:
# Only 2-D supported; any request means "swap last two"
return CasadiLike(x.array.T)
@staticmethod
[docs]
def expand_dims(x: CasadiLike, axis: int) -> CasadiLike:
"""Expand dimensions of a CasADi array.
Args:
x: Input array (CasadiLike)
axis: Position where new axis is to be inserted
Returns:
CasadiLike: Array with expanded dimensions
"""
# If axis=-1, we're adding a column dimension to make it (n,1)
if axis == -1:
# Reshape to column vector
return CasadiLike(cs.reshape(x.array, (-1, 1)))
else:
# For other axes, just return as is (CasADi is 2D only)
return x
@staticmethod
[docs]
def inv(x: CasadiLike) -> CasadiLike:
"""Matrix inversion for CasADi.
Args:
x: Matrix to invert (CasadiLike)
Returns:
CasadiLike: Inverse of x
"""
return CasadiLike(cs.inv(x.array))
@staticmethod
[docs]
def solve(A: CasadiLike, B: CasadiLike) -> CasadiLike:
"""Solve linear system Ax = B for x using CasADi.
Args:
A: Coefficient matrix (CasadiLike)
B: Right-hand side vector or matrix (CasadiLike)
Returns:
CasadiLike: Solution x
"""
return CasadiLike(cs.solve(A.array, B.array))
@staticmethod
[docs]
def mtimes(A: CasadiLike, B: CasadiLike) -> CasadiLike:
"""Matrix-matrix multiplication for CasADi.
Args:
A: First matrix (CasadiLike)
B: Second matrix (CasadiLike)
Returns:
CasadiLike: Result of A @ B
"""
return CasadiLike(cs.mtimes(A.array, B.array))
@staticmethod
[docs]
def mxv(m: CasadiLike, v: CasadiLike) -> CasadiLike:
"""Matrix-vector multiplication for CasADi.
Args:
m: Matrix (CasadiLike)
v: Vector (CasadiLike)
Returns:
CasadiLike: Returns a *column* vector (n,1).
"""
return CasadiLike(cs.mtimes(m.array, v.array))
@staticmethod
[docs]
def vxs(v: CasadiLike, c: CasadiLike) -> CasadiLike:
"""
Vector times scalar multiplication for CasADi.
Args:
v: Vector (CasadiLike)
c: Scalar (CasadiLike)
Returns:
CasadiLike: Result of vector times scalar
"""
return CasadiLike(v.array * c.array)