Skip to content

Langton's Ant

cax.cs.langton_ant.cs.LangtonAnt

Bases: ComplexSystem[LangtonAntState, Array]

Langton's Ant and multi-color generalizations.

A two-dimensional cellular automaton with a mobile agent that traverses a grid of colored cells. At each step the ant turns according to the color of its current cell, advances the cell to the next color in a cyclic sequence, and moves forward one step. The rule is specified as a sequence of turn directions — one per color — encoded as a string over the alphabet {R, L, N, U}.

Classic examples include "RL" (Langton's original ant, produces a highway), "LLRR" (symmetric growth filling a square), and "LRRRRRLLR" (triangle-building ant).

Source code in src/cax/cs/langton_ant/cs.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class LangtonAnt(ComplexSystem[LangtonAntState, Array]):
	"""Langton's Ant and multi-color generalizations.

	A two-dimensional cellular automaton with a mobile agent that traverses a grid of colored
	cells. At each step the ant turns according to the color of its current cell, advances the
	cell to the next color in a cyclic sequence, and moves forward one step. The rule is
	specified as a sequence of turn directions — one per color — encoded as a string over the
	alphabet {R, L, N, U}.

	Classic examples include "RL" (Langton's original ant, produces a highway), "LLRR"
	(symmetric growth filling a square), and "LRRRRRLLR" (triangle-building ant).
	"""

	def __init__(self, *, turns: Array, rngs: nnx.Rngs):
		"""Initialize Langton's Ant.

		Args:
			turns: Array of shape (num_colors,) with integer values in {0, 1, 2, 3}
				encoding turn directions for each cell color. Values correspond to
				clockwise rotation in multiples of 90 degrees: 0=no turn, 1=right,
				2=u-turn, 3=left.
			rngs: rng key.

		"""
		self.perceive = LangtonAntPerceive()
		self.update = LangtonAntUpdate(turns=turns)

	def _step(
		self, state: LangtonAntState, input: Array | None = None, *, sow: bool = False
	) -> LangtonAntState:
		perception = self.perceive(state)
		next_state = self.update(state, perception, input)

		if sow:
			self.sow(nnx.Intermediate, "state", next_state)

		return next_state

	@classmethod
	def turns_from_rule_string(cls, rule_string: str) -> Array:
		"""Create turns array from a rule string.

		Parses a rule string in the standard Langton's Ant notation where each character
		specifies the turn direction for the corresponding cell color.

		Args:
			rule_string: String of characters from {R, L, N, U}. For example, "RL" for
				classic Langton's Ant. The length of the string determines the number of
				cell colors.

		Returns:
			Array of shape (num_colors,) with integer turn values where 0=no turn,
				1=right (90 degrees clockwise), 2=u-turn (180 degrees), 3=left (90 degrees
				counter-clockwise).

		"""
		assert len(rule_string) >= 2, (
			f"Rule string must have at least 2 characters, got: '{rule_string}'"
		)

		valid_chars = set(TURN_CHAR_TO_INT.keys())
		for char in rule_string:
			assert char in valid_chars, (
				f"Invalid character '{char}' in rule string '{rule_string}'. "
				f"Valid characters are: {sorted(valid_chars)}"
			)

		turns = jnp.array([TURN_CHAR_TO_INT[char] for char in rule_string], dtype=jnp.int32)
		return turns

	@nnx.jit
	def render(self, state: LangtonAntState) -> Array:
		"""Render state to RGB image.

		Visualizes the grid by mapping each cell color to a distinct hue. For 2-color rules
		the rendering is grayscale (black for color 0, white for color 1). For multi-color
		rules, color 0 (unvisited) is black and colors 1 through num_colors-1 are evenly
		spaced around the HSV hue wheel at full saturation and value.

		Args:
			state: Langton's Ant state with grid of shape (height, width, 1).

		Returns:
			RGB image with dtype uint8 and shape (height, width, 3).

		"""
		num_colors = self.update.turns.shape[0]

		if num_colors == 2:
			rgb = jnp.repeat(state.grid, 3, axis=-1)
		else:
			visited = state.grid[..., 0:1] > 0
			hue = jnp.where(visited, (state.grid[..., 0:1] - 1) / (num_colors - 1), 0.0)
			saturation = jnp.where(visited, 1.0, 0.0)
			value = jnp.where(visited, 1.0, 0.0)
			hsv = jnp.concatenate([hue, saturation, value], axis=-1)
			rgb = hsv_to_rgb(hsv)

		return clip_and_uint8(rgb)

__init__(*, turns, rngs)

Initialize Langton's Ant.

Parameters:

Name Type Description Default
turns Array

Array of shape (num_colors,) with integer values in {0, 1, 2, 3} encoding turn directions for each cell color. Values correspond to clockwise rotation in multiples of 90 degrees: 0=no turn, 1=right, 2=u-turn, 3=left.

required
rngs Rngs

rng key.

required
Source code in src/cax/cs/langton_ant/cs.py
37
38
39
40
41
42
43
44
45
46
47
48
49
def __init__(self, *, turns: Array, rngs: nnx.Rngs):
	"""Initialize Langton's Ant.

	Args:
		turns: Array of shape (num_colors,) with integer values in {0, 1, 2, 3}
			encoding turn directions for each cell color. Values correspond to
			clockwise rotation in multiples of 90 degrees: 0=no turn, 1=right,
			2=u-turn, 3=left.
		rngs: rng key.

	"""
	self.perceive = LangtonAntPerceive()
	self.update = LangtonAntUpdate(turns=turns)

turns_from_rule_string(rule_string) classmethod

Create turns array from a rule string.

Parses a rule string in the standard Langton's Ant notation where each character specifies the turn direction for the corresponding cell color.

Parameters:

Name Type Description Default
rule_string str

String of characters from {R, L, N, U}. For example, "RL" for classic Langton's Ant. The length of the string determines the number of cell colors.

required

Returns:

Type Description
Array

Array of shape (num_colors,) with integer turn values where 0=no turn, 1=right (90 degrees clockwise), 2=u-turn (180 degrees), 3=left (90 degrees counter-clockwise).

Source code in src/cax/cs/langton_ant/cs.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
@classmethod
def turns_from_rule_string(cls, rule_string: str) -> Array:
	"""Create turns array from a rule string.

	Parses a rule string in the standard Langton's Ant notation where each character
	specifies the turn direction for the corresponding cell color.

	Args:
		rule_string: String of characters from {R, L, N, U}. For example, "RL" for
			classic Langton's Ant. The length of the string determines the number of
			cell colors.

	Returns:
		Array of shape (num_colors,) with integer turn values where 0=no turn,
			1=right (90 degrees clockwise), 2=u-turn (180 degrees), 3=left (90 degrees
			counter-clockwise).

	"""
	assert len(rule_string) >= 2, (
		f"Rule string must have at least 2 characters, got: '{rule_string}'"
	)

	valid_chars = set(TURN_CHAR_TO_INT.keys())
	for char in rule_string:
		assert char in valid_chars, (
			f"Invalid character '{char}' in rule string '{rule_string}'. "
			f"Valid characters are: {sorted(valid_chars)}"
		)

	turns = jnp.array([TURN_CHAR_TO_INT[char] for char in rule_string], dtype=jnp.int32)
	return turns

render(state)

Render state to RGB image.

Visualizes the grid by mapping each cell color to a distinct hue. For 2-color rules the rendering is grayscale (black for color 0, white for color 1). For multi-color rules, color 0 (unvisited) is black and colors 1 through num_colors-1 are evenly spaced around the HSV hue wheel at full saturation and value.

Parameters:

Name Type Description Default
state LangtonAntState

Langton's Ant state with grid of shape (height, width, 1).

required

Returns:

Type Description
Array

RGB image with dtype uint8 and shape (height, width, 3).

Source code in src/cax/cs/langton_ant/cs.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
@nnx.jit
def render(self, state: LangtonAntState) -> Array:
	"""Render state to RGB image.

	Visualizes the grid by mapping each cell color to a distinct hue. For 2-color rules
	the rendering is grayscale (black for color 0, white for color 1). For multi-color
	rules, color 0 (unvisited) is black and colors 1 through num_colors-1 are evenly
	spaced around the HSV hue wheel at full saturation and value.

	Args:
		state: Langton's Ant state with grid of shape (height, width, 1).

	Returns:
		RGB image with dtype uint8 and shape (height, width, 3).

	"""
	num_colors = self.update.turns.shape[0]

	if num_colors == 2:
		rgb = jnp.repeat(state.grid, 3, axis=-1)
	else:
		visited = state.grid[..., 0:1] > 0
		hue = jnp.where(visited, (state.grid[..., 0:1] - 1) / (num_colors - 1), 0.0)
		saturation = jnp.where(visited, 1.0, 0.0)
		value = jnp.where(visited, 1.0, 0.0)
		hsv = jnp.concatenate([hue, saturation, value], axis=-1)
		rgb = hsv_to_rgb(hsv)

	return clip_and_uint8(rgb)

__call__(state, input=None, *, num_steps=1, input_in_axis=None, sow=False)

Step the system for multiple time steps.

This method wraps _step inside a JAX scan for efficiency and JIT-compiles the loop. If input is time-varying, set input_in_axis to the axis containing the time dimension so that each step receives the corresponding slice of input.

When remat is enabled, the scan body is wrapped with nnx.remat to reduce memory usage during backpropagation at the cost of recomputing intermediates.

Parameters:

Name Type Description Default
state State

Current state.

required
input Input | None

Optional input.

None
num_steps int

Number of steps.

1
input_in_axis int | None

Axis for input if provided for each step.

None
sow bool

Whether to sow intermediate values.

False

Returns:

Type Description
State

Final state after num_steps applications of _step.

Source code in src/cax/core/cs.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@nnx.jit(static_argnames=("num_steps", "input_in_axis", "sow"))
def __call__(
	self,
	state: State,
	input: Input | None = None,
	*,
	num_steps: int = 1,
	input_in_axis: int | None = None,
	sow: bool = False,
) -> State:
	"""Step the system for multiple time steps.

	This method wraps `_step` inside a JAX scan for efficiency and JIT-compiles the loop.
	If `input` is time-varying, set `input_in_axis` to the axis containing the time
	dimension so that each step receives the corresponding slice of input.

	When `remat` is enabled, the scan body is wrapped with `nnx.remat` to reduce memory
	usage during backpropagation at the cost of recomputing intermediates.

	Args:
		state: Current state.
		input: Optional input.
		num_steps: Number of steps.
		input_in_axis: Axis for input if provided for each step.
		sow: Whether to sow intermediate values.

	Returns:
		Final state after `num_steps` applications of `_step`.

	"""

	def step_fn(cs: ComplexSystem, state: State, input: Input | None) -> State:
		return cs._step(state, input, sow=sow)

	if self.remat:
		step_fn = nnx.remat(step_fn)

	state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry})
	state = nnx.scan(
		step_fn,
		in_axes=(state_axes, nnx.Carry, input_in_axis),
		out_axes=nnx.Carry,
		length=num_steps,
	)(self, state, input)

	return state