Skip to content

Neural Network Utils

Neural Network Utils

cax.nn.pool

Pool module.

Pool

Bases: PyTreeNode

A container for PyTree arrays supporting in-place updates and random sampling.

The pool holds a PyTree of arrays whose first dimension is the pool size. It can be created from a PyTree with leading batch dimension. Sampling returns indices and the sliced batch for the same indices across all leaves.

Attributes:

Name Type Description
size int

Number of items in the pool (inferred from the leading dimension of the data).

data PyTree

PyTree of arrays stacked along the leading dimension.

Source code in src/cax/nn/pool.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class Pool(struct.PyTreeNode):
	"""A container for PyTree arrays supporting in-place updates and random sampling.

	The pool holds a PyTree of arrays whose first dimension is the pool size. It can be created
	from a PyTree with leading batch dimension. Sampling returns indices and the sliced
	batch for the same indices across all leaves.

	Attributes:
		size: Number of items in the pool (inferred from the leading dimension of the data).
		data: PyTree of arrays stacked along the leading dimension.

	"""

	size: int = struct.field(pytree_node=False)
	data: PyTree

	@classmethod
	def create(cls, data: PyTree) -> "Pool":
		"""Create a new Pool instance.

		Args:
			data: PyTree whose leaves are arrays with shape `(N, ...)`, where `N` is the pool size.

		Returns:
			A new Pool instance with `size == N` and `data == data`.

		"""
		size = jax.tree.leaves(data)[0].shape[0]
		return cls(size=size, data=data)

	@jax.jit
	def update(self, idxs: Array, batch: PyTree) -> "Pool":
		"""Update batch in the pool at the specified indices.

		Args:
			idxs: Integer indices with shape `(B,)` indicating rows to overwrite.
			batch: PyTree matching `data` leaves sliced to `(B, ...)`.

		Returns:
			A new Pool instance with the updated batch applied at `idxs` across all leaves.

		"""
		data = jax.tree.map(
			lambda data_leaf, batch_leaf: data_leaf.at[idxs].set(batch_leaf), self.data, batch
		)
		return self.replace(data=data)

	@partial(jax.jit, static_argnames=("batch_size",))
	def sample(self, key: Array, *, batch_size: int) -> tuple[Array, PyTree]:
		"""Sample a batch from the pool.

		Args:
			key: JAX PRNG key.
			batch_size: Number of rows to sample.

		Returns:
			A tuple `(idxs, batch)` where `idxs` has shape `(batch_size,)` and `batch` is a PyTree
			with each leaf shaped `(batch_size, ...)`.

		"""
		idxs = jax.random.choice(key, self.size, shape=(batch_size,))
		batch = jax.tree.map(lambda leaf: leaf[idxs], self.data)
		return idxs, batch

create(data) classmethod

Create a new Pool instance.

Parameters:

Name Type Description Default
data PyTree

PyTree whose leaves are arrays with shape (N, ...), where N is the pool size.

required

Returns:

Type Description
Pool

A new Pool instance with size == N and data == data.

Source code in src/cax/nn/pool.py
28
29
30
31
32
33
34
35
36
37
38
39
40
@classmethod
def create(cls, data: PyTree) -> "Pool":
	"""Create a new Pool instance.

	Args:
		data: PyTree whose leaves are arrays with shape `(N, ...)`, where `N` is the pool size.

	Returns:
		A new Pool instance with `size == N` and `data == data`.

	"""
	size = jax.tree.leaves(data)[0].shape[0]
	return cls(size=size, data=data)

update(idxs, batch)

Update batch in the pool at the specified indices.

Parameters:

Name Type Description Default
idxs Array

Integer indices with shape (B,) indicating rows to overwrite.

required
batch PyTree

PyTree matching data leaves sliced to (B, ...).

required

Returns:

Type Description
Pool

A new Pool instance with the updated batch applied at idxs across all leaves.

