Skip to content

Lenia

cax.cs.lenia.cs.Lenia

Bases: ComplexSystem

Lenia class.

Source code in src/cax/cs/lenia/cs.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
class Lenia(ComplexSystem):
	"""Lenia class."""

	def __init__(
		self,
		spatial_dims: tuple[int, ...],
		channel_size: int,
		*,
		R: int,
		T: int,
		state_scale: float = 1.0,
		kernel_fn: Callable = gaussian_kernel_fn,
		growth_fn: Callable = exponential_growth_fn,
		rule_params: LeniaRuleParams,
	):
		"""Initialize Lenia.

		Args:
			spatial_dims: Spatial dimensions (e.g., (64, 64) for 2D or (32, 32, 32) for 3D).
			channel_size: Number of channels.
			R: Space resolution defining the kernel radius. Larger values create wider
				neighborhoods and smoother patterns.
			T: Time resolution controlling the temporal discretization. Higher values
				produce smoother temporal dynamics with smaller update steps.
			state_scale: Scaling factor applied to state values.
			kernel_fn: Callable that generates convolution kernels. Takes rule parameters
				and returns kernel weights.
			growth_fn: Callable that maps neighborhood potential to growth values. Defines
				how cells respond to their local environment.
			rule_params: Instance of LeniaRuleParams containing kernel and growth parameters
				for each channel.

		"""
		self.perceive = LeniaPerceive(
			spatial_dims=spatial_dims,
			channel_size=channel_size,
			R=R,
			state_scale=state_scale,
			kernel_fn=kernel_fn,
			rule_params=rule_params,
		)
		self.update = LeniaUpdate(
			channel_size=channel_size,
			T=T,
			growth_fn=growth_fn,
			rule_params=rule_params,
		)

	def _step(self, state: State, input: Input | None = None, *, sow: bool = False) -> State:
		perception = self.perceive(state)
		next_state = self.update(state, perception, input)

		if sow:
			metrics = metrics_fn(next_state, R=self.perceive.R)
			self.sow(nnx.Intermediate, "state", next_state)
			self.sow(nnx.Intermediate, "metrics", metrics)

		return next_state

	@nnx.jit
	def render(self, state: State) -> Array:
		"""Render state to RGB image.

		Converts the multi-channel Lenia state to an RGB visualization. Channels are
		mapped to color channels (Red, Green, Blue) for visualization. If there are
		more than 3 channels, only the first 3 are displayed. If there are fewer than
		3 channels, the missing channels are filled with zeros.

		Args:
			state: Array with shape (*spatial_dims, channel_size) representing the
				Lenia state, where each cell contains continuous values typically in [0, 1].

		Returns:
			RGB image with dtype uint8 and shape (*spatial_dims, 3), where state
				values are mapped to colors in the range [0, 255].

		"""
		rgb = render_array_with_channels_to_rgb(state)

		return clip_and_uint8(rgb)

__init__(spatial_dims, channel_size, *, R, T, state_scale=1.0, kernel_fn=gaussian_kernel_fn, growth_fn=exponential_growth_fn, rule_params)

Initialize Lenia.

Parameters:

Name Type Description Default
spatial_dims tuple[int, ...]

Spatial dimensions (e.g., (64, 64) for 2D or (32, 32, 32) for 3D).

required
channel_size int

Number of channels.

required
R int

Space resolution defining the kernel radius. Larger values create wider neighborhoods and smoother patterns.

required
T int

Time resolution controlling the temporal discretization. Higher values produce smoother temporal dynamics with smaller update steps.

required
state_scale float

Scaling factor applied to state values.

1.0
kernel_fn Callable

Callable that generates convolution kernels. Takes rule parameters and returns kernel weights.

gaussian_kernel_fn
growth_fn Callable

Callable that maps neighborhood potential to growth values. Defines how cells respond to their local environment.

exponential_growth_fn
rule_params LeniaRuleParams

Instance of LeniaRuleParams containing kernel and growth parameters for each channel.

