Skip to content

Perceive

Perceive

cax.core.perceive.perceive

Perceive base module.

Perceive

Bases: Module

Base class for perception modules.

Subclasses implement neighborhood gathering or convolutional transforms that map a state to a perception. Perceptions are PyTrees; commonly arrays shaped (..., *spatial_dims, perception_size).

Source code in src/cax/core/perceive/perceive.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Perceive(nnx.Module):
	"""Base class for perception modules.

	Subclasses implement neighborhood gathering or convolutional transforms that map a state
	to a perception. Perceptions are PyTrees; commonly arrays shaped
	`(..., *spatial_dims, perception_size)`.
	"""

	def __call__(self, state: State) -> Perception:
		"""Process the current state to produce a perception.

		Args:
			state: Current state.

		Returns:
			Perception derived from `state`.

		"""
		raise NotImplementedError

__call__(state)

Process the current state to produce a perception.

Parameters:

Name Type Description Default
state State

Current state.

required

Returns:

Type Description
Perception

Perception derived from state.

Source code in src/cax/core/perceive/perceive.py
19
20
21
22
23
24
25
26
27
28
29
def __call__(self, state: State) -> Perception:
	"""Process the current state to produce a perception.

	Args:
		state: Current state.

	Returns:
		Perception derived from `state`.

	"""
	raise NotImplementedError

cax.core.perceive.moore_perceive

Moore perceive module.

MoorePerceive

Bases: Perceive

Moore perceive class.

This class implements perception based on the Moore neighborhood. The Moore neighborhood includes cells that are within a certain distance from the central cell in all dimensions simultaneously.

Source code in src/cax/core/perceive/moore_perceive.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class MoorePerceive(Perceive):
	"""Moore perceive class.

	This class implements perception based on the Moore neighborhood.
	The Moore neighborhood includes cells that are within a certain distance from the central cell
	in all dimensions simultaneously.
	"""

	def __init__(self, num_spatial_dims: int, radius: int):
		"""Initialize Moore perceive.

		Args:
			num_spatial_dims: Number of spatial dimensions.
			radius: Radius for Manhattan distance to compute the Moore neighborhood.

		"""
		self.num_spatial_dims = num_spatial_dims
		self.radius = radius

	def __call__(self, state: State) -> Perception:
		"""Apply Moore perception to the input state.

		The input is assumed to have shape `(..., *spatial_dims, channel_size)` where `spatial_dims`
		is a tuple of `num_spatial_dims` dimensions and `channel_size` is the number of channels.
		This method concatenates the central cell and all neighbors within the Moore neighborhood
		along the channel axis, yielding an output with shape `(..., *spatial, channel_size * N)`,
		where `N = (2 * radius + 1) ** num_spatial_dims`.

		Args:
			state: State of the cellular automaton.

		Returns:
			The Moore neighborhood for each state, with the central cell first.

		"""
		# Init neighbors
		neighbors = [state]

		# Get Moore shifts
		moore_shifts = [
			shift
			for shift in product(range(-self.radius, self.radius + 1), repeat=self.num_spatial_dims)
			if shift != (0,) * self.num_spatial_dims
		]

		# Compute the neighbors
		for shift in moore_shifts:
			neighbors.append(
				jnp.roll(state, shift, axis=tuple(range(-self.num_spatial_dims - 1, -1)))
			)

		return jnp.concatenate(neighbors, axis=-1)

__init__(num_spatial_dims, radius)

Initialize Moore perceive.

Parameters:

Name Type Description Default
num_spatial_dims int

Number of spatial dimensions.

required
radius int

Radius for Manhattan distance to compute the Moore neighborhood.

required
Source code in src/cax/core/perceive/moore_perceive.py
20
21
22
23
24
25
26
27
28
29
def __init__(self, num_spatial_dims: int, radius: int):
	"""Initialize Moore perceive.

	Args:
		num_spatial_dims: Number of spatial dimensions.
		radius: Radius for Manhattan distance to compute the Moore neighborhood.

	"""
	self.num_spatial_dims = num_spatial_dims
	self.radius = radius

__call__(state)