Source code in src/cax/nn/pool.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
@jax.jit
def update(self, idxs: Array, batch: PyTree) -> "Pool":
	"""Update batch in the pool at the specified indices.

	Args:
		idxs: Integer indices with shape `(B,)` indicating rows to overwrite.
		batch: PyTree matching `data` leaves sliced to `(B, ...)`.

	Returns:
		A new Pool instance with the updated batch applied at `idxs` across all leaves.

	"""
	data = jax.tree.map(
		lambda data_leaf, batch_leaf: data_leaf.at[idxs].set(batch_leaf), self.data, batch
	)
	return self.replace(data=data)

sample(key, *, batch_size)

Sample a batch from the pool.

Parameters:

Name Type Description Default
key Array

JAX PRNG key.

required
batch_size int

Number of rows to sample.

required

Returns:

Type Description
Array

A tuple (idxs, batch) where idxs has shape (batch_size,) and batch is a PyTree

PyTree

with each leaf shaped (batch_size, ...).

Source code in src/cax/nn/pool.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@partial(jax.jit, static_argnames=("batch_size",))
def sample(self, key: Array, *, batch_size: int) -> tuple[Array, PyTree]:
	"""Sample a batch from the pool.

	Args:
		key: JAX PRNG key.
		batch_size: Number of rows to sample.

	Returns:
		A tuple `(idxs, batch)` where `idxs` has shape `(batch_size,)` and `batch` is a PyTree
		with each leaf shaped `(batch_size, ...)`.

	"""
	idxs = jax.random.choice(key, self.size, shape=(batch_size,))
	batch = jax.tree.map(lambda leaf: leaf[idxs], self.data)
	return idxs, batch

cax.nn.buffer

Buffer module.

Buffer

Bases: PyTreeNode

A container for PyTree arrays with circular writes and random sampling.

The buffer stores a PyTree of arrays with a fixed capacity along the leading dimension. New batches are written sequentially with wrap-around semantics. Sampling draws indices from the subset of entries that have been written at least once.

Attributes:

Name Type Description
size int

Maximum number of items stored.

data PyTree

PyTree of arrays with leading dimension size.

is_full Array

Boolean mask of shape (size,) indicating which entries are initialized.

idx Array

Current write pointer (modulo size).

Source code in src/cax/nn/buffer.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class Buffer(struct.PyTreeNode):
	"""A container for PyTree arrays with circular writes and random sampling.

	The buffer stores a PyTree of arrays with a fixed capacity along the leading dimension.
	New batches are written sequentially with wrap-around semantics. Sampling draws indices
	from the subset of entries that have been written at least once.

	Attributes:
		size: Maximum number of items stored.
		data: PyTree of arrays with leading dimension `size`.
		is_full: Boolean mask of shape `(size,)` indicating which entries are initialized.
		idx: Current write pointer (modulo `size`).

	"""

	size: int = struct.field(pytree_node=False)
	data: PyTree
	is_full: Array
	idx: Array

	@classmethod
	def create(cls, size: int, datum: PyTree) -> "Buffer":
		"""Create a new Buffer instance.

		Args:
			size: Size of the buffer.
			datum: PyTree example whose leaf dtypes/shapes are used to allocate storage.

		Returns:
			A new Buffer instance with empty storage of capacity `size`.

		"""
		data = jax.tree.map(jnp.empty_like, datum)
		data = jax.tree.map(
			lambda leaf: jnp.broadcast_to(leaf[None, ...], (size, *leaf.shape)), data
		)
		return cls(
			size=size,
			data=data,
			is_full=jnp.full((size,), False, dtype=jnp.bool),
			idx=jnp.array(0, dtype=jnp.int32),
		)

	@jax.jit
	def add(self, batch: PyTree) -> "Buffer":
		"""Add a batch to the buffer.

		Args:
			batch: PyTree whose leaves have shape `(B, ...)`, where `B` is the batch size.

		Returns:
			A new Buffer instance with the batch written at consecutive indices (with wrap-around).

		"""
		batch_size = jax.tree.leaves(batch)[0].shape[0]
		idxs = self.idx + jnp.arange(batch_size)
		idxs = idxs % self.size

		# Update data
		data = jax.tree.map(lambda data, batch: data.at[idxs].set(batch), self.data, batch)

		# Update is_full and idx
		is_full = self.is_full.at[idxs].set(True)
		new_idx = (self.idx + batch_size) % self.size

		return self.replace(data=data, is_full=is_full, idx=new_idx)

	@partial(jax.jit, static_argnames=("batch_size",))
	def sample(self, key: Array, *, batch_size: int) -> PyTree:
		"""Sample a batch from the buffer.

		Args:
			key: JAX PRNG key.
			batch_size: Number of rows to sample from initialized entries.

		Returns:
			A PyTree with each leaf shaped `(batch_size, ...)`, sampled from filled slots.

		"""
		idxs = jax.random.choice(key, self.size, shape=(batch_size,), p=self.is_full)
		batch: PyTree = jax.tree.map(lambda leaf: leaf[idxs], self.data)
		return batch

