Skip to content

Particle Life

cax.cs.particle_life.cs.ParticleLife

Bases: ComplexSystem

Particle Life class.

Source code in src/cax/cs/particle_life/cs.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
class ParticleLife(ComplexSystem):
	"""Particle Life class."""

	def __init__(
		self,
		num_classes: int,
		*,
		dt: float = 0.01,
		force_factor: float = 1.0,
		velocity_half_life: float = 0.01,
		r_max: float = 0.15,
		beta: float = 0.3,
		A: Array,
	):
		"""Initialize Particle Life.

		Args:
			num_classes: Number of distinct particle types (classes). Each type can have
				different interactions with other types as specified in the attraction matrix.
			dt: Time step of the simulation in arbitrary time units. Smaller values
				produce smoother motion but require more steps for the same duration.
			force_factor: Global scaling factor for all interaction forces. Higher values
				create stronger, more dynamic interactions.
			velocity_half_life: Time constant for velocity decay due to friction. After
				this time, velocity is halved without force input. Smaller values create
				more damped, viscous dynamics.
			r_max: Maximum interaction distance in coordinate space [0, 1]. Particles beyond
				this distance do not interact. Larger values increase computation cost.
			beta: Distance threshold parameter controlling the transition from repulsion to
				attraction. Typically in range [0, 1], where smaller values create stronger
				short-range repulsion.
			A: Attraction matrix of shape (num_classes, num_classes) where A[i, j] defines
				the attraction strength from type i to type j. Positive values attract,
				negative values repel. Values typically range from -1 to 1.

		"""
		self.num_classes = num_classes

		self.perceive = ParticleLifePerceive(
			force_factor=force_factor,
			r_max=r_max,
			beta=beta,
			A=A,
		)
		self.update = ParticleLifeUpdate(
			dt=dt,
			velocity_half_life=velocity_half_life,
		)

	def _step(
		self, state: ParticleLifeState, input: Input | None = None, *, sow: bool = False
	) -> ParticleLifeState:
		perception = self.perceive(state)
		next_state = self.update(state, perception, input)

		if sow:
			self.sow(nnx.Intermediate, "state", next_state)

		return next_state

	@partial(nnx.jit, static_argnames=("resolution", "particle_radius"))
	def render(
		self,
		state: ParticleLifeState,
		*,
		resolution: int = 512,
		particle_radius: float = 0.005,
	) -> Array:
		"""Render state to RGB image.

		Renders particles as colored circles on a black background. Each particle type
		(class) is assigned a distinct hue from the color spectrum, with colors evenly
		distributed across the HSV color space. Particles are drawn with smooth anti-aliased
		edges based on their distance from pixel centers. The visualization uses 2D coordinates
		in the range [0, 1].

		Args:
			state: ParticleLifeState containing class_, position, and velocity arrays.
				Position should have shape (num_particles, 2) with coordinates in [0, 1].
				Class array determines the color of each particle.
			resolution: Size of the output image in pixels for both width and height.
				Higher values produce smoother, more detailed renderings.
			particle_radius: Radius of each particle in coordinate space [0, 1]. Particles
				appear as smooth circles with this radius. Larger values make particles more
				visible but may cause overlap.

		Returns:
			RGB image with dtype uint8 and shape (resolution, resolution, 3), where
				particles appear as colored circles on a black background, with colors
				determined by particle type.

		"""
		assert state.position.shape[-1] == 2, "Particle Life only supports 2D visualization."

		# Create grid of pixel centers
		x = jnp.linspace(0, 1, resolution)
		y = jnp.linspace(0, 1, resolution)
		grid = jnp.stack(jnp.meshgrid(x, y), axis=-1)  # Shape: (resolution, resolution, 2)

		# Compute squared distances to all particles
		positions = state.position  # Shape: (num_particles, 2)
		distance_sq = jnp.sum(
			(grid[:, :, None, :] - positions[None, None, :, :]) ** 2, axis=-1
		)  # Shape: (resolution, resolution, num_particles)

		# Find minimum squared distance and index of closest particle
		min_distance_sq = jnp.min(distance_sq, axis=-1)  # Shape: (resolution, resolution)
		closest_particle_idx = jnp.argmin(distance_sq, axis=-1)  # Shape: (resolution, resolution)

		# Get class of the closest particle for each pixel
		closest_class = state.class_[closest_particle_idx]  # Shape: (resolution, resolution)

		# Compute smooth mask based on distance to closest particle
		mask = jnp.clip(
			1.0 - min_distance_sq / (particle_radius**2), 0.0, 1.0
		)  # Shape: (resolution, resolution)

		# Generate colors for each class using HSV
		hues = jnp.linspace(0, 1, self.num_classes, endpoint=False)
		hsv = jnp.stack([hues, jnp.ones_like(hues), jnp.ones_like(hues)], axis=-1)
		colors = hsv_to_rgb(hsv)  # Shape: (num_classes, 3)

		# Assign colors based on closest particle's class
		particle_colors = colors[closest_class]  # Shape: (resolution, resolution, 3)

		# Create black background
		background = jnp.zeros((resolution, resolution, 3))  # Shape: (resolution, resolution, 3)

		# Blend particle colors with background using the mask
		rgb = (
			background * (1.0 - mask[..., None]) + particle_colors * mask[..., None]
		)  # Shape: (resolution, resolution, 3)

		return clip_and_uint8(rgb)

