Skip to content

Reaction-Diffusion (Gray-Scott)

cax.cs.reaction_diffusion.cs.ReactionDiffusion

Bases: ComplexSystem[Array, Array]

Gray-Scott reaction-diffusion system.

A continuous cellular automaton modeling two chemical species (U and V) that diffuse and react on a grid. The dynamics follow: dU/dt = D_u * lap(U) - UV^2 + f(1 - U) dV/dt = D_v * lap(V) + UV^2 - (f + k)V

Different parameter regimes (feed rate f and kill rate k) produce diverse pattern types: spots, stripes, waves, mitosis, and more.

Source code in src/cax/cs/reaction_diffusion/cs.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class ReactionDiffusion(ComplexSystem[Array, Array]):
	"""Gray-Scott reaction-diffusion system.

	A continuous cellular automaton modeling two chemical species (U and V) that
	diffuse and react on a grid. The dynamics follow:
		dU/dt = D_u * lap(U) - U*V^2 + f*(1 - U)
		dV/dt = D_v * lap(V) + U*V^2 - (f + k)*V

	Different parameter regimes (feed rate f and kill rate k) produce diverse
	pattern types: spots, stripes, waves, mitosis, and more.
	"""

	def __init__(
		self,
		*,
		num_spatial_dims: int = 2,
		diffusion_rate_u: float = 0.16,
		diffusion_rate_v: float = 0.08,
		feed_rate: float = 0.06,
		kill_rate: float = 0.062,
		dt: float = 1.0,
		rngs: nnx.Rngs,
	):
		"""Initialize Reaction-Diffusion.

		Args:
			num_spatial_dims: Number of spatial dimensions (default 2).
			diffusion_rate_u: Diffusion coefficient for species U.
			diffusion_rate_v: Diffusion coefficient for species V.
			feed_rate: Feed rate f — controls how quickly U is replenished.
			kill_rate: Kill rate k — controls how quickly V is removed.
			dt: Time step size for the Euler integration.
			rngs: RNG key.

		"""
		self.perceive = ReactionDiffusionPerceive(num_spatial_dims=num_spatial_dims, rngs=rngs)
		self.update = ReactionDiffusionUpdate(
			diffusion_rate_u=diffusion_rate_u,
			diffusion_rate_v=diffusion_rate_v,
			feed_rate=feed_rate,
			kill_rate=kill_rate,
			dt=dt,
		)

	def _step(self, state: Array, input: Array | None = None, *, sow: bool = False) -> Array:
		perception = self.perceive(state)
		next_state = self.update(state, perception, input)

		if sow:
			self.sow(nnx.Intermediate, "state", next_state)

		return next_state

	@nnx.jit
	def render(self, state: Array) -> Array:
		"""Render state to RGB image.

		Maps the two-species state to an RGB visualization. Species V concentration
		is used as the primary visual signal: high V appears as colored regions against
		a background determined by U.

		Args:
			state: Array with shape (..., *spatial_dims, 2) where channel 0 is U
				concentration and channel 1 is V concentration, both in [0, 1].

		Returns:
			RGB image with dtype uint8 and shape (..., *spatial_dims, 3).

		"""
		u = state[..., 0:1]
		v = state[..., 1:2]

		r = 1.0 - v
		g = 1.0 - 0.5 * v - 0.5 * u
		b = u

		rgb = jnp.concatenate([r, g, b], axis=-1)
		return clip_and_uint8(rgb)

__init__(*, num_spatial_dims=2, diffusion_rate_u=0.16, diffusion_rate_v=0.08, feed_rate=0.06, kill_rate=0.062, dt=1.0, rngs)

Initialize Reaction-Diffusion.

Parameters:

Name Type Description Default
num_spatial_dims int

Number of spatial dimensions (default 2).

2
diffusion_rate_u float

Diffusion coefficient for species U.

0.16
diffusion_rate_v float

Diffusion coefficient for species V.

0.08
feed_rate float

Feed rate f — controls how quickly U is replenished.

0.06
kill_rate float

Kill rate k — controls how quickly V is removed.

0.062
dt float

Time step size for the Euler integration.

1.0
rngs Rngs

RNG key.