create(size, datum) classmethod

Create a new Buffer instance.

Parameters:

Name Type Description Default
size int

Size of the buffer.

required
datum PyTree

PyTree example whose leaf dtypes/shapes are used to allocate storage.

required

Returns:

Type Description
Buffer

A new Buffer instance with empty storage of capacity size.

Source code in src/cax/nn/buffer.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@classmethod
def create(cls, size: int, datum: PyTree) -> "Buffer":
	"""Create a new Buffer instance.

	Args:
		size: Size of the buffer.
		datum: PyTree example whose leaf dtypes/shapes are used to allocate storage.

	Returns:
		A new Buffer instance with empty storage of capacity `size`.

	"""
	data = jax.tree.map(jnp.empty_like, datum)
	data = jax.tree.map(
		lambda leaf: jnp.broadcast_to(leaf[None, ...], (size, *leaf.shape)), data
	)
	return cls(
		size=size,
		data=data,
		is_full=jnp.full((size,), False, dtype=jnp.bool),
		idx=jnp.array(0, dtype=jnp.int32),
	)

add(batch)

Add a batch to the buffer.

Parameters:

Name Type Description Default
batch PyTree

PyTree whose leaves have shape (B, ...), where B is the batch size.

required

Returns:

Type Description
Buffer

A new Buffer instance with the batch written at consecutive indices (with wrap-around).

Source code in src/cax/nn/buffer.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@jax.jit
def add(self, batch: PyTree) -> "Buffer":
	"""Add a batch to the buffer.

	Args:
		batch: PyTree whose leaves have shape `(B, ...)`, where `B` is the batch size.

	Returns:
		A new Buffer instance with the batch written at consecutive indices (with wrap-around).

	"""
	batch_size = jax.tree.leaves(batch)[0].shape[0]
	idxs = self.idx + jnp.arange(batch_size)
	idxs = idxs % self.size

	# Update data
	data = jax.tree.map(lambda data, batch: data.at[idxs].set(batch), self.data, batch)

	# Update is_full and idx
	is_full = self.is_full.at[idxs].set(True)
	new_idx = (self.idx + batch_size) % self.size

	return self.replace(data=data, is_full=is_full, idx=new_idx)

sample(key, *, batch_size)

Sample a batch from the buffer.

Parameters:

Name Type Description Default
key Array

JAX PRNG key.

required
batch_size int

Number of rows to sample from initialized entries.

required

Returns:

Type Description
PyTree

A PyTree with each leaf shaped (batch_size, ...), sampled from filled slots.

Source code in src/cax/nn/buffer.py
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
@partial(jax.jit, static_argnames=("batch_size",))
def sample(self, key: Array, *, batch_size: int) -> PyTree:
	"""Sample a batch from the buffer.

	Args:
		key: JAX PRNG key.
		batch_size: Number of rows to sample from initialized entries.

	Returns:
		A PyTree with each leaf shaped `(batch_size, ...)`, sampled from filled slots.

	"""
	idxs = jax.random.choice(key, self.size, shape=(batch_size,), p=self.is_full)
	batch: PyTree = jax.tree.map(lambda leaf: leaf[idxs], self.data)
	return batch

