Skip to content

Abelian Sandpile

cax.cs.sandpile.cs.Sandpile

Bases: ComplexSystem[Array, Array]

Abelian Sandpile model.

A discrete cellular automaton demonstrating self-organized criticality. The state is a grid of non-negative integers representing chip counts. When a cell reaches the critical threshold, it topples, distributing chips to neighbors. Cascading avalanches of topplings produce power-law distributed events.

Two boundary modes are supported
  • "CIRCULAR": periodic (toroidal) boundaries conserving total mass.
  • "OPEN": dissipative boundaries where sand falling off the edge is lost, which is required for proper self-organized criticality.
Source code in src/cax/cs/sandpile/cs.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
class Sandpile(ComplexSystem[Array, Array]):
	"""Abelian Sandpile model.

	A discrete cellular automaton demonstrating self-organized criticality. The state
	is a grid of non-negative integers representing chip counts. When a cell reaches
	the critical threshold, it topples, distributing chips to neighbors. Cascading
	avalanches of topplings produce power-law distributed events.

	Two boundary modes are supported:
		- "CIRCULAR": periodic (toroidal) boundaries conserving total mass.
		- "OPEN": dissipative boundaries where sand falling off the edge is lost,
			which is required for proper self-organized criticality.
	"""

	def __init__(
		self,
		*,
		num_spatial_dims: int = 2,
		threshold: int | None = None,
		padding: Literal["CIRCULAR", "OPEN"] = "CIRCULAR",
	):
		"""Initialize Sandpile.

		Args:
			num_spatial_dims: Number of spatial dimensions (default 2).
			threshold: Critical chip count for toppling. Defaults to
				2 * num_spatial_dims (4 in 2D, 6 in 3D).
			padding: Boundary condition mode. "CIRCULAR" for periodic boundaries,
				"OPEN" for dissipative boundaries (required for SOC).

		"""
		self.num_spatial_dims = num_spatial_dims
		self.threshold = threshold if threshold is not None else 2 * num_spatial_dims
		self.perceive = SandpilePerceive(
			num_spatial_dims=num_spatial_dims,
			padding=padding,
		)
		self.update = SandpileUpdate(
			num_spatial_dims=num_spatial_dims,
			threshold=self.threshold,
		)

	def _step(self, state: Array, input: Array | None = None, *, sow: bool = False) -> Array:
		perception = self.perceive(state)
		next_state = self.update(state, perception, input)

		if sow:
			self.sow(nnx.Intermediate, "state", next_state)

		return next_state

	@nnx.jit
	def render(self, state: Array) -> Array:
		"""Render state to RGB image.

		Maps chip counts to distinct colors using the classic sandpile palette:
		0 chips → dark blue, 1 chip → cyan, 2 chips → yellow, 3 chips → orange.
		Cells at or above the critical threshold are rendered in red.

		Args:
			state: Array with shape (..., *spatial_dims, 1) containing integer chip
				counts stored as float32.

		Returns:
			RGB image with dtype uint8 and shape (..., *spatial_dims, 3).

		"""
		chips = state[..., 0]

		r = jnp.where(
			chips == 0,
			0.1,
			jnp.where(chips == 1, 0.0, jnp.where(chips == 2, 0.9, jnp.where(chips == 3, 1.0, 0.8))),
		)
		g = jnp.where(
			chips == 0,
			0.1,
			jnp.where(chips == 1, 0.7, jnp.where(chips == 2, 0.9, jnp.where(chips == 3, 0.5, 0.0))),
		)
		b = jnp.where(
			chips == 0,
			0.4,
			jnp.where(chips == 1, 0.8, jnp.where(chips == 2, 0.1, jnp.where(chips == 3, 0.0, 0.0))),
		)

		rgb = jnp.stack([r, g, b], axis=-1)
		return clip_and_uint8(rgb)

__init__(*, num_spatial_dims=2, threshold=None, padding='CIRCULAR')

Initialize Sandpile.

Parameters:

Name Type Description Default
num_spatial_dims int

Number of spatial dimensions (default 2).

2
threshold int | None

Critical chip count for toppling. Defaults to 2 * num_spatial_dims (4 in 2D, 6 in 3D).

None
padding Literal['CIRCULAR', 'OPEN']

Boundary condition mode. "CIRCULAR" for periodic boundaries, "OPEN" for dissipative boundaries (required for SOC).

