Texture Neural Cellular Automata
¶
Installation¶
You will need Python 3.12 or later, and a working JAX installation. For example, you can install JAX with:
In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"
Then, install CAX from PyPi:
In [ ]:
Copied!
%pip install -U "cax[examples]"
%pip install -U "cax[examples]"
Import¶
In [ ]:
Copied!
import io
import random
from functools import partial
import jax
import jax.numpy as jnp
import mediapy
import numpy as np
import optax
import PIL.Image
import requests
import torch
import torchvision.models as models
from flax import nnx
from jax import Array
from tqdm.notebook import tqdm
from cax.core import ComplexSystem
from cax.core.perceive import ConvPerceive, grad2_kernel, grad_kernel, identity_kernel
from cax.core.update import ResidualUpdate
from cax.nn.pool import Pool
from cax.utils import clip_and_uint8
import io
import random
from functools import partial
import jax
import jax.numpy as jnp
import mediapy
import numpy as np
import optax
import PIL.Image
import requests
import torch
import torchvision.models as models
from flax import nnx
from jax import Array
from tqdm.notebook import tqdm
from cax.core import ComplexSystem
from cax.core.perceive import ConvPerceive, grad2_kernel, grad_kernel, identity_kernel
from cax.core.update import ResidualUpdate
from cax.nn.pool import Pool
from cax.utils import clip_and_uint8
Configuration¶
In [ ]:
Copied!
seed = 0
channel_size = 12
num_kernels = 4
hidden_size = 96
cell_dropout_rate = 0.5
step_choices = [32, 48, 64, 80, 96]
num_steps = 256
pool_size = 256
batch_size = 4
learning_rate = 1e-3
spatial_dims = (128, 128)
overflow_weight = 1.0
key = jax.random.key(seed)
rngs = nnx.Rngs(seed)
seed = 0
channel_size = 12
num_kernels = 4
hidden_size = 96
cell_dropout_rate = 0.5
step_choices = [32, 48, 64, 80, 96]
num_steps = 256
pool_size = 256
batch_size = 4
learning_rate = 1e-3
spatial_dims = (128, 128)
overflow_weight = 1.0
key = jax.random.key(seed)
rngs = nnx.Rngs(seed)
Target texture¶
In [ ]:
Copied!
url = "https://www.robots.ox.ac.uk/~vgg/data/dtd/thumbs/dotted/dotted_0201.jpg"
response = requests.get(url)
target_pil = PIL.Image.open(io.BytesIO(response.content)).convert("RGB")
target_pil = target_pil.resize(spatial_dims, resample=PIL.Image.Resampling.LANCZOS)
target_image = jnp.array(target_pil, dtype=jnp.float32) / 255.0
mediapy.show_image(target_image)
url = "https://www.robots.ox.ac.uk/~vgg/data/dtd/thumbs/dotted/dotted_0201.jpg"
response = requests.get(url)
target_pil = PIL.Image.open(io.BytesIO(response.content)).convert("RGB")
target_pil = target_pil.resize(spatial_dims, resample=PIL.Image.Resampling.LANCZOS)
target_image = jnp.array(target_pil, dtype=jnp.float32) / 255.0
mediapy.show_image(target_image)
VGG16 style model¶
In [ ]:
Copied!
STYLE_LAYERS = [1, 6, 11, 18, 25]
MAX_LAYER = max(STYLE_LAYERS) + 1
class VGGFeatureExtractor(nnx.Module):
"""VGG16 feature extractor that returns activations at specified style layers."""
def __init__(self, torch_features: torch.nn.Sequential, *, rngs: nnx.Rngs):
"""Initialize from a torchvision VGG features sequential module.
Args:
torch_features: The `.features` attribute of a torchvision VGG16 model.
rngs: Flax NNX random number generators.
"""
convs = []
layer_types = []
for i in range(MAX_LAYER):
layer = torch_features[i]
if isinstance(layer, torch.nn.Conv2d):
conv = nnx.Conv(
in_features=layer.in_channels,
out_features=layer.out_channels,
kernel_size=(layer.kernel_size[0], layer.kernel_size[1]),
strides=(layer.stride[0], layer.stride[1]),
padding=(
(layer.padding[0], layer.padding[0]),
(layer.padding[1], layer.padding[1]),
),
use_bias=layer.bias is not None,
rngs=rngs,
)
weight_np = layer.weight.detach().cpu().numpy()
conv.kernel.value = jnp.array(np.transpose(weight_np, (2, 3, 1, 0)))
if layer.bias is not None:
conv.bias.value = jnp.array(layer.bias.detach().cpu().numpy())
convs.append(conv)
layer_types.append("conv")
elif isinstance(layer, torch.nn.ReLU):
layer_types.append("relu")
elif isinstance(layer, torch.nn.MaxPool2d):
layer_types.append("maxpool")
else:
raise ValueError(f"Unexpected layer type: {type(layer)}")
self.convs = nnx.List(convs)
self.layer_types = layer_types
def __call__(self, x: jax.Array) -> list[jax.Array]:
"""Extract features at style layers.
Args:
x: Input images with shape (batch, height, width, 3) in [0, 1] range.
Returns:
List of feature maps at each style layer index.
"""
mean = jnp.array([0.485, 0.456, 0.406])
std = jnp.array([0.229, 0.224, 0.225])
x = (x - mean) / std
features = []
conv_idx = 0
for i, layer_type in enumerate(self.layer_types):
if layer_type == "conv":
x = self.convs[conv_idx](x)
conv_idx += 1
elif layer_type == "relu":
x = nnx.relu(x)
elif layer_type == "maxpool":
x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))
if i in STYLE_LAYERS:
features.append(x)
return features
vgg16_torch = models.vgg16(weights="IMAGENET1K_V1").features.eval()
vgg = VGGFeatureExtractor(vgg16_torch, rngs=rngs)
del vgg16_torch
STYLE_LAYERS = [1, 6, 11, 18, 25]
MAX_LAYER = max(STYLE_LAYERS) + 1
class VGGFeatureExtractor(nnx.Module):
"""VGG16 feature extractor that returns activations at specified style layers."""
def __init__(self, torch_features: torch.nn.Sequential, *, rngs: nnx.Rngs):
"""Initialize from a torchvision VGG features sequential module.
Args:
torch_features: The `.features` attribute of a torchvision VGG16 model.
rngs: Flax NNX random number generators.
"""
convs = []
layer_types = []
for i in range(MAX_LAYER):
layer = torch_features[i]
if isinstance(layer, torch.nn.Conv2d):
conv = nnx.Conv(
in_features=layer.in_channels,
out_features=layer.out_channels,
kernel_size=(layer.kernel_size[0], layer.kernel_size[1]),
strides=(layer.stride[0], layer.stride[1]),
padding=(
(layer.padding[0], layer.padding[0]),
(layer.padding[1], layer.padding[1]),
),
use_bias=layer.bias is not None,
rngs=rngs,
)
weight_np = layer.weight.detach().cpu().numpy()
conv.kernel.value = jnp.array(np.transpose(weight_np, (2, 3, 1, 0)))
if layer.bias is not None:
conv.bias.value = jnp.array(layer.bias.detach().cpu().numpy())
convs.append(conv)
layer_types.append("conv")
elif isinstance(layer, torch.nn.ReLU):
layer_types.append("relu")
elif isinstance(layer, torch.nn.MaxPool2d):
layer_types.append("maxpool")
else:
raise ValueError(f"Unexpected layer type: {type(layer)}")
self.convs = nnx.List(convs)
self.layer_types = layer_types
def __call__(self, x: jax.Array) -> list[jax.Array]:
"""Extract features at style layers.
Args:
x: Input images with shape (batch, height, width, 3) in [0, 1] range.
Returns:
List of feature maps at each style layer index.
"""
mean = jnp.array([0.485, 0.456, 0.406])
std = jnp.array([0.229, 0.224, 0.225])
x = (x - mean) / std
features = []
conv_idx = 0
for i, layer_type in enumerate(self.layer_types):
if layer_type == "conv":
x = self.convs[conv_idx](x)
conv_idx += 1
elif layer_type == "relu":
x = nnx.relu(x)
elif layer_type == "maxpool":
x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))
if i in STYLE_LAYERS:
features.append(x)
return features
vgg16_torch = models.vgg16(weights="IMAGENET1K_V1").features.eval()
vgg = VGGFeatureExtractor(vgg16_torch, rngs=rngs)
del vgg16_torch
Style loss¶
In [ ]:
Copied!
def gram_matrix(features: jax.Array) -> jax.Array:
"""Compute Gram matrix from feature maps."""
b, h, w, c = features.shape
f = features.reshape(b, h * w, c)
return jnp.einsum("bic,bjc->bcj", f, f) / (h * w)
def style_loss(source_features: list[jax.Array], target_features: list[jax.Array]) -> jax.Array:
"""Compute style loss between source and target feature lists."""
loss = jnp.array(0.0)
for sf, tf in zip(source_features, target_features, strict=True):
loss += jnp.mean(jnp.square(gram_matrix(sf) - gram_matrix(tf)))
return loss
target_features = vgg(target_image[None])
def gram_matrix(features: jax.Array) -> jax.Array:
"""Compute Gram matrix from feature maps."""
b, h, w, c = features.shape
f = features.reshape(b, h * w, c)
return jnp.einsum("bic,bjc->bcj", f, f) / (h * w)
def style_loss(source_features: list[jax.Array], target_features: list[jax.Array]) -> jax.Array:
"""Compute style loss between source and target feature lists."""
loss = jnp.array(0.0)
for sf, tf in zip(source_features, target_features, strict=True):
loss += jnp.mean(jnp.square(gram_matrix(sf) - gram_matrix(tf)))
return loss
target_features = vgg(target_image[None])
Instantiate system¶
In [ ]:
Copied!
class TextureNCA(ComplexSystem):
"""Texture Neural Cellular Automata class."""
def __init__(self, *, rngs: nnx.Rngs):
"""Initialize Texture NCA.
Args:
rngs: rng key.
"""
self.perceive = ConvPerceive(
channel_size=channel_size,
perception_size=num_kernels * channel_size,
feature_group_count=channel_size,
padding="CIRCULAR",
rngs=rngs,
)
self.update = ResidualUpdate(
num_spatial_dims=2,
channel_size=channel_size,
perception_size=num_kernels * channel_size,
hidden_layer_sizes=(hidden_size,),
cell_dropout_rate=cell_dropout_rate,
zeros_init=True,
rngs=rngs,
)
# Initialize kernel: identity + sobel x/y + laplacian
kernel = jnp.concatenate(
[identity_kernel(num_dims=2), grad_kernel(num_dims=2), grad2_kernel(num_dims=2)],
axis=-1,
)
kernel = jnp.expand_dims(jnp.concatenate([kernel] * channel_size, axis=-1), axis=-2)
self.perceive.conv.kernel[...] = kernel
def _step(self, state: Array, input: Array | None = None, *, sow: bool = False) -> Array:
perception = self.perceive(state)
next_state = self.update(state, perception, input)
if sow:
self.sow(nnx.Intermediate, "state", next_state)
return next_state
@nnx.jit
def render(self, state):
"""Render state to RGB."""
rgb = state[..., :3] + 0.5
return clip_and_uint8(rgb)
class TextureNCA(ComplexSystem):
"""Texture Neural Cellular Automata class."""
def __init__(self, *, rngs: nnx.Rngs):
"""Initialize Texture NCA.
Args:
rngs: rng key.
"""
self.perceive = ConvPerceive(
channel_size=channel_size,
perception_size=num_kernels * channel_size,
feature_group_count=channel_size,
padding="CIRCULAR",
rngs=rngs,
)
self.update = ResidualUpdate(
num_spatial_dims=2,
channel_size=channel_size,
perception_size=num_kernels * channel_size,
hidden_layer_sizes=(hidden_size,),
cell_dropout_rate=cell_dropout_rate,
zeros_init=True,
rngs=rngs,
)
# Initialize kernel: identity + sobel x/y + laplacian
kernel = jnp.concatenate(
[identity_kernel(num_dims=2), grad_kernel(num_dims=2), grad2_kernel(num_dims=2)],
axis=-1,
)
kernel = jnp.expand_dims(jnp.concatenate([kernel] * channel_size, axis=-1), axis=-2)
self.perceive.conv.kernel[...] = kernel
def _step(self, state: Array, input: Array | None = None, *, sow: bool = False) -> Array:
perception = self.perceive(state)
next_state = self.update(state, perception, input)
if sow:
self.sow(nnx.Intermediate, "state", next_state)
return next_state
@nnx.jit
def render(self, state):
"""Render state to RGB."""
rgb = state[..., :3] + 0.5
return clip_and_uint8(rgb)
In [ ]:
Copied!
cs = TextureNCA(rngs=rngs)
cs = TextureNCA(rngs=rngs)
In [ ]:
Copied!
params = nnx.state(cs, nnx.Param)
print("Number of params:", sum(x.size for x in jax.tree.leaves(params)))
params = nnx.state(cs, nnx.Param)
print("Number of params:", sum(x.size for x in jax.tree.leaves(params)))
Sample initial state¶
In [ ]:
Copied!
def sample_state():
"""Sample an initial state (all zeros)."""
return jnp.zeros(spatial_dims + (channel_size,))
def sample_state():
"""Sample an initial state (all zeros)."""
return jnp.zeros(spatial_dims + (channel_size,))
Train¶
Pool¶
In [ ]:
Copied!
state = jax.vmap(lambda _: sample_state())(jnp.zeros(pool_size))
pool = Pool.create({"state": state})
state = jax.vmap(lambda _: sample_state())(jnp.zeros(pool_size))
pool = Pool.create({"state": state})
Optimizer¶
In [ ]:
Copied!
lr_sched = optax.piecewise_constant_schedule(
init_value=learning_rate, boundaries_and_scales={1000: 0.3, 2000: 0.3}
)
optimizer = nnx.Optimizer(cs, optax.adam(learning_rate=lr_sched), wrt=nnx.Param)
lr_sched = optax.piecewise_constant_schedule(
init_value=learning_rate, boundaries_and_scales={1000: 0.3, 2000: 0.3}
)
optimizer = nnx.Optimizer(cs, optax.adam(learning_rate=lr_sched), wrt=nnx.Param)
Loss¶
In [ ]:
Copied!
def to_rgb(state: jax.Array) -> jax.Array:
"""Convert NCA state to RGB in [0, 1]."""
return state[..., :3] + 0.5
def overflow_loss(state: jax.Array) -> jax.Array:
"""Penalize values outside [-1, 1]."""
return jnp.abs(state - jnp.clip(state, -1.0, 1.0)).sum()
def loss_fn(cs, state, num_steps):
"""Loss function."""
state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None})
state = nnx.split_rngs(splits=batch_size)(
nnx.vmap(
lambda cs, state: cs(state, num_steps=num_steps),
in_axes=(state_axes, 0),
)
)(cs, state)
# Style loss on RGB output
rgb = to_rgb(state)
source_features = vgg(rgb)
loss = style_loss(source_features, target_features)
# Overflow regularization
loss += overflow_weight * overflow_loss(state)
return loss, state
def to_rgb(state: jax.Array) -> jax.Array:
"""Convert NCA state to RGB in [0, 1]."""
return state[..., :3] + 0.5
def overflow_loss(state: jax.Array) -> jax.Array:
"""Penalize values outside [-1, 1]."""
return jnp.abs(state - jnp.clip(state, -1.0, 1.0)).sum()
def loss_fn(cs, state, num_steps):
"""Loss function."""
state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None})
state = nnx.split_rngs(splits=batch_size)(
nnx.vmap(
lambda cs, state: cs(state, num_steps=num_steps),
in_axes=(state_axes, 0),
)
)(cs, state)
# Style loss on RGB output
rgb = to_rgb(state)
source_features = vgg(rgb)
loss = style_loss(source_features, target_features)
# Overflow regularization
loss += overflow_weight * overflow_loss(state)
return loss, state
Train step¶
In [ ]:
Copied!
@partial(nnx.jit, static_argnames=("num_steps",))
def train_step(cs, optimizer, pool, key, *, num_steps):
"""Train step."""
# Sample from pool
pool_idx, batch = pool.sample(key, batch_size=batch_size)
current_state = batch["state"]
# Reset one sample to initial state
new_state = sample_state()
current_state = current_state.at[0].set(new_state)
(loss, current_state), grad = nnx.value_and_grad(
loss_fn, has_aux=True, argnums=nnx.DiffState(0, nnx.Param)
)(cs, current_state, num_steps)
# Normalize gradients
grad = jax.tree.map(lambda g: g / (jnp.linalg.norm(g) + 1e-8), grad)
optimizer.update(cs, grad)
pool = pool.update(pool_idx, {"state": current_state})
return loss, pool
@partial(nnx.jit, static_argnames=("num_steps",))
def train_step(cs, optimizer, pool, key, *, num_steps):
"""Train step."""
# Sample from pool
pool_idx, batch = pool.sample(key, batch_size=batch_size)
current_state = batch["state"]
# Reset one sample to initial state
new_state = sample_state()
current_state = current_state.at[0].set(new_state)
(loss, current_state), grad = nnx.value_and_grad(
loss_fn, has_aux=True, argnums=nnx.DiffState(0, nnx.Param)
)(cs, current_state, num_steps)
# Normalize gradients
grad = jax.tree.map(lambda g: g / (jnp.linalg.norm(g) + 1e-8), grad)
optimizer.update(cs, grad)
pool = pool.update(pool_idx, {"state": current_state})
return loss, pool
Main loop¶
In [ ]:
Copied!
num_train_steps = 5_000
print_interval = 100
pbar = tqdm(range(num_train_steps), desc="Training", unit="train_step")
losses = []
for i in pbar:
key, subkey = jax.random.split(key)
loss, pool = train_step(cs, optimizer, pool, subkey, num_steps=random.choice(step_choices))
losses.append(loss)
if i % print_interval == 0 or i == num_train_steps - 1:
avg_loss = sum(losses[-print_interval:]) / len(losses[-print_interval:])
pbar.set_postfix({"Average Loss": f"{avg_loss:.3e}"})
num_train_steps = 5_000
print_interval = 100
pbar = tqdm(range(num_train_steps), desc="Training", unit="train_step")
losses = []
for i in pbar:
key, subkey = jax.random.split(key)
loss, pool = train_step(cs, optimizer, pool, subkey, num_steps=random.choice(step_choices))
losses.append(loss)
if i % print_interval == 0 or i == num_train_steps - 1:
avg_loss = sum(losses[-print_interval:]) / len(losses[-print_interval:])
pbar.set_postfix({"Average Loss": f"{avg_loss:.3e}"})
Run¶
In [ ]:
Copied!
num_examples = 4
state_init = jax.vmap(lambda _: sample_state())(jnp.zeros(num_examples))
state_axes = nnx.StateAxes({nnx.RngState: 0, nnx.Intermediate: 0, ...: None})
state_final = nnx.split_rngs(splits=num_examples)(
nnx.vmap(
lambda cs, state_init: cs(state_init, num_steps=num_steps, sow=True),
in_axes=(state_axes, 0),
)
)(cs, state_init)
num_examples = 4
state_init = jax.vmap(lambda _: sample_state())(jnp.zeros(num_examples))
state_axes = nnx.StateAxes({nnx.RngState: 0, nnx.Intermediate: 0, ...: None})
state_final = nnx.split_rngs(splits=num_examples)(
nnx.vmap(
lambda cs, state_init: cs(state_init, num_steps=num_steps, sow=True),
in_axes=(state_axes, 0),
)
)(cs, state_init)
Visualize¶
In [ ]:
Copied!
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state[0]
intermediates = nnx.pop(cs, nnx.Intermediate)
states = intermediates.state[0]
In [ ]:
Copied!
frames_final = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs, state_final)
mediapy.show_images(frames_final, width=128, height=128)
frames_final = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs, state_final)
mediapy.show_images(frames_final, width=128, height=128)
In [ ]:
Copied!
states = jnp.concatenate([state_init[:, None], states], axis=1)
frames = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs, states)
mediapy.show_videos(frames, width=128, height=128)
states = jnp.concatenate([state_init[:, None], states], axis=1)
frames = nnx.vmap(
lambda cs, state: cs.render(state),
in_axes=(None, 0),
)(cs, states)
mediapy.show_videos(frames, width=128, height=128)