Core
Core
cax.core.cs
Complex system module.
This module defines the abstract interface for complex systems simulated in CAX. A complex system encapsulates state transition dynamics over discrete time steps and a rendering routine to visualize states.
Subclasses must implement _step for a single-step transition and render for converting
a state to an RGB image representation. The public __call__ method handles multi-step
evolution with JAX/Flax scanning utilities.
ComplexSystem
Bases: Module
Base class for complex systems.
This class specifies the minimal interface for systems that evolve a State over time.
It provides a JIT-compiled multi-step driver via __call__ that wraps the subclass-defined
single-step transition _step.
Subclasses typically compose perception and update modules and may store hyperparameters
and learned parameters within the Flax nnx.Module state.
Source code in src/cax/core/cs.py
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | |
__call__(state, input=None, *, num_steps=1, input_in_axis=None, sow=False)
Step the system for multiple time steps.
This method wraps _step inside a JAX scan for efficiency and JIT-compiles the loop.
If input is time-varying, set input_in_axis to the axis containing the time
dimension so that each step receives the corresponding slice of input.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
Current state. |
required |
input
|
Input | None
|
Optional input. |
None
|
num_steps
|
int
|
Number of steps. |
1
|
input_in_axis
|
int | None
|
Axis for input if provided for each step. |
None
|
sow
|
bool
|
Whether to sow intermediate values. |
False
|
Returns:
| Type | Description |
|---|---|
State
|
Final state after |
Source code in src/cax/core/cs.py
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | |
render(state, **kwargs)
Render state to RGB image.
Implementations should return values in the range [0, 255] with dtype uint8 and
shape (..., 3) for RGB. For systems that naturally produce RGBA, either drop the alpha
channel or composite it over a background in this method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
A state. |
required |
**kwargs
|
Any
|
Additional rendering-specific keyword arguments. |
{}
|
Returns:
| Type | Description |
|---|---|
Array
|
An RGB image with dtype |
Source code in src/cax/core/cs.py
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | |