Skip to content

Conway's Game of Life

cax.cs.life.cs.Life

Bases: ComplexSystem

Conway's Game of Life and Life-like cellular automata.

A two-dimensional cellular automaton where each cell evolves based on its current state (alive or dead) and the number of alive neighbors in its Moore neighborhood. The system is defined by birth and survival rules that determine when cells become alive or remain alive. Classic examples include Conway's Game of Life (B3/S23), HighLife (B36/S23), and Day & Night (B3678/S34678).

Source code in src/cax/cs/life/cs.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class Life(ComplexSystem):
	"""Conway's Game of Life and Life-like cellular automata.

	A two-dimensional cellular automaton where each cell evolves based on its current
	state (alive or dead) and the number of alive neighbors in its Moore neighborhood.
	The system is defined by birth and survival rules that determine when cells become
	alive or remain alive. Classic examples include Conway's Game of Life (B3/S23),
	HighLife (B36/S23), and Day & Night (B3678/S34678).
	"""

	def __init__(
		self,
		*,
		birth: Array,
		survival: Array,
		rngs: nnx.Rngs,
	):
		"""Initialize Life.

		Args:
			birth: Array of shape (9,) defining birth conditions. Element i is 1.0 if a dead
				cell with i alive neighbors should become alive, 0.0 otherwise.
			survival: Array of shape (9,) defining survival conditions. Element i is 1.0 if a
				live cell with i alive neighbors should stay alive, 0.0 otherwise.
			rngs: rng key.

		"""
		self.perceive = LifePerceive(rngs=rngs)
		self.update = LifeUpdate(birth=birth, survival=survival)

	def _step(self, state: State, input: Input | None = None, *, sow: bool = False) -> State:
		perception = self.perceive(state)
		next_state = self.update(state, perception, input)

		if sow:
			self.sow(nnx.Intermediate, "state", next_state)

		return next_state

	@classmethod
	def birth_survival_from_string(cls, rule_golly: str) -> tuple[Array, Array]:
		"""Create birth and survival arrays from a rule string in Golly format.

		Parses a rule string in the standard B/S notation used by Golly and other
		Life simulators. For example, "B3/S23" represents Conway's Game of Life,
		where dead cells with exactly 3 neighbors become alive (Birth), and live
		cells with 2 or 3 neighbors survive (Survival).

		Args:
			rule_golly: Rule string in format "B{birth_numbers}/S{survival_numbers}",
				where birth_numbers and survival_numbers are digits from 0 to 8.
				For example, "B3/S23" for Conway's Game of Life.

		Returns:
			Tuple of (birth, survival) arrays, each of shape (9,) containing binary
				values (0.0 or 1.0) indicating which neighbor counts activate the rule.

		"""
		assert "/" in rule_golly, (
			f"Invalid rule string format: {rule_golly}. Expected format: B{{digits}}/S{{digits}}"
		)

		# Split the rule string into birth and survival parts
		birth_string, survival_string = rule_golly.split("/")

		assert birth_string.startswith("B"), (
			f"Invalid rule string format: {rule_golly}. Expected format: B{{digits}}/S{{digits}}"
		)
		assert survival_string.startswith("S"), (
			f"Invalid rule string format: {rule_golly}. Expected format: B{{digits}}/S{{digits}}"
		)

		# Extract the birth and survival numbers
		birth_numbers = [int(digit) for digit in birth_string[1:]]
		survival_numbers = [int(digit) for digit in survival_string[1:]]

		assert all(0 <= num <= 8 for num in birth_numbers + survival_numbers), (
			"Numbers in rule string must be between 0 and 8."
		)

		# Create birth and survival rules
		birth = jnp.array(
			[1.0 if num_neighbors in birth_numbers else 0.0 for num_neighbors in range(9)]
		)
		survival = jnp.array(
			[1.0 if num_neighbors in survival_numbers else 0.0 for num_neighbors in range(9)]
		)
		return birth, survival

	@nnx.jit
	def render(self, state: State) -> Array:
		"""Render state to RGB image.

		Converts the Life state to an RGB visualization by replicating the single-channel
		state values across all three color channels, resulting in a grayscale image where
		alive cells appear white and dead cells appear black.

		Args:
			state: Array with shape (..., height, width, 1) representing the Life state,
				where each cell is 0.0 (dead) or 1.0 (alive).

		Returns:
			RGB image with dtype uint8 and shape (..., height, width, 3), where cell
				values are mapped to grayscale colors in the range [0, 255].

		"""
		rgb = jnp.repeat(state, 3, axis=-1)

		return clip_and_uint8(rgb)

__init__(*, birth, survival, rngs)

Initialize Life.

Parameters:

Name Type Description Default
birth Array

Array of shape (9,) defining birth conditions. Element i is 1.0 if a dead cell with i alive neighbors should become alive, 0.0 otherwise.

required
survival Array

Array of shape (9,) defining survival conditions. Element i is 1.0 if a live cell with i alive neighbors should stay alive, 0.0 otherwise.

required
rngs Rngs

rng key.

required
Source code in src/cax/cs/life/cs.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def __init__(
	self,
	*,
	birth: Array,
	survival: Array,
	rngs: nnx.Rngs,
):
	"""Initialize Life.

	Args:
		birth: Array of shape (9,) defining birth conditions. Element i is 1.0 if a dead
			cell with i alive neighbors should become alive, 0.0 otherwise.
		survival: Array of shape (9,) defining survival conditions. Element i is 1.0 if a
			live cell with i alive neighbors should stay alive, 0.0 otherwise.
		rngs: rng key.

	"""
	self.perceive = LifePerceive(rngs=rngs)
	self.update = LifeUpdate(birth=birth, survival=survival)

birth_survival_from_string(rule_golly) classmethod

