Core
Core
cax.core.cs
Complex system module.
This module defines the abstract interface for complex systems simulated in CAX. A complex system encapsulates state transition dynamics over discrete time steps and a rendering routine to visualize states.
Subclasses must implement _step for a single-step transition and render for converting
a state to an RGB image representation. The public __call__ method handles multi-step
evolution with JAX/Flax scanning utilities.
ComplexSystem
Bases: Module
Base class for complex systems.
This class specifies the minimal interface for systems that evolve a State over time.
It provides a JIT-compiled multi-step driver via __call__ that wraps the subclass-defined
single-step transition _step.
Subclasses typically compose perception and update modules and may store hyperparameters
and learned parameters within the Flax nnx.Module state.
Attributes:
| Name | Type | Description |
|---|---|---|
remat |
bool
|
If True, applies gradient checkpointing (rematerialization) to the scan body, trading compute for memory during backpropagation through long step sequences. Subclasses can set this as a class variable or instance attribute. |
Source code in src/cax/core/cs.py
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 | |
__call__(state, input=None, *, num_steps=1, input_in_axis=None, sow=False)
Step the system for multiple time steps.
This method wraps _step inside a JAX scan for efficiency and JIT-compiles the loop.
If input is time-varying, set input_in_axis to the axis containing the time
dimension so that each step receives the corresponding slice of input.
When remat is enabled, the scan body is wrapped with nnx.remat to reduce memory
usage during backpropagation at the cost of recomputing intermediates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
Current state. |
required |
input
|
Input | None
|
Optional input. |
None
|
num_steps
|
int
|
Number of steps. |
1
|
input_in_axis
|
int | None
|
Axis for input if provided for each step. |
None
|
sow
|
bool
|
Whether to sow intermediate values. |
False
|
Returns:
| Type | Description |
|---|---|
State
|
Final state after |
Source code in src/cax/core/cs.py
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | |
render(state, **kwargs)
Render state to RGB image.
Implementations should return values in the range [0, 255] with dtype uint8 and
shape (..., 3) for RGB. For systems that naturally produce RGBA, either drop the alpha
channel or composite it over a background in this method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
A state. |
required |
**kwargs
|
Any
|
Additional rendering-specific keyword arguments. |
{}
|
Returns:
| Type | Description |
|---|---|
Array
|
An RGB image with dtype |
Source code in src/cax/core/cs.py
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 | |