Skip to content

Elementary Cellular Automata

cax.cs.elementary.cs.Elementary

Bases: ComplexSystem

Elementary Cellular Automata class.

A one-dimensional cellular automaton where each cell evolves based on its current state and the states of its two immediate neighbors according to a Wolfram rule. The system supports all 256 possible rules and can simulate classic patterns such as Rule 30, Rule 110, and Rule 184.

Source code in src/cax/cs/elementary/cs.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class Elementary(ComplexSystem):
	"""Elementary Cellular Automata class.

	A one-dimensional cellular automaton where each cell evolves based on its current state and the
	states of its two immediate neighbors according to a Wolfram rule. The system supports all 256
	possible rules and can simulate classic patterns such as Rule 30, Rule 110, and Rule 184.
	"""

	def __init__(
		self,
		*,
		wolfram_code: Array,
		rngs: nnx.Rngs,
	):
		"""Initialize Elementary Cellular Automaton.

		Args:
			wolfram_code: Array of 8 binary values defining the Wolfram rule. Each element
				corresponds to the output for one of the 8 possible three-cell neighborhood
				configurations (111, 110, 101, 100, 011, 010, 001, 000).
			rngs: rng key.

		"""
		self.perceive = ElementaryPerceive(rngs=rngs)
		self.update = ElementaryUpdate(wolfram_code=wolfram_code)

	def _step(self, state: State, input: Input | None = None, *, sow: bool = False) -> State:
		perception = self.perceive(state)
		next_state = self.update(state, perception, input)

		if sow:
			self.sow(nnx.Intermediate, "state", next_state)

		return next_state

	@classmethod
	def wolfram_code_from_rule_number(cls, rule_number: int) -> Array:
		"""Create Wolfram code array from a rule number.

		Converts a Wolfram rule number (0-255) to its binary representation as an array
		of 8 floats. For example, rule 30 becomes [0, 0, 0, 1, 1, 1, 1, 0].

		Args:
			rule_number: Integer between 0 and 255 representing the Wolfram rule.

		Returns:
			Array of shape (8,) containing binary values (0.0 or 1.0) representing
				the rule's lookup table.

		"""
		assert 0 <= rule_number < 256, "Wolfram code must be between 0 and 255."
		return ((rule_number >> 7 - jnp.arange(8)) & 1).astype(jnp.float32)

	@nnx.jit
	def render(self, state: State) -> Array:
		"""Render state to RGB image.

		Converts the one-dimensional cellular automaton state to an RGB visualization
		by replicating the single-channel state values across all three color channels,
		resulting in a grayscale image.

		Args:
			state: Array with shape (num_steps, width, 1) representing the
				cellular automaton state, where each cell contains a value in [0, 1].

		Returns:
			RGB image with dtype uint8 and shape (num_steps, width, 3), where cell values are mapped
				to grayscale colors in the range [0, 255].

		"""
		rgb = jnp.repeat(state, 3, axis=-1)

		return clip_and_uint8(rgb)

__init__(*, wolfram_code, rngs)

Initialize Elementary Cellular Automaton.

Parameters:

Name Type Description Default
wolfram_code Array

Array of 8 binary values defining the Wolfram rule. Each element corresponds to the output for one of the 8 possible three-cell neighborhood configurations (111, 110, 101, 100, 011, 010, 001, 000).

required
rngs Rngs

rng key.

required
Source code in src/cax/cs/elementary/cs.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def __init__(
	self,
	*,
	wolfram_code: Array,
	rngs: nnx.Rngs,
):
	"""Initialize Elementary Cellular Automaton.

	Args:
		wolfram_code: Array of 8 binary values defining the Wolfram rule. Each element
			corresponds to the output for one of the 8 possible three-cell neighborhood
			configurations (111, 110, 101, 100, 011, 010, 001, 000).
		rngs: rng key.

	"""
	self.perceive = ElementaryPerceive(rngs=rngs)
	self.update = ElementaryUpdate(wolfram_code=wolfram_code)

wolfram_code_from_rule_number(rule_number) classmethod

Create Wolfram code array from a rule number.

