Quality-Diversity
¶
In this notebook, we show how to do black-box optimization on bbobax using Quality-Diversity.
Install¶
You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:
In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"
Then, install bbobax from PyPi:
In [ ]:
Copied!
%pip install -U "bbobax[notebooks]"
%pip install -U "bbobax[notebooks]"
Import¶
In [ ]:
Copied!
from functools import partial
from typing import Any
import flax.struct
import jax
import jax.numpy as jnp
from numpy.random import RandomState
from sklearn.cluster import KMeans
# Types
type Genotype = Any
type Fitness = jax.Array
type Descriptor = jax.Array
type RNGKey = jax.Array
type Centroid = jax.Array
from functools import partial
from typing import Any
import flax.struct
import jax
import jax.numpy as jnp
from numpy.random import RandomState
from sklearn.cluster import KMeans
# Types
type Genotype = Any
type Fitness = jax.Array
type Descriptor = jax.Array
type RNGKey = jax.Array
type Centroid = jax.Array
Metrics¶
In [ ]:
Copied!
def novelty_and_dominated_novelty(
fitness, descriptor, novelty_k=3, dominated_novelty_k=3
):
"""Compute novelty and dominated novelty."""
valid = fitness != -jnp.inf
# Neighbors
neighbor = valid[:, None] & valid[None, :]
neighbor = jnp.fill_diagonal(neighbor, False, inplace=False)
# Fitter
fitter = fitness[:, None] <= fitness[None, :]
fitter = jnp.where(neighbor, fitter, False)
# Distance to neighbors
distance = jnp.linalg.norm(descriptor[:, None, :] - descriptor[None, :, :], axis=-1)
distance = jnp.where(neighbor, distance, jnp.inf)
# Distance to fitter neighbors
distance_fitter = jnp.where(fitter, distance, jnp.inf)
# Novelty - distance to k-nearest neighbors
values, indices = jax.vmap(partial(jax.lax.top_k, k=novelty_k))(-distance)
novelty = jnp.mean(
-values, axis=-1, where=jnp.take_along_axis(neighbor, indices, axis=-1)
)
# Dominated Novelty - distance to k-fitter-nearest neighbors
values, indices = jax.vmap(partial(jax.lax.top_k, k=dominated_novelty_k))(
-distance_fitter
)
dominated_novelty = jnp.mean(
-values, axis=-1, where=jnp.take_along_axis(fitter, indices, axis=-1)
) # only max fitness individual should be nan
return novelty, dominated_novelty
def metrics_fn(
key: RNGKey,
population: Genotype,
fitness: Fitness,
descriptor: Descriptor,
state: "QDState",
params: "QDParams",
) -> dict:
"""Compute QD metrics."""
k = 3
novelty, dominated_novelty = novelty_and_dominated_novelty(
fitness,
descriptor,
novelty_k=k,
dominated_novelty_k=k,
)
dominated_novelty = jnp.where(
jnp.isposinf(dominated_novelty), jnp.nan, dominated_novelty
)
return {
"fitness": fitness,
"descriptor": descriptor,
"novelty": novelty,
"dominated_novelty": dominated_novelty,
}
@jax.jit
def metrics_agg_fn(metrics: dict) -> dict:
"""Aggregate QD metrics."""
valid = metrics["fitness"] != -jnp.inf
descriptor_mean = jnp.mean(metrics["descriptor"], axis=-2, where=valid[..., None])
distance_to_mean = jnp.linalg.norm(
metrics["descriptor"] - descriptor_mean[..., None, :], axis=-1
)
descriptor_std = jnp.std(distance_to_mean, axis=-1, where=valid)
return {
"population_size": jnp.sum(valid, axis=-1),
"fitness_max": jnp.max(
metrics["fitness"], axis=-1, initial=-jnp.inf, where=valid
),
"fitness_mean": jnp.mean(metrics["fitness"], axis=-1, where=valid),
"novelty_mean": jnp.mean(metrics["novelty"], axis=-1, where=valid),
"dominated_novelty_mean": jnp.nanmean(
metrics["dominated_novelty"], axis=-1, where=valid
),
"descriptor_std": descriptor_std,
}
def novelty_and_dominated_novelty(
fitness, descriptor, novelty_k=3, dominated_novelty_k=3
):
"""Compute novelty and dominated novelty."""
valid = fitness != -jnp.inf
# Neighbors
neighbor = valid[:, None] & valid[None, :]
neighbor = jnp.fill_diagonal(neighbor, False, inplace=False)
# Fitter
fitter = fitness[:, None] <= fitness[None, :]
fitter = jnp.where(neighbor, fitter, False)
# Distance to neighbors
distance = jnp.linalg.norm(descriptor[:, None, :] - descriptor[None, :, :], axis=-1)
distance = jnp.where(neighbor, distance, jnp.inf)
# Distance to fitter neighbors
distance_fitter = jnp.where(fitter, distance, jnp.inf)
# Novelty - distance to k-nearest neighbors
values, indices = jax.vmap(partial(jax.lax.top_k, k=novelty_k))(-distance)
novelty = jnp.mean(
-values, axis=-1, where=jnp.take_along_axis(neighbor, indices, axis=-1)
)
# Dominated Novelty - distance to k-fitter-nearest neighbors
values, indices = jax.vmap(partial(jax.lax.top_k, k=dominated_novelty_k))(
-distance_fitter
)
dominated_novelty = jnp.mean(
-values, axis=-1, where=jnp.take_along_axis(fitter, indices, axis=-1)
) # only max fitness individual should be nan
return novelty, dominated_novelty
def metrics_fn(
key: RNGKey,
population: Genotype,
fitness: Fitness,
descriptor: Descriptor,
state: "QDState",
params: "QDParams",
) -> dict:
"""Compute QD metrics."""
k = 3
novelty, dominated_novelty = novelty_and_dominated_novelty(
fitness,
descriptor,
novelty_k=k,
dominated_novelty_k=k,
)
dominated_novelty = jnp.where(
jnp.isposinf(dominated_novelty), jnp.nan, dominated_novelty
)
return {
"fitness": fitness,
"descriptor": descriptor,
"novelty": novelty,
"dominated_novelty": dominated_novelty,
}
@jax.jit
def metrics_agg_fn(metrics: dict) -> dict:
"""Aggregate QD metrics."""
valid = metrics["fitness"] != -jnp.inf
descriptor_mean = jnp.mean(metrics["descriptor"], axis=-2, where=valid[..., None])
distance_to_mean = jnp.linalg.norm(
metrics["descriptor"] - descriptor_mean[..., None, :], axis=-1
)
descriptor_std = jnp.std(distance_to_mean, axis=-1, where=valid)
return {
"population_size": jnp.sum(valid, axis=-1),
"fitness_max": jnp.max(
metrics["fitness"], axis=-1, initial=-jnp.inf, where=valid
),
"fitness_mean": jnp.mean(metrics["fitness"], axis=-1, where=valid),
"novelty_mean": jnp.mean(metrics["novelty"], axis=-1, where=valid),
"dominated_novelty_mean": jnp.nanmean(
metrics["dominated_novelty"], axis=-1, where=valid
),
"descriptor_std": descriptor_std,
}
Algorithms¶
In [ ]:
Copied!
@flax.struct.dataclass
class QDState:
"""State for QD algorithms."""
population: Genotype
fitness: Fitness
descriptor: Descriptor
generation_counter: int
@flax.struct.dataclass
class QDParams:
"""Parameters for QD algorithms."""
mutation_sigma: float = 0.1
@flax.struct.dataclass
class QDState:
"""State for QD algorithms."""
population: Genotype
fitness: Fitness
descriptor: Descriptor
generation_counter: int
@flax.struct.dataclass
class QDParams:
"""Parameters for QD algorithms."""
mutation_sigma: float = 0.1
In [ ]:
Copied!
def gaussian_mutation(key: RNGKey, genotype: Genotype, sigma: float) -> Genotype:
"""Apply Gaussian mutation to the genotype."""
return jax.tree.map(lambda x: x + sigma * jax.random.normal(key, x.shape), genotype)
def gaussian_mutation(key: RNGKey, genotype: Genotype, sigma: float) -> Genotype:
"""Apply Gaussian mutation to the genotype."""
return jax.tree.map(lambda x: x + sigma * jax.random.normal(key, x.shape), genotype)
In [ ]:
Copied!
class QDAlgorithm:
"""Base class for Quality-Diversity algorithms."""
def __init__(
self,
population_size: int,
solution: Genotype,
fitness_shaping_fn,
descriptor_size: int = 2,
):
"""Initialize the QD Algorithm."""
self.population_size = population_size
self.solution = solution
self.fitness_shaping_fn = fitness_shaping_fn
self.metrics_fn = metrics_fn
self.descriptor_size = descriptor_size
@partial(jax.jit, static_argnames=("self",))
def init(
self,
key: jax.Array,
population: Genotype,
fitness: Fitness,
descriptor: Descriptor,
params: QDParams,
) -> QDState:
"""Initialize evolutionary algorithm."""
state = self._init(key, params)
state, _ = self.tell(key, population, fitness, descriptor, state, params)
return state
@partial(jax.jit, static_argnames=("self",))
def ask(
self,
key: jax.Array,
state: QDState,
params: QDParams,
) -> tuple[Genotype, QDState]:
"""Ask evolutionary algorithm for new candidate solutions."""
return self._ask(key, state, params)
@property
def default_params(self) -> QDParams:
"""Return default parameters for the algorithm."""
return QDParams()
def tell(
self,
key: RNGKey,
population: Genotype,
fitness: Fitness,
descriptor: Descriptor,
state: QDState,
params: QDParams,
) -> tuple[QDState, dict]:
"""Tell Fitness and Descriptors."""
# Concatenate
all_genotype = jax.tree.map(
lambda x, y: jnp.concatenate([x, y], axis=0),
state.population,
population,
)
all_fitness = jnp.concatenate([state.fitness, fitness], axis=0)
all_descriptor = jnp.concatenate([state.descriptor, descriptor], axis=0)
# Compute competition fitness
key_shaping, key_metrics = jax.random.split(key)
shaped_fitness = self.fitness_shaping_fn(
key_shaping,
all_fitness,
all_descriptor,
state,
params,
)
# Sort by competition fitness
indices = jnp.argsort(shaped_fitness, descending=True)
indices = indices[: self.population_size]
# Keep best
new_genotype = jax.tree.map(lambda x: x[indices], all_genotype)
new_fitness = all_fitness[indices]
new_descriptor = all_descriptor[indices]
# Mark invalid individuals as -inf
is_valid = shaped_fitness[indices] != -jnp.inf
new_fitness = jnp.where(is_valid, new_fitness, -jnp.inf)
new_descriptor = jnp.where(is_valid[:, None], new_descriptor, jnp.nan)
state = state.replace(
population=new_genotype,
fitness=new_fitness,
descriptor=new_descriptor,
generation_counter=state.generation_counter + 1,
)
# Metrics
metrics = self.metrics_fn(
key_metrics,
state.population,
state.fitness,
state.descriptor,
state,
params,
)
metrics["generation"] = state.generation_counter
return state, metrics
def _init(self, key: RNGKey, params: QDParams) -> QDState:
genotype = jax.tree.map(
lambda x: jnp.full((self.population_size,) + x.shape, fill_value=jnp.nan),
self.solution,
)
fitness = jnp.full((self.population_size,), fill_value=-jnp.inf)
descriptor = jnp.full(
(self.population_size, self.descriptor_size), fill_value=jnp.nan
)
state = QDState(
population=genotype,
fitness=fitness,
descriptor=descriptor,
generation_counter=0,
)
return state
def _ask(
self, key: RNGKey, state: QDState, params: QDParams
) -> tuple[Genotype, QDState]:
"""Ask for new candidate solutions."""
# Simple Selection -> Mutation
valid = state.fitness != -jnp.inf
p = valid / jnp.sum(valid)
p = jnp.where(jnp.isnan(p), 1.0 / self.population_size, p)
population = jax.tree.map(
lambda x: jax.random.choice(key, x, shape=(self.population_size,), p=p),
state.population,
)
population = gaussian_mutation(key, population, params.mutation_sigma)
return population, state
class QDAlgorithm:
"""Base class for Quality-Diversity algorithms."""
def __init__(
self,
population_size: int,
solution: Genotype,
fitness_shaping_fn,
descriptor_size: int = 2,
):
"""Initialize the QD Algorithm."""
self.population_size = population_size
self.solution = solution
self.fitness_shaping_fn = fitness_shaping_fn
self.metrics_fn = metrics_fn
self.descriptor_size = descriptor_size
@partial(jax.jit, static_argnames=("self",))
def init(
self,
key: jax.Array,
population: Genotype,
fitness: Fitness,
descriptor: Descriptor,
params: QDParams,
) -> QDState:
"""Initialize evolutionary algorithm."""
state = self._init(key, params)
state, _ = self.tell(key, population, fitness, descriptor, state, params)
return state
@partial(jax.jit, static_argnames=("self",))
def ask(
self,
key: jax.Array,
state: QDState,
params: QDParams,
) -> tuple[Genotype, QDState]:
"""Ask evolutionary algorithm for new candidate solutions."""
return self._ask(key, state, params)
@property
def default_params(self) -> QDParams:
"""Return default parameters for the algorithm."""
return QDParams()
def tell(
self,
key: RNGKey,
population: Genotype,
fitness: Fitness,
descriptor: Descriptor,
state: QDState,
params: QDParams,
) -> tuple[QDState, dict]:
"""Tell Fitness and Descriptors."""
# Concatenate
all_genotype = jax.tree.map(
lambda x, y: jnp.concatenate([x, y], axis=0),
state.population,
population,
)
all_fitness = jnp.concatenate([state.fitness, fitness], axis=0)
all_descriptor = jnp.concatenate([state.descriptor, descriptor], axis=0)
# Compute competition fitness
key_shaping, key_metrics = jax.random.split(key)
shaped_fitness = self.fitness_shaping_fn(
key_shaping,
all_fitness,
all_descriptor,
state,
params,
)
# Sort by competition fitness
indices = jnp.argsort(shaped_fitness, descending=True)
indices = indices[: self.population_size]
# Keep best
new_genotype = jax.tree.map(lambda x: x[indices], all_genotype)
new_fitness = all_fitness[indices]
new_descriptor = all_descriptor[indices]
# Mark invalid individuals as -inf
is_valid = shaped_fitness[indices] != -jnp.inf
new_fitness = jnp.where(is_valid, new_fitness, -jnp.inf)
new_descriptor = jnp.where(is_valid[:, None], new_descriptor, jnp.nan)
state = state.replace(
population=new_genotype,
fitness=new_fitness,
descriptor=new_descriptor,
generation_counter=state.generation_counter + 1,
)
# Metrics
metrics = self.metrics_fn(
key_metrics,
state.population,
state.fitness,
state.descriptor,
state,
params,
)
metrics["generation"] = state.generation_counter
return state, metrics
def _init(self, key: RNGKey, params: QDParams) -> QDState:
genotype = jax.tree.map(
lambda x: jnp.full((self.population_size,) + x.shape, fill_value=jnp.nan),
self.solution,
)
fitness = jnp.full((self.population_size,), fill_value=-jnp.inf)
descriptor = jnp.full(
(self.population_size, self.descriptor_size), fill_value=jnp.nan
)
state = QDState(
population=genotype,
fitness=fitness,
descriptor=descriptor,
generation_counter=0,
)
return state
def _ask(
self, key: RNGKey, state: QDState, params: QDParams
) -> tuple[Genotype, QDState]:
"""Ask for new candidate solutions."""
# Simple Selection -> Mutation
valid = state.fitness != -jnp.inf
p = valid / jnp.sum(valid)
p = jnp.where(jnp.isnan(p), 1.0 / self.population_size, p)
population = jax.tree.map(
lambda x: jax.random.choice(key, x, shape=(self.population_size,), p=p),
state.population,
)
population = gaussian_mutation(key, population, params.mutation_sigma)
return population, state
Random Search¶
In [ ]:
Copied!
def random_fitness_shaping(
key: RNGKey,
fitness: Fitness,
descriptor: Descriptor,
state: QDState,
params: QDParams,
) -> Fitness:
"""Random Fitness."""
random_fitness = jax.random.uniform(key, fitness.shape)
valid = fitness != -jnp.inf
return jnp.where(valid, random_fitness, -jnp.inf)
def random_fitness_shaping(
key: RNGKey,
fitness: Fitness,
descriptor: Descriptor,
state: QDState,
params: QDParams,
) -> Fitness:
"""Random Fitness."""
random_fitness = jax.random.uniform(key, fitness.shape)
valid = fitness != -jnp.inf
return jnp.where(valid, random_fitness, -jnp.inf)
In [ ]:
Copied!
class RandomSearch(QDAlgorithm):
"""Random Search: replaces individuals randomly."""
def __init__(
self, population_size: int, solution: Genotype, descriptor_size: int = 2
):
"""Initialize Random Search."""
super().__init__(
population_size,
solution,
fitness_shaping_fn=random_fitness_shaping,
descriptor_size=descriptor_size,
)
class RandomSearch(QDAlgorithm):
"""Random Search: replaces individuals randomly."""
def __init__(
self, population_size: int, solution: Genotype, descriptor_size: int = 2
):
"""Initialize Random Search."""
super().__init__(
population_size,
solution,
fitness_shaping_fn=random_fitness_shaping,
descriptor_size=descriptor_size,
)
Genetic Algorithm¶
In [ ]:
Copied!
def identity_fitness_shaping(
key: RNGKey,
fitness: Fitness,
descriptor: Descriptor,
state: QDState,
params: QDParams,
) -> Fitness:
"""Raw Fitness."""
return fitness
def identity_fitness_shaping(
key: RNGKey,
fitness: Fitness,
descriptor: Descriptor,
state: QDState,
params: QDParams,
) -> Fitness:
"""Raw Fitness."""
return fitness
In [ ]:
Copied!
class GeneticAlgorithm(QDAlgorithm):
"""Genetic Algorithm: Standard unstructured population with identity fitness."""
def __init__(
self,
population_size: int,
solution: Genotype,
fitness_shaping_fn=identity_fitness_shaping,
descriptor_size: int = 2,
):
"""Initialize Genetic Algorithm."""
super().__init__(
population_size,
solution,
fitness_shaping_fn=fitness_shaping_fn,
descriptor_size=descriptor_size,
)
class GeneticAlgorithm(QDAlgorithm):
"""Genetic Algorithm: Standard unstructured population with identity fitness."""
def __init__(
self,
population_size: int,
solution: Genotype,
fitness_shaping_fn=identity_fitness_shaping,
descriptor_size: int = 2,
):
"""Initialize Genetic Algorithm."""
super().__init__(
population_size,
solution,
fitness_shaping_fn=fitness_shaping_fn,
descriptor_size=descriptor_size,
)
Novelty Search¶
In [ ]:
Copied!
def novelty_fitness_shaping(
key: RNGKey,
fitness: Fitness,
descriptor: Descriptor,
state: QDState,
params: QDParams,
novelty_k: int = 3,
) -> Fitness:
"""Novelty Score."""
novelty, _ = novelty_and_dominated_novelty(
fitness,
descriptor,
novelty_k=novelty_k,
)
valid = fitness != -jnp.inf
return jnp.where(valid, novelty, -jnp.inf)
def novelty_fitness_shaping(
key: RNGKey,
fitness: Fitness,
descriptor: Descriptor,
state: QDState,
params: QDParams,
novelty_k: int = 3,
) -> Fitness:
"""Novelty Score."""
novelty, _ = novelty_and_dominated_novelty(
fitness,
descriptor,
novelty_k=novelty_k,
)
valid = fitness != -jnp.inf
return jnp.where(valid, novelty, -jnp.inf)
In [ ]:
Copied!
class NoveltySearch(QDAlgorithm):
"""Novelty Search Algorithm."""
def __init__(
self,
population_size: int,
solution: Genotype,
novelty_k: int = 3,
descriptor_size: int = 2,
):
"""Initialize Novelty Search."""
super().__init__(
population_size,
solution,
fitness_shaping_fn=partial(novelty_fitness_shaping, novelty_k=novelty_k),
descriptor_size=descriptor_size,
)
class NoveltySearch(QDAlgorithm):
"""Novelty Search Algorithm."""
def __init__(
self,
population_size: int,
solution: Genotype,
novelty_k: int = 3,
descriptor_size: int = 2,
):
"""Initialize Novelty Search."""
super().__init__(
population_size,
solution,
fitness_shaping_fn=partial(novelty_fitness_shaping, novelty_k=novelty_k),
descriptor_size=descriptor_size,
)
Dominated Novelty Search¶
In [ ]:
Copied!
def dominated_novelty_fitness_shaping(
key: RNGKey,
fitness: Fitness,
descriptor: Descriptor,
state: QDState,
params: QDParams,
novelty_k: int = 3,
) -> Fitness:
"""Dominated Novelty Score."""
_, dominated_novelty = novelty_and_dominated_novelty(
fitness,
descriptor,
dominated_novelty_k=novelty_k,
)
valid = fitness != -jnp.inf
return jnp.where(valid, dominated_novelty, -jnp.inf)
def dominated_novelty_fitness_shaping(
key: RNGKey,
fitness: Fitness,
descriptor: Descriptor,
state: QDState,
params: QDParams,
novelty_k: int = 3,
) -> Fitness:
"""Dominated Novelty Score."""
_, dominated_novelty = novelty_and_dominated_novelty(
fitness,
descriptor,
dominated_novelty_k=novelty_k,
)
valid = fitness != -jnp.inf
return jnp.where(valid, dominated_novelty, -jnp.inf)
In [ ]:
Copied!
class DominatedNoveltySearch(QDAlgorithm):
"""Dominated Novelty Search Algorithm."""
def __init__(
self,
population_size: int,
solution: Genotype,
novelty_k: int = 3,
descriptor_size: int = 2,
):
"""Initialize Dominated Novelty Search."""
super().__init__(
population_size,
solution,
fitness_shaping_fn=partial(
dominated_novelty_fitness_shaping, novelty_k=novelty_k
),
descriptor_size=descriptor_size,
)
class DominatedNoveltySearch(QDAlgorithm):
"""Dominated Novelty Search Algorithm."""
def __init__(
self,
population_size: int,
solution: Genotype,
novelty_k: int = 3,
descriptor_size: int = 2,
):
"""Initialize Dominated Novelty Search."""
super().__init__(
population_size,
solution,
fitness_shaping_fn=partial(
dominated_novelty_fitness_shaping, novelty_k=novelty_k
),
descriptor_size=descriptor_size,
)
MAP-Elites¶
Grid helpers¶
In [ ]:
Copied!
def get_centroid_indices(descriptors: Descriptor, centroids: Centroid) -> jax.Array:
"""Assign descriptors to their closest centroid and return centroid indices."""
def _get_centroid_indices(descriptor: Descriptor) -> jax.Array:
return jnp.argmin(jnp.linalg.norm(descriptor - centroids, axis=-1))
indices = jax.vmap(_get_centroid_indices)(descriptors)
return indices
def get_centroids(
num_centroids: int,
descriptor_size: int,
descriptor_min: float | list[float],
descriptor_max: float | list[float],
num_init_cvt_samples: int,
key: RNGKey,
) -> jax.Array:
"""Compute centroids using CVT (Centroidal Voronoi Tessellation)."""
descriptor_min = jnp.array(descriptor_min)
descriptor_max = jnp.array(descriptor_max)
# Sample x uniformly in [0, 1]
key_x, key_kmeans = jax.random.split(key)
x = jax.random.uniform(key_x, (num_init_cvt_samples, descriptor_size))
# Generate an integer seed for RandomState
seed = jax.random.randint(key_kmeans, (), 0, 2**30, dtype=jnp.int32)
def _kmeans_host_fn(x_np, seed_np):
rs = RandomState(int(seed_np))
kmeans = KMeans(
init="k-means++",
n_clusters=num_centroids,
n_init=1,
random_state=rs,
)
kmeans.fit(x_np)
return kmeans.cluster_centers_.astype(x_np.dtype)
# Call host function
centroids = jax.pure_callback(
_kmeans_host_fn,
jax.ShapeDtypeStruct((num_centroids, descriptor_size), x.dtype),
x,
seed,
)
# Rescale
return descriptor_min + (descriptor_max - descriptor_min) * centroids
def segment_argmax(data, segment_ids, num_segments):
"""Compute the argmax of data for each segment."""
return jnp.argmax(
jax.vmap(lambda i: jnp.where(i == segment_ids, data, -jnp.inf))(
jnp.arange(num_segments)
),
axis=1,
)
def get_centroid_indices(descriptors: Descriptor, centroids: Centroid) -> jax.Array:
"""Assign descriptors to their closest centroid and return centroid indices."""
def _get_centroid_indices(descriptor: Descriptor) -> jax.Array:
return jnp.argmin(jnp.linalg.norm(descriptor - centroids, axis=-1))
indices = jax.vmap(_get_centroid_indices)(descriptors)
return indices
def get_centroids(
num_centroids: int,
descriptor_size: int,
descriptor_min: float | list[float],
descriptor_max: float | list[float],
num_init_cvt_samples: int,
key: RNGKey,
) -> jax.Array:
"""Compute centroids using CVT (Centroidal Voronoi Tessellation)."""
descriptor_min = jnp.array(descriptor_min)
descriptor_max = jnp.array(descriptor_max)
# Sample x uniformly in [0, 1]
key_x, key_kmeans = jax.random.split(key)
x = jax.random.uniform(key_x, (num_init_cvt_samples, descriptor_size))
# Generate an integer seed for RandomState
seed = jax.random.randint(key_kmeans, (), 0, 2**30, dtype=jnp.int32)
def _kmeans_host_fn(x_np, seed_np):
rs = RandomState(int(seed_np))
kmeans = KMeans(
init="k-means++",
n_clusters=num_centroids,
n_init=1,
random_state=rs,
)
kmeans.fit(x_np)
return kmeans.cluster_centers_.astype(x_np.dtype)
# Call host function
centroids = jax.pure_callback(
_kmeans_host_fn,
jax.ShapeDtypeStruct((num_centroids, descriptor_size), x.dtype),
x,
seed,
)
# Rescale
return descriptor_min + (descriptor_max - descriptor_min) * centroids
def segment_argmax(data, segment_ids, num_segments):
"""Compute the argmax of data for each segment."""
return jnp.argmax(
jax.vmap(lambda i: jnp.where(i == segment_ids, data, -jnp.inf))(
jnp.arange(num_segments)
),
axis=1,
)
In [ ]:
Copied!
def map_elites_fitness_shaping(
key: RNGKey,
fitness: Fitness,
descriptor: Descriptor,
state: QDState,
params: QDParams,
) -> Fitness:
"""Grid Fitness."""
centroids = state.centroids
# Get centroid assignments
centroid_indices = get_centroid_indices(descriptor, centroids)
num_centroids = centroids.shape[0]
best_index_per_centroid = segment_argmax(fitness, centroid_indices, num_centroids)
# Check which centroids have assigned individuals
centroid_assigned = jnp.isin(jnp.arange(num_centroids), centroid_indices)
# Handle empty centroids to avoid collision at index 0
best_index_per_centroid = jnp.where(
centroid_assigned,
best_index_per_centroid,
fitness.shape[0], # if centroid not used, put the best index out of bounds
)
# Create mask for individuals that are the best in their assigned cell
best_index = (
jnp.zeros_like(fitness, dtype=bool).at[best_index_per_centroid].set(True)
)
return jnp.where(best_index, fitness, -jnp.inf)
def map_elites_fitness_shaping(
key: RNGKey,
fitness: Fitness,
descriptor: Descriptor,
state: QDState,
params: QDParams,
) -> Fitness:
"""Grid Fitness."""
centroids = state.centroids
# Get centroid assignments
centroid_indices = get_centroid_indices(descriptor, centroids)
num_centroids = centroids.shape[0]
best_index_per_centroid = segment_argmax(fitness, centroid_indices, num_centroids)
# Check which centroids have assigned individuals
centroid_assigned = jnp.isin(jnp.arange(num_centroids), centroid_indices)
# Handle empty centroids to avoid collision at index 0
best_index_per_centroid = jnp.where(
centroid_assigned,
best_index_per_centroid,
fitness.shape[0], # if centroid not used, put the best index out of bounds
)
# Create mask for individuals that are the best in their assigned cell
best_index = (
jnp.zeros_like(fitness, dtype=bool).at[best_index_per_centroid].set(True)
)
return jnp.where(best_index, fitness, -jnp.inf)
In [ ]:
Copied!
@flax.struct.dataclass
class MAPElitesState(QDState):
"""State for MAP-Elites algorithm."""
centroids: Centroid
class MAPElites(QDAlgorithm):
"""MAP-Elites Algorithm."""
def __init__(
self,
population_size: int,
solution: Genotype,
descriptor_size: int,
descriptor_min: float | list[float],
descriptor_max: float | list[float],
num_init_cvt_samples: int = 10000,
):
"""Initialize MAP-Elites."""
super().__init__(
population_size,
solution,
fitness_shaping_fn=map_elites_fitness_shaping,
descriptor_size=descriptor_size,
)
self.descriptor_min = descriptor_min
self.descriptor_max = descriptor_max
self.num_init_cvt_samples = num_init_cvt_samples
def _init(self, key: RNGKey, params: QDParams) -> MAPElitesState:
genotype = jax.tree.map(
lambda x: jnp.full((self.population_size,) + x.shape, fill_value=jnp.nan),
self.solution,
)
fitness = jnp.full((self.population_size,), fill_value=-jnp.inf)
descriptor = jnp.full(
(self.population_size, self.descriptor_size), fill_value=jnp.nan
)
centroids = get_centroids(
num_centroids=self.population_size,
descriptor_size=self.descriptor_size,
descriptor_min=self.descriptor_min,
descriptor_max=self.descriptor_max,
num_init_cvt_samples=self.num_init_cvt_samples,
key=key,
)
state = MAPElitesState(
population=genotype,
fitness=fitness,
descriptor=descriptor,
centroids=centroids,
generation_counter=0,
)
return state
@flax.struct.dataclass
class MAPElitesState(QDState):
"""State for MAP-Elites algorithm."""
centroids: Centroid
class MAPElites(QDAlgorithm):
"""MAP-Elites Algorithm."""
def __init__(
self,
population_size: int,
solution: Genotype,
descriptor_size: int,
descriptor_min: float | list[float],
descriptor_max: float | list[float],
num_init_cvt_samples: int = 10000,
):
"""Initialize MAP-Elites."""
super().__init__(
population_size,
solution,
fitness_shaping_fn=map_elites_fitness_shaping,
descriptor_size=descriptor_size,
)
self.descriptor_min = descriptor_min
self.descriptor_max = descriptor_max
self.num_init_cvt_samples = num_init_cvt_samples
def _init(self, key: RNGKey, params: QDParams) -> MAPElitesState:
genotype = jax.tree.map(
lambda x: jnp.full((self.population_size,) + x.shape, fill_value=jnp.nan),
self.solution,
)
fitness = jnp.full((self.population_size,), fill_value=-jnp.inf)
descriptor = jnp.full(
(self.population_size, self.descriptor_size), fill_value=jnp.nan
)
centroids = get_centroids(
num_centroids=self.population_size,
descriptor_size=self.descriptor_size,
descriptor_min=self.descriptor_min,
descriptor_max=self.descriptor_max,
num_init_cvt_samples=self.num_init_cvt_samples,
key=key,
)
state = MAPElitesState(
population=genotype,
fitness=fitness,
descriptor=descriptor,
centroids=centroids,
generation_counter=0,
)
return state
Run¶
In [ ]:
Copied!
from bbobax import QDBBOB
from bbobax.descriptor_fns import get_random_projection_descriptor
from bbobax.fitness_fns import bbob_fns
# Configuration
seed = 1
pop_size = 1024
num_generations = 100
dim = 2
# Setup Task
bbob = QDBBOB(
min_num_dims=dim,
max_num_dims=dim,
fitness_fns=[bbob_fns["sphere"]],
descriptor_fns=[get_random_projection_descriptor()],
descriptor_size=2,
)
key = jax.random.key(seed)
key_bbob, key_init, key_qd, key_pop = jax.random.split(key, 4)
bbob_params = bbob.sample(key_bbob)
bbob_state = bbob.init(key_init, bbob_params)
# Solution template
solution_template = jnp.zeros((dim,))
# Sample initial population from task
keys = jax.random.split(key_pop, pop_size)
initial_population = jax.vmap(bbob.sample_x)(keys)
# Evaluate initial population to get fitness/descriptor for init
fitness_fn = jax.vmap(bbob.evaluate, in_axes=(0, 0, None, None))
eval_keys = jax.random.split(key_pop, pop_size)
bbob_state_batch, bbob_eval = fitness_fn(
eval_keys, initial_population, bbob_state, bbob_params
)
bbob_state = jax.tree.map(lambda x: x[0], bbob_state_batch)
# Algorithms to test
algorithms = {
"Random": RandomSearch(pop_size, solution_template),
"GA": GeneticAlgorithm(pop_size, solution_template),
"NoveltySearch": NoveltySearch(pop_size, solution_template),
"DominatedNoveltySearch": DominatedNoveltySearch(pop_size, solution_template),
"MAP-Elites": MAPElites(
pop_size,
solution_template,
descriptor_size=2,
descriptor_min=[-3.0, -3.0],
descriptor_max=[3.0, 3.0],
),
}
print(f"Starting benchmark on Sphere (dim={dim}) for {num_generations} generations...")
for name, qd in algorithms.items():
print(f"\n--- {name} ---")
# Init Algorithm
qd_params = qd.default_params
# Initialize with the sampled population
qd_state = qd.init(
key_qd,
population=initial_population,
fitness=-bbob_eval.fitness,
descriptor=bbob_eval.descriptor,
params=qd_params,
)
# Loop
for gen in range(num_generations):
key_qd, key_ask, key_eval, key_tell = jax.random.split(key_qd, 4)
# Ask
population, qd_state = qd.ask(key_ask, qd_state, qd_params)
# Evaluate
eval_keys = jax.random.split(key_eval, pop_size)
bbob_state_batch, bbob_eval_gen = fitness_fn(
eval_keys, population, bbob_state, bbob_params
)
bbob_state = jax.tree.map(lambda x: x[0], bbob_state_batch)
# Tell
qd_state, metrics = qd.tell(
key_tell,
population,
-bbob_eval_gen.fitness,
bbob_eval_gen.descriptor,
qd_state,
qd_params,
)
if gen % 20 == 0:
# Aggregate metrics for display
agg = metrics_agg_fn(metrics)
print(
f"Generation {gen:03d}: "
f"population_size={agg['population_size']:.0f}, "
f"fitness_max={agg['fitness_max']:.4f}, "
f"novelty_mean={agg['novelty_mean']:.4f}, "
f"dominated_novelty_mean={agg['dominated_novelty_mean']:.4f}"
)
from bbobax import QDBBOB
from bbobax.descriptor_fns import get_random_projection_descriptor
from bbobax.fitness_fns import bbob_fns
# Configuration
seed = 1
pop_size = 1024
num_generations = 100
dim = 2
# Setup Task
bbob = QDBBOB(
min_num_dims=dim,
max_num_dims=dim,
fitness_fns=[bbob_fns["sphere"]],
descriptor_fns=[get_random_projection_descriptor()],
descriptor_size=2,
)
key = jax.random.key(seed)
key_bbob, key_init, key_qd, key_pop = jax.random.split(key, 4)
bbob_params = bbob.sample(key_bbob)
bbob_state = bbob.init(key_init, bbob_params)
# Solution template
solution_template = jnp.zeros((dim,))
# Sample initial population from task
keys = jax.random.split(key_pop, pop_size)
initial_population = jax.vmap(bbob.sample_x)(keys)
# Evaluate initial population to get fitness/descriptor for init
fitness_fn = jax.vmap(bbob.evaluate, in_axes=(0, 0, None, None))
eval_keys = jax.random.split(key_pop, pop_size)
bbob_state_batch, bbob_eval = fitness_fn(
eval_keys, initial_population, bbob_state, bbob_params
)
bbob_state = jax.tree.map(lambda x: x[0], bbob_state_batch)
# Algorithms to test
algorithms = {
"Random": RandomSearch(pop_size, solution_template),
"GA": GeneticAlgorithm(pop_size, solution_template),
"NoveltySearch": NoveltySearch(pop_size, solution_template),
"DominatedNoveltySearch": DominatedNoveltySearch(pop_size, solution_template),
"MAP-Elites": MAPElites(
pop_size,
solution_template,
descriptor_size=2,
descriptor_min=[-3.0, -3.0],
descriptor_max=[3.0, 3.0],
),
}
print(f"Starting benchmark on Sphere (dim={dim}) for {num_generations} generations...")
for name, qd in algorithms.items():
print(f"\n--- {name} ---")
# Init Algorithm
qd_params = qd.default_params
# Initialize with the sampled population
qd_state = qd.init(
key_qd,
population=initial_population,
fitness=-bbob_eval.fitness,
descriptor=bbob_eval.descriptor,
params=qd_params,
)
# Loop
for gen in range(num_generations):
key_qd, key_ask, key_eval, key_tell = jax.random.split(key_qd, 4)
# Ask
population, qd_state = qd.ask(key_ask, qd_state, qd_params)
# Evaluate
eval_keys = jax.random.split(key_eval, pop_size)
bbob_state_batch, bbob_eval_gen = fitness_fn(
eval_keys, population, bbob_state, bbob_params
)
bbob_state = jax.tree.map(lambda x: x[0], bbob_state_batch)
# Tell
qd_state, metrics = qd.tell(
key_tell,
population,
-bbob_eval_gen.fitness,
bbob_eval_gen.descriptor,
qd_state,
qd_params,
)
if gen % 20 == 0:
# Aggregate metrics for display
agg = metrics_agg_fn(metrics)
print(
f"Generation {gen:03d}: "
f"population_size={agg['population_size']:.0f}, "
f"fitness_max={agg['fitness_max']:.4f}, "
f"novelty_mean={agg['novelty_mean']:.4f}, "
f"dominated_novelty_mean={agg['dominated_novelty_mean']:.4f}"
)