Skip to content

Update

Update

cax.core.update.update

Update base module.

Update

Bases: Module

Base class for update modules.

Subclasses implement transforms mapping a state and a perception (and optional input) to the next state.

Source code in src/cax/core/update/update.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Update(nnx.Module):
	"""Base class for update modules.

	Subclasses implement transforms mapping a state and a perception (and optional input)
	to the next state.
	"""

	def __call__(self, state: State, perception: Perception, input: Input | None = None) -> State:
		"""Process the current state, perception, and input to produce a new state.

		This method should be implemented by subclasses to define specific update logic.

		Args:
			state: Current state.
			perception: Current perception.
			input: Optional input.

		Returns:
			Next state.

		"""
		raise NotImplementedError

__call__(state, perception, input=None)

Process the current state, perception, and input to produce a new state.

This method should be implemented by subclasses to define specific update logic.

Parameters:

Name Type Description Default
state State

Current state.

required
perception Perception

Current perception.

required
input Input | None

Optional input.

None

Returns:

Type Description
State

Next state.

Source code in src/cax/core/update/update.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def __call__(self, state: State, perception: Perception, input: Input | None = None) -> State:
	"""Process the current state, perception, and input to produce a new state.

	This method should be implemented by subclasses to define specific update logic.

	Args:
		state: Current state.
		perception: Current perception.
		input: Optional input.

	Returns:
		Next state.

	"""
	raise NotImplementedError

cax.core.update.mlp_update

MLP update module.

MLPUpdate

Bases: Update

MLP update class.

Maps a perception (and optional input) to the next state using pointwise convolutional layers (kernel size 1) applied independently at each spatial position.

Source code in src/cax/core/update/mlp_update.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
class MLPUpdate(Update):
	"""MLP update class.

	Maps a perception (and optional input) to the next state using pointwise convolutional
	layers (kernel size 1) applied independently at each spatial position.

	"""

	def __init__(
		self,
		num_spatial_dims: int,
		channel_size: int,
		perception_size: int,
		hidden_layer_sizes: tuple[int, ...],
		*,
		activation_fn: Callable = nnx.relu,
		zeros_init: bool = False,
		rngs: nnx.Rngs,
	):
		"""Initialize MLP update.

		Args:
			num_spatial_dims: Number of spatial dimensions.
			channel_size: Number of channels in the output.
			perception_size: Size of the input perception.
			hidden_layer_sizes: Sizes of hidden layers.
			activation_fn: Activation function to use.
			zeros_init: Whether to use zeros initialization for the weights of the last layer.
			rngs: rng key.

		"""
		in_features = (perception_size,) + hidden_layer_sizes
		out_features = hidden_layer_sizes + (channel_size,)
		kernel_init = [default_kernel_init for _ in hidden_layer_sizes] + [
			initializers.zeros_init() if zeros_init else default_kernel_init
		]
		self.layers = nnx.List(
			[
				nnx.Conv(
					in_features,
					out_features,
					kernel_size=num_spatial_dims * (1,),
					kernel_init=kernel_init,
					rngs=rngs,
				)
				for in_features, out_features, kernel_init in zip(
					in_features, out_features, kernel_init
				)
			]
		)
		self.activation_fn = activation_fn

	def __call__(self, state: State, perception: Perception, input: Input | None = None) -> State:
		"""Process the current state, perception, and input to produce a new state.

		If input is provided, it is concatenated to the perception along the channel axis
		before being passed through the layers.

		Args:
			state: Current state.
			perception: Current perception.
			input: Optional input.

		Returns:
			Next state.

		"""
		if input is not None:
			perception = jnp.concatenate([perception, input], axis=-1)

		for layer in self.layers[:-1]:
			perception = self.activation_fn(layer(perception))
		state = self.layers[-1](perception)
		return state

__init__(num_spatial_dims, channel_size, perception_size, hidden_layer_sizes, *, activation_fn=nnx.relu, zeros_init=False, rngs)

Initialize MLP update.

Parameters:

Name Type Description Default
num_spatial_dims int

Number of spatial dimensions.

required
channel_size int

Number of channels in the output.

required
perception_size int

Size of the input perception.

required
hidden_layer_sizes tuple[int, ...]

Sizes of hidden layers.

required
activation_fn Callable

Activation function to use.

relu
zeros_init bool

Whether to use zeros initialization for the weights of the last layer.

False
rngs Rngs

rng key.

