Skip to content

Particle Lenia

cax.cs.particle_lenia.cs.ParticleLenia

Bases: ComplexSystem

Particle Lenia class.

Source code in src/cax/cs/particle_lenia/cs.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
class ParticleLenia(ComplexSystem):
	"""Particle Lenia class."""

	def __init__(
		self,
		num_spatial_dims: int,
		*,
		T: int,
		kernel_fn: Callable = peak_kernel_fn,
		growth_fn: Callable = exponential_growth_fn,
		rule_params: ParticleLeniaRuleParams,
	):
		"""Initialize Particle Lenia.

		Args:
			num_spatial_dims: Number of spatial dimensions (e.g., 2 for 2D, 3 for 3D).
				Determines the dimensionality of particle positions and field computations.
			T: Time resolution controlling the temporal discretization. Higher values
				produce smoother temporal dynamics with smaller update steps.
			kernel_fn: Callable that computes pairwise kernel weights between particles
				based on their distance. Takes rule parameters and returns kernel values.
			growth_fn: Callable that maps kernel field values to growth field values.
				Defines how particles respond to their local neighborhood density.
			rule_params: Instance of ParticleLeniaRuleParams containing kernel and growth
				parameters such as radii, peak positions, widths, and heights.

		"""
		self.num_spatial_dims = num_spatial_dims
		self.perceive = ParticleLeniaPerceive(
			num_spatial_dims=num_spatial_dims,
			kernel_fn=kernel_fn,
			growth_fn=growth_fn,
			rule_params=rule_params,
		)
		self.update = ParticleLeniaUpdate(
			T=T,
		)

	def _step(self, state: State, input: Input | None = None, *, sow: bool = False) -> State:
		perception = self.perceive(state)
		next_state = self.update(state, perception, input)

		if sow:
			self.sow(nnx.Intermediate, "state", next_state)

		return next_state

	@partial(nnx.jit, static_argnames=("resolution", "extent", "particle_radius", "type"))
	def render(
		self,
		state: State,
		*,
		resolution: int = 512,
		extent: float = 15.0,
		particle_radius: float = 0.3,
		type: str = "UG",  # Options: "particles", "UG", "E"
	) -> Array:
		"""Render state to RGB image.

		Renders Particle Lenia state as particles optionally overlaid on field visualizations.
		Particles appear as blue circles. The background can show kernel field (U), growth field
		(G), or energy field (E) to visualize the underlying dynamics driving particle motion.
		Field visualizations use color mapping to represent field intensities across space.

		Args:
			state: Array of shape (num_particles, num_spatial_dims) containing particle positions
				in continuous space. Currently only 2D visualization is supported.
			resolution: Size of the output image in pixels for both width and height.
				Higher values produce smoother field gradients but increase computation cost.
			extent: Half-width of the viewing area in coordinate space. The view spans from
				-extent to +extent in each dimension. Adjust to zoom in or out on the particle
				system.
			particle_radius: Radius of each particle in coordinate space. Particles are drawn
				as smooth circles with anti-aliased edges.
			type: Visualization mode determining what fields to display:
				"particles": Only show particles on white background (default).
				"UG": Show particles overlaid on kernel (U) and growth (G) field visualization.
				"E": Show particles overlaid on energy field visualization.

		Returns:
			RGB image with dtype uint8 and shape (resolution, resolution, 3), showing particles
				and optionally the underlying field structure that drives their motion.

		"""
		assert self.num_spatial_dims == 2, "Particle Lenia only supports 2D visualization."

		# Create a grid of coordinates
		x = jnp.linspace(-extent, extent, resolution)
		y = jnp.linspace(-extent, extent, resolution)
		grid = jnp.stack(jnp.meshgrid(x, y), axis=-1)  # Shape: (resolution, resolution, 2)

		# Reshape grid for computation
		flat_grid = grid.reshape(-1, 2)

		# Vectorize the field computation over all grid points
		flat_E, flat_U, flat_G = jax.vmap(self.perceive.compute_fields, in_axes=(None, 0))(
			state, flat_grid
		)

		# Reshape back to grid
		E_field = flat_E.reshape(resolution, resolution)
		U_field = flat_U.reshape(resolution, resolution)
		G_field = flat_G.reshape(resolution, resolution)

		# Helper functions for colormapping
		def lerp(x: Array, a: Array, b: Array) -> Array:
			return a * (1.0 - x) + b * x

		def cmap_e(e: Array) -> Array:
			stacked = jnp.stack([e, -e], -1).clip(0)
			colors = jnp.array([[0.3, 1.0, 1.0], [1.0, 0.3, 1.0]], dtype=jnp.float32)
			return 1.0 - jnp.matmul(stacked, colors)

		def cmap_ug(u: Array, g: Array) -> Array:
			vis = lerp(u[..., None], jnp.array([0.1, 0.1, 0.3]), jnp.array([0.2, 0.7, 1.0]))
			return lerp(g[..., None], vis, jnp.array([1.17, 0.91, 0.13]))

		# Calculate particle mask
		distance_sq = jnp.sum(jnp.square(grid[:, :, None, :] - state[None, None, :, :]), axis=-1)
		distance_sq_min = jnp.min(distance_sq, axis=-1)
		particle_mask = jnp.clip(1.0 - distance_sq_min / (particle_radius**2), 0.0, 1.0)

		# Normalize fields for visualization
		_ = (E_field - jnp.min(E_field)) / (jnp.max(E_field) - jnp.min(E_field) + 1e-8)  # E_norm
		U_norm = (U_field - jnp.min(U_field)) / (jnp.max(U_field) - jnp.min(U_field) + 1e-8)
		G_norm = (G_field - jnp.min(G_field)) / (jnp.max(G_field) - jnp.min(G_field) + 1e-8)

		# Create visualizations
		vis_e = cmap_e(E_field)
		vis_ug = cmap_ug(U_norm, G_norm)

		# Apply particle mask
		particle_mask = particle_mask[:, :, None]

		# Create base particle visualization (blue particles on white background)
		vis_particle = jnp.ones((resolution, resolution, 3))
		vis_particle = (
			vis_particle * (1.0 - particle_mask) + jnp.array([0.0, 0.0, 1.0]) * particle_mask
		)

		# Choose visualization based on type
		if type == "UG":
			# Blend particles with UG field
			rgb = vis_ug * (1.0 - particle_mask * 0.7) + vis_particle * (particle_mask * 0.7)
		elif type == "E":
			# Blend particles with E field
			rgb = vis_e * (1.0 - particle_mask * 0.7) + vis_particle * (particle_mask * 0.7)
		else:  # "particles" (default)
			# Just show particles
			rgb = vis_particle

		return clip_and_uint8(rgb)

