Skip to content

Flow Lenia

cax.cs.flow_lenia.cs.FlowLenia

Bases: ComplexSystem[Array, Array]

Flow Lenia class.

Source code in src/cax/cs/flow_lenia/cs.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class FlowLenia(ComplexSystem[Array, Array]):
	"""Flow Lenia class."""

	def __init__(
		self,
		spatial_dims: tuple[int, ...],
		channel_size: int,
		*,
		R: int,
		T: int,
		state_scale: float = 1.0,
		kernel_fn: Callable = gaussian_kernel_fn,
		growth_fn: Callable = exponential_growth_fn,
		rule_params: LeniaRuleParams,
		# Flow Lenia parameters
		theta_A: float = 1.0,
		n: int = 2,
		dd: int = 5,
		sigma: float = 0.65,
	):
		"""Initialize Flow Lenia.

		Args:
			spatial_dims: Spatial dimensions (e.g., (64, 64) for 2D or (32, 32, 32) for 3D).
			channel_size: Number of channels.
			R: Space resolution defining the kernel radius. Larger values create wider
				neighborhoods and smoother patterns.
			T: Time resolution controlling the temporal discretization. Higher values
				produce smoother temporal dynamics with smaller update steps.
			state_scale: Scaling factor applied to state values.
			kernel_fn: Callable that generates convolution kernels. Takes rule parameters
				and returns kernel weights.
			growth_fn: Callable that maps neighborhood potential to growth values. Defines
				how cells respond to their local environment.
			rule_params: Instance of LeniaRuleParams containing kernel and growth parameters
				for each channel.
			theta_A: Threshold value for computing the flow activation alpha. Higher values
				make flow less sensitive to local density.
			n: Exponent controlling the nonlinearity of flow activation. Higher values create
				sharper transitions between flow and no-flow regions.
			dd: Maximum displacement distance in pixels that flow can induce per time step.
				Controls the strength of advective transport.
			sigma: Spread parameter for the displacement kernel. Smaller values create more
				localized flow, larger values produce smoother displacement fields.

		"""
		self.perceive = LeniaPerceive(
			spatial_dims=spatial_dims,
			channel_size=channel_size,
			R=R,
			state_scale=state_scale,
			kernel_fn=kernel_fn,
			rule_params=rule_params,
		)
		self.update = FlowLeniaUpdate(
			channel_size=channel_size,
			T=T,
			growth_fn=growth_fn,
			rule_params=rule_params,
			theta_A=theta_A,
			n=n,
			dd=dd,
			sigma=sigma,
		)

	def _step(self, state: Array, input: Array | None = None, *, sow: bool = False) -> Array:
		perception = self.perceive(state)
		next_state = self.update(state, perception, input)

		if sow:
			self.sow(nnx.Intermediate, "state", next_state)

		return next_state

	@nnx.jit
	def render(self, state: Array) -> Array:
		"""Render state to RGB image.

		Converts the multi-channel Lenia state to an RGB visualization. Channels are
		mapped to color channels (Red, Green, Blue) for visualization. If there are
		more than 3 channels, only the first 3 are displayed. If there are fewer than
		3 channels, the missing channels are filled with zeros.

		Args:
			state: Array with shape (*spatial_dims, channel_size) representing the
				Lenia state, where each cell contains continuous values typically in [0, 1].

		Returns:
			RGB image with dtype uint8 and shape (*spatial_dims, 3), where state
				values are mapped to colors in the range [0, 255].

		"""
		rgb = render_array_with_channels_to_rgb(state)

		return clip_and_uint8(rgb)

__init__(spatial_dims, channel_size, *, R, T, state_scale=1.0, kernel_fn=gaussian_kernel_fn, growth_fn=exponential_growth_fn, rule_params, theta_A=1.0, n=2, dd=5, sigma=0.65)

Initialize Flow Lenia.

Parameters:

Name Type Description Default
spatial_dims tuple[int, ...]

Spatial dimensions (e.g., (64, 64) for 2D or (32, 32, 32) for 3D).

required
channel_size int

Number of channels.

required
R int

Space resolution defining the kernel radius. Larger values create wider neighborhoods and smoother patterns.

required
T int

Time resolution controlling the temporal discretization. Higher values produce smoother temporal dynamics with smaller update steps.

required
state_scale float

Scaling factor applied to state values.

1.0
kernel_fn Callable

Callable that generates convolution kernels. Takes rule parameters and returns kernel weights.

gaussian_kernel_fn
growth_fn Callable

Callable that maps neighborhood potential to growth values. Defines how cells respond to their local environment.

exponential_growth_fn
rule_params LeniaRuleParams

Instance of LeniaRuleParams containing kernel and growth parameters for each channel.

required
theta_A float

Threshold value for computing the flow activation alpha. Higher values make flow less sensitive to local density.

1.0
n int

Exponent controlling the nonlinearity of flow activation. Higher values create sharper transitions between flow and no-flow regions.

2
dd int

Maximum displacement distance in pixels that flow can induce per time step. Controls the strength of advective transport.

5
sigma float

Spread parameter for the displacement kernel. Smaller values create more localized flow, larger values produce smoother displacement fields.