Converts a Wolfram rule number (0-255) to its binary representation as an array of 8 floats. For example, rule 30 becomes [0, 0, 0, 1, 1, 1, 1, 0].

Parameters:

Name Type Description Default
rule_number int

Integer between 0 and 255 representing the Wolfram rule.

required

Returns:

Type Description
Array

Array of shape (8,) containing binary values (0.0 or 1.0) representing the rule's lookup table.

Source code in src/cax/cs/elementary/cs.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@classmethod
def wolfram_code_from_rule_number(cls, rule_number: int) -> Array:
	"""Create Wolfram code array from a rule number.

	Converts a Wolfram rule number (0-255) to its binary representation as an array
	of 8 floats. For example, rule 30 becomes [0, 0, 0, 1, 1, 1, 1, 0].

	Args:
		rule_number: Integer between 0 and 255 representing the Wolfram rule.

	Returns:
		Array of shape (8,) containing binary values (0.0 or 1.0) representing
			the rule's lookup table.

	"""
	assert 0 <= rule_number < 256, "Wolfram code must be between 0 and 255."
	return ((rule_number >> 7 - jnp.arange(8)) & 1).astype(jnp.float32)

render(state)

Render state to RGB image.

Converts the one-dimensional cellular automaton state to an RGB visualization by replicating the single-channel state values across all three color channels, resulting in a grayscale image.

Parameters:

Name Type Description Default
state State

Array with shape (num_steps, width, 1) representing the cellular automaton state, where each cell contains a value in [0, 1].

required

Returns:

Type Description
Array

RGB image with dtype uint8 and shape (num_steps, width, 3), where cell values are mapped to grayscale colors in the range [0, 255].

Source code in src/cax/cs/elementary/cs.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
@nnx.jit
def render(self, state: State) -> Array:
	"""Render state to RGB image.

	Converts the one-dimensional cellular automaton state to an RGB visualization
	by replicating the single-channel state values across all three color channels,
	resulting in a grayscale image.

	Args:
		state: Array with shape (num_steps, width, 1) representing the
			cellular automaton state, where each cell contains a value in [0, 1].

	Returns:
		RGB image with dtype uint8 and shape (num_steps, width, 3), where cell values are mapped
			to grayscale colors in the range [0, 255].

	"""
	rgb = jnp.repeat(state, 3, axis=-1)

	return clip_and_uint8(rgb)

__call__(state, input=None, *, num_steps=1, input_in_axis=None, sow=False)

Step the system for multiple time steps.

This method wraps _step inside a JAX scan for efficiency and JIT-compiles the loop. If input is time-varying, set input_in_axis to the axis containing the time dimension so that each step receives the corresponding slice of input.

Parameters:

Name Type Description Default
state State

Current state.

required
input Input | None

Optional input.

None
num_steps int

Number of steps.

1
input_in_axis int | None

Axis for input if provided for each step.

None
sow bool

Whether to sow intermediate values.

False

Returns:

Type Description
State

Final state after num_steps applications of _step.

Source code in src/cax/core/cs.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@partial(nnx.jit, static_argnames=("num_steps", "input_in_axis", "sow"))
def __call__(
	self,
	state: State,
	input: Input | None = None,
	*,
	num_steps: int = 1,
	input_in_axis: int | None = None,
	sow: bool = False,
) -> State:
	"""Step the system for multiple time steps.

	This method wraps `_step` inside a JAX scan for efficiency and JIT-compiles the loop.
	If `input` is time-varying, set `input_in_axis` to the axis containing the time
	dimension so that each step receives the corresponding slice of input.

	Args:
		state: Current state.
		input: Optional input.
		num_steps: Number of steps.
		input_in_axis: Axis for input if provided for each step.
		sow: Whether to sow intermediate values.

	Returns:
		Final state after `num_steps` applications of `_step`.

	"""
	state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry})
	state = nnx.scan(
		lambda cs, state, input: cs._step(state, input, sow=sow),
		in_axes=(state_axes, nnx.Carry, input_in_axis),
		out_axes=nnx.Carry,
		length=num_steps,
	)(self, state, input)

	return state