Skip to content

Lenia

cax.cs.lenia.cs.Lenia

Bases: ComplexSystem[Array, Array]

Lenia class.

Source code in src/cax/cs/lenia/cs.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
class Lenia(ComplexSystem[Array, Array]):
	"""Lenia class."""

	def __init__(
		self,
		spatial_dims: tuple[int, ...],
		channel_size: int,
		*,
		R: int,
		T: int,
		state_scale: float = 1.0,
		kernel_fn: Callable = gaussian_kernel_fn,
		growth_fn: Callable = exponential_growth_fn,
		rule_params: LeniaRuleParams,
	):
		"""Initialize Lenia.

		Args:
			spatial_dims: Spatial dimensions (e.g., (64, 64) for 2D or (32, 32, 32) for 3D).
			channel_size: Number of channels.
			R: Space resolution defining the kernel radius. Larger values create wider
				neighborhoods and smoother patterns.
			T: Time resolution controlling the temporal discretization. Higher values
				produce smoother temporal dynamics with smaller update steps.
			state_scale: Scaling factor applied to state values.
			kernel_fn: Callable that generates convolution kernels. Takes rule parameters
				and returns kernel weights.
			growth_fn: Callable that maps neighborhood potential to growth values. Defines
				how cells respond to their local environment.
			rule_params: Instance of LeniaRuleParams containing kernel and growth parameters
				for each channel.

		"""
		self.perceive = LeniaPerceive(
			spatial_dims=spatial_dims,
			channel_size=channel_size,
			R=R,
			state_scale=state_scale,
			kernel_fn=kernel_fn,
			rule_params=rule_params,
		)
		self.update = LeniaUpdate(
			channel_size=channel_size,
			T=T,
			growth_fn=growth_fn,
			rule_params=rule_params,
		)

	def _step(self, state: Array, input: Array | None = None, *, sow: bool = False) -> Array:
		perception = self.perceive(state)
		next_state = self.update(state, perception, input)

		if sow:
			metrics = metrics_fn(next_state, R=self.perceive.R)
			self.sow(nnx.Intermediate, "state", next_state)
			self.sow(nnx.Intermediate, "metrics", metrics)

		return next_state

	@nnx.jit
	def render(self, state: Array) -> Array:
		"""Render state to RGB image.

		Converts the multi-channel Lenia state to an RGB visualization. Channels are
		mapped to color channels (Red, Green, Blue) for visualization. If there are
		more than 3 channels, only the first 3 are displayed. If there are fewer than
		3 channels, the missing channels are filled with zeros.

		Args:
			state: Array with shape (*spatial_dims, channel_size) representing the
				Lenia state, where each cell contains continuous values typically in [0, 1].

		Returns:
			RGB image with dtype uint8 and shape (*spatial_dims, 3), where state
				values are mapped to colors in the range [0, 255].

		"""
		rgb = render_array_with_channels_to_rgb(state)

		return clip_and_uint8(rgb)

__init__(spatial_dims, channel_size, *, R, T, state_scale=1.0, kernel_fn=gaussian_kernel_fn, growth_fn=exponential_growth_fn, rule_params)

Initialize Lenia.

Parameters:

Name Type Description Default
spatial_dims tuple[int, ...]

Spatial dimensions (e.g., (64, 64) for 2D or (32, 32, 32) for 3D).

required
channel_size int

Number of channels.

required
R int

Space resolution defining the kernel radius. Larger values create wider neighborhoods and smoother patterns.

required
T int

Time resolution controlling the temporal discretization. Higher values produce smoother temporal dynamics with smaller update steps.

required
state_scale float

Scaling factor applied to state values.

1.0
kernel_fn Callable

Callable that generates convolution kernels. Takes rule parameters and returns kernel weights.

gaussian_kernel_fn
growth_fn Callable

Callable that maps neighborhood potential to growth values. Defines how cells respond to their local environment.

exponential_growth_fn
rule_params LeniaRuleParams

Instance of LeniaRuleParams containing kernel and growth parameters for each channel.

required
Source code in src/cax/cs/lenia/cs.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
	self,
	spatial_dims: tuple[int, ...],
	channel_size: int,
	*,
	R: int,
	T: int,
	state_scale: float = 1.0,
	kernel_fn: Callable = gaussian_kernel_fn,
	growth_fn: Callable = exponential_growth_fn,
	rule_params: LeniaRuleParams,
):
	"""Initialize Lenia.

	Args:
		spatial_dims: Spatial dimensions (e.g., (64, 64) for 2D or (32, 32, 32) for 3D).
		channel_size: Number of channels.
		R: Space resolution defining the kernel radius. Larger values create wider
			neighborhoods and smoother patterns.
		T: Time resolution controlling the temporal discretization. Higher values
			produce smoother temporal dynamics with smaller update steps.
		state_scale: Scaling factor applied to state values.
		kernel_fn: Callable that generates convolution kernels. Takes rule parameters
			and returns kernel weights.
		growth_fn: Callable that maps neighborhood potential to growth values. Defines
			how cells respond to their local environment.
		rule_params: Instance of LeniaRuleParams containing kernel and growth parameters
			for each channel.

	"""
	self.perceive = LeniaPerceive(
		spatial_dims=spatial_dims,
		channel_size=channel_size,
		R=R,
		state_scale=state_scale,
		kernel_fn=kernel_fn,
		rule_params=rule_params,
	)
	self.update = LeniaUpdate(
		channel_size=channel_size,
		T=T,
		growth_fn=growth_fn,
		rule_params=rule_params,
	)

render(state)

Render state to RGB image.

Converts the multi-channel Lenia state to an RGB visualization. Channels are mapped to color channels (Red, Green, Blue) for visualization. If there are more than 3 channels, only the first 3 are displayed. If there are fewer than 3 channels, the missing channels are filled with zeros.

Parameters:

Name Type Description Default
state Array

Array with shape (*spatial_dims, channel_size) representing the Lenia state, where each cell contains continuous values typically in [0, 1].

required

Returns:

Type Description
Array

RGB image with dtype uint8 and shape (*spatial_dims, 3), where state values are mapped to colors in the range [0, 255].

Source code in src/cax/cs/lenia/cs.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
@nnx.jit
def render(self, state: Array) -> Array:
	"""Render state to RGB image.

	Converts the multi-channel Lenia state to an RGB visualization. Channels are
	mapped to color channels (Red, Green, Blue) for visualization. If there are
	more than 3 channels, only the first 3 are displayed. If there are fewer than
	3 channels, the missing channels are filled with zeros.

	Args:
		state: Array with shape (*spatial_dims, channel_size) representing the
			Lenia state, where each cell contains continuous values typically in [0, 1].

	Returns:
		RGB image with dtype uint8 and shape (*spatial_dims, 3), where state
			values are mapped to colors in the range [0, 255].

	"""
	rgb = render_array_with_channels_to_rgb(state)

	return clip_and_uint8(rgb)

__call__(state, input=None, *, num_steps=1, input_in_axis=None, sow=False)

Step the system for multiple time steps.

This method wraps _step inside a JAX scan for efficiency and JIT-compiles the loop. If input is time-varying, set input_in_axis to the axis containing the time dimension so that each step receives the corresponding slice of input.

When remat is enabled, the scan body is wrapped with nnx.remat to reduce memory usage during backpropagation at the cost of recomputing intermediates.

Parameters:

Name Type Description Default
state State

Current state.

required
input Input | None

Optional input.

None
num_steps int

Number of steps.

1
input_in_axis int | None

Axis for input if provided for each step.

None
sow bool

Whether to sow intermediate values.

False

Returns:

Type Description
State

Final state after num_steps applications of _step.

Source code in src/cax/core/cs.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@nnx.jit(static_argnames=("num_steps", "input_in_axis", "sow"))
def __call__(
	self,
	state: State,
	input: Input | None = None,
	*,
	num_steps: int = 1,
	input_in_axis: int | None = None,
	sow: bool = False,
) -> State:
	"""Step the system for multiple time steps.

	This method wraps `_step` inside a JAX scan for efficiency and JIT-compiles the loop.
	If `input` is time-varying, set `input_in_axis` to the axis containing the time
	dimension so that each step receives the corresponding slice of input.

	When `remat` is enabled, the scan body is wrapped with `nnx.remat` to reduce memory
	usage during backpropagation at the cost of recomputing intermediates.

	Args:
		state: Current state.
		input: Optional input.
		num_steps: Number of steps.
		input_in_axis: Axis for input if provided for each step.
		sow: Whether to sow intermediate values.

	Returns:
		Final state after `num_steps` applications of `_step`.

	"""

	def step_fn(cs: ComplexSystem, state: State, input: Input | None) -> State:
		return cs._step(state, input, sow=sow)

	if self.remat:
		step_fn = nnx.remat(step_fn)

	state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry})
	state = nnx.scan(
		step_fn,
		in_axes=(state_axes, nnx.Carry, input_in_axis),
		out_axes=nnx.Carry,
		length=num_steps,
	)(self, state, input)

	return state