Skip to content

Elementary Cellular Automata

cax.cs.elementary.cs.Elementary

Bases: ComplexSystem[Array, Array]

Elementary Cellular Automata class.

A one-dimensional cellular automaton where each cell evolves based on its current state and the states of its two immediate neighbors according to a Wolfram rule. The system supports all 256 possible rules and can simulate classic patterns such as Rule 30, Rule 110, and Rule 184.

Source code in src/cax/cs/elementary/cs.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class Elementary(ComplexSystem[Array, Array]):
	"""Elementary Cellular Automata class.

	A one-dimensional cellular automaton where each cell evolves based on its current state and the
	states of its two immediate neighbors according to a Wolfram rule. The system supports all 256
	possible rules and can simulate classic patterns such as Rule 30, Rule 110, and Rule 184.
	"""

	def __init__(
		self,
		*,
		wolfram_code: Array,
		rngs: nnx.Rngs,
	):
		"""Initialize Elementary Cellular Automaton.

		Args:
			wolfram_code: Array of 8 binary values defining the Wolfram rule. Each element
				corresponds to the output for one of the 8 possible three-cell neighborhood
				configurations (111, 110, 101, 100, 011, 010, 001, 000).
			rngs: rng key.

		"""
		self.perceive = ElementaryPerceive()
		self.update = ElementaryUpdate(wolfram_code=wolfram_code)

	def _step(self, state: Array, input: Array | None = None, *, sow: bool = False) -> Array:
		perception = self.perceive(state)
		next_state = self.update(state, perception, input)

		if sow:
			self.sow(nnx.Intermediate, "state", next_state)

		return next_state

	@classmethod
	def wolfram_code_from_rule_number(cls, rule_number: int) -> Array:
		"""Create Wolfram code array from a rule number.

		Converts a Wolfram rule number (0-255) to its binary representation as an array
		of 8 floats. For example, rule 30 becomes [0, 0, 0, 1, 1, 1, 1, 0].

		Args:
			rule_number: Integer between 0 and 255 representing the Wolfram rule.

		Returns:
			Array of shape (8,) containing binary values (0.0 or 1.0) representing
				the rule's lookup table.

		"""
		assert 0 <= rule_number < 256, "Wolfram code must be between 0 and 255."
		return ((rule_number >> 7 - jnp.arange(8)) & 1).astype(jnp.float32)

	@nnx.jit
	def render(self, state: Array) -> Array:
		"""Render state to RGB image.

		Converts the one-dimensional cellular automaton state to an RGB visualization
		by replicating the single-channel state values across all three color channels,
		resulting in a grayscale image.

		Args:
			state: Array with shape (num_steps, width, 1) representing the
				cellular automaton state, where each cell contains a value in [0, 1].

		Returns:
			RGB image with dtype uint8 and shape (num_steps, width, 3), where cell values are mapped
				to grayscale colors in the range [0, 255].

		"""
		rgb = jnp.repeat(state, 3, axis=-1)

		return clip_and_uint8(rgb)

__init__(*, wolfram_code, rngs)

Initialize Elementary Cellular Automaton.

Parameters:

Name Type Description Default
wolfram_code Array

Array of 8 binary values defining the Wolfram rule. Each element corresponds to the output for one of the 8 possible three-cell neighborhood configurations (111, 110, 101, 100, 011, 010, 001, 000).

required
rngs Rngs

rng key.

required
Source code in src/cax/cs/elementary/cs.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def __init__(
	self,
	*,
	wolfram_code: Array,
	rngs: nnx.Rngs,
):
	"""Initialize Elementary Cellular Automaton.

	Args:
		wolfram_code: Array of 8 binary values defining the Wolfram rule. Each element
			corresponds to the output for one of the 8 possible three-cell neighborhood
			configurations (111, 110, 101, 100, 011, 010, 001, 000).
		rngs: rng key.

	"""
	self.perceive = ElementaryPerceive()
	self.update = ElementaryUpdate(wolfram_code=wolfram_code)

wolfram_code_from_rule_number(rule_number) classmethod

Create Wolfram code array from a rule number.

Converts a Wolfram rule number (0-255) to its binary representation as an array of 8 floats. For example, rule 30 becomes [0, 0, 0, 1, 1, 1, 1, 0].

Parameters:

Name Type Description Default
rule_number int

