Elementary Cellular Automata
¶
Installation¶
You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:
In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"
Then, install CAX from PyPi:
In [ ]:
Copied!
%pip install -U "cax[examples]"
%pip install -U "cax[examples]"
Import¶
In [ ]:
Copied!
import jax.numpy as jnp
import mediapy
from flax import nnx
from cax.cs.elementary import Elementary
import jax.numpy as jnp
import mediapy
from flax import nnx
from cax.cs.elementary import Elementary
Configuration¶
In [ ]:
Copied!
seed = 0
num_steps = 512
spatial_dims = (1_024,)
wolfram_code_int = 110 # Rule 110
rngs = nnx.Rngs(seed)
seed = 0
num_steps = 512
spatial_dims = (1_024,)
wolfram_code_int = 110 # Rule 110
rngs = nnx.Rngs(seed)
Instantiate system¶
In [ ]:
Copied!
wolfram_code = Elementary.wolfram_code_from_rule_number(wolfram_code_int)
wolfram_code
wolfram_code = Elementary.wolfram_code_from_rule_number(wolfram_code_int)
wolfram_code
In [ ]:
Copied!
cs = Elementary(wolfram_code=wolfram_code, rngs=rngs)
cs = Elementary(wolfram_code=wolfram_code, rngs=rngs)
Sample initial state¶
In [ ]:
Copied!
def sample_state():
"""Sample a state with a single active cell."""
state = jnp.zeros((*spatial_dims, 1))
return state.at[spatial_dims[0] // 2].set(1.0)
def sample_state():
"""Sample a state with a single active cell."""
state = jnp.zeros((*spatial_dims, 1))
return state.at[spatial_dims[0] // 2].set(1.0)
Run¶
In [ ]:
Copied!
state_init = sample_state()
state_final = cs(state_init, num_steps=num_steps, sow=True)
state_init = sample_state()
state_final = cs(state_init, num_steps=num_steps, sow=True)
Visualize¶
In [ ]:
Copied!
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state.value[0]
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state.value[0]
In [ ]:
Copied!
states = jnp.concatenate([state_init[None], states])
frame = cs.render(states)
mediapy.show_image(frame)
states = jnp.concatenate([state_init[None], states])
frame = cs.render(states)
mediapy.show_image(frame)