Conway's Game of Life
¶
Installation¶
You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:
In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"
Then, install CAX from PyPi:
In [ ]:
Copied!
%pip install -U "cax[examples]"
%pip install -U "cax[examples]"
Import¶
In [ ]:
Copied!
import jax.numpy as jnp
import mediapy
from flax import nnx
from cax.cs.life import Life
import jax.numpy as jnp
import mediapy
from flax import nnx
from cax.cs.life import Life
Configuration¶
In [ ]:
Copied!
seed = 0
num_steps = 128
spatial_dims = (32, 32)
rule_golly = "B3/S23" # Conway's Game of Life
rngs = nnx.Rngs(seed)
seed = 0
num_steps = 128
spatial_dims = (32, 32)
rule_golly = "B3/S23" # Conway's Game of Life
rngs = nnx.Rngs(seed)
Instantiate system¶
In [ ]:
Copied!
birth, survival = Life.birth_survival_from_string(rule_golly)
birth, survival
birth, survival = Life.birth_survival_from_string(rule_golly)
birth, survival
In [ ]:
Copied!
cs = Life(birth=birth, survival=survival, rngs=rngs)
cs = Life(birth=birth, survival=survival, rngs=rngs)
Sample initial state¶
In [ ]:
Copied!
def sample_state():
"""Sample a state with a glider for the Game of Life."""
state = jnp.zeros((*spatial_dims, 1))
mid_x, mid_y = spatial_dims[0] // 2, spatial_dims[1] // 2
glider = jnp.array(
[
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[1.0, 1.0, 1.0],
]
)
return state.at[mid_x : mid_x + 3, mid_y : mid_y + 3, 0].set(glider)
def sample_state():
"""Sample a state with a glider for the Game of Life."""
state = jnp.zeros((*spatial_dims, 1))
mid_x, mid_y = spatial_dims[0] // 2, spatial_dims[1] // 2
glider = jnp.array(
[
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[1.0, 1.0, 1.0],
]
)
return state.at[mid_x : mid_x + 3, mid_y : mid_y + 3, 0].set(glider)
Run¶
In [ ]:
Copied!
state_init = sample_state()
state_final = cs(state_init, num_steps=num_steps, sow=True)
state_init = sample_state()
state_final = cs(state_init, num_steps=num_steps, sow=True)
Visualize¶
In [ ]:
Copied!
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state.value[0]
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state.value[0]
In [ ]:
Copied!
states = jnp.concatenate([state_init[None], states])
frames = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs, states)
mediapy.show_video(frames, width=256, height=256, codec="gif")
states = jnp.concatenate([state_init[None], states])
frames = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs, states)
mediapy.show_video(frames, width=256, height=256, codec="gif")
Life Family¶
You can experiment with other Life-like Cellular Automata by changing the rule.
In [ ]:
Copied!
birth, survival = Life.birth_survival_from_string("B36/S23")
cs = Life(birth=birth, survival=survival, rngs=rngs)
birth, survival = Life.birth_survival_from_string("B36/S23")
cs = Life(birth=birth, survival=survival, rngs=rngs)
In [ ]:
Copied!
birth, survival = Life.birth_survival_from_string("B3/S012345678")
cs = Life(birth=birth, survival=survival, rngs=rngs)
birth, survival = Life.birth_survival_from_string("B3/S012345678")
cs = Life(birth=birth, survival=survival, rngs=rngs)
Run¶
In [ ]:
Copied!
state_init = sample_state()
state_final = cs(state_init, num_steps=num_steps, sow=True)
state_init = sample_state()
state_final = cs(state_init, num_steps=num_steps, sow=True)
Visualize¶
In [ ]:
Copied!
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state.value[0]
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state.value[0]
In [ ]:
Copied!
states = jnp.concatenate([state_init[None], states])
frames = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs, states)
mediapy.show_video(frames, width=256, height=256, codec="gif")
states = jnp.concatenate([state_init[None], states])
frames = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs, states)
mediapy.show_video(frames, width=256, height=256, codec="gif")