adam.jax.jax_like#

Classes#

JaxLike

Wrapper class for Jax types

JaxLikeFactory

Generic factory. Give it (a) a Like class and (b) an xp namespace

SpatialMath

A drop-in SpatialMath that implements sin/cos/outer/concat/skew with the Array API.

Module Contents#

class adam.jax.jax_like.JaxLike[source]#

Bases: adam.core.array_api_math.ArrayAPILike

Wrapper class for Jax types

array: jax.numpy.array[source]#
class adam.jax.jax_like.JaxLikeFactory(spec: adam.core.array_api_math.ArraySpec | None = None)[source]#

Bases: adam.core.array_api_math.ArrayAPIFactory

Generic factory. Give it (a) a Like class and (b) an xp namespace (array_api_compat.* if available; otherwise the library module).

class adam.jax.jax_like.SpatialMath(spec: adam.core.array_api_math.ArraySpec | None = None)[source]#

Bases: adam.core.array_api_math.ArrayAPISpatialMath

A drop-in SpatialMath that implements sin/cos/outer/concat/skew with the Array API.

Works for NumPy, PyTorch, and JAX; CasADi should keep its own subclass.

solve(A: adam.core.array_api_math.ArrayAPILike, B: adam.core.array_api_math.ArrayAPILike) adam.core.array_api_math.ArrayAPILike[source]#

Override solve to handle JAX’s batched solve API correctly

JAX requires b to have shape (…, N, M) for batched solves, not just (…, N). This follows JAX’s recommendation: use solve(a, b[…, None]).squeeze(-1) for 1D solves.