required
Source code in src/cax/core/update/mlp_update.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(
	self,
	num_spatial_dims: int,
	channel_size: int,
	perception_size: int,
	hidden_layer_sizes: tuple[int, ...],
	*,
	activation_fn: Callable = nnx.relu,
	zeros_init: bool = False,
	rngs: nnx.Rngs,
):
	"""Initialize MLP update.

	Args:
		num_spatial_dims: Number of spatial dimensions.
		channel_size: Number of channels in the output.
		perception_size: Size of the input perception.
		hidden_layer_sizes: Sizes of hidden layers.
		activation_fn: Activation function to use.
		zeros_init: Whether to use zeros initialization for the weights of the last layer.
		rngs: rng key.

	"""
	in_features = (perception_size,) + hidden_layer_sizes
	out_features = hidden_layer_sizes + (channel_size,)
	kernel_init = [default_kernel_init for _ in hidden_layer_sizes] + [
		initializers.zeros_init() if zeros_init else default_kernel_init
	]
	self.layers = nnx.List(
		[
			nnx.Conv(
				in_features,
				out_features,
				kernel_size=num_spatial_dims * (1,),
				kernel_init=kernel_init,
				rngs=rngs,
			)
			for in_features, out_features, kernel_init in zip(
				in_features, out_features, kernel_init
			)
		]
	)
	self.activation_fn = activation_fn

__call__(state, perception, input=None)

Process the current state, perception, and input to produce a new state.

If input is provided, it is concatenated to the perception along the channel axis before being passed through the layers.

Parameters:

Name Type Description Default
state State

Current state.

required
perception Perception

Current perception.

required
input Input | None

Optional input.

None

Returns:

Type Description
State

Next state.

Source code in src/cax/core/update/mlp_update.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __call__(self, state: State, perception: Perception, input: Input | None = None) -> State:
	"""Process the current state, perception, and input to produce a new state.

	If input is provided, it is concatenated to the perception along the channel axis
	before being passed through the layers.

	Args:
		state: Current state.
		perception: Current perception.
		input: Optional input.

	Returns:
		Next state.

	"""
	if input is not None:
		perception = jnp.concatenate([perception, input], axis=-1)

	for layer in self.layers[:-1]:
		perception = self.activation_fn(layer(perception))
	state = self.layers[-1](perception)
	return state

cax.core.update.residual_update

Residual update module.

ResidualUpdate

Bases: MLPUpdate

Residual update class.

Extends the MLP update with a residual connection and cell dropout applied to the update.

Source code in src/cax/core/update/residual_update.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class ResidualUpdate(MLPUpdate):
	"""Residual update class.

	Extends the MLP update with a residual connection and cell dropout applied to the update.
	"""

	def __init__(
		self,
		num_spatial_dims: int,
		channel_size: int,
		perception_size: int,
		hidden_layer_sizes: tuple[int, ...],
		*,
		activation_fn: Callable = nnx.relu,
		step_size: float = 1.0,
		cell_dropout_rate: float = 0.0,
		zeros_init: bool = False,
		rngs: nnx.Rngs,
	):
		"""Initialize the ResidualUpdate module.

		Args:
			num_spatial_dims: Number of spatial dimensions.
			channel_size: Number of channels in the state.
			perception_size: Size of the perception input.
			hidden_layer_sizes: Sizes of hidden layers in the MLP.
			activation_fn: Activation function to use.
			step_size: Step size for the residual update.
			cell_dropout_rate: Dropout rate for cell updates.
			zeros_init: Whether to use zeros initialization for the weights of the last layer.
			rngs: rng key.

		"""
		super().__init__(
			num_spatial_dims=num_spatial_dims,
			channel_size=channel_size,
			perception_size=perception_size,
			hidden_layer_sizes=hidden_layer_sizes,
			activation_fn=activation_fn,
			zeros_init=zeros_init,
			rngs=rngs,
		)
		self.dropout = nnx.Dropout(rate=cell_dropout_rate, broadcast_dims=(-1,), rngs=rngs)
		self.step_size = step_size

	def __call__(self, state: State, perception: Perception, input: Input | None = None) -> State:
		"""Process the current state, perception, and input to produce a new state.

		Args:
			state: Current state.
			perception: Current perception.
			input: Optional input.

		Returns:
			Next state.

		"""
		update = super().__call__(state, perception, input)
		update = self.dropout(update)
		state += self.step_size * update
		return state

__init__(num_spatial_dims, channel_size, perception_size, hidden_layer_sizes, *, activation_fn=nnx.relu, step_size=1.0, cell_dropout_rate=0.0, zeros_init=False, rngs)

Initialize the ResidualUpdate module.

Parameters:

Name Type Description Default
num_spatial_dims int

Number of spatial dimensions.

required
channel_size int

Number of channels in the state.

required
perception_size int

Size of the perception input.

required
hidden_layer_sizes tuple[int, ...]

Sizes of hidden layers in the MLP.