cax.nn.vae

Variational Autoencoder module.

Encoder

Bases: Module

Encoder module for the VAE.

Applies a stack of strided convolutions followed by linear layers to produce mean and log-variance parameters of a diagonal Gaussian in latent space.

Source code in src/cax/nn/vae.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class Encoder(nnx.Module):
	"""Encoder module for the VAE.

	Applies a stack of strided convolutions followed by linear layers to produce mean and
	log-variance parameters of a diagonal Gaussian in latent space.
	"""

	def __init__(
		self,
		spatial_dims: Sequence[int],
		features: Sequence[int],
		latent_size: int,
		*,
		rngs: nnx.Rngs,
	):
		"""Initialize the Encoder module.

		Args:
			spatial_dims: Spatial dimensions of the input.
			features: Sequence of feature sizes for convolutional layers.
			latent_size: Size of the latent space.
			rngs: rng key.

		"""
		self.features = features
		self.latent_size = latent_size

		self.convs = nnx.List(
			[
				nnx.Conv(
					in_features=in_features,
					out_features=out_features,
					kernel_size=(3, 3),
					strides=(2, 2),
					padding="SAME",
					rngs=rngs,
				)
				for in_features, out_features in zip(self.features[:-1], self.features[1:])
			]
		)

		flattened_size = spatial_dims[0] * spatial_dims[1] * self.features[-1]
		for _ in range(len(self.features) - 1):
			flattened_size //= 4

		self.linear = nnx.Linear(in_features=flattened_size, out_features=flattened_size, rngs=rngs)
		self.mean = nnx.Linear(in_features=flattened_size, out_features=self.latent_size, rngs=rngs)
		self.logvar = nnx.Linear(
			in_features=flattened_size, out_features=self.latent_size, rngs=rngs
		)
		self.rngs = rngs

	def __call__(self, x: Array) -> tuple[Array, Array]:
		"""Forward pass of the encoder.

		Args:
			x: Input tensor with shape `(..., H, W, channel_size)`.

		Returns:
			Tuple `(mean, logvar)` each with shape `(..., latent_size)`.

		"""
		for conv in self.convs:
			x = jax.nn.relu(conv(x))
		x = jnp.reshape(x, x.shape[:-3] + (-1,))
		x = jax.nn.relu(self.linear(x))
		mean = self.mean(x)
		logvar = self.logvar(x)
		return mean, logvar

	def reparameterize(self, mean: Array, logvar: Array) -> Array:
		"""Perform the reparameterization trick.

		Args:
			mean: Mean of the latent distribution.
			logvar: Log variance of the latent distribution.

		Returns:
			Sampled latent vector with shape matching `mean`.

		"""
		return mean + jnp.exp(logvar * 0.5) * jax.random.normal(self.rngs(), shape=mean.shape)

__init__(spatial_dims, features, latent_size, *, rngs)

Initialize the Encoder module.

Parameters:

Name Type Description Default
spatial_dims Sequence[int]

Spatial dimensions of the input.

required
features Sequence[int]

Sequence of feature sizes for convolutional layers.

required
latent_size int

Size of the latent space.

required
rngs Rngs

rng key.

