Skip to content

Perceive

Perceive

cax.core.perceive.perceive

Perceive base module.

Perceive

Bases: Module

Base class for perception modules.

Subclasses implement neighborhood gathering or convolutional transforms that map a state to a perception. Perceptions are PyTrees; commonly arrays shaped (..., *spatial_dims, perception_size).

Source code in src/cax/core/perceive/perceive.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Perceive[State](nnx.Module):
	"""Base class for perception modules.

	Subclasses implement neighborhood gathering or convolutional transforms that map a state
	to a perception. Perceptions are PyTrees; commonly arrays shaped
	`(..., *spatial_dims, perception_size)`.
	"""

	def __call__(self, state: State) -> Perception:
		"""Process the current state to produce a perception.

		Args:
			state: Current state.

		Returns:
			Perception derived from `state`.

		"""
		raise NotImplementedError

__call__(state)

Process the current state to produce a perception.

Parameters:

Name Type Description Default
state State

Current state.

required

Returns:

Type Description
Perception

Perception derived from state.

Source code in src/cax/core/perceive/perceive.py
18
19
20
21
22
23
24
25
26
27
28
def __call__(self, state: State) -> Perception:
	"""Process the current state to produce a perception.

	Args:
		state: Current state.

	Returns:
		Perception derived from `state`.

	"""
	raise NotImplementedError

cax.core.perceive.moore_perceive

Moore perceive module.

MoorePerceive

Bases: NeighborhoodPerceive

Moore perceive class.

This class implements perception based on the Moore neighborhood. The Moore neighborhood includes all cells within Chebyshev distance radius of the central cell — i.e., the full hypercube of side length 2 * radius + 1 excluding the center.

Source code in src/cax/core/perceive/moore_perceive.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class MoorePerceive(NeighborhoodPerceive):
	"""Moore perceive class.

	This class implements perception based on the Moore neighborhood.
	The Moore neighborhood includes all cells within Chebyshev distance `radius` of the
	central cell — i.e., the full hypercube of side length `2 * radius + 1` excluding
	the center.
	"""

	def __init__(
		self,
		num_spatial_dims: int,
		radius: int,
		*,
		padding: str = "CIRCULAR",
		include_center: bool = True,
		reduce_fn: Callable[..., Array] | None = None,
	):
		"""Initialize Moore perceive.

		Args:
			num_spatial_dims: Number of spatial dimensions.
			radius: Chebyshev distance defining the Moore neighborhood extent.
			padding: Boundary condition mode. One of "CIRCULAR" (periodic/torus),
				"ZERO" (zero-padded), "REFLECT" (mirror), or "EDGE" (clamp to boundary).
			include_center: Whether to include the center cell in the output.
			reduce_fn: Optional reduction function applied over the neighbor axis. If None,
				neighbors are concatenated along the channel axis. If provided, it is called
				as `reduce_fn(stacked_neighbors, axis=0)` and the result is concatenated
				with the center (if `include_center` is True).

		"""
		super().__init__(
			num_spatial_dims=num_spatial_dims,
			radius=radius,
			padding=padding,
			include_center=include_center,
			reduce_fn=reduce_fn,
		)

	def _get_shifts(self) -> list[tuple[int, ...]]:
		"""Return all shifts in the Moore neighborhood (excluding center).

		Returns:
			List of shift tuples covering the full hypercube minus the origin.

		"""
		return [
			shift
			for shift in product(range(-self.radius, self.radius + 1), repeat=self.num_spatial_dims)
			if shift != (0,) * self.num_spatial_dims
		]

__init__(num_spatial_dims, radius, *, padding='CIRCULAR', include_center=True, reduce_fn=None)

Initialize Moore perceive.

Parameters:

Name Type Description Default
num_spatial_dims int

Number of spatial dimensions.

required
radius int

Chebyshev distance defining the Moore neighborhood extent.

required
padding str

Boundary condition mode. One of "CIRCULAR" (periodic/torus), "ZERO" (zero-padded), "REFLECT" (mirror), or "EDGE" (clamp to boundary).

'CIRCULAR'
include_center bool

Whether to include the center cell in the output.

True
reduce_fn Callable[..., Array] | None