__init__(num_spatial_dims, *, T, kernel_fn=peak_kernel_fn, growth_fn=exponential_growth_fn, rule_params)

Initialize Particle Lenia.

Parameters:

Name Type Description Default
num_spatial_dims int

Number of spatial dimensions (e.g., 2 for 2D, 3 for 3D). Determines the dimensionality of particle positions and field computations.

required
T int

Time resolution controlling the temporal discretization. Higher values produce smoother temporal dynamics with smaller update steps.

required
kernel_fn Callable

Callable that computes pairwise kernel weights between particles based on their distance. Takes rule parameters and returns kernel values.

peak_kernel_fn
growth_fn Callable

Callable that maps kernel field values to growth field values. Defines how particles respond to their local neighborhood density.

exponential_growth_fn
rule_params ParticleLeniaRuleParams

Instance of ParticleLeniaRuleParams containing kernel and growth parameters such as radii, peak positions, widths, and heights.

required
Source code in src/cax/cs/particle_lenia/cs.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def __init__(
	self,
	num_spatial_dims: int,
	*,
	T: int,
	kernel_fn: Callable = peak_kernel_fn,
	growth_fn: Callable = exponential_growth_fn,
	rule_params: ParticleLeniaRuleParams,
):
	"""Initialize Particle Lenia.

	Args:
		num_spatial_dims: Number of spatial dimensions (e.g., 2 for 2D, 3 for 3D).
			Determines the dimensionality of particle positions and field computations.
		T: Time resolution controlling the temporal discretization. Higher values
			produce smoother temporal dynamics with smaller update steps.
		kernel_fn: Callable that computes pairwise kernel weights between particles
			based on their distance. Takes rule parameters and returns kernel values.
		growth_fn: Callable that maps kernel field values to growth field values.
			Defines how particles respond to their local neighborhood density.
		rule_params: Instance of ParticleLeniaRuleParams containing kernel and growth
			parameters such as radii, peak positions, widths, and heights.

	"""
	self.num_spatial_dims = num_spatial_dims
	self.perceive = ParticleLeniaPerceive(
		num_spatial_dims=num_spatial_dims,
		kernel_fn=kernel_fn,
		growth_fn=growth_fn,
		rule_params=rule_params,
	)
	self.update = ParticleLeniaUpdate(
		T=T,
	)

