Update
Update
cax.core.update.update
Update base module.
Update
Bases: Module
Base class for update modules.
Subclasses implement transforms mapping a state and a perception (and optional input) to the next state.
Source code in src/cax/core/update/update.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | |
__call__(state, perception, input=None)
Process the current state, perception, and input to produce a new state.
This method should be implemented by subclasses to define specific update logic.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
Current state. |
required |
perception
|
Perception
|
Current perception. |
required |
input
|
Input | None
|
Optional input. |
None
|
Returns:
| Type | Description |
|---|---|
State
|
Next state. |
Source code in src/cax/core/update/update.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | |
cax.core.update.mlp_update
MLP update module.
MLPUpdate
Bases: Update
MLP update class.
Maps a perception (and optional input) to the next state using pointwise convolutional layers (kernel size 1) applied independently at each spatial position.
Source code in src/cax/core/update/mlp_update.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | |
__init__(num_spatial_dims, channel_size, perception_size, hidden_layer_sizes, *, activation_fn=nnx.relu, zeros_init=False, rngs)
Initialize MLP update.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_spatial_dims
|
int
|
Number of spatial dimensions. |
required |
channel_size
|
int
|
Number of channels in the output. |
required |
perception_size
|
int
|
Size of the input perception. |
required |
hidden_layer_sizes
|
tuple[int, ...]
|
Sizes of hidden layers. |
required |
activation_fn
|
Callable
|
Activation function to use. |
relu
|
zeros_init
|
bool
|
Whether to use zeros initialization for the weights of the last layer. |
False
|
rngs
|
Rngs
|
rng key. |
required |
Source code in src/cax/core/update/mlp_update.py
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | |
__call__(state, perception, input=None)
Process the current state, perception, and input to produce a new state.
If input is provided, it is concatenated to the perception along the channel axis before being passed through the layers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
Current state. |
required |
perception
|
Perception
|
Current perception. |
required |
input
|
Input | None
|
Optional input. |
None
|
Returns:
| Type | Description |
|---|---|
State
|
Next state. |
Source code in src/cax/core/update/mlp_update.py
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | |
cax.core.update.residual_update
Residual update module.
ResidualUpdate
Bases: MLPUpdate
Residual update class.
Extends the MLP update with a residual connection and cell dropout applied to the update.
Source code in src/cax/core/update/residual_update.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | |
__init__(num_spatial_dims, channel_size, perception_size, hidden_layer_sizes, *, activation_fn=nnx.relu, step_size=1.0, cell_dropout_rate=0.0, zeros_init=False, rngs)
Initialize the ResidualUpdate module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_spatial_dims
|
int
|
Number of spatial dimensions. |
required |
channel_size
|
int
|
Number of channels in the state. |
required |
perception_size
|
int
|
Size of the perception input. |
required |
hidden_layer_sizes
|
tuple[int, ...]
|
Sizes of hidden layers in the MLP. |
required |
activation_fn
|
Callable
|
Activation function to use. |
relu
|
step_size
|
float
|
Step size for the residual update. |
1.0
|
cell_dropout_rate
|
float
|
Dropout rate for cell updates. |
0.0
|
zeros_init
|
bool
|
Whether to use zeros initialization for the weights of the last layer. |
False
|
rngs
|
Rngs
|
rng key. |
required |
Source code in src/cax/core/update/residual_update.py
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | |
__call__(state, perception, input=None)
Process the current state, perception, and input to produce a new state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
Current state. |
required |
perception
|
Perception
|
Current perception. |
required |
input
|
Input | None
|
Optional input. |
None
|
Returns:
| Type | Description |
|---|---|
State
|
Next state. |
Source code in src/cax/core/update/residual_update.py
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | |
cax.core.update.nca_update
Neural Cellular Automata update module.
NCAUpdate
Bases: ResidualUpdate
Neural Cellular Automata update class.
Builds on the residual update and applies an alive mask so that only active cells update.
Source code in src/cax/core/update/nca_update.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | |
__init__(channel_size, perception_size, hidden_layer_sizes, *, activation_fn=nnx.relu, step_size=1.0, cell_dropout_rate=0.0, kernel_size=(3, 3), alive_threshold=0.1, zeros_init=False, rngs)
Initialize NCA update.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
channel_size
|
int
|
Number of input channels. |
required |
perception_size
|
int
|
Size of the perception. |
required |
hidden_layer_sizes
|
tuple[int, ...]
|
Sizes of hidden layers. |
required |
activation_fn
|
Callable
|
Activation function to use. |
relu
|
step_size
|
float
|
Step size for the update. |
1.0
|
cell_dropout_rate
|
float
|
Dropout rate for cells. |
0.0
|
kernel_size
|
Sequence[int]
|
Size of the convolutional kernel. |
(3, 3)
|
alive_threshold
|
float
|
Threshold for determining if a cell is alive. |
0.1
|
zeros_init
|
bool
|
Whether to use zeros initialization for the weights of the last layer. |
False
|
rngs
|
Rngs
|
rng key. |
required |
Source code in src/cax/core/update/nca_update.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | |
__call__(state, perception, input=None)
Process the current state, perception, and input to produce a new state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
Current state. |
required |
perception
|
Perception
|
Current perception. |
required |
input
|
Input | None
|
Optional input. |
None
|
Returns:
| Type | Description |
|---|---|
State
|
Next state. |
Source code in src/cax/core/update/nca_update.py
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | |
get_alive_mask(state)
Generate a mask of alive cells based on the current state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
Current state. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
A boolean mask indicating which cells are alive. |
Source code in src/cax/core/update/nca_update.py
81 82 83 84 85 86 87 88 89 90 91 92 93 | |
state_to_alive(state)
Extract the 'alive' component from the state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
Input state. |
required |
Returns:
| Type | Description |
|---|---|
State
|
The 'alive' component of the state. |
Source code in src/cax/core/update/nca_update.py
96 97 98 99 100 101 102 103 104 105 106 | |