Optional reduction function applied over the neighbor axis. If None, neighbors are concatenated along the channel axis. If provided, it is called as reduce_fn(stacked_neighbors, axis=0) and the result is concatenated with the center (if include_center is True).

None
Source code in src/cax/core/perceive/moore_perceive.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def __init__(
	self,
	num_spatial_dims: int,
	radius: int,
	*,
	padding: str = "CIRCULAR",
	include_center: bool = True,
	reduce_fn: Callable[..., Array] | None = None,
):
	"""Initialize Moore perceive.

	Args:
		num_spatial_dims: Number of spatial dimensions.
		radius: Chebyshev distance defining the Moore neighborhood extent.
		padding: Boundary condition mode. One of "CIRCULAR" (periodic/torus),
			"ZERO" (zero-padded), "REFLECT" (mirror), or "EDGE" (clamp to boundary).
		include_center: Whether to include the center cell in the output.
		reduce_fn: Optional reduction function applied over the neighbor axis. If None,
			neighbors are concatenated along the channel axis. If provided, it is called
			as `reduce_fn(stacked_neighbors, axis=0)` and the result is concatenated
			with the center (if `include_center` is True).

	"""
	super().__init__(
		num_spatial_dims=num_spatial_dims,
		radius=radius,
		padding=padding,
		include_center=include_center,
		reduce_fn=reduce_fn,
	)

__call__(state)

Apply neighborhood perception to the input state.

The input is assumed to have shape (..., *spatial_dims, channel_size).

When reduce_fn is None, neighbors are concatenated along the channel axis: - With include_center=True: shape (..., *spatial, channel_size * (N + 1)) - With include_center=False: shape (..., *spatial, channel_size * N) where N is the number of neighbor shifts.

When reduce_fn is provided, neighbors are stacked and reduced: - With include_center=True: shape (..., *spatial, 2 * channel_size) - With include_center=False: shape (..., *spatial, channel_size)

Parameters:

Name Type Description Default
state Array

State of the cellular automaton.

required

Returns:

Type Description
Perception

Neighborhood perception.

Source code in src/cax/core/perceive/neighborhood_perceive.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def __call__(self, state: Array) -> Perception:
	"""Apply neighborhood perception to the input state.

	The input is assumed to have shape `(..., *spatial_dims, channel_size)`.

	When `reduce_fn` is None, neighbors are concatenated along the channel axis:
		- With `include_center=True`: shape `(..., *spatial, channel_size * (N + 1))`
		- With `include_center=False`: shape `(..., *spatial, channel_size * N)`
		where N is the number of neighbor shifts.

	When `reduce_fn` is provided, neighbors are stacked and reduced:
		- With `include_center=True`: shape `(..., *spatial, 2 * channel_size)`
		- With `include_center=False`: shape `(..., *spatial, channel_size)`

	Args:
		state: State of the cellular automaton.

	Returns:
		Neighborhood perception.

	"""
	spatial_start = state.ndim - self.num_spatial_dims - 1
	spatial_sizes = state.shape[spatial_start : spatial_start + self.num_spatial_dims]

	pad_mode = _PADDING_TO_PAD_MODE[self.padding]
	pad_widths = (
		[(0, 0)] * spatial_start
		+ [(self.radius, self.radius)] * self.num_spatial_dims
		+ [(0, 0)]
	)
	padded = jnp.pad(state, pad_widths, mode=pad_mode)

	shifts = self._get_shifts()

	neighbors = []
	for shift in shifts:
		slices = [slice(None)] * spatial_start
		for dim_idx in range(self.num_spatial_dims):
			start = self.radius - shift[dim_idx]
			slices.append(slice(start, start + spatial_sizes[dim_idx]))
		slices.append(slice(None))
		neighbors.append(padded[tuple(slices)])

	if self.reduce_fn is not None:
		stacked = jnp.stack(neighbors, axis=0)
		reduced = self.reduce_fn(stacked, axis=0)
		if self.include_center:
			return jnp.concatenate([state, reduced], axis=-1)
		return reduced

	if self.include_center:
		return jnp.concatenate([state, *neighbors], axis=-1)
	return jnp.concatenate(neighbors, axis=-1)