render(state, *, resolution=512, extent=15.0, particle_radius=0.3, type='UG')

Render state to RGB image.

Renders Particle Lenia state as particles optionally overlaid on field visualizations. Particles appear as blue circles. The background can show kernel field (U), growth field (G), or energy field (E) to visualize the underlying dynamics driving particle motion. Field visualizations use color mapping to represent field intensities across space.

Parameters:

Name Type Description Default
state State

Array of shape (num_particles, num_spatial_dims) containing particle positions in continuous space. Currently only 2D visualization is supported.

required
resolution int

Size of the output image in pixels for both width and height. Higher values produce smoother field gradients but increase computation cost.

512
extent float

Half-width of the viewing area in coordinate space. The view spans from -extent to +extent in each dimension. Adjust to zoom in or out on the particle system.

15.0
particle_radius float

Radius of each particle in coordinate space. Particles are drawn as smooth circles with anti-aliased edges.

0.3
type str

Visualization mode determining what fields to display: "particles": Only show particles on white background (default). "UG": Show particles overlaid on kernel (U) and growth (G) field visualization. "E": Show particles overlaid on energy field visualization.

'UG'

Returns:

Type Description
Array

RGB image with dtype uint8 and shape (resolution, resolution, 3), showing particles and optionally the underlying field structure that drives their motion.

