Skip to content

Core

Core

cax.core.cs

Complex system module.

This module defines the abstract interface for complex systems simulated in CAX. A complex system encapsulates state transition dynamics over discrete time steps and a rendering routine to visualize states.

Subclasses must implement _step for a single-step transition and render for converting a state to an RGB image representation. The public __call__ method handles multi-step evolution with JAX/Flax scanning utilities.

ComplexSystem

Bases: Module

Base class for complex systems.

This class specifies the minimal interface for systems that evolve a State over time. It provides a JIT-compiled multi-step driver via __call__ that wraps the subclass-defined single-step transition _step.

Subclasses typically compose perception and update modules and may store hyperparameters and learned parameters within the Flax nnx.Module state.

Source code in src/cax/core/cs.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
class ComplexSystem(nnx.Module):
	"""Base class for complex systems.

	This class specifies the minimal interface for systems that evolve a `State` over time.
	It provides a JIT-compiled multi-step driver via `__call__` that wraps the subclass-defined
	single-step transition `_step`.

	Subclasses typically compose perception and update modules and may store hyperparameters
	and learned parameters within the Flax `nnx.Module` state.

	"""

	def _step(self, state: State, input: Input | None = None, *, sow: bool = False) -> State:
		"""Step the system by a single time step.

		Implementations should be side-effect free with respect to the provided `state` argument
		(unless leveraging Flax `sow`/`nnx.Intermediate` mechanics) and return the next state.
		Shapes and dtypes of `state` are system-specific but should be stable across steps.

		Args:
			state: Current state.
			input: Optional input.
			sow: Whether to sow intermediate values.

		Returns:
			Next state.

		"""
		raise NotImplementedError

	@partial(nnx.jit, static_argnames=("num_steps", "input_in_axis", "sow"))
	def __call__(
		self,
		state: State,
		input: Input | None = None,
		*,
		num_steps: int = 1,
		input_in_axis: int | None = None,
		sow: bool = False,
	) -> State:
		"""Step the system for multiple time steps.

		This method wraps `_step` inside a JAX scan for efficiency and JIT-compiles the loop.
		If `input` is time-varying, set `input_in_axis` to the axis containing the time
		dimension so that each step receives the corresponding slice of input.

		Args:
			state: Current state.
			input: Optional input.
			num_steps: Number of steps.
			input_in_axis: Axis for input if provided for each step.
			sow: Whether to sow intermediate values.

		Returns:
			Final state after `num_steps` applications of `_step`.

		"""
		state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry})
		state = nnx.scan(
			lambda cs, state, input: cs._step(state, input, sow=sow),
			in_axes=(state_axes, nnx.Carry, input_in_axis),
			out_axes=nnx.Carry,
			length=num_steps,
		)(self, state, input)

		return state

	@nnx.jit
	def render(self, state: State, **kwargs: Any) -> Array:
		"""Render state to RGB image.

		Implementations should return values in the range `[0, 255]` with dtype `uint8` and
		shape `(..., 3)` for RGB. For systems that naturally produce RGBA, either drop the alpha
		channel or composite it over a background in this method.

		Args:
			state: A state.
			**kwargs: Additional rendering-specific keyword arguments.

		Returns:
			An RGB image with dtype `uint8` and shape `(..., 3)`.

		"""
		raise NotImplementedError

__call__(state, input=None, *, num_steps=1, input_in_axis=None, sow=False)

Step the system for multiple time steps.

This method wraps _step inside a JAX scan for efficiency and JIT-compiles the loop. If input is time-varying, set input_in_axis to the axis containing the time dimension so that each step receives the corresponding slice of input.

Parameters:

Name Type Description Default
state State

Current state.

required
input Input | None

Optional input.

None
num_steps int

Number of steps.

1
input_in_axis int | None

Axis for input if provided for each step.

None
sow bool

Whether to sow intermediate values.

False

Returns:

Type Description
State

Final state after num_steps applications of _step.

Source code in src/cax/core/cs.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@partial(nnx.jit, static_argnames=("num_steps", "input_in_axis", "sow"))
def __call__(
	self,
	state: State,
	input: Input | None = None,
	*,
	num_steps: int = 1,
	input_in_axis: int | None = None,
	sow: bool = False,
) -> State:
	"""Step the system for multiple time steps.

	This method wraps `_step` inside a JAX scan for efficiency and JIT-compiles the loop.
	If `input` is time-varying, set `input_in_axis` to the axis containing the time
	dimension so that each step receives the corresponding slice of input.

	Args:
		state: Current state.
		input: Optional input.
		num_steps: Number of steps.
		input_in_axis: Axis for input if provided for each step.
		sow: Whether to sow intermediate values.

	Returns:
		Final state after `num_steps` applications of `_step`.

	"""
	state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry})
	state = nnx.scan(
		lambda cs, state, input: cs._step(state, input, sow=sow),
		in_axes=(state_axes, nnx.Carry, input_in_axis),
		out_axes=nnx.Carry,
		length=num_steps,
	)(self, state, input)

	return state

render(state, **kwargs)

Render state to RGB image.

Implementations should return values in the range [0, 255] with dtype uint8 and shape (..., 3) for RGB. For systems that naturally produce RGBA, either drop the alpha channel or composite it over a background in this method.

Parameters:

Name Type Description Default
state State

A state.

required
**kwargs Any

Additional rendering-specific keyword arguments.

{}

Returns:

Type Description
Array

An RGB image with dtype uint8 and shape (..., 3).

Source code in src/cax/core/cs.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
@nnx.jit
def render(self, state: State, **kwargs: Any) -> Array:
	"""Render state to RGB image.

	Implementations should return values in the range `[0, 255]` with dtype `uint8` and
	shape `(..., 3)` for RGB. For systems that naturally produce RGBA, either drop the alpha
	channel or composite it over a background in this method.

	Args:
		state: A state.
		**kwargs: Additional rendering-specific keyword arguments.

	Returns:
		An RGB image with dtype `uint8` and shape `(..., 3)`.

	"""
	raise NotImplementedError