cax.core.perceive.von_neumann_perceive

Von Neumann perceive module.

VonNeumannPerceive

Bases: NeighborhoodPerceive

Von Neumann perceive class.

This class implements perception based on the Von Neumann neighborhood. The Von Neumann neighborhood includes all cells within Manhattan distance radius of the central cell.

Source code in src/cax/core/perceive/von_neumann_perceive.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class VonNeumannPerceive(NeighborhoodPerceive):
	"""Von Neumann perceive class.

	This class implements perception based on the Von Neumann neighborhood.
	The Von Neumann neighborhood includes all cells within Manhattan distance `radius`
	of the central cell.
	"""

	def __init__(
		self,
		num_spatial_dims: int,
		radius: int,
		*,
		padding: str = "CIRCULAR",
		include_center: bool = True,
		reduce_fn: Callable[..., Array] | None = None,
	):
		"""Initialize Von Neumann perceive.

		Args:
			num_spatial_dims: Number of spatial dimensions.
			radius: Manhattan distance defining the Von Neumann neighborhood extent.
			padding: Boundary condition mode. One of "CIRCULAR" (periodic/torus),
				"ZERO" (zero-padded), "REFLECT" (mirror), or "EDGE" (clamp to boundary).
			include_center: Whether to include the center cell in the output.
			reduce_fn: Optional reduction function applied over the neighbor axis. If None,
				neighbors are concatenated along the channel axis. If provided, it is called
				as `reduce_fn(stacked_neighbors, axis=0)` and the result is concatenated
				with the center (if `include_center` is True).

		"""
		super().__init__(
			num_spatial_dims=num_spatial_dims,
			radius=radius,
			padding=padding,
			include_center=include_center,
			reduce_fn=reduce_fn,
		)

	def _get_shifts(self) -> list[tuple[int, ...]]:
		"""Return all shifts in the Von Neumann neighborhood (excluding center).

		Returns:
			List of shift tuples with Manhattan distance <= radius, excluding the origin.

		"""
		return [
			shift
			for shift in product(range(-self.radius, self.radius + 1), repeat=self.num_spatial_dims)
			if 0 < sum(map(abs, shift)) <= self.radius
		]

__init__(num_spatial_dims, radius, *, padding='CIRCULAR', include_center=True, reduce_fn=None)

Initialize Von Neumann perceive.

Parameters:

Name Type Description Default
num_spatial_dims int

Number of spatial dimensions.

required
radius int

Manhattan distance defining the Von Neumann neighborhood extent.

required
padding str

Boundary condition mode. One of "CIRCULAR" (periodic/torus), "ZERO" (zero-padded), "REFLECT" (mirror), or "EDGE" (clamp to boundary).

'CIRCULAR'
include_center bool

Whether to include the center cell in the output.

True
reduce_fn Callable[..., Array] | None

Optional reduction function applied over the neighbor axis. If None, neighbors are concatenated along the channel axis. If provided, it is called as reduce_fn(stacked_neighbors, axis=0) and the result is concatenated with the center (if include_center is True).

None
Source code in src/cax/core/perceive/von_neumann_perceive.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def __init__(
	self,
	num_spatial_dims: int,
	radius: int,
	*,
	padding: str = "CIRCULAR",
	include_center: bool = True,
	reduce_fn: Callable[..., Array] | None = None,
):
	"""Initialize Von Neumann perceive.

	Args:
		num_spatial_dims: Number of spatial dimensions.
		radius: Manhattan distance defining the Von Neumann neighborhood extent.
		padding: Boundary condition mode. One of "CIRCULAR" (periodic/torus),
			"ZERO" (zero-padded), "REFLECT" (mirror), or "EDGE" (clamp to boundary).
		include_center: Whether to include the center cell in the output.
		reduce_fn: Optional reduction function applied over the neighbor axis. If None,
			neighbors are concatenated along the channel axis. If provided, it is called
			as `reduce_fn(stacked_neighbors, axis=0)` and the result is concatenated
			with the center (if `include_center` is True).

	"""
	super().__init__(
		num_spatial_dims=num_spatial_dims,
		radius=radius,
		padding=padding,
		include_center=include_center,
		reduce_fn=reduce_fn,
	)