Apply Moore perception to the input state.

The input is assumed to have shape (..., *spatial_dims, channel_size) where spatial_dims is a tuple of num_spatial_dims dimensions and channel_size is the number of channels. This method concatenates the central cell and all neighbors within the Moore neighborhood along the channel axis, yielding an output with shape (..., *spatial, channel_size * N), where N = (2 * radius + 1) ** num_spatial_dims.

Parameters:

Name Type Description Default
state State

State of the cellular automaton.

required

Returns:

Type Description
Perception

The Moore neighborhood for each state, with the central cell first.

Source code in src/cax/core/perceive/moore_perceive.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __call__(self, state: State) -> Perception:
	"""Apply Moore perception to the input state.

	The input is assumed to have shape `(..., *spatial_dims, channel_size)` where `spatial_dims`
	is a tuple of `num_spatial_dims` dimensions and `channel_size` is the number of channels.
	This method concatenates the central cell and all neighbors within the Moore neighborhood
	along the channel axis, yielding an output with shape `(..., *spatial, channel_size * N)`,
	where `N = (2 * radius + 1) ** num_spatial_dims`.

	Args:
		state: State of the cellular automaton.

	Returns:
		The Moore neighborhood for each state, with the central cell first.

	"""
	# Init neighbors
	neighbors = [state]

	# Get Moore shifts
	moore_shifts = [
		shift
		for shift in product(range(-self.radius, self.radius + 1), repeat=self.num_spatial_dims)
		if shift != (0,) * self.num_spatial_dims
	]

	# Compute the neighbors
	for shift in moore_shifts:
		neighbors.append(
			jnp.roll(state, shift, axis=tuple(range(-self.num_spatial_dims - 1, -1)))
		)

	return jnp.concatenate(neighbors, axis=-1)

cax.core.perceive.von_neumann_perceive

Von Neumann perceive module.

VonNeumannPerceive

Bases: Perceive

Von Neumann perceive class.

This class implements perception based on the Von Neumann neighborhood. The Von Neumann neighborhood includes cells within a specified Manhattan distance of the central cell.

Source code in src/cax/core/perceive/von_neumann_perceive.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class VonNeumannPerceive(Perceive):
	"""Von Neumann perceive class.

	This class implements perception based on the Von Neumann neighborhood.
	The Von Neumann neighborhood includes cells within a specified Manhattan distance of the central
	cell.
	"""

	def __init__(self, num_spatial_dims: int, radius: int):
		"""Initialize Von Neumann perceive.

		Args:
			num_spatial_dims: Number of spatial dimensions.
			radius: Radius for Manhattan distance to compute the Von Neumann neighborhood.

		"""
		self.num_spatial_dims = num_spatial_dims
		self.radius = radius

	def __call__(self, state: State) -> Perception:
		"""Apply Von Neumann perception to the state.

		The input is assumed to have shape `(..., *spatial_dims, channel_size)` where `spatial_dims`
		is a tuple of `num_spatial_dims` dimensions and `channel_size` is the number of channels.
		This method concatenates the central cell and all neighbors within the Von Neumann
		neighborhood (Manhattan distance `<= radius`) along the channel axis. The number of
		concatenated positions equals:
			`1 + sum_{k=1..radius} 2 * num_spatial_dims * binom(num_spatial_dims + k - 1, k)`.

		Args:
			state: State of the cellular automaton.

		Returns:
			The Von Neumann neighborhood for each state.

		"""
		# Init neighbors
		neighbors = [state]

		# Get Moore shifts
		moore_shifts = product(range(-self.radius, self.radius + 1), repeat=self.num_spatial_dims)

		# Get Von Neumann shifts by filtering Moore shifts with Manhattan distance <= radius
		von_neumann_shifts = [
			shift for shift in moore_shifts if 0 < sum(map(abs, shift)) <= self.radius
		]

		# Compute the neighbors
		for shift in von_neumann_shifts:
			neighbors.append(
				jnp.roll(state, shift, axis=tuple(range(-self.num_spatial_dims - 1, -1)))
			)
		return jnp.concatenate(neighbors, axis=-1)