0.65
Source code in src/cax/cs/flow_lenia/cs.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def __init__(
	self,
	spatial_dims: tuple[int, ...],
	channel_size: int,
	*,
	R: int,
	T: int,
	state_scale: float = 1.0,
	kernel_fn: Callable = gaussian_kernel_fn,
	growth_fn: Callable = exponential_growth_fn,
	rule_params: LeniaRuleParams,
	# Flow Lenia parameters
	theta_A: float = 1.0,
	n: int = 2,
	dd: int = 5,
	sigma: float = 0.65,
):
	"""Initialize Flow Lenia.

	Args:
		spatial_dims: Spatial dimensions (e.g., (64, 64) for 2D or (32, 32, 32) for 3D).
		channel_size: Number of channels.
		R: Space resolution defining the kernel radius. Larger values create wider
			neighborhoods and smoother patterns.
		T: Time resolution controlling the temporal discretization. Higher values
			produce smoother temporal dynamics with smaller update steps.
		state_scale: Scaling factor applied to state values.
		kernel_fn: Callable that generates convolution kernels. Takes rule parameters
			and returns kernel weights.
		growth_fn: Callable that maps neighborhood potential to growth values. Defines
			how cells respond to their local environment.
		rule_params: Instance of LeniaRuleParams containing kernel and growth parameters
			for each channel.
		theta_A: Threshold value for computing the flow activation alpha. Higher values
			make flow less sensitive to local density.
		n: Exponent controlling the nonlinearity of flow activation. Higher values create
			sharper transitions between flow and no-flow regions.
		dd: Maximum displacement distance in pixels that flow can induce per time step.
			Controls the strength of advective transport.
		sigma: Spread parameter for the displacement kernel. Smaller values create more
			localized flow, larger values produce smoother displacement fields.

	"""
	self.perceive = LeniaPerceive(
		spatial_dims=spatial_dims,
		channel_size=channel_size,
		R=R,
		state_scale=state_scale,
		kernel_fn=kernel_fn,
		rule_params=rule_params,
	)
	self.update = FlowLeniaUpdate(
		channel_size=channel_size,
		T=T,
		growth_fn=growth_fn,
		rule_params=rule_params,
		theta_A=theta_A,
		n=n,
		dd=dd,
		sigma=sigma,
	)

render(state)

Render state to RGB image.

Converts the multi-channel Lenia state to an RGB visualization. Channels are mapped to color channels (Red, Green, Blue) for visualization. If there are more than 3 channels, only the first 3 are displayed. If there are fewer than 3 channels, the missing channels are filled with zeros.

Parameters:

Name Type Description Default
state Array

Array with shape (*spatial_dims, channel_size) representing the Lenia state, where each cell contains continuous values typically in [0, 1].

required

Returns:

Type Description
Array

RGB image with dtype uint8 and shape (*spatial_dims, 3), where state values are mapped to colors in the range [0, 255].

Source code in src/cax/cs/flow_lenia/cs.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@nnx.jit
def render(self, state: Array) -> Array:
	"""Render state to RGB image.

	Converts the multi-channel Lenia state to an RGB visualization. Channels are
	mapped to color channels (Red, Green, Blue) for visualization. If there are
	more than 3 channels, only the first 3 are displayed. If there are fewer than
	3 channels, the missing channels are filled with zeros.

	Args:
		state: Array with shape (*spatial_dims, channel_size) representing the
			Lenia state, where each cell contains continuous values typically in [0, 1].

	Returns:
		RGB image with dtype uint8 and shape (*spatial_dims, 3), where state
			values are mapped to colors in the range [0, 255].

	"""
	rgb = render_array_with_channels_to_rgb(state)

	return clip_and_uint8(rgb)

__call__(state, input=None, *, num_steps=1, input_in_axis=None, sow=False)

Step the system for multiple time steps.

This method wraps _step inside a JAX scan for efficiency and JIT-compiles the loop. If input is time-varying, set input_in_axis to the axis containing the time dimension so that each step receives the corresponding slice of input.

When remat is enabled, the scan body is wrapped with nnx.remat to reduce memory usage during backpropagation at the cost of recomputing intermediates.

Parameters:

Name Type Description Default
state State

Current state.

required
input Input | None

Optional input.

None
num_steps int

Number of steps.

1
input_in_axis int | None

Axis for input if provided for each step.

None
sow bool

Whether to sow intermediate values.

False

Returns:

Type Description
State

Final state after num_steps applications of _step.

Source code in src/cax/core/cs.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@nnx.jit(static_argnames=("num_steps", "input_in_axis", "sow"))
def __call__(
	self,
	state: State,
	input: Input | None = None,
	*,
	num_steps: int = 1,
	input_in_axis: int | None = None,
	sow: bool = False,
) -> State:
	"""Step the system for multiple time steps.

	This method wraps `_step` inside a JAX scan for efficiency and JIT-compiles the loop.
	If `input` is time-varying, set `input_in_axis` to the axis containing the time
	dimension so that each step receives the corresponding slice of input.

	When `remat` is enabled, the scan body is wrapped with `nnx.remat` to reduce memory
	usage during backpropagation at the cost of recomputing intermediates.

	Args:
		state: Current state.
		input: Optional input.
		num_steps: Number of steps.
		input_in_axis: Axis for input if provided for each step.
		sow: Whether to sow intermediate values.

	Returns:
		Final state after `num_steps` applications of `_step`.

	"""

	def step_fn(cs: ComplexSystem, state: State, input: Input | None) -> State:
		return cs._step(state, input, sow=sow)

	if self.remat:
		step_fn = nnx.remat(step_fn)

	state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry})
	state = nnx.scan(
		step_fn,
		in_axes=(state_axes, nnx.Carry, input_in_axis),
		out_axes=nnx.Carry,
		length=num_steps,
	)(self, state, input)

	return state