required
Source code in src/cax/cs/lenia/cs.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
	self,
	spatial_dims: tuple[int, ...],
	channel_size: int,
	*,
	R: int,
	T: int,
	state_scale: float = 1.0,
	kernel_fn: Callable = gaussian_kernel_fn,
	growth_fn: Callable = exponential_growth_fn,
	rule_params: LeniaRuleParams,
):
	"""Initialize Lenia.

	Args:
		spatial_dims: Spatial dimensions (e.g., (64, 64) for 2D or (32, 32, 32) for 3D).
		channel_size: Number of channels.
		R: Space resolution defining the kernel radius. Larger values create wider
			neighborhoods and smoother patterns.
		T: Time resolution controlling the temporal discretization. Higher values
			produce smoother temporal dynamics with smaller update steps.
		state_scale: Scaling factor applied to state values.
		kernel_fn: Callable that generates convolution kernels. Takes rule parameters
			and returns kernel weights.
		growth_fn: Callable that maps neighborhood potential to growth values. Defines
			how cells respond to their local environment.
		rule_params: Instance of LeniaRuleParams containing kernel and growth parameters
			for each channel.

	"""
	self.perceive = LeniaPerceive(
		spatial_dims=spatial_dims,
		channel_size=channel_size,
		R=R,
		state_scale=state_scale,
		kernel_fn=kernel_fn,
		rule_params=rule_params,
	)
	self.update = LeniaUpdate(
		channel_size=channel_size,
		T=T,
		growth_fn=growth_fn,
		rule_params=rule_params,
	)

render(state)

Render state to RGB image.

Converts the multi-channel Lenia state to an RGB visualization. Channels are mapped to color channels (Red, Green, Blue) for visualization. If there are more than 3 channels, only the first 3 are displayed. If there are fewer than 3 channels, the missing channels are filled with zeros.

Parameters:

Name Type Description Default
state State

Array with shape (*spatial_dims, channel_size) representing the Lenia state, where each cell contains continuous values typically in [0, 1].

required

Returns:

Type Description
Array

RGB image with dtype uint8 and shape (*spatial_dims, 3), where state values are mapped to colors in the range [0, 255].

Source code in src/cax/cs/lenia/cs.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
@nnx.jit
def render(self, state: State) -> Array:
	"""Render state to RGB image.

	Converts the multi-channel Lenia state to an RGB visualization. Channels are
	mapped to color channels (Red, Green, Blue) for visualization. If there are
	more than 3 channels, only the first 3 are displayed. If there are fewer than
	3 channels, the missing channels are filled with zeros.

	Args:
		state: Array with shape (*spatial_dims, channel_size) representing the
			Lenia state, where each cell contains continuous values typically in [0, 1].

	Returns:
		RGB image with dtype uint8 and shape (*spatial_dims, 3), where state
			values are mapped to colors in the range [0, 255].

	"""
	rgb = render_array_with_channels_to_rgb(state)

	return clip_and_uint8(rgb)

__call__(state, input=None, *, num_steps=1, input_in_axis=None, sow=False)

Step the system for multiple time steps.

This method wraps _step inside a JAX scan for efficiency and JIT-compiles the loop. If input is time-varying, set input_in_axis to the axis containing the time dimension so that each step receives the corresponding slice of input.

Parameters:

Name Type Description Default
state State

Current state.

required
input Input | None

Optional input.

None
num_steps int

Number of steps.

1
input_in_axis int | None

Axis for input if provided for each step.

None
sow bool

Whether to sow intermediate values.

False

Returns:

Type Description
State

Final state after num_steps applications of _step.

Source code in src/cax/core/cs.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@partial(nnx.jit, static_argnames=("num_steps", "input_in_axis", "sow"))
def __call__(
	self,
	state: State,
	input: Input | None = None,
	*,
	num_steps: int = 1,
	input_in_axis: int | None = None,
	sow: bool = False,
) -> State:
	"""Step the system for multiple time steps.

	This method wraps `_step` inside a JAX scan for efficiency and JIT-compiles the loop.
	If `input` is time-varying, set `input_in_axis` to the axis containing the time
	dimension so that each step receives the corresponding slice of input.

	Args:
		state: Current state.
		input: Optional input.
		num_steps: Number of steps.
		input_in_axis: Axis for input if provided for each step.
		sow: Whether to sow intermediate values.

	Returns:
		Final state after `num_steps` applications of `_step`.

	"""
	state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry})
	state = nnx.scan(
		lambda cs, state, input: cs._step(state, input, sow=sow),
		in_axes=(state_axes, nnx.Carry, input_in_axis),
		out_axes=nnx.Carry,
		length=num_steps,
	)(self, state, input)

	return state