__call__(state)

Apply neighborhood perception to the input state.

The input is assumed to have shape (..., *spatial_dims, channel_size).

When reduce_fn is None, neighbors are concatenated along the channel axis: - With include_center=True: shape (..., *spatial, channel_size * (N + 1)) - With include_center=False: shape (..., *spatial, channel_size * N) where N is the number of neighbor shifts.

When reduce_fn is provided, neighbors are stacked and reduced: - With include_center=True: shape (..., *spatial, 2 * channel_size) - With include_center=False: shape (..., *spatial, channel_size)

Parameters:

Name Type Description Default
state Array

State of the cellular automaton.

required

Returns:

Type Description
Perception

Neighborhood perception.

Source code in src/cax/core/perceive/neighborhood_perceive.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def __call__(self, state: Array) -> Perception:
	"""Apply neighborhood perception to the input state.

	The input is assumed to have shape `(..., *spatial_dims, channel_size)`.

	When `reduce_fn` is None, neighbors are concatenated along the channel axis:
		- With `include_center=True`: shape `(..., *spatial, channel_size * (N + 1))`
		- With `include_center=False`: shape `(..., *spatial, channel_size * N)`
		where N is the number of neighbor shifts.

	When `reduce_fn` is provided, neighbors are stacked and reduced:
		- With `include_center=True`: shape `(..., *spatial, 2 * channel_size)`
		- With `include_center=False`: shape `(..., *spatial, channel_size)`

	Args:
		state: State of the cellular automaton.

	Returns:
		Neighborhood perception.

	"""
	spatial_start = state.ndim - self.num_spatial_dims - 1
	spatial_sizes = state.shape[spatial_start : spatial_start + self.num_spatial_dims]

	pad_mode = _PADDING_TO_PAD_MODE[self.padding]
	pad_widths = (
		[(0, 0)] * spatial_start
		+ [(self.radius, self.radius)] * self.num_spatial_dims
		+ [(0, 0)]
	)
	padded = jnp.pad(state, pad_widths, mode=pad_mode)

	shifts = self._get_shifts()

	neighbors = []
	for shift in shifts:
		slices = [slice(None)] * spatial_start
		for dim_idx in range(self.num_spatial_dims):
			start = self.radius - shift[dim_idx]
			slices.append(slice(start, start + spatial_sizes[dim_idx]))
		slices.append(slice(None))
		neighbors.append(padded[tuple(slices)])

	if self.reduce_fn is not None:
		stacked = jnp.stack(neighbors, axis=0)
		reduced = self.reduce_fn(stacked, axis=0)
		if self.include_center:
			return jnp.concatenate([state, reduced], axis=-1)
		return reduced

	if self.include_center:
		return jnp.concatenate([state, *neighbors], axis=-1)
	return jnp.concatenate(neighbors, axis=-1)

cax.core.perceive.conv_perceive

Convolution perceive module.

ConvPerceive

Bases: Perceive[Array]

Convolution perceive class.

Source code in src/cax/core/perceive/conv_perceive.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class ConvPerceive(Perceive[Array]):
	"""Convolution perceive class."""

	def __init__(
		self,
		channel_size: int,
		perception_size: int,
		*,
		kernel_size: int | tuple[int, ...] = (3, 3),
		padding: str = "SAME",
		feature_group_count: int = 1,
		use_bias: bool = False,
		activation_fn: Callable | None = None,
		rngs: nnx.Rngs,
	):
		"""Initialize convolution perceive.

		Args:
			channel_size: Number of input channels.
			perception_size: Number of output perception features.
			kernel_size: Size of the convolutional kernel.
			padding: Padding to use.
			feature_group_count: Number of feature groups.
			use_bias: Whether to use bias in convolutional layers.
			activation_fn: Activation function to use.
			rngs: rng key.

		"""
		self.conv = nnx.Conv(
			in_features=channel_size,
			out_features=perception_size,
			kernel_size=kernel_size,
			padding=padding,
			feature_group_count=feature_group_count,
			use_bias=use_bias,
			rngs=rngs,
		)
		self.activation_fn = activation_fn

	def __call__(self, state: Array) -> Perception:
		"""Apply convolutional perception to the input state.

		Inputs are expected to have shape `(..., *spatial_dims, channel_size)` where `spatial_dims`
		is a tuple of `num_spatial_dims` dimensions and `channel_size` is the number of channels.
		The output shape is `(..., *spatial_dims, perception_size)`. If `activation_fn` is provided,
		it is applied element-wise to the convolution output. If `activation_fn` is `None`, the
		convolution output is returned as is.

		Args:
			state: State of the cellular automaton.

		Returns:
			The perceived state after applying convolutional layers.

		"""
		perception = self.conv(state)
		return self.activation_fn(perception) if self.activation_fn else perception