__init__(num_spatial_dims, radius)

Initialize Von Neumann perceive.

Parameters:

Name Type Description Default
num_spatial_dims int

Number of spatial dimensions.

required
radius int

Radius for Manhattan distance to compute the Von Neumann neighborhood.

required
Source code in src/cax/core/perceive/von_neumann_perceive.py
20
21
22
23
24
25
26
27
28
29
def __init__(self, num_spatial_dims: int, radius: int):
	"""Initialize Von Neumann perceive.

	Args:
		num_spatial_dims: Number of spatial dimensions.
		radius: Radius for Manhattan distance to compute the Von Neumann neighborhood.

	"""
	self.num_spatial_dims = num_spatial_dims
	self.radius = radius

__call__(state)

Apply Von Neumann perception to the state.

The input is assumed to have shape (..., *spatial_dims, channel_size) where spatial_dims is a tuple of num_spatial_dims dimensions and channel_size is the number of channels. This method concatenates the central cell and all neighbors within the Von Neumann neighborhood (Manhattan distance <= radius) along the channel axis. The number of concatenated positions equals: 1 + sum_{k=1..radius} 2 * num_spatial_dims * binom(num_spatial_dims + k - 1, k).

Parameters:

Name Type Description Default
state State

State of the cellular automaton.

required

Returns:

Type Description
Perception

The Von Neumann neighborhood for each state.

Source code in src/cax/core/perceive/von_neumann_perceive.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def __call__(self, state: State) -> Perception:
	"""Apply Von Neumann perception to the state.

	The input is assumed to have shape `(..., *spatial_dims, channel_size)` where `spatial_dims`
	is a tuple of `num_spatial_dims` dimensions and `channel_size` is the number of channels.
	This method concatenates the central cell and all neighbors within the Von Neumann
	neighborhood (Manhattan distance `<= radius`) along the channel axis. The number of
	concatenated positions equals:
		`1 + sum_{k=1..radius} 2 * num_spatial_dims * binom(num_spatial_dims + k - 1, k)`.

	Args:
		state: State of the cellular automaton.

	Returns:
		The Von Neumann neighborhood for each state.

	"""
	# Init neighbors
	neighbors = [state]

	# Get Moore shifts
	moore_shifts = product(range(-self.radius, self.radius + 1), repeat=self.num_spatial_dims)

	# Get Von Neumann shifts by filtering Moore shifts with Manhattan distance <= radius
	von_neumann_shifts = [
		shift for shift in moore_shifts if 0 < sum(map(abs, shift)) <= self.radius
	]

	# Compute the neighbors
	for shift in von_neumann_shifts:
		neighbors.append(
			jnp.roll(state, shift, axis=tuple(range(-self.num_spatial_dims - 1, -1)))
		)
	return jnp.concatenate(neighbors, axis=-1)

cax.core.perceive.conv_perceive

Convolution perceive module.

ConvPerceive

Bases: Perceive

Convolution perceive class.

Source code in src/cax/core/perceive/conv_perceive.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class ConvPerceive(Perceive):
	"""Convolution perceive class."""

	def __init__(
		self,
		channel_size: int,
		perception_size: int,
		*,
		kernel_size: int | tuple[int, ...] = (3, 3),
		padding: str = "SAME",
		feature_group_count: int = 1,
		use_bias: bool = False,
		activation_fn: Callable | None = None,
		rngs: nnx.Rngs,
	):
		"""Initialize convolution perceive.

		Args:
			channel_size: Number of input channels.
			perception_size: Number of output perception features.
			kernel_size: Size of the convolutional kernel.
			padding: Padding to use.
			feature_group_count: Number of feature groups.
			use_bias: Whether to use bias in convolutional layers.
			activation_fn: Activation function to use.
			rngs: rng key.

		"""
		self.conv = nnx.Conv(
			in_features=channel_size,
			out_features=perception_size,
			kernel_size=kernel_size,
			padding=padding,
			feature_group_count=feature_group_count,
			use_bias=use_bias,
			rngs=rngs,
		)
		self.activation_fn = activation_fn

	def __call__(self, state: State) -> Perception:
		"""Apply convolutional perception to the input state.

		Inputs are expected to have shape `(..., *spatial_dims, channel_size)` where `spatial_dims`
		is a tuple of `num_spatial_dims` dimensions and `channel_size` is the number of channels.
		The output shape is `(..., *spatial_dims, perception_size)`. If `activation_fn` is provided,
		it is applied element-wise to the convolution output. If `activation_fn` is `None`, the
		convolution output is returned as is.

		Args:
			state: State of the cellular automaton.

		Returns:
			The perceived state after applying convolutional layers.

		"""
		perception = self.conv(state)
		return self.activation_fn(perception) if self.activation_fn else perception