required
Source code in src/cax/nn/vae.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def __init__(
	self,
	spatial_dims: Sequence[int],
	features: Sequence[int],
	latent_size: int,
	*,
	rngs: nnx.Rngs,
):
	"""Initialize the Encoder module.

	Args:
		spatial_dims: Spatial dimensions of the input.
		features: Sequence of feature sizes for convolutional layers.
		latent_size: Size of the latent space.
		rngs: rng key.

	"""
	self.features = features
	self.latent_size = latent_size

	self.convs = nnx.List(
		[
			nnx.Conv(
				in_features=in_features,
				out_features=out_features,
				kernel_size=(3, 3),
				strides=(2, 2),
				padding="SAME",
				rngs=rngs,
			)
			for in_features, out_features in zip(self.features[:-1], self.features[1:])
		]
	)

	flattened_size = spatial_dims[0] * spatial_dims[1] * self.features[-1]
	for _ in range(len(self.features) - 1):
		flattened_size //= 4

	self.linear = nnx.Linear(in_features=flattened_size, out_features=flattened_size, rngs=rngs)
	self.mean = nnx.Linear(in_features=flattened_size, out_features=self.latent_size, rngs=rngs)
	self.logvar = nnx.Linear(
		in_features=flattened_size, out_features=self.latent_size, rngs=rngs
	)
	self.rngs = rngs

__call__(x)

Forward pass of the encoder.

Parameters:

Name Type Description Default
x Array

Input tensor with shape (..., H, W, channel_size).

required

Returns:

Type Description
tuple[Array, Array]

Tuple (mean, logvar) each with shape (..., latent_size).

Source code in src/cax/nn/vae.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __call__(self, x: Array) -> tuple[Array, Array]:
	"""Forward pass of the encoder.

	Args:
		x: Input tensor with shape `(..., H, W, channel_size)`.

	Returns:
		Tuple `(mean, logvar)` each with shape `(..., latent_size)`.

	"""
	for conv in self.convs:
		x = jax.nn.relu(conv(x))
	x = jnp.reshape(x, x.shape[:-3] + (-1,))
	x = jax.nn.relu(self.linear(x))
	mean = self.mean(x)
	logvar = self.logvar(x)
	return mean, logvar

reparameterize(mean, logvar)

Perform the reparameterization trick.

Parameters:

Name Type Description Default
mean Array

Mean of the latent distribution.

required
logvar Array

Log variance of the latent distribution.

required

Returns:

Type Description
Array

Sampled latent vector with shape matching mean.

Source code in src/cax/nn/vae.py
81
82
83
84
85
86
87
88
89
90
91
92
def reparameterize(self, mean: Array, logvar: Array) -> Array:
	"""Perform the reparameterization trick.

	Args:
		mean: Mean of the latent distribution.
		logvar: Log variance of the latent distribution.

	Returns:
		Sampled latent vector with shape matching `mean`.

	"""
	return mean + jnp.exp(logvar * 0.5) * jax.random.normal(self.rngs(), shape=mean.shape)

Decoder

Bases: Module

Decoder module for the VAE.

Maps latent vectors back to image space using a linear layer followed by transposed convolutions.

Source code in src/cax/nn/vae.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class Decoder(nnx.Module):
	"""Decoder module for the VAE.

	Maps latent vectors back to image space using a linear layer followed by transposed
	convolutions.
	"""

	def __init__(
		self, spatial_dims: Sequence[int], features: Sequence[int], latent_size: int, rngs: nnx.Rngs
	):
		"""Initialize the Decoder module.

		Args:
			spatial_dims: Spatial dimensions of the output.
			features: Sequence of feature sizes for transposed convolutional layers.
			latent_size: Size of the latent space.
			rngs: rng key.

		"""
		self.features = features
		self.latent_size = latent_size

		self._spatial_dims = tuple(
			dim // (2 ** (len(self.features) - 1)) for dim in spatial_dims[:2]
		)

		flattened_size = self._spatial_dims[0] * self._spatial_dims[1] * self.features[0]

		self.linear = nnx.Linear(
			in_features=self.latent_size, out_features=flattened_size, rngs=rngs
		)

		self.convs = nnx.List(
			[
				nnx.ConvTranspose(
					in_features=in_features,
					out_features=out_features,
					kernel_size=(3, 3),
					strides=(2, 2),
					padding="SAME",
					rngs=rngs,
				)
				for in_features, out_features in zip(self.features[:-1], self.features[1:])
			]
		)

	def __call__(self, z: Array) -> Array:
		"""Forward pass of the decoder.

		Args:
			z: Latent vector with shape `(..., latent_size)`.

		Returns:
			Reconstructed output tensor with shape `(..., H, W, channel_size)`.

		"""
		x = jax.nn.relu(self.linear(z))
		x = jnp.reshape(x, x.shape[:-1] + self._spatial_dims + (self.features[0],))
		for conv in self.convs[:-1]:
			x = jax.nn.relu(conv(x))
		x = self.convs[-1](x)
		return x