required
activation_fn Callable

Activation function to use.

relu
step_size float

Step size for the residual update.

1.0
cell_dropout_rate float

Dropout rate for cell updates.

0.0
zeros_init bool

Whether to use zeros initialization for the weights of the last layer.

False
rngs Rngs

rng key.

required
Source code in src/cax/core/update/residual_update.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def __init__(
	self,
	num_spatial_dims: int,
	channel_size: int,
	perception_size: int,
	hidden_layer_sizes: tuple[int, ...],
	*,
	activation_fn: Callable = nnx.relu,
	step_size: float = 1.0,
	cell_dropout_rate: float = 0.0,
	zeros_init: bool = False,
	rngs: nnx.Rngs,
):
	"""Initialize the ResidualUpdate module.

	Args:
		num_spatial_dims: Number of spatial dimensions.
		channel_size: Number of channels in the state.
		perception_size: Size of the perception input.
		hidden_layer_sizes: Sizes of hidden layers in the MLP.
		activation_fn: Activation function to use.
		step_size: Step size for the residual update.
		cell_dropout_rate: Dropout rate for cell updates.
		zeros_init: Whether to use zeros initialization for the weights of the last layer.
		rngs: rng key.

	"""
	super().__init__(
		num_spatial_dims=num_spatial_dims,
		channel_size=channel_size,
		perception_size=perception_size,
		hidden_layer_sizes=hidden_layer_sizes,
		activation_fn=activation_fn,
		zeros_init=zeros_init,
		rngs=rngs,
	)
	self.dropout = nnx.Dropout(rate=cell_dropout_rate, broadcast_dims=(-1,), rngs=rngs)
	self.step_size = step_size

__call__(state, perception, input=None)

Process the current state, perception, and input to produce a new state.

Parameters:

Name Type Description Default
state State

Current state.

required
perception Perception

Current perception.

required
input Input | None

Optional input.

None

Returns:

Type Description
State

Next state.

Source code in src/cax/core/update/residual_update.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def __call__(self, state: State, perception: Perception, input: Input | None = None) -> State:
	"""Process the current state, perception, and input to produce a new state.

	Args:
		state: Current state.
		perception: Current perception.
		input: Optional input.

	Returns:
		Next state.

	"""
	update = super().__call__(state, perception, input)
	update = self.dropout(update)
	state += self.step_size * update
	return state

cax.core.update.nca_update

Neural Cellular Automata update module.

NCAUpdate

Bases: ResidualUpdate

Neural Cellular Automata update class.

Builds on the residual update and applies an alive mask so that only active cells update.

Source code in src/cax/core/update/nca_update.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class NCAUpdate(ResidualUpdate):
	"""Neural Cellular Automata update class.

	Builds on the residual update and applies an alive mask so that only active cells update.
	"""

	def __init__(
		self,
		channel_size: int,
		perception_size: int,
		hidden_layer_sizes: tuple[int, ...],
		*,
		activation_fn: Callable = nnx.relu,
		step_size: float = 1.0,
		cell_dropout_rate: float = 0.0,
		kernel_size: Sequence[int] = (3, 3),
		alive_threshold: float = 0.1,
		zeros_init: bool = False,
		rngs: nnx.Rngs,
	):
		"""Initialize NCA update.

		Args:
			channel_size: Number of input channels.
			perception_size: Size of the perception.
			hidden_layer_sizes: Sizes of hidden layers.
			activation_fn: Activation function to use.
			step_size: Step size for the update.
			cell_dropout_rate: Dropout rate for cells.
			kernel_size: Size of the convolutional kernel.
			alive_threshold: Threshold for determining if a cell is alive.
			zeros_init: Whether to use zeros initialization for the weights of the last layer.
			rngs: rng key.

		"""
		super().__init__(
			num_spatial_dims=len(kernel_size),
			channel_size=channel_size,
			perception_size=perception_size,
			hidden_layer_sizes=hidden_layer_sizes,
			rngs=rngs,
			activation_fn=activation_fn,
			step_size=step_size,
			cell_dropout_rate=cell_dropout_rate,
			zeros_init=zeros_init,
		)
		self.pool = partial(nnx.max_pool, window_shape=kernel_size, padding="SAME")
		self.alive_threshold = alive_threshold

	def __call__(self, state: State, perception: Perception, input: Input | None = None) -> State:
		"""Process the current state, perception, and input to produce a new state.

		Args:
			state: Current state.
			perception: Current perception.
			input: Optional input.

		Returns:
			Next state.

		"""
		alive_mask = self.get_alive_mask(state)
		state = super().__call__(state, perception, input)
		alive_mask &= self.get_alive_mask(state)
		return alive_mask * state

	def get_alive_mask(self, state: State) -> Array:
		"""Generate a mask of alive cells based on the current state.

		Args:
			state: Current state.

		Returns:
			A boolean mask indicating which cells are alive.

		"""
		alive = state_to_alive(state)
		alive_mask: Array = self.pool(alive) > self.alive_threshold
		return alive_mask