Integer between 0 and 255 representing the Wolfram rule.

required

Returns:

Type Description
Array

Array of shape (8,) containing binary values (0.0 or 1.0) representing the rule's lookup table.

Source code in src/cax/cs/elementary/cs.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@classmethod
def wolfram_code_from_rule_number(cls, rule_number: int) -> Array:
	"""Create Wolfram code array from a rule number.

	Converts a Wolfram rule number (0-255) to its binary representation as an array
	of 8 floats. For example, rule 30 becomes [0, 0, 0, 1, 1, 1, 1, 0].

	Args:
		rule_number: Integer between 0 and 255 representing the Wolfram rule.

	Returns:
		Array of shape (8,) containing binary values (0.0 or 1.0) representing
			the rule's lookup table.

	"""
	assert 0 <= rule_number < 256, "Wolfram code must be between 0 and 255."
	return ((rule_number >> 7 - jnp.arange(8)) & 1).astype(jnp.float32)

render(state)

Render state to RGB image.

Converts the one-dimensional cellular automaton state to an RGB visualization by replicating the single-channel state values across all three color channels, resulting in a grayscale image.

Parameters:

Name Type Description Default
state Array

Array with shape (num_steps, width, 1) representing the cellular automaton state, where each cell contains a value in [0, 1].

required

Returns:

Type Description
Array

RGB image with dtype uint8 and shape (num_steps, width, 3), where cell values are mapped to grayscale colors in the range [0, 255].

Source code in src/cax/cs/elementary/cs.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
@nnx.jit
def render(self, state: Array) -> Array:
	"""Render state to RGB image.

	Converts the one-dimensional cellular automaton state to an RGB visualization
	by replicating the single-channel state values across all three color channels,
	resulting in a grayscale image.

	Args:
		state: Array with shape (num_steps, width, 1) representing the
			cellular automaton state, where each cell contains a value in [0, 1].

	Returns:
		RGB image with dtype uint8 and shape (num_steps, width, 3), where cell values are mapped
			to grayscale colors in the range [0, 255].

	"""
	rgb = jnp.repeat(state, 3, axis=-1)

	return clip_and_uint8(rgb)

__call__(state, input=None, *, num_steps=1, input_in_axis=None, sow=False)

Step the system for multiple time steps.

This method wraps _step inside a JAX scan for efficiency and JIT-compiles the loop. If input is time-varying, set input_in_axis to the axis containing the time dimension so that each step receives the corresponding slice of input.

When remat is enabled, the scan body is wrapped with nnx.remat to reduce memory usage during backpropagation at the cost of recomputing intermediates.

Parameters:

Name Type Description Default
state State

Current state.

required
input Input | None

Optional input.

None
num_steps int

Number of steps.

1
input_in_axis int | None

Axis for input if provided for each step.

None
sow bool

Whether to sow intermediate values.

False

Returns:

Type Description
State

Final state after num_steps applications of _step.

Source code in src/cax/core/cs.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@nnx.jit(static_argnames=("num_steps", "input_in_axis", "sow"))
def __call__(
	self,
	state: State,
	input: Input | None = None,
	*,
	num_steps: int = 1,
	input_in_axis: int | None = None,
	sow: bool = False,
) -> State:
	"""Step the system for multiple time steps.

	This method wraps `_step` inside a JAX scan for efficiency and JIT-compiles the loop.
	If `input` is time-varying, set `input_in_axis` to the axis containing the time
	dimension so that each step receives the corresponding slice of input.

	When `remat` is enabled, the scan body is wrapped with `nnx.remat` to reduce memory
	usage during backpropagation at the cost of recomputing intermediates.

	Args:
		state: Current state.
		input: Optional input.
		num_steps: Number of steps.
		input_in_axis: Axis for input if provided for each step.
		sow: Whether to sow intermediate values.

	Returns:
		Final state after `num_steps` applications of `_step`.

	"""

	def step_fn(cs: ComplexSystem, state: State, input: Input | None) -> State:
		return cs._step(state, input, sow=sow)

	if self.remat:
		step_fn = nnx.remat(step_fn)

	state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry})
	state = nnx.scan(
		step_fn,
		in_axes=(state_axes, nnx.Carry, input_in_axis),
		out_axes=nnx.Carry,
		length=num_steps,
	)(self, state, input)

	return state