Particle Lenia
¶
Installation¶
You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:
In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"
Then, install CAX from PyPi:
In [ ]:
Copied!
%pip install -U "cax[examples]"
%pip install -U "cax[examples]"
Import¶
In [ ]:
Copied!
import jax
import jax.numpy as jnp
import mediapy
from flax import nnx
from cax.cs.particle_lenia import (
GrowthParams,
KernelParams,
ParticleLenia,
ParticleLeniaRuleParams,
bell,
)
import jax
import jax.numpy as jnp
import mediapy
from flax import nnx
from cax.cs.particle_lenia import (
GrowthParams,
KernelParams,
ParticleLenia,
ParticleLeniaRuleParams,
bell,
)
Configuration¶
In [ ]:
Copied!
seed = 0
num_steps = 8_192
num_spatial_dims = 2
num_particles = 200
T = 10
key = jax.random.key(seed)
rngs = nnx.Rngs(seed)
seed = 0
num_steps = 8_192
num_spatial_dims = 2
num_particles = 200
T = 10
key = jax.random.key(seed)
rngs = nnx.Rngs(seed)
Instantiate system¶
Rule parameters¶
In [ ]:
Copied!
mean = 4.0
std = 1.0
def compute_weight(mean, std, num_spatial_dims):
"""Compute weight for the kernel."""
r = jnp.linspace(max(mean - 4 * std, 0.0), mean + 4 * std, 51)
y = bell(r, mean, std) * r ** (num_spatial_dims - 1)
s = jnp.trapezoid(y, r) * {2: 2, 3: 4}[num_spatial_dims] * jnp.pi
return 1 / s
weight = compute_weight(mean, std, num_spatial_dims)
mean = 4.0
std = 1.0
def compute_weight(mean, std, num_spatial_dims):
"""Compute weight for the kernel."""
r = jnp.linspace(max(mean - 4 * std, 0.0), mean + 4 * std, 51)
y = bell(r, mean, std) * r ** (num_spatial_dims - 1)
s = jnp.trapezoid(y, r) * {2: 2, 3: 4}[num_spatial_dims] * jnp.pi
return 1 / s
weight = compute_weight(mean, std, num_spatial_dims)
In [ ]:
Copied!
kernel_params = KernelParams(
weight=weight,
mean=mean,
std=std,
)
growth_params = GrowthParams(
mean=0.6,
std=0.15,
)
rule_params = ParticleLeniaRuleParams(
c_rep=1.0,
kernel_params=kernel_params,
growth_params=growth_params,
)
kernel_params = KernelParams(
weight=weight,
mean=mean,
std=std,
)
growth_params = GrowthParams(
mean=0.6,
std=0.15,
)
rule_params = ParticleLeniaRuleParams(
c_rep=1.0,
kernel_params=kernel_params,
growth_params=growth_params,
)
In [ ]:
Copied!
cs = ParticleLenia(
num_spatial_dims=num_spatial_dims,
T=T,
rule_params=rule_params,
)
cs = ParticleLenia(
num_spatial_dims=num_spatial_dims,
T=T,
rule_params=rule_params,
)
Sample initial state¶
In [ ]:
Copied!
def sample_state(key):
"""Sample a state with random particule positions."""
state = 12.0 * (jax.random.uniform(key, (num_particles, num_spatial_dims)) - 0.5)
return state
def sample_state(key):
"""Sample a state with random particule positions."""
state = 12.0 * (jax.random.uniform(key, (num_particles, num_spatial_dims)) - 0.5)
return state
Run¶
In [ ]:
Copied!
key, subkey = jax.random.split(key)
state_init = sample_state(subkey)
state_final = cs(state_init, num_steps=num_steps, sow=True)
key, subkey = jax.random.split(key)
state_init = sample_state(subkey)
state_final = cs(state_init, num_steps=num_steps, sow=True)
Visualize¶
In [ ]:
Copied!
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state.value[0]
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state.value[0]
In [ ]:
Copied!
states = jnp.concatenate([state_init[None], states])
frames = nnx.vmap(
lambda cs, state: cs.render(state, resolution=512, particle_radius=0.3),
in_axes=(None, 0),
)(cs, states)
mediapy.show_video(frames, width=256, height=256, fps=600)
states = jnp.concatenate([state_init[None], states])
frames = nnx.vmap(
lambda cs, state: cs.render(state, resolution=512, particle_radius=0.3),
in_axes=(None, 0),
)(cs, states)
mediapy.show_video(frames, width=256, height=256, fps=600)