Langton's Ant
¶
Installation¶
You will need Python 3.12 or later, and a working JAX installation. For example, you can install JAX with:
In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"
Then, install CAX from PyPi:
In [ ]:
Copied!
%pip install -U "cax[examples]"
%pip install -U "cax[examples]"
Import¶
In [ ]:
Copied!
import jax
import jax.numpy as jnp
import mediapy
from flax import nnx
from cax.cs.langton_ant import LangtonAnt, LangtonAntState
import jax
import jax.numpy as jnp
import mediapy
from flax import nnx
from cax.cs.langton_ant import LangtonAnt, LangtonAntState
Configuration¶
In [ ]:
Copied!
seed = 0
num_steps = 11_000
spatial_dims = (64, 64)
rule_string = "RL" # Classic Langton's Ant
rngs = nnx.Rngs(seed)
seed = 0
num_steps = 11_000
spatial_dims = (64, 64)
rule_string = "RL" # Classic Langton's Ant
rngs = nnx.Rngs(seed)
Instantiate system¶
In [ ]:
Copied!
turns = LangtonAnt.turns_from_rule_string(rule_string)
turns
turns = LangtonAnt.turns_from_rule_string(rule_string)
turns
In [ ]:
Copied!
cs = LangtonAnt(turns=turns, rngs=rngs)
cs = LangtonAnt(turns=turns, rngs=rngs)
Sample initial state¶
In [ ]:
Copied!
def sample_state():
"""Sample a state with the ant at the center of an empty grid."""
grid = jnp.zeros((*spatial_dims, 1))
position = jnp.array([spatial_dims[0] // 2, spatial_dims[1] // 2], dtype=jnp.float32)
direction = jnp.array(0.0, dtype=jnp.float32) # North
return LangtonAntState(grid=grid, position=position, direction=direction)
def sample_state():
"""Sample a state with the ant at the center of an empty grid."""
grid = jnp.zeros((*spatial_dims, 1))
position = jnp.array([spatial_dims[0] // 2, spatial_dims[1] // 2], dtype=jnp.float32)
direction = jnp.array(0.0, dtype=jnp.float32) # North
return LangtonAntState(grid=grid, position=position, direction=direction)
Run¶
In [ ]:
Copied!
state_init = sample_state()
state_final = cs(state_init, num_steps=num_steps, sow=True)
state_init = sample_state()
state_final = cs(state_init, num_steps=num_steps, sow=True)
Visualize¶
In [ ]:
Copied!
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state[0]
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state[0]
In [ ]:
Copied!
states = jax.tree.map(
lambda state_init, states: jnp.concatenate([state_init[None], states]),
state_init,
states,
)
frames = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs, states)
mediapy.show_video(frames, width=256, height=256)
states = jax.tree.map(
lambda state_init, states: jnp.concatenate([state_init[None], states]),
state_init,
states,
)
frames = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs, states)
mediapy.show_video(frames, width=256, height=256)
Langton's Ant Family¶
You can experiment with other generalized Langton's Ants by changing the rule string.
LLRR¶
In [ ]:
Copied!
turns = LangtonAnt.turns_from_rule_string("LLRR")
cs = LangtonAnt(turns=turns, rngs=rngs)
turns = LangtonAnt.turns_from_rule_string("LLRR")
cs = LangtonAnt(turns=turns, rngs=rngs)
LRRRRRLLR¶
In [ ]:
Copied!
turns = LangtonAnt.turns_from_rule_string("LRRRRRLLR")
cs = LangtonAnt(turns=turns, rngs=rngs)
turns = LangtonAnt.turns_from_rule_string("LRRRRRLLR")
cs = LangtonAnt(turns=turns, rngs=rngs)
Run¶
In [ ]:
Copied!
state_init = sample_state()
state_final = cs(state_init, num_steps=num_steps, sow=True)
state_init = sample_state()
state_final = cs(state_init, num_steps=num_steps, sow=True)
Visualize¶
In [ ]:
Copied!
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state[0]
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state[0]
In [ ]:
Copied!
states = jax.tree.map(
lambda state_init, states: jnp.concatenate([state_init[None], states]),
state_init,
states,
)
frames = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs, states)
mediapy.show_video(frames, width=256, height=256)
states = jax.tree.map(
lambda state_init, states: jnp.concatenate([state_init[None], states]),
state_init,
states,
)
frames = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs, states)
mediapy.show_video(frames, width=256, height=256)