Skip to content

Core

Core

cax.core.cs

Complex system module.

This module defines the abstract interface for complex systems simulated in CAX. A complex system encapsulates state transition dynamics over discrete time steps and a rendering routine to visualize states.

Subclasses must implement _step for a single-step transition and render for converting a state to an RGB image representation. The public __call__ method handles multi-step evolution with JAX/Flax scanning utilities.

ComplexSystem

Bases: Module

Base class for complex systems.

This class specifies the minimal interface for systems that evolve a State over time. It provides a JIT-compiled multi-step driver via __call__ that wraps the subclass-defined single-step transition _step.

Subclasses typically compose perception and update modules and may store hyperparameters and learned parameters within the Flax nnx.Module state.

Attributes:

Name Type Description
remat bool

If True, applies gradient checkpointing (rematerialization) to the scan body, trading compute for memory during backpropagation through long step sequences. Subclasses can set this as a class variable or instance attribute.

Source code in src/cax/core/cs.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class ComplexSystem[State, Input](nnx.Module):
	"""Base class for complex systems.

	This class specifies the minimal interface for systems that evolve a `State` over time.
	It provides a JIT-compiled multi-step driver via `__call__` that wraps the subclass-defined
	single-step transition `_step`.

	Subclasses typically compose perception and update modules and may store hyperparameters
	and learned parameters within the Flax `nnx.Module` state.

	Attributes:
		remat: If True, applies gradient checkpointing (rematerialization) to the scan body,
			trading compute for memory during backpropagation through long step sequences.
			Subclasses can set this as a class variable or instance attribute.

	"""

	remat: bool = False

	def _step(self, state: State, input: Input | None = None, *, sow: bool = False) -> State:
		"""Step the system by a single time step.

		Implementations should be side-effect free with respect to the provided `state` argument
		(unless leveraging Flax `sow`/`nnx.Intermediate` mechanics) and return the next state.
		Shapes and dtypes of `state` are system-specific but should be stable across steps.

		Args:
			state: Current state.
			input: Optional input.
			sow: Whether to sow intermediate values.

		Returns:
			Next state.

		"""
		raise NotImplementedError

	@nnx.jit(static_argnames=("num_steps", "input_in_axis", "sow"))
	def __call__(
		self,
		state: State,
		input: Input | None = None,
		*,
		num_steps: int = 1,
		input_in_axis: int | None = None,
		sow: bool = False,
	) -> State:
		"""Step the system for multiple time steps.

		This method wraps `_step` inside a JAX scan for efficiency and JIT-compiles the loop.
		If `input` is time-varying, set `input_in_axis` to the axis containing the time
		dimension so that each step receives the corresponding slice of input.

		When `remat` is enabled, the scan body is wrapped with `nnx.remat` to reduce memory
		usage during backpropagation at the cost of recomputing intermediates.

		Args:
			state: Current state.
			input: Optional input.
			num_steps: Number of steps.
			input_in_axis: Axis for input if provided for each step.
			sow: Whether to sow intermediate values.

		Returns:
			Final state after `num_steps` applications of `_step`.

		"""

		def step_fn(cs: ComplexSystem, state: State, input: Input | None) -> State:
			return cs._step(state, input, sow=sow)

		if self.remat:
			step_fn = nnx.remat(step_fn)

		state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry})
		state = nnx.scan(
			step_fn,
			in_axes=(state_axes, nnx.Carry, input_in_axis),
			out_axes=nnx.Carry,
			length=num_steps,
		)(self, state, input)

		return state

	@nnx.jit
	def render(self, state: State, **kwargs: Any) -> Array:
		"""Render state to RGB image.

		Implementations should return values in the range `[0, 255]` with dtype `uint8` and
		shape `(..., 3)` for RGB. For systems that naturally produce RGBA, either drop the alpha
		channel or composite it over a background in this method.

		Args:
			state: A state.
			**kwargs: Additional rendering-specific keyword arguments.

		Returns:
			An RGB image with dtype `uint8` and shape `(..., 3)`.

		"""
		raise NotImplementedError

__call__(state, input=None, *, num_steps=1, input_in_axis=None, sow=False)

Step the system for multiple time steps.

This method wraps _step inside a JAX scan for efficiency and JIT-compiles the loop. If input is time-varying, set input_in_axis to the axis containing the time dimension so that each step receives the corresponding slice of input.

When remat is enabled, the scan body is wrapped with nnx.remat to reduce memory usage during backpropagation at the cost of recomputing intermediates.

Parameters:

Name Type Description Default
state State

Current state.

required
input Input | None

Optional input.

None
num_steps int

Number of steps.

1
input_in_axis int | None

Axis for input if provided for each step.

None
sow bool

Whether to sow intermediate values.

False

Returns:

Type Description
State

Final state after num_steps applications of _step.

Source code in src/cax/core/cs.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@nnx.jit(static_argnames=("num_steps", "input_in_axis", "sow"))
def __call__(
	self,
	state: State,
	input: Input | None = None,
	*,
	num_steps: int = 1,
	input_in_axis: int | None = None,
	sow: bool = False,
) -> State:
	"""Step the system for multiple time steps.

	This method wraps `_step` inside a JAX scan for efficiency and JIT-compiles the loop.
	If `input` is time-varying, set `input_in_axis` to the axis containing the time
	dimension so that each step receives the corresponding slice of input.

	When `remat` is enabled, the scan body is wrapped with `nnx.remat` to reduce memory
	usage during backpropagation at the cost of recomputing intermediates.

	Args:
		state: Current state.
		input: Optional input.
		num_steps: Number of steps.
		input_in_axis: Axis for input if provided for each step.
		sow: Whether to sow intermediate values.

	Returns:
		Final state after `num_steps` applications of `_step`.

	"""

	def step_fn(cs: ComplexSystem, state: State, input: Input | None) -> State:
		return cs._step(state, input, sow=sow)

	if self.remat:
		step_fn = nnx.remat(step_fn)

	state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry})
	state = nnx.scan(
		step_fn,
		in_axes=(state_axes, nnx.Carry, input_in_axis),
		out_axes=nnx.Carry,
		length=num_steps,
	)(self, state, input)

	return state

render(state, **kwargs)

Render state to RGB image.

Implementations should return values in the range [0, 255] with dtype uint8 and shape (..., 3) for RGB. For systems that naturally produce RGBA, either drop the alpha channel or composite it over a background in this method.

Parameters:

Name Type Description Default
state State

A state.

required
**kwargs Any

Additional rendering-specific keyword arguments.

{}

Returns:

Type Description
Array

An RGB image with dtype uint8 and shape (..., 3).

Source code in src/cax/core/cs.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
@nnx.jit
def render(self, state: State, **kwargs: Any) -> Array:
	"""Render state to RGB image.

	Implementations should return values in the range `[0, 255]` with dtype `uint8` and
	shape `(..., 3)` for RGB. For systems that naturally produce RGBA, either drop the alpha
	channel or composite it over a background in this method.

	Args:
		state: A state.
		**kwargs: Additional rendering-specific keyword arguments.

	Returns:
		An RGB image with dtype `uint8` and shape `(..., 3)`.

	"""
	raise NotImplementedError