'CIRCULAR'
Source code in src/cax/cs/sandpile/cs.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(
	self,
	*,
	num_spatial_dims: int = 2,
	threshold: int | None = None,
	padding: Literal["CIRCULAR", "OPEN"] = "CIRCULAR",
):
	"""Initialize Sandpile.

	Args:
		num_spatial_dims: Number of spatial dimensions (default 2).
		threshold: Critical chip count for toppling. Defaults to
			2 * num_spatial_dims (4 in 2D, 6 in 3D).
		padding: Boundary condition mode. "CIRCULAR" for periodic boundaries,
			"OPEN" for dissipative boundaries (required for SOC).

	"""
	self.num_spatial_dims = num_spatial_dims
	self.threshold = threshold if threshold is not None else 2 * num_spatial_dims
	self.perceive = SandpilePerceive(
		num_spatial_dims=num_spatial_dims,
		padding=padding,
	)
	self.update = SandpileUpdate(
		num_spatial_dims=num_spatial_dims,
		threshold=self.threshold,
	)

render(state)

Render state to RGB image.

Maps chip counts to distinct colors using the classic sandpile palette: 0 chips → dark blue, 1 chip → cyan, 2 chips → yellow, 3 chips → orange. Cells at or above the critical threshold are rendered in red.

Parameters:

Name Type Description Default
state Array

Array with shape (..., *spatial_dims, 1) containing integer chip counts stored as float32.

required

Returns:

Type Description
Array

RGB image with dtype uint8 and shape (..., *spatial_dims, 3).

Source code in src/cax/cs/sandpile/cs.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
@nnx.jit
def render(self, state: Array) -> Array:
	"""Render state to RGB image.

	Maps chip counts to distinct colors using the classic sandpile palette:
	0 chips → dark blue, 1 chip → cyan, 2 chips → yellow, 3 chips → orange.
	Cells at or above the critical threshold are rendered in red.

	Args:
		state: Array with shape (..., *spatial_dims, 1) containing integer chip
			counts stored as float32.

	Returns:
		RGB image with dtype uint8 and shape (..., *spatial_dims, 3).

	"""
	chips = state[..., 0]

	r = jnp.where(
		chips == 0,
		0.1,
		jnp.where(chips == 1, 0.0, jnp.where(chips == 2, 0.9, jnp.where(chips == 3, 1.0, 0.8))),
	)
	g = jnp.where(
		chips == 0,
		0.1,
		jnp.where(chips == 1, 0.7, jnp.where(chips == 2, 0.9, jnp.where(chips == 3, 0.5, 0.0))),
	)
	b = jnp.where(
		chips == 0,
		0.4,
		jnp.where(chips == 1, 0.8, jnp.where(chips == 2, 0.1, jnp.where(chips == 3, 0.0, 0.0))),
	)

	rgb = jnp.stack([r, g, b], axis=-1)
	return clip_and_uint8(rgb)

__call__(state, input=None, *, num_steps=1, input_in_axis=None, sow=False)

Step the system for multiple time steps.

This method wraps _step inside a JAX scan for efficiency and JIT-compiles the loop. If input is time-varying, set input_in_axis to the axis containing the time dimension so that each step receives the corresponding slice of input.

When remat is enabled, the scan body is wrapped with nnx.remat to reduce memory usage during backpropagation at the cost of recomputing intermediates.

Parameters:

Name Type Description Default
state State

Current state.

required
input Input | None

Optional input.

None
num_steps int

Number of steps.

1
input_in_axis int | None

Axis for input if provided for each step.

None
sow bool

Whether to sow intermediate values.

False

Returns:

Type Description
State

Final state after num_steps applications of _step.

Source code in src/cax/core/cs.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@nnx.jit(static_argnames=("num_steps", "input_in_axis", "sow"))
def __call__(
	self,
	state: State,
	input: Input | None = None,
	*,
	num_steps: int = 1,
	input_in_axis: int | None = None,
	sow: bool = False,
) -> State:
	"""Step the system for multiple time steps.

	This method wraps `_step` inside a JAX scan for efficiency and JIT-compiles the loop.
	If `input` is time-varying, set `input_in_axis` to the axis containing the time
	dimension so that each step receives the corresponding slice of input.

	When `remat` is enabled, the scan body is wrapped with `nnx.remat` to reduce memory
	usage during backpropagation at the cost of recomputing intermediates.

	Args:
		state: Current state.
		input: Optional input.
		num_steps: Number of steps.
		input_in_axis: Axis for input if provided for each step.
		sow: Whether to sow intermediate values.

	Returns:
		Final state after `num_steps` applications of `_step`.

	"""

	def step_fn(cs: ComplexSystem, state: State, input: Input | None) -> State:
		return cs._step(state, input, sow=sow)

	if self.remat:
		step_fn = nnx.remat(step_fn)

	state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry})
	state = nnx.scan(
		step_fn,
		in_axes=(state_axes, nnx.Carry, input_in_axis),
		out_axes=nnx.Carry,
		length=num_steps,
	)(self, state, input)

	return state