Reaction-Diffusion (Gray-Scott)
¶
Installation¶
You will need Python 3.12 or later, and a working JAX installation. For example, you can install JAX with:
In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"
Then, install CAX from PyPi:
In [ ]:
Copied!
%pip install -U "cax[examples]"
%pip install -U "cax[examples]"
Import¶
In [ ]:
Copied!
import jax
import jax.numpy as jnp
import mediapy
from flax import nnx
from cax.cs.reaction_diffusion import ReactionDiffusion
import jax
import jax.numpy as jnp
import mediapy
from flax import nnx
from cax.cs.reaction_diffusion import ReactionDiffusion
Configuration¶
In [ ]:
Copied!
seed = 0
num_steps = 10000
spatial_dims = (128, 128)
diffusion_rate_u = 0.16
diffusion_rate_v = 0.08
feed_rate = 0.06
kill_rate = 0.062
dt = 1.0
rngs = nnx.Rngs(seed)
seed = 0
num_steps = 10000
spatial_dims = (128, 128)
diffusion_rate_u = 0.16
diffusion_rate_v = 0.08
feed_rate = 0.06
kill_rate = 0.062
dt = 1.0
rngs = nnx.Rngs(seed)
Instantiate system¶
In [ ]:
Copied!
cs = ReactionDiffusion(
diffusion_rate_u=diffusion_rate_u,
diffusion_rate_v=diffusion_rate_v,
feed_rate=feed_rate,
kill_rate=kill_rate,
dt=dt,
rngs=rngs,
)
cs = ReactionDiffusion(
diffusion_rate_u=diffusion_rate_u,
diffusion_rate_v=diffusion_rate_v,
feed_rate=feed_rate,
kill_rate=kill_rate,
dt=dt,
rngs=rngs,
)
Sample initial state¶
In [ ]:
Copied!
def sample_state():
"""Sample a state with U≈1 everywhere and a small square seed of V in the center."""
key = jax.random.key(seed)
key1, key2 = jax.random.split(key)
random_influence = 0.2
u = (1.0 - random_influence) * jnp.ones(spatial_dims) + random_influence * jax.random.uniform(
key1, spatial_dims
)
v = random_influence * jax.random.uniform(key2, spatial_dims)
mid_x, mid_y = spatial_dims[0] // 2, spatial_dims[1] // 2
radius = spatial_dims[0] // 10
u = u.at[mid_x - radius : mid_x + radius, mid_y - radius : mid_y + radius].set(0.5)
v = v.at[mid_x - radius : mid_x + radius, mid_y - radius : mid_y + radius].set(0.25)
state = jnp.stack([u, v], axis=-1)
return state
def sample_state():
"""Sample a state with U≈1 everywhere and a small square seed of V in the center."""
key = jax.random.key(seed)
key1, key2 = jax.random.split(key)
random_influence = 0.2
u = (1.0 - random_influence) * jnp.ones(spatial_dims) + random_influence * jax.random.uniform(
key1, spatial_dims
)
v = random_influence * jax.random.uniform(key2, spatial_dims)
mid_x, mid_y = spatial_dims[0] // 2, spatial_dims[1] // 2
radius = spatial_dims[0] // 10
u = u.at[mid_x - radius : mid_x + radius, mid_y - radius : mid_y + radius].set(0.5)
v = v.at[mid_x - radius : mid_x + radius, mid_y - radius : mid_y + radius].set(0.25)
state = jnp.stack([u, v], axis=-1)
return state
Run¶
In [ ]:
Copied!
state_init = sample_state()
state_final = cs(state_init, num_steps=num_steps, sow=True)
state_init = sample_state()
state_final = cs(state_init, num_steps=num_steps, sow=True)
Visualize¶
In [ ]:
Copied!
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state[0]
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state[0]
In [ ]:
Copied!
states = jnp.concatenate([state_init[None], states])
frames = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs, states)
mediapy.show_video(frames, width=256, height=256)
states = jnp.concatenate([state_init[None], states])
frames = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs, states)
mediapy.show_video(frames, width=256, height=256)
Pattern Gallery¶
You can experiment with other Gray-Scott patterns by changing the feed rate and kill rate. Different parameter regimes produce diverse Turing patterns.
Bacteria (f=0.035, k=0.065)¶
In [ ]:
Copied!
cs = ReactionDiffusion(
diffusion_rate_u=0.14,
diffusion_rate_v=0.06,
feed_rate=0.035,
kill_rate=0.065,
rngs=rngs,
)
cs = ReactionDiffusion(
diffusion_rate_u=0.14,
diffusion_rate_v=0.06,
feed_rate=0.035,
kill_rate=0.065,
rngs=rngs,
)
Coral (f=0.062, k=0.063)¶
In [ ]:
Copied!
cs = ReactionDiffusion(
diffusion_rate_u=0.16,
diffusion_rate_v=0.08,
feed_rate=0.062,
kill_rate=0.063,
rngs=rngs,
)
cs = ReactionDiffusion(
diffusion_rate_u=0.16,
diffusion_rate_v=0.08,
feed_rate=0.062,
kill_rate=0.063,
rngs=rngs,
)
Run¶
In [ ]:
Copied!
state_init = sample_state()
state_final = cs(state_init, num_steps=num_steps, sow=True)
state_init = sample_state()
state_final = cs(state_init, num_steps=num_steps, sow=True)
Visualize¶
In [ ]:
Copied!
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state[0]
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state[0]
In [ ]:
Copied!
states = jnp.concatenate([state_init[None], states])
frames = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs, states)
mediapy.show_video(frames, width=256, height=256)
states = jnp.concatenate([state_init[None], states])
frames = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs, states)
mediapy.show_video(frames, width=256, height=256)