__init__(channel_size, perception_size, *, kernel_size=(3, 3), padding='SAME', feature_group_count=1, use_bias=False, activation_fn=None, rngs)

Initialize convolution perceive.

Parameters:

Name Type Description Default
channel_size int

Number of input channels.

required
perception_size int

Number of output perception features.

required
kernel_size int | tuple[int, ...]

Size of the convolutional kernel.

(3, 3)
padding str

Padding to use.

'SAME'
feature_group_count int

Number of feature groups.

1
use_bias bool

Whether to use bias in convolutional layers.

False
activation_fn Callable | None

Activation function to use.

None
rngs Rngs

rng key.

required
Source code in src/cax/core/perceive/conv_perceive.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def __init__(
	self,
	channel_size: int,
	perception_size: int,
	*,
	kernel_size: int | tuple[int, ...] = (3, 3),
	padding: str = "SAME",
	feature_group_count: int = 1,
	use_bias: bool = False,
	activation_fn: Callable | None = None,
	rngs: nnx.Rngs,
):
	"""Initialize convolution perceive.

	Args:
		channel_size: Number of input channels.
		perception_size: Number of output perception features.
		kernel_size: Size of the convolutional kernel.
		padding: Padding to use.
		feature_group_count: Number of feature groups.
		use_bias: Whether to use bias in convolutional layers.
		activation_fn: Activation function to use.
		rngs: rng key.

	"""
	self.conv = nnx.Conv(
		in_features=channel_size,
		out_features=perception_size,
		kernel_size=kernel_size,
		padding=padding,
		feature_group_count=feature_group_count,
		use_bias=use_bias,
		rngs=rngs,
	)
	self.activation_fn = activation_fn

__call__(state)

Apply convolutional perception to the input state.

Inputs are expected to have shape (..., *spatial_dims, channel_size) where spatial_dims is a tuple of num_spatial_dims dimensions and channel_size is the number of channels. The output shape is (..., *spatial_dims, perception_size). If activation_fn is provided, it is applied element-wise to the convolution output. If activation_fn is None, the convolution output is returned as is.

Parameters:

Name Type Description Default
state Array

State of the cellular automaton.

required

Returns:

Type Description
Perception

The perceived state after applying convolutional layers.

Source code in src/cax/core/perceive/conv_perceive.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def __call__(self, state: Array) -> Perception:
	"""Apply convolutional perception to the input state.

	Inputs are expected to have shape `(..., *spatial_dims, channel_size)` where `spatial_dims`
	is a tuple of `num_spatial_dims` dimensions and `channel_size` is the number of channels.
	The output shape is `(..., *spatial_dims, perception_size)`. If `activation_fn` is provided,
	it is applied element-wise to the convolution output. If `activation_fn` is `None`, the
	convolution output is returned as is.

	Args:
		state: State of the cellular automaton.

	Returns:
		The perceived state after applying convolutional layers.

	"""
	perception = self.conv(state)
	return self.activation_fn(perception) if self.activation_fn else perception

cax.core.perceive.kernels

Kernel utilities for perception modules.

Each function returns a small spatial kernel suitable for neighborhood aggregation or finite-difference style operations. Kernels use channel-last layout and a support of size 3 along each spatial dimension.

identity_kernel(*, num_dims)

Create an identity kernel for the given number of dimensions.

The kernel has value 1 at the central position and 0 elsewhere.

Parameters:

Name Type Description Default
num_dims int

Number of dimensions for the kernel.