__init__(channel_size, perception_size, hidden_layer_sizes, *, activation_fn=nnx.relu, step_size=1.0, cell_dropout_rate=0.0, kernel_size=(3, 3), alive_threshold=0.1, zeros_init=False, rngs)

Initialize NCA update.

Parameters:

Name Type Description Default
channel_size int

Number of input channels.

required
perception_size int

Size of the perception.

required
hidden_layer_sizes tuple[int, ...]

Sizes of hidden layers.

required
activation_fn Callable

Activation function to use.

relu
step_size float

Step size for the update.

1.0
cell_dropout_rate float

Dropout rate for cells.

0.0
kernel_size Sequence[int]

Size of the convolutional kernel.

(3, 3)
alive_threshold float

Threshold for determining if a cell is alive.

0.1
zeros_init bool

Whether to use zeros initialization for the weights of the last layer.

False
rngs Rngs

rng key.

required
Source code in src/cax/core/update/nca_update.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def __init__(
	self,
	channel_size: int,
	perception_size: int,
	hidden_layer_sizes: tuple[int, ...],
	*,
	activation_fn: Callable = nnx.relu,
	step_size: float = 1.0,
	cell_dropout_rate: float = 0.0,
	kernel_size: Sequence[int] = (3, 3),
	alive_threshold: float = 0.1,
	zeros_init: bool = False,
	rngs: nnx.Rngs,
):
	"""Initialize NCA update.

	Args:
		channel_size: Number of input channels.
		perception_size: Size of the perception.
		hidden_layer_sizes: Sizes of hidden layers.
		activation_fn: Activation function to use.
		step_size: Step size for the update.
		cell_dropout_rate: Dropout rate for cells.
		kernel_size: Size of the convolutional kernel.
		alive_threshold: Threshold for determining if a cell is alive.
		zeros_init: Whether to use zeros initialization for the weights of the last layer.
		rngs: rng key.

	"""
	super().__init__(
		num_spatial_dims=len(kernel_size),
		channel_size=channel_size,
		perception_size=perception_size,
		hidden_layer_sizes=hidden_layer_sizes,
		rngs=rngs,
		activation_fn=activation_fn,
		step_size=step_size,
		cell_dropout_rate=cell_dropout_rate,
		zeros_init=zeros_init,
	)
	self.pool = partial(nnx.max_pool, window_shape=kernel_size, padding="SAME")
	self.alive_threshold = alive_threshold

__call__(state, perception, input=None)

Process the current state, perception, and input to produce a new state.

Parameters:

Name Type Description Default
state State

Current state.

required
perception Perception

Current perception.

required
input Input | None

Optional input.

None

Returns:

Type Description
State

Next state.

Source code in src/cax/core/update/nca_update.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __call__(self, state: State, perception: Perception, input: Input | None = None) -> State:
	"""Process the current state, perception, and input to produce a new state.

	Args:
		state: Current state.
		perception: Current perception.
		input: Optional input.

	Returns:
		Next state.

	"""
	alive_mask = self.get_alive_mask(state)
	state = super().__call__(state, perception, input)
	alive_mask &= self.get_alive_mask(state)
	return alive_mask * state

get_alive_mask(state)

Generate a mask of alive cells based on the current state.

Parameters:

Name Type Description Default
state State

Current state.

required

Returns:

Type Description
Array

A boolean mask indicating which cells are alive.

Source code in src/cax/core/update/nca_update.py
81
82
83
84
85
86
87
88
89
90
91
92
93
def get_alive_mask(self, state: State) -> Array:
	"""Generate a mask of alive cells based on the current state.

	Args:
		state: Current state.

	Returns:
		A boolean mask indicating which cells are alive.

	"""
	alive = state_to_alive(state)
	alive_mask: Array = self.pool(alive) > self.alive_threshold
	return alive_mask

state_to_alive(state)

Extract the 'alive' component from the state.

Parameters:

Name Type Description Default
state State

Input state.

required

Returns:

Type Description
State

The 'alive' component of the state.

Source code in src/cax/core/update/nca_update.py
 96
 97
 98
 99
100
101
102
103
104
105
106
def state_to_alive(state: State) -> State:
	"""Extract the 'alive' component from the state.

	Args:
		state: Input state.

	Returns:
		The 'alive' component of the state.

	"""
	return state[..., -1:]