__init__(spatial_dims, features, latent_size, rngs)

Initialize the Decoder module.

Parameters:

Name Type Description Default
spatial_dims Sequence[int]

Spatial dimensions of the output.

required
features Sequence[int]

Sequence of feature sizes for transposed convolutional layers.

required
latent_size int

Size of the latent space.

required
rngs Rngs

rng key.

required
Source code in src/cax/nn/vae.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def __init__(
	self, spatial_dims: Sequence[int], features: Sequence[int], latent_size: int, rngs: nnx.Rngs
):
	"""Initialize the Decoder module.

	Args:
		spatial_dims: Spatial dimensions of the output.
		features: Sequence of feature sizes for transposed convolutional layers.
		latent_size: Size of the latent space.
		rngs: rng key.

	"""
	self.features = features
	self.latent_size = latent_size

	self._spatial_dims = tuple(
		dim // (2 ** (len(self.features) - 1)) for dim in spatial_dims[:2]
	)

	flattened_size = self._spatial_dims[0] * self._spatial_dims[1] * self.features[0]

	self.linear = nnx.Linear(
		in_features=self.latent_size, out_features=flattened_size, rngs=rngs
	)

	self.convs = nnx.List(
		[
			nnx.ConvTranspose(
				in_features=in_features,
				out_features=out_features,
				kernel_size=(3, 3),
				strides=(2, 2),
				padding="SAME",
				rngs=rngs,
			)
			for in_features, out_features in zip(self.features[:-1], self.features[1:])
		]
	)

__call__(z)

Forward pass of the decoder.

Parameters:

Name Type Description Default
z Array

Latent vector with shape (..., latent_size).

required

Returns:

Type Description
Array

Reconstructed output tensor with shape (..., H, W, channel_size).

Source code in src/cax/nn/vae.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def __call__(self, z: Array) -> Array:
	"""Forward pass of the decoder.

	Args:
		z: Latent vector with shape `(..., latent_size)`.

	Returns:
		Reconstructed output tensor with shape `(..., H, W, channel_size)`.

	"""
	x = jax.nn.relu(self.linear(z))
	x = jnp.reshape(x, x.shape[:-1] + self._spatial_dims + (self.features[0],))
	for conv in self.convs[:-1]:
		x = jax.nn.relu(conv(x))
	x = self.convs[-1](x)
	return x

VAE

Bases: Module

Variational Autoencoder module.

Combines an encoder and decoder with a reparameterization sampler for training with the evidence lower bound (ELBO).