required

Returns:

Type Description
Array

Array with shape num_dims * (3,) + (1,).

Source code in src/cax/core/perceive/kernels.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def identity_kernel(*, num_dims: int) -> Array:
	"""Create an identity kernel for the given number of dimensions.

	The kernel has value 1 at the central position and 0 elsewhere.

	Args:
		num_dims: Number of dimensions for the kernel.

	Returns:
		Array with shape `num_dims * (3,) + (1,)`.

	"""
	kernel = jnp.zeros(num_dims * (3,))
	center_idx = num_dims * (1,)
	kernel = kernel.at[center_idx].set(1.0)
	return jnp.expand_dims(kernel, axis=-1)

neighbors_kernel(*, num_dims)

Create a neighbors kernel for the given number of dimensions.

This kernel is 1 - identity_kernel, selecting all neighbors and excluding the center.

Parameters:

Name Type Description Default
num_dims int

Number of dimensions for the kernel.

required

Returns:

Type Description
Array

Array with shape num_dims * (3,) + (1,).

Source code in src/cax/core/perceive/kernels.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def neighbors_kernel(*, num_dims: int) -> Array:
	"""Create a neighbors kernel for the given number of dimensions.

	This kernel is `1 - identity_kernel`, selecting all neighbors and excluding the center.

	Args:
		num_dims: Number of dimensions for the kernel.

	Returns:
		Array with shape `num_dims * (3,) + (1,)`.

	"""
	kernel = identity_kernel(num_dims=num_dims)
	return 1.0 - kernel

grad_kernel(*, num_dims, normalize=True)

Create a gradient kernel for the given number of dimensions.

Parameters:

Name Type Description Default
num_dims int

Number of dimensions for the kernel.

required
normalize bool

Whether to L1-normalize each axis kernel.

True

Returns:

Type Description
Array

Array with shape num_dims * (3,) + (num_dims,), one kernel per spatial axis.

Source code in src/cax/core/perceive/kernels.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def grad_kernel(*, num_dims: int, normalize: bool = True) -> Array:
	"""Create a gradient kernel for the given number of dimensions.

	Args:
		num_dims: Number of dimensions for the kernel.
		normalize: Whether to L1-normalize each axis kernel.

	Returns:
		Array with shape `num_dims * (3,) + (num_dims,)`, one kernel per spatial axis.

	"""
	grad = jnp.array([-1, 0, 1])
	smooth = jnp.array([1, 2, 1])

	kernels = []
	for i in range(num_dims):
		kernel = jnp.ones([3] * num_dims)

		for j in range(num_dims):
			axis_kernel = smooth if i != j else grad
			kernel = kernel * axis_kernel.reshape([-1 if k == j else 1 for k in range(num_dims)])

		kernels.append(kernel)

	if normalize:
		kernels = [kernel / jnp.sum(jnp.abs(kernel)) for kernel in kernels]

	return jnp.stack(kernels, axis=-1)

grad2_kernel(*, num_dims, normalize=True)

Create a second-order (Laplacian) kernel.

Parameters:

Name Type Description Default
num_dims int

Number of dimensions for the kernel.

required
normalize bool

Whether to L1-normalize the kernel.

True

Returns:

Type Description
Array

Array with shape num_dims * (3,) + (1,).

Source code in src/cax/core/perceive/kernels.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def grad2_kernel(*, num_dims: int, normalize: bool = True) -> Array:
	"""Create a second-order (Laplacian) kernel.

	Args:
		num_dims: Number of dimensions for the kernel.
		normalize: Whether to L1-normalize the kernel.

	Returns:
		Array with shape `num_dims * (3,) + (1,)`.

	"""
	kernel = jnp.zeros([3] * num_dims)
	center = tuple(1 for _ in range(num_dims))
	kernel = kernel.at[center].set(-2.0 * num_dims)

	for axis in range(num_dims):
		for offset in (-1, 1):
			idx = list(center)
			idx[axis] += offset
			kernel = kernel.at[tuple(idx)].set(1.0)

	if normalize:
		kernel = kernel / jnp.sum(jnp.abs(kernel))

	return jnp.expand_dims(kernel, axis=-1)