__init__(num_classes, *, dt=0.01, force_factor=1.0, velocity_half_life=0.01, r_max=0.15, beta=0.3, A)

Initialize Particle Life.

Parameters:

Name Type Description Default
num_classes int

Number of distinct particle types (classes). Each type can have different interactions with other types as specified in the attraction matrix.

required
dt float

Time step of the simulation in arbitrary time units. Smaller values produce smoother motion but require more steps for the same duration.

0.01
force_factor float

Global scaling factor for all interaction forces. Higher values create stronger, more dynamic interactions.

1.0
velocity_half_life float

Time constant for velocity decay due to friction. After this time, velocity is halved without force input. Smaller values create more damped, viscous dynamics.

0.01
r_max float

Maximum interaction distance in coordinate space [0, 1]. Particles beyond this distance do not interact. Larger values increase computation cost.

0.15
beta float

Distance threshold parameter controlling the transition from repulsion to attraction. Typically in range [0, 1], where smaller values create stronger short-range repulsion.

0.3
A Array

Attraction matrix of shape (num_classes, num_classes) where A[i, j] defines the attraction strength from type i to type j. Positive values attract, negative values repel. Values typically range from -1 to 1.

required
Source code in src/cax/cs/particle_life/cs.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def __init__(
	self,
	num_classes: int,
	*,
	dt: float = 0.01,
	force_factor: float = 1.0,
	velocity_half_life: float = 0.01,
	r_max: float = 0.15,
	beta: float = 0.3,
	A: Array,
):
	"""Initialize Particle Life.

	Args:
		num_classes: Number of distinct particle types (classes). Each type can have
			different interactions with other types as specified in the attraction matrix.
		dt: Time step of the simulation in arbitrary time units. Smaller values
			produce smoother motion but require more steps for the same duration.
		force_factor: Global scaling factor for all interaction forces. Higher values
			create stronger, more dynamic interactions.
		velocity_half_life: Time constant for velocity decay due to friction. After
			this time, velocity is halved without force input. Smaller values create
			more damped, viscous dynamics.
		r_max: Maximum interaction distance in coordinate space [0, 1]. Particles beyond
			this distance do not interact. Larger values increase computation cost.
		beta: Distance threshold parameter controlling the transition from repulsion to
			attraction. Typically in range [0, 1], where smaller values create stronger
			short-range repulsion.
		A: Attraction matrix of shape (num_classes, num_classes) where A[i, j] defines
			the attraction strength from type i to type j. Positive values attract,
			negative values repel. Values typically range from -1 to 1.

	"""
	self.num_classes = num_classes

	self.perceive = ParticleLifePerceive(
		force_factor=force_factor,
		r_max=r_max,
		beta=beta,
		A=A,
	)
	self.update = ParticleLifeUpdate(
		dt=dt,
		velocity_half_life=velocity_half_life,
	)

render(state, *, resolution=512, particle_radius=0.005)

Render state to RGB image.

Renders particles as colored circles on a black background. Each particle type (class) is assigned a distinct hue from the color spectrum, with colors evenly distributed across the HSV color space. Particles are drawn with smooth anti-aliased edges based on their distance from pixel centers. The visualization uses 2D coordinates in the range [0, 1].

Parameters:

Name Type Description Default
state ParticleLifeState

ParticleLifeState containing class_, position, and velocity arrays. Position should have shape (num_particles, 2) with coordinates in [0, 1]. Class array determines the color of each particle.

required
resolution int

Size of the output image in pixels for both width and height. Higher values produce smoother, more detailed renderings.

512
particle_radius float

Radius of each particle in coordinate space [0, 1]. Particles appear as smooth circles with this radius. Larger values make particles more visible but may cause overlap.

0.005

Returns:

Type Description
Array

RGB image with dtype uint8 and shape (resolution, resolution, 3), where particles appear as colored circles on a black background, with colors determined by particle type.