Source code in src/cax/nn/vae.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
class VAE(nnx.Module):
	"""Variational Autoencoder module.

	Combines an encoder and decoder with a reparameterization sampler for training with the
	evidence lower bound (ELBO).

	"""

	def __init__(
		self,
		spatial_dims: tuple[int, int],
		features: Sequence[int],
		latent_size: int,
		rngs: nnx.Rngs,
	):
		"""Initialize the VAE module.

		Args:
			spatial_dims: Spatial dimensions of the input/output.
			features: Sequence of feature sizes for encoder and decoder.
			latent_size: Size of the latent space.
			rngs: rng key.

		"""
		super().__init__()
		self.encoder = Encoder(
			spatial_dims=spatial_dims, features=features, latent_size=latent_size, rngs=rngs
		)
		self.decoder = Decoder(
			spatial_dims=spatial_dims, features=features[::-1], latent_size=latent_size, rngs=rngs
		)

	def encode(self, x: Array) -> tuple[Array, Array, Array]:
		"""Encode input to latent space.

		Args:
			x: Input tensor with shape `(..., H, W, channel_size)`.

		Returns:
			Tuple `(z, mean, logvar)` where all have shape `(..., latent_size)`.

		"""
		mean, logvar = self.encoder(x)
		return self.encoder.reparameterize(mean, logvar), mean, logvar

	def decode(self, z: Array) -> Array:
		"""Decode latent vector to output space.

		Args:
			z: Latent vector with shape `(..., latent_size)`.

		Returns:
			Reconstructed output tensor with shape `(..., H, W, channel_size)`.

		"""
		return self.decoder(z)

	def generate(self, z: Array) -> Array:
		"""Generate output from latent vector.

		Args:
			z: Latent vector with shape `(..., latent_size)`.

		Returns:
			Generated output tensor with shape `(..., H, W, channel_size)` in the range `[0, 1]`.

		"""
		return jax.nn.sigmoid(self.decoder(z))

	def __call__(self, x: Array) -> tuple[Array, Array, Array]:
		"""Forward pass of the VAE.

		Args:
			x: Input tensor with shape `(..., H, W, channel_size)`.

		Returns:
			Tuple `(logits, mean, logvar)`

		"""
		z, mean, logvar = self.encode(x)
		logits = self.decode(z)
		return logits, mean, logvar

__init__(spatial_dims, features, latent_size, rngs)

Initialize the VAE module.

Parameters:

Name Type Description Default
spatial_dims tuple[int, int]

Spatial dimensions of the input/output.

required
features Sequence[int]

Sequence of feature sizes for encoder and decoder.

required
latent_size int

Size of the latent space.

required
rngs Rngs

rng key.

required
Source code in src/cax/nn/vae.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def __init__(
	self,
	spatial_dims: tuple[int, int],
	features: Sequence[int],
	latent_size: int,
	rngs: nnx.Rngs,
):
	"""Initialize the VAE module.

	Args:
		spatial_dims: Spatial dimensions of the input/output.
		features: Sequence of feature sizes for encoder and decoder.
		latent_size: Size of the latent space.
		rngs: rng key.

	"""
	super().__init__()
	self.encoder = Encoder(
		spatial_dims=spatial_dims, features=features, latent_size=latent_size, rngs=rngs
	)
	self.decoder = Decoder(
		spatial_dims=spatial_dims, features=features[::-1], latent_size=latent_size, rngs=rngs
	)

encode(x)

Encode input to latent space.

Parameters:

Name Type Description Default
x Array

Input tensor with shape (..., H, W, channel_size).

required

Returns:

Type Description
tuple[Array, Array, Array]

Tuple (z, mean, logvar) where all have shape (..., latent_size).

Source code in src/cax/nn/vae.py
191
192
193
194
195
196
197
198
199
200
201
202
def encode(self, x: Array) -> tuple[Array, Array, Array]:
	"""Encode input to latent space.

	Args:
		x: Input tensor with shape `(..., H, W, channel_size)`.

	Returns:
		Tuple `(z, mean, logvar)` where all have shape `(..., latent_size)`.

	"""
	mean, logvar = self.encoder(x)
	return self.encoder.reparameterize(mean, logvar), mean, logvar

decode(z)

Decode latent vector to output space.

Parameters:

Name Type Description Default
z Array

Latent vector with shape (..., latent_size).

required

Returns:

Type Description
Array

Reconstructed output tensor with shape (..., H, W, channel_size).

Source code in src/cax/nn/vae.py
204
205
206
207
208
209
210
211
212
213
214
def decode(self, z: Array) -> Array:
	"""Decode latent vector to output space.

	Args:
		z: Latent vector with shape `(..., latent_size)`.

	Returns:
		Reconstructed output tensor with shape `(..., H, W, channel_size)`.

	"""
	return self.decoder(z)

generate(z)

Generate output from latent vector.

Parameters:

Name Type Description Default
z Array

Latent vector with shape (..., latent_size).

required

Returns:

