Abelian Sandpile
¶
Installation¶
You will need Python 3.12 or later, and a working JAX installation. For example, you can install JAX with:
In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"
Then, install CAX from PyPi:
In [ ]:
Copied!
%pip install -U "cax[examples]"
%pip install -U "cax[examples]"
Import¶
In [ ]:
Copied!
import jax
import jax.numpy as jnp
import mediapy
from flax import nnx
from cax.cs.sandpile import Sandpile
import jax
import jax.numpy as jnp
import mediapy
from flax import nnx
from cax.cs.sandpile import Sandpile
Configuration¶
In [ ]:
Copied!
seed = 0
num_steps = 1024
spatial_dims = (64, 64)
initial_chips = 1024
seed = 0
num_steps = 1024
spatial_dims = (64, 64)
initial_chips = 1024
Instantiate system¶
In [ ]:
Copied!
cs = Sandpile()
cs = Sandpile()
Sample initial state¶
In [ ]:
Copied!
def sample_state():
"""Sample a state with all chips stacked on the center cell."""
state = jnp.zeros((*spatial_dims, 1))
mid_x, mid_y = spatial_dims[0] // 2, spatial_dims[1] // 2
state = state.at[mid_x, mid_y, 0].set(initial_chips)
return state
def sample_state():
"""Sample a state with all chips stacked on the center cell."""
state = jnp.zeros((*spatial_dims, 1))
mid_x, mid_y = spatial_dims[0] // 2, spatial_dims[1] // 2
state = state.at[mid_x, mid_y, 0].set(initial_chips)
return state
Run¶
In [ ]:
Copied!
state_init = sample_state()
state_final = cs(state_init, num_steps=num_steps, sow=True)
state_init = sample_state()
state_final = cs(state_init, num_steps=num_steps, sow=True)
Visualize¶
In [ ]:
Copied!
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state[0]
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state[0]
In [ ]:
Copied!
states = jnp.concatenate([state_init[None], states])
frames = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs, states)
mediapy.show_video(frames, width=256, height=256)
states = jnp.concatenate([state_init[None], states])
frames = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs, states)
mediapy.show_video(frames, width=256, height=256)
Self-Organized Criticality¶
Instead of starting with a large pile and letting it relax, we can drive the system by dropping one grain of sand at a random location each step. The system self-organizes to a critical state where avalanche sizes follow a power law.
In [ ]:
Copied!
num_steps_driven = 4096
spatial_dims_driven = (32, 32)
num_steps_driven = 4096
spatial_dims_driven = (32, 32)
In [ ]:
Copied!
key = jax.random.key(seed)
key1, key2 = jax.random.split(key)
drop_x = jax.random.randint(key1, (num_steps_driven,), 0, spatial_dims_driven[0])
drop_y = jax.random.randint(key2, (num_steps_driven,), 0, spatial_dims_driven[1])
inputs = jnp.zeros((num_steps_driven, *spatial_dims_driven, 1))
inputs = inputs.at[jnp.arange(num_steps_driven), drop_x, drop_y, 0].set(1.0)
key = jax.random.key(seed)
key1, key2 = jax.random.split(key)
drop_x = jax.random.randint(key1, (num_steps_driven,), 0, spatial_dims_driven[0])
drop_y = jax.random.randint(key2, (num_steps_driven,), 0, spatial_dims_driven[1])
inputs = jnp.zeros((num_steps_driven, *spatial_dims_driven, 1))
inputs = inputs.at[jnp.arange(num_steps_driven), drop_x, drop_y, 0].set(1.0)
In [ ]:
Copied!
cs_driven = Sandpile(padding="OPEN")
state_init_driven = jnp.zeros((*spatial_dims_driven, 1))
state_final_driven = cs_driven(
state_init_driven, input=inputs, num_steps=num_steps_driven, input_in_axis=0, sow=True
)
cs_driven = Sandpile(padding="OPEN")
state_init_driven = jnp.zeros((*spatial_dims_driven, 1))
state_final_driven = cs_driven(
state_init_driven, input=inputs, num_steps=num_steps_driven, input_in_axis=0, sow=True
)
In [ ]:
Copied!
intermediates = nnx.pop(cs_driven, nnx.Intermediate)
states = intermediates.state[0]
intermediates = nnx.pop(cs_driven, nnx.Intermediate)
states = intermediates.state[0]
In [ ]:
Copied!
states = jnp.concatenate([state_init_driven[None], states])
frames = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs_driven, states)
mediapy.show_video(frames, width=256, height=256)
states = jnp.concatenate([state_init_driven[None], states])
frames = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs_driven, states)
mediapy.show_video(frames, width=256, height=256)