__init__(channel_size, perception_size, *, kernel_size=(3, 3), padding='SAME', feature_group_count=1, use_bias=False, activation_fn=None, rngs)

Initialize convolution perceive.

Parameters:

Name Type Description Default
channel_size int

Number of input channels.

required
perception_size int

Number of output perception features.

required
kernel_size int | tuple[int, ...]

Size of the convolutional kernel.

(3, 3)
padding str

Padding to use.

'SAME'
feature_group_count int

Number of feature groups.

1
use_bias bool

Whether to use bias in convolutional layers.

False
activation_fn Callable | None

Activation function to use.

None
rngs Rngs

rng key.

required
Source code in src/cax/core/perceive/conv_perceive.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def __init__(
	self,
	channel_size: int,
	perception_size: int,
	*,
	kernel_size: int | tuple[int, ...] = (3, 3),
	padding: str = "SAME",
	feature_group_count: int = 1,
	use_bias: bool = False,
	activation_fn: Callable | None = None,
	rngs: nnx.Rngs,
):
	"""Initialize convolution perceive.

	Args:
		channel_size: Number of input channels.
		perception_size: Number of output perception features.
		kernel_size: Size of the convolutional kernel.
		padding: Padding to use.
		feature_group_count: Number of feature groups.
		use_bias: Whether to use bias in convolutional layers.
		activation_fn: Activation function to use.
		rngs: rng key.

	"""
	self.conv = nnx.Conv(
		in_features=channel_size,
		out_features=perception_size,
		kernel_size=kernel_size,
		padding=padding,
		feature_group_count=feature_group_count,
		use_bias=use_bias,
		rngs=rngs,
	)
	self.activation_fn = activation_fn

__call__(state)

Apply convolutional perception to the input state.

Inputs are expected to have shape (..., *spatial_dims, channel_size) where spatial_dims is a tuple of num_spatial_dims dimensions and channel_size is the number of channels. The output shape is (..., *spatial_dims, perception_size). If activation_fn is provided, it is applied element-wise to the convolution output. If activation_fn is None, the convolution output is returned as is.

Parameters:

Name Type Description Default
state State

State of the cellular automaton.

required

Returns:

Type Description
Perception

The perceived state after applying convolutional layers.

Source code in src/cax/core/perceive/conv_perceive.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def __call__(self, state: State) -> Perception:
	"""Apply convolutional perception to the input state.

	Inputs are expected to have shape `(..., *spatial_dims, channel_size)` where `spatial_dims`
	is a tuple of `num_spatial_dims` dimensions and `channel_size` is the number of channels.
	The output shape is `(..., *spatial_dims, perception_size)`. If `activation_fn` is provided,
	it is applied element-wise to the convolution output. If `activation_fn` is `None`, the
	convolution output is returned as is.

	Args:
		state: State of the cellular automaton.

	Returns:
		The perceived state after applying convolutional layers.

	"""
	perception = self.conv(state)
	return self.activation_fn(perception) if self.activation_fn else perception

cax.core.perceive.kernels

Kernel utilities for perception modules.

Each function returns a small spatial kernel suitable for neighborhood aggregation or finite-difference style operations. Kernels use channel-last layout and a support of size 3 along each spatial dimension.

identity_kernel(ndim)

Create an identity kernel for the given number of dimensions.

The kernel has value 1 at the central position and 0 elsewhere.

Parameters:

Name Type Description Default
ndim int

Number of dimensions for the kernel.

required

Returns:

Type Description
Array

Array with shape ndim * (3,) + (1,).

Source code in src/cax/core/perceive/kernels.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def identity_kernel(ndim: int) -> Array:
	"""Create an identity kernel for the given number of dimensions.

	The kernel has value 1 at the central position and 0 elsewhere.

	Args:
		ndim: Number of dimensions for the kernel.

	Returns:
		Array with shape `ndim * (3,) + (1,)`.

	"""
	kernel = jnp.zeros(ndim * (3,))
	center_idx = ndim * (1,)
	kernel = kernel.at[center_idx].set(1.0)
	return jnp.expand_dims(kernel, axis=-1)

neighbors_kernel(ndim)

Create a neighbors kernel for the given number of dimensions.

This kernel is 1 - identity_kernel, selecting all neighbors and excluding the center.

Parameters:

Name Type Description Default
ndim int

Number of dimensions for the kernel.

required

Returns:

Type Description
Array

Array with shape ndim * (3,) + (1,).

Source code in src/cax/core/perceive/kernels.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def neighbors_kernel(ndim: int) -> Array:
	"""Create a neighbors kernel for the given number of dimensions.

	This kernel is `1 - identity_kernel`, selecting all neighbors and excluding the center.

	Args:
		ndim: Number of dimensions for the kernel.

	Returns:
		Array with shape `ndim * (3,) + (1,)`.

	"""
	kernel = identity_kernel(ndim)
	return 1.0 - kernel

grad_kernel(ndim, *, normalize=True)

Create a gradient kernel for the given number of dimensions.

Parameters:

Name Type Description Default
ndim int

Number of dimensions for the kernel.

required
normalize bool

Whether to L1-normalize each axis kernel.

True

Returns:

Type Description
Array

Array with shape ndim * (3,) + (ndim,), one kernel per spatial axis.

Source code in src/cax/core/perceive/kernels.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def grad_kernel(ndim: int, *, normalize: bool = True) -> Array:
	"""Create a gradient kernel for the given number of dimensions.

	Args:
		ndim: Number of dimensions for the kernel.
		normalize: Whether to L1-normalize each axis kernel.

	Returns:
		Array with shape `ndim * (3,) + (ndim,)`, one kernel per spatial axis.

	"""
	grad = jnp.array([-1, 0, 1])
	smooth = jnp.array([1, 2, 1])

	kernels = []
	for i in range(ndim):
		kernel = jnp.ones([3] * ndim)

		for j in range(ndim):
			axis_kernel = smooth if i != j else grad
			kernel = kernel * axis_kernel.reshape([-1 if k == j else 1 for k in range(ndim)])

		kernels.append(kernel)

	if normalize:
		kernels = [kernel / jnp.sum(jnp.abs(kernel)) for kernel in kernels]

	return jnp.stack(kernels, axis=-1)

grad2_kernel(ndim, normalize=True)

Create a second-order (Laplacian) kernel.

Parameters:

Name Type Description Default
ndim int

Number of dimensions for the kernel.

required
normalize bool

Whether to L1-normalize the kernel.

True

Returns:

Type Description
Array

Array with shape ndim * (3,) + (1,).

Source code in src/cax/core/perceive/kernels.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def grad2_kernel(ndim: int, normalize: bool = True) -> Array:
	"""Create a second-order (Laplacian) kernel.

	Args:
		ndim: Number of dimensions for the kernel.
		normalize: Whether to L1-normalize the kernel.

	Returns:
		Array with shape `ndim * (3,) + (1,)`.

	"""
	kernel = jnp.zeros([3] * ndim)
	center = tuple(1 for _ in range(ndim))
	kernel = kernel.at[center].set(-2.0 * ndim)

	for axis in range(ndim):
		for offset in (-1, 1):
			idx = list(center)
			idx[axis] += offset
			kernel = kernel.at[tuple(idx)].set(1.0)

	if normalize:
		kernel = kernel / jnp.sum(jnp.abs(kernel))

	return jnp.expand_dims(kernel, axis=-1)