Source code in src/cax/cs/particle_life/cs.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
@partial(nnx.jit, static_argnames=("resolution", "particle_radius"))
def render(
	self,
	state: ParticleLifeState,
	*,
	resolution: int = 512,
	particle_radius: float = 0.005,
) -> Array:
	"""Render state to RGB image.

	Renders particles as colored circles on a black background. Each particle type
	(class) is assigned a distinct hue from the color spectrum, with colors evenly
	distributed across the HSV color space. Particles are drawn with smooth anti-aliased
	edges based on their distance from pixel centers. The visualization uses 2D coordinates
	in the range [0, 1].

	Args:
		state: ParticleLifeState containing class_, position, and velocity arrays.
			Position should have shape (num_particles, 2) with coordinates in [0, 1].
			Class array determines the color of each particle.
		resolution: Size of the output image in pixels for both width and height.
			Higher values produce smoother, more detailed renderings.
		particle_radius: Radius of each particle in coordinate space [0, 1]. Particles
			appear as smooth circles with this radius. Larger values make particles more
			visible but may cause overlap.

	Returns:
		RGB image with dtype uint8 and shape (resolution, resolution, 3), where
			particles appear as colored circles on a black background, with colors
			determined by particle type.

	"""
	assert state.position.shape[-1] == 2, "Particle Life only supports 2D visualization."

	# Create grid of pixel centers
	x = jnp.linspace(0, 1, resolution)
	y = jnp.linspace(0, 1, resolution)
	grid = jnp.stack(jnp.meshgrid(x, y), axis=-1)  # Shape: (resolution, resolution, 2)

	# Compute squared distances to all particles
	positions = state.position  # Shape: (num_particles, 2)
	distance_sq = jnp.sum(
		(grid[:, :, None, :] - positions[None, None, :, :]) ** 2, axis=-1
	)  # Shape: (resolution, resolution, num_particles)

	# Find minimum squared distance and index of closest particle
	min_distance_sq = jnp.min(distance_sq, axis=-1)  # Shape: (resolution, resolution)
	closest_particle_idx = jnp.argmin(distance_sq, axis=-1)  # Shape: (resolution, resolution)

	# Get class of the closest particle for each pixel
	closest_class = state.class_[closest_particle_idx]  # Shape: (resolution, resolution)

	# Compute smooth mask based on distance to closest particle
	mask = jnp.clip(
		1.0 - min_distance_sq / (particle_radius**2), 0.0, 1.0
	)  # Shape: (resolution, resolution)

	# Generate colors for each class using HSV
	hues = jnp.linspace(0, 1, self.num_classes, endpoint=False)
	hsv = jnp.stack([hues, jnp.ones_like(hues), jnp.ones_like(hues)], axis=-1)
	colors = hsv_to_rgb(hsv)  # Shape: (num_classes, 3)

	# Assign colors based on closest particle's class
	particle_colors = colors[closest_class]  # Shape: (resolution, resolution, 3)

	# Create black background
	background = jnp.zeros((resolution, resolution, 3))  # Shape: (resolution, resolution, 3)

	# Blend particle colors with background using the mask
	rgb = (
		background * (1.0 - mask[..., None]) + particle_colors * mask[..., None]
	)  # Shape: (resolution, resolution, 3)

	return clip_and_uint8(rgb)

__call__(state, input=None, *, num_steps=1, input_in_axis=None, sow=False)

Step the system for multiple time steps.

This method wraps _step inside a JAX scan for efficiency and JIT-compiles the loop. If input is time-varying, set input_in_axis to the axis containing the time dimension so that each step receives the corresponding slice of input.

Parameters:

Name Type Description Default
state State

Current state.

required
input Input | None

Optional input.

None
num_steps int

Number of steps.

1
input_in_axis int | None

Axis for input if provided for each step.

None
sow bool

Whether to sow intermediate values.

False

Returns:

Type Description
State

Final state after num_steps applications of _step.

Source code in src/cax/core/cs.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@partial(nnx.jit, static_argnames=("num_steps", "input_in_axis", "sow"))
def __call__(
	self,
	state: State,
	input: Input | None = None,
	*,
	num_steps: int = 1,
	input_in_axis: int | None = None,
	sow: bool = False,
) -> State:
	"""Step the system for multiple time steps.

	This method wraps `_step` inside a JAX scan for efficiency and JIT-compiles the loop.
	If `input` is time-varying, set `input_in_axis` to the axis containing the time
	dimension so that each step receives the corresponding slice of input.

	Args:
		state: Current state.
		input: Optional input.
		num_steps: Number of steps.
		input_in_axis: Axis for input if provided for each step.
		sow: Whether to sow intermediate values.

	Returns:
		Final state after `num_steps` applications of `_step`.

	"""
	state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry})
	state = nnx.scan(
		lambda cs, state, input: cs._step(state, input, sow=sow),
		in_axes=(state_axes, nnx.Carry, input_in_axis),
		out_axes=nnx.Carry,
		length=num_steps,
	)(self, state, input)

	return state