required
Source code in src/cax/cs/reaction_diffusion/cs.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(
	self,
	*,
	num_spatial_dims: int = 2,
	diffusion_rate_u: float = 0.16,
	diffusion_rate_v: float = 0.08,
	feed_rate: float = 0.06,
	kill_rate: float = 0.062,
	dt: float = 1.0,
	rngs: nnx.Rngs,
):
	"""Initialize Reaction-Diffusion.

	Args:
		num_spatial_dims: Number of spatial dimensions (default 2).
		diffusion_rate_u: Diffusion coefficient for species U.
		diffusion_rate_v: Diffusion coefficient for species V.
		feed_rate: Feed rate f — controls how quickly U is replenished.
		kill_rate: Kill rate k — controls how quickly V is removed.
		dt: Time step size for the Euler integration.
		rngs: RNG key.

	"""
	self.perceive = ReactionDiffusionPerceive(num_spatial_dims=num_spatial_dims, rngs=rngs)
	self.update = ReactionDiffusionUpdate(
		diffusion_rate_u=diffusion_rate_u,
		diffusion_rate_v=diffusion_rate_v,
		feed_rate=feed_rate,
		kill_rate=kill_rate,
		dt=dt,
	)

render(state)

Render state to RGB image.

Maps the two-species state to an RGB visualization. Species V concentration is used as the primary visual signal: high V appears as colored regions against a background determined by U.

Parameters:

Name Type Description Default
state Array

Array with shape (..., *spatial_dims, 2) where channel 0 is U concentration and channel 1 is V concentration, both in [0, 1].

required

Returns:

Type Description
Array

RGB image with dtype uint8 and shape (..., *spatial_dims, 3).

Source code in src/cax/cs/reaction_diffusion/cs.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@nnx.jit
def render(self, state: Array) -> Array:
	"""Render state to RGB image.

	Maps the two-species state to an RGB visualization. Species V concentration
	is used as the primary visual signal: high V appears as colored regions against
	a background determined by U.

	Args:
		state: Array with shape (..., *spatial_dims, 2) where channel 0 is U
			concentration and channel 1 is V concentration, both in [0, 1].

	Returns:
		RGB image with dtype uint8 and shape (..., *spatial_dims, 3).

	"""
	u = state[..., 0:1]
	v = state[..., 1:2]

	r = 1.0 - v
	g = 1.0 - 0.5 * v - 0.5 * u
	b = u

	rgb = jnp.concatenate([r, g, b], axis=-1)
	return clip_and_uint8(rgb)

__call__(state, input=None, *, num_steps=1, input_in_axis=None, sow=False)

Step the system for multiple time steps.

This method wraps _step inside a JAX scan for efficiency and JIT-compiles the loop. If input is time-varying, set input_in_axis to the axis containing the time dimension so that each step receives the corresponding slice of input.

When remat is enabled, the scan body is wrapped with nnx.remat to reduce memory usage during backpropagation at the cost of recomputing intermediates.

Parameters:

Name Type Description Default
state State

Current state.

required
input Input | None

Optional input.

None
num_steps int

Number of steps.

1
input_in_axis int | None

Axis for input if provided for each step.

None
sow bool

Whether to sow intermediate values.

False

Returns:

Type Description
State

Final state after num_steps applications of _step.

Source code in src/cax/core/cs.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@nnx.jit(static_argnames=("num_steps", "input_in_axis", "sow"))
def __call__(
	self,
	state: State,
	input: Input | None = None,
	*,
	num_steps: int = 1,
	input_in_axis: int | None = None,
	sow: bool = False,
) -> State:
	"""Step the system for multiple time steps.

	This method wraps `_step` inside a JAX scan for efficiency and JIT-compiles the loop.
	If `input` is time-varying, set `input_in_axis` to the axis containing the time
	dimension so that each step receives the corresponding slice of input.

	When `remat` is enabled, the scan body is wrapped with `nnx.remat` to reduce memory
	usage during backpropagation at the cost of recomputing intermediates.

	Args:
		state: Current state.
		input: Optional input.
		num_steps: Number of steps.
		input_in_axis: Axis for input if provided for each step.
		sow: Whether to sow intermediate values.

	Returns:
		Final state after `num_steps` applications of `_step`.

	"""

	def step_fn(cs: ComplexSystem, state: State, input: Input | None) -> State:
		return cs._step(state, input, sow=sow)

	if self.remat:
		step_fn = nnx.remat(step_fn)

	state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry})
	state = nnx.scan(
		step_fn,
		in_axes=(state_axes, nnx.Carry, input_in_axis),
		out_axes=nnx.Carry,
		length=num_steps,
	)(self, state, input)

	return state