Create birth and survival arrays from a rule string in Golly format.

Parses a rule string in the standard B/S notation used by Golly and other Life simulators. For example, "B3/S23" represents Conway's Game of Life, where dead cells with exactly 3 neighbors become alive (Birth), and live cells with 2 or 3 neighbors survive (Survival).

Parameters:

Name Type Description Default
rule_golly str

Rule string in format "B{birth_numbers}/S{survival_numbers}", where birth_numbers and survival_numbers are digits from 0 to 8. For example, "B3/S23" for Conway's Game of Life.

required

Returns:

Type Description
tuple[Array, Array]

Tuple of (birth, survival) arrays, each of shape (9,) containing binary values (0.0 or 1.0) indicating which neighbor counts activate the rule.

Source code in src/cax/cs/life/cs.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
@classmethod
def birth_survival_from_string(cls, rule_golly: str) -> tuple[Array, Array]:
	"""Create birth and survival arrays from a rule string in Golly format.

	Parses a rule string in the standard B/S notation used by Golly and other
	Life simulators. For example, "B3/S23" represents Conway's Game of Life,
	where dead cells with exactly 3 neighbors become alive (Birth), and live
	cells with 2 or 3 neighbors survive (Survival).

	Args:
		rule_golly: Rule string in format "B{birth_numbers}/S{survival_numbers}",
			where birth_numbers and survival_numbers are digits from 0 to 8.
			For example, "B3/S23" for Conway's Game of Life.

	Returns:
		Tuple of (birth, survival) arrays, each of shape (9,) containing binary
			values (0.0 or 1.0) indicating which neighbor counts activate the rule.

	"""
	assert "/" in rule_golly, (
		f"Invalid rule string format: {rule_golly}. Expected format: B{{digits}}/S{{digits}}"
	)

	# Split the rule string into birth and survival parts
	birth_string, survival_string = rule_golly.split("/")

	assert birth_string.startswith("B"), (
		f"Invalid rule string format: {rule_golly}. Expected format: B{{digits}}/S{{digits}}"
	)
	assert survival_string.startswith("S"), (
		f"Invalid rule string format: {rule_golly}. Expected format: B{{digits}}/S{{digits}}"
	)

	# Extract the birth and survival numbers
	birth_numbers = [int(digit) for digit in birth_string[1:]]
	survival_numbers = [int(digit) for digit in survival_string[1:]]

	assert all(0 <= num <= 8 for num in birth_numbers + survival_numbers), (
		"Numbers in rule string must be between 0 and 8."
	)

	# Create birth and survival rules
	birth = jnp.array(
		[1.0 if num_neighbors in birth_numbers else 0.0 for num_neighbors in range(9)]
	)
	survival = jnp.array(
		[1.0 if num_neighbors in survival_numbers else 0.0 for num_neighbors in range(9)]
	)
	return birth, survival

render(state)

Render state to RGB image.

Converts the Life state to an RGB visualization by replicating the single-channel state values across all three color channels, resulting in a grayscale image where alive cells appear white and dead cells appear black.

Parameters:

Name Type Description Default
state State

Array with shape (..., height, width, 1) representing the Life state, where each cell is 0.0 (dead) or 1.0 (alive).

required

Returns:

Type Description
Array

RGB image with dtype uint8 and shape (..., height, width, 3), where cell values are mapped to grayscale colors in the range [0, 255].

Source code in src/cax/cs/life/cs.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
@nnx.jit
def render(self, state: State) -> Array:
	"""Render state to RGB image.

	Converts the Life state to an RGB visualization by replicating the single-channel
	state values across all three color channels, resulting in a grayscale image where
	alive cells appear white and dead cells appear black.

	Args:
		state: Array with shape (..., height, width, 1) representing the Life state,
			where each cell is 0.0 (dead) or 1.0 (alive).

	Returns:
		RGB image with dtype uint8 and shape (..., height, width, 3), where cell
			values are mapped to grayscale colors in the range [0, 255].

	"""
	rgb = jnp.repeat(state, 3, axis=-1)

	return clip_and_uint8(rgb)

__call__(state, input=None, *, num_steps=1, input_in_axis=None, sow=False)

Step the system for multiple time steps.

This method wraps _step inside a JAX scan for efficiency and JIT-compiles the loop. If input is time-varying, set input_in_axis to the axis containing the time dimension so that each step receives the corresponding slice of input.

Parameters:

Name Type Description Default
state State

Current state.

required
input Input | None

Optional input.

None
num_steps int

Number of steps.

1
input_in_axis int | None

Axis for input if provided for each step.

None
sow bool

Whether to sow intermediate values.

False

Returns:

Type Description
State

Final state after num_steps applications of _step.

Source code in src/cax/core/cs.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@partial(nnx.jit, static_argnames=("num_steps", "input_in_axis", "sow"))
def __call__(
	self,
	state: State,
	input: Input | None = None,
	*,
	num_steps: int = 1,
	input_in_axis: int | None = None,
	sow: bool = False,
) -> State:
	"""Step the system for multiple time steps.

	This method wraps `_step` inside a JAX scan for efficiency and JIT-compiles the loop.
	If `input` is time-varying, set `input_in_axis` to the axis containing the time
	dimension so that each step receives the corresponding slice of input.

	Args:
		state: Current state.
		input: Optional input.
		num_steps: Number of steps.
		input_in_axis: Axis for input if provided for each step.
		sow: Whether to sow intermediate values.

	Returns:
		Final state after `num_steps` applications of `_step`.

	"""
	state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry})
	state = nnx.scan(
		lambda cs, state, input: cs._step(state, input, sow=sow),
		in_axes=(state_axes, nnx.Carry, input_in_axis),
		out_axes=nnx.Carry,
		length=num_steps,
	)(self, state, input)

	return state