Source code in src/cax/cs/particle_lenia/cs.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
@partial(nnx.jit, static_argnames=("resolution", "extent", "particle_radius", "type"))
def render(
	self,
	state: State,
	*,
	resolution: int = 512,
	extent: float = 15.0,
	particle_radius: float = 0.3,
	type: str = "UG",  # Options: "particles", "UG", "E"
) -> Array:
	"""Render state to RGB image.

	Renders Particle Lenia state as particles optionally overlaid on field visualizations.
	Particles appear as blue circles. The background can show kernel field (U), growth field
	(G), or energy field (E) to visualize the underlying dynamics driving particle motion.
	Field visualizations use color mapping to represent field intensities across space.

	Args:
		state: Array of shape (num_particles, num_spatial_dims) containing particle positions
			in continuous space. Currently only 2D visualization is supported.
		resolution: Size of the output image in pixels for both width and height.
			Higher values produce smoother field gradients but increase computation cost.
		extent: Half-width of the viewing area in coordinate space. The view spans from
			-extent to +extent in each dimension. Adjust to zoom in or out on the particle
			system.
		particle_radius: Radius of each particle in coordinate space. Particles are drawn
			as smooth circles with anti-aliased edges.
		type: Visualization mode determining what fields to display:
			"particles": Only show particles on white background (default).
			"UG": Show particles overlaid on kernel (U) and growth (G) field visualization.
			"E": Show particles overlaid on energy field visualization.

	Returns:
		RGB image with dtype uint8 and shape (resolution, resolution, 3), showing particles
			and optionally the underlying field structure that drives their motion.

	"""
	assert self.num_spatial_dims == 2, "Particle Lenia only supports 2D visualization."

	# Create a grid of coordinates
	x = jnp.linspace(-extent, extent, resolution)
	y = jnp.linspace(-extent, extent, resolution)
	grid = jnp.stack(jnp.meshgrid(x, y), axis=-1)  # Shape: (resolution, resolution, 2)

	# Reshape grid for computation
	flat_grid = grid.reshape(-1, 2)

	# Vectorize the field computation over all grid points
	flat_E, flat_U, flat_G = jax.vmap(self.perceive.compute_fields, in_axes=(None, 0))(
		state, flat_grid
	)

	# Reshape back to grid
	E_field = flat_E.reshape(resolution, resolution)
	U_field = flat_U.reshape(resolution, resolution)
	G_field = flat_G.reshape(resolution, resolution)

	# Helper functions for colormapping
	def lerp(x: Array, a: Array, b: Array) -> Array:
		return a * (1.0 - x) + b * x

	def cmap_e(e: Array) -> Array:
		stacked = jnp.stack([e, -e], -1).clip(0)
		colors = jnp.array([[0.3, 1.0, 1.0], [1.0, 0.3, 1.0]], dtype=jnp.float32)
		return 1.0 - jnp.matmul(stacked, colors)

	def cmap_ug(u: Array, g: Array) -> Array:
		vis = lerp(u[..., None], jnp.array([0.1, 0.1, 0.3]), jnp.array([0.2, 0.7, 1.0]))
		return lerp(g[..., None], vis, jnp.array([1.17, 0.91, 0.13]))

	# Calculate particle mask
	distance_sq = jnp.sum(jnp.square(grid[:, :, None, :] - state[None, None, :, :]), axis=-1)
	distance_sq_min = jnp.min(distance_sq, axis=-1)
	particle_mask = jnp.clip(1.0 - distance_sq_min / (particle_radius**2), 0.0, 1.0)

	# Normalize fields for visualization
	_ = (E_field - jnp.min(E_field)) / (jnp.max(E_field) - jnp.min(E_field) + 1e-8)  # E_norm
	U_norm = (U_field - jnp.min(U_field)) / (jnp.max(U_field) - jnp.min(U_field) + 1e-8)
	G_norm = (G_field - jnp.min(G_field)) / (jnp.max(G_field) - jnp.min(G_field) + 1e-8)

	# Create visualizations
	vis_e = cmap_e(E_field)
	vis_ug = cmap_ug(U_norm, G_norm)

	# Apply particle mask
	particle_mask = particle_mask[:, :, None]

	# Create base particle visualization (blue particles on white background)
	vis_particle = jnp.ones((resolution, resolution, 3))
	vis_particle = (
		vis_particle * (1.0 - particle_mask) + jnp.array([0.0, 0.0, 1.0]) * particle_mask
	)

	# Choose visualization based on type
	if type == "UG":
		# Blend particles with UG field
		rgb = vis_ug * (1.0 - particle_mask * 0.7) + vis_particle * (particle_mask * 0.7)
	elif type == "E":
		# Blend particles with E field
		rgb = vis_e * (1.0 - particle_mask * 0.7) + vis_particle * (particle_mask * 0.7)
	else:  # "particles" (default)
		# Just show particles
		rgb = vis_particle

	return clip_and_uint8(rgb)

__call__(state, input=None, *, num_steps=1, input_in_axis=None, sow=False)

Step the system for multiple time steps.

This method wraps _step inside a JAX scan for efficiency and JIT-compiles the loop. If input is time-varying, set input_in_axis to the axis containing the time dimension so that each step receives the corresponding slice of input.

Parameters:

Name Type Description Default
state State

Current state.

required
input Input | None

Optional input.

None
num_steps int

Number of steps.

1
input_in_axis int | None

Axis for input if provided for each step.

None
sow bool

Whether to sow intermediate values.

False

Returns:

Type Description
State

Final state after num_steps applications of _step.

Source code in src/cax/core/cs.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@partial(nnx.jit, static_argnames=("num_steps", "input_in_axis", "sow"))
def __call__(
	self,
	state: State,
	input: Input | None = None,
	*,
	num_steps: int = 1,
	input_in_axis: int | None = None,
	sow: bool = False,
) -> State:
	"""Step the system for multiple time steps.

	This method wraps `_step` inside a JAX scan for efficiency and JIT-compiles the loop.
	If `input` is time-varying, set `input_in_axis` to the axis containing the time
	dimension so that each step receives the corresponding slice of input.

	Args:
		state: Current state.
		input: Optional input.
		num_steps: Number of steps.
		input_in_axis: Axis for input if provided for each step.
		sow: Whether to sow intermediate values.

	Returns:
		Final state after `num_steps` applications of `_step`.

	"""
	state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry})
	state = nnx.scan(
		lambda cs, state, input: cs._step(state, input, sow=sow),
		in_axes=(state_axes, nnx.Carry, input_in_axis),
		out_axes=nnx.Carry,
		length=num_steps,
	)(self, state, input)

	return state