Type Description
Array

Generated output tensor with shape (..., H, W, channel_size) in the range [0, 1].

Source code in src/cax/nn/vae.py
216
217
218
219
220
221
222
223
224
225
226
def generate(self, z: Array) -> Array:
	"""Generate output from latent vector.

	Args:
		z: Latent vector with shape `(..., latent_size)`.

	Returns:
		Generated output tensor with shape `(..., H, W, channel_size)` in the range `[0, 1]`.

	"""
	return jax.nn.sigmoid(self.decoder(z))

__call__(x)

Forward pass of the VAE.

Parameters:

Name Type Description Default
x Array

Input tensor with shape (..., H, W, channel_size).

required

Returns:

Type Description
tuple[Array, Array, Array]

Tuple (logits, mean, logvar)

Source code in src/cax/nn/vae.py
228
229
230
231
232
233
234
235
236
237
238
239
240
def __call__(self, x: Array) -> tuple[Array, Array, Array]:
	"""Forward pass of the VAE.

	Args:
		x: Input tensor with shape `(..., H, W, channel_size)`.

	Returns:
		Tuple `(logits, mean, logvar)`

	"""
	z, mean, logvar = self.encode(x)
	logits = self.decode(z)
	return logits, mean, logvar

kl_divergence(mean, logvar)

Compute KL divergence between latent distribution and standard normal.

Parameters:

Name Type Description Default
mean Array

Mean of the latent distribution.

required
logvar Array

Log variance of the latent distribution.

required

Returns:

Type Description
Array

Scalar KL divergence value (sum over last dimension).

Source code in src/cax/nn/vae.py
243
244
245
246
247
248
249
250
251
252
253
254
255
@jax.jit
def kl_divergence(mean: Array, logvar: Array) -> Array:
	"""Compute KL divergence between latent distribution and standard normal.

	Args:
		mean: Mean of the latent distribution.
		logvar: Log variance of the latent distribution.

	Returns:
		Scalar KL divergence value (sum over last dimension).

	"""
	return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))

binary_cross_entropy_with_logits(logits, labels)

Compute binary cross-entropy loss with logits.

Parameters:

Name Type Description Default
logits Array

Predicted logits.

required
labels Array

True labels.

required

Returns:

Type Description
Array

Summed Binary Cross-Entropy loss over the last dimension.

Source code in src/cax/nn/vae.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
@jax.jit
def binary_cross_entropy_with_logits(logits: Array, labels: Array) -> Array:
	"""Compute binary cross-entropy loss with logits.

	Args:
		logits: Predicted logits.
		labels: True labels.

	Returns:
		Summed Binary Cross-Entropy loss over the last dimension.

	"""
	logits = jax.nn.log_sigmoid(logits)
	return -jnp.sum(labels * logits + (1.0 - labels) * jnp.log(-jnp.expm1(logits)))

vae_loss(logits, labels, mean, logvar)

Compute VAE loss.

Parameters:

Name Type Description Default
logits Array

Predicted logits from the decoder.

required
labels Array

Target labels (e.g., normalized images or one-hot vectors).

required
mean Array

Mean of the latent distribution.

required
logvar Array

Log variance of the latent distribution.

required

Returns:

Type Description
Array

Total VAE loss equal to mean(BCE) + mean(KL).

Source code in src/cax/nn/vae.py
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
@jax.jit
def vae_loss(logits: Array, labels: Array, mean: Array, logvar: Array) -> Array:
	"""Compute VAE loss.

	Args:
		logits: Predicted logits from the decoder.
		labels: Target labels (e.g., normalized images or one-hot vectors).
		mean: Mean of the latent distribution.
		logvar: Log variance of the latent distribution.

	Returns:
		Total VAE loss equal to `mean(BCE) + mean(KL)`.

	"""
	bce_loss = jnp.mean(binary_cross_entropy_with_logits(logits, labels))
	kld_loss = jnp.mean(kl_divergence(mean, logvar))
	return bce_loss + kld_loss