Meta-Black-Box Optimization
¶
In this notebook, we show how to do meta-black-box optimization on bbobax using Evolution Strategies.
Install¶
You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:
In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"
Then, install bbobax from PyPi:
In [ ]:
Copied!
%pip install -U "bbobax[notebooks]"
%pip install -U "bbobax[notebooks]"
Import¶
In [ ]:
Copied!
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from evosax.algorithms import algorithms
from bbobax import BBOB
from bbobax.fitness_fns import bbob_fns
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from evosax.algorithms import algorithms
from bbobax import BBOB
from bbobax.fitness_fns import bbob_fns
Initialize BBO Problem¶
In [ ]:
Copied!
fn_names = [
"sphere",
"ellipsoidal",
"rastrigin",
"bueche_rastrigin",
"linear_slope",
"attractive_sector",
"step_ellipsoidal",
"rosenbrock",
"rosenbrock_rotated",
"ellipsoidal_rotated",
"discus",
"bent_cigar",
"sharp_ridge",
"different_powers",
"rastrigin_rotated",
"weierstrass",
"schaffers_f7",
"schaffers_f7_ill_conditioned",
"griewank_rosenbrock",
"katsuura",
"lunacek",
]
# Map function names to callables
fitness_fns = [bbob_fns[fn_name] for fn_name in fn_names]
bbob = BBOB(
fitness_fns=fitness_fns,
min_num_dims=2,
max_num_dims=16,
x_range=[-5.0, 5.0],
x_opt_range=[-4.0, 4.0],
f_opt_range=[0.0, 0.0], # Force optimal fitness to 0.0
)
fn_names = [
"sphere",
"ellipsoidal",
"rastrigin",
"bueche_rastrigin",
"linear_slope",
"attractive_sector",
"step_ellipsoidal",
"rosenbrock",
"rosenbrock_rotated",
"ellipsoidal_rotated",
"discus",
"bent_cigar",
"sharp_ridge",
"different_powers",
"rastrigin_rotated",
"weierstrass",
"schaffers_f7",
"schaffers_f7_ill_conditioned",
"griewank_rosenbrock",
"katsuura",
"lunacek",
]
# Map function names to callables
fitness_fns = [bbob_fns[fn_name] for fn_name in fn_names]
bbob = BBOB(
fitness_fns=fitness_fns,
min_num_dims=2,
max_num_dims=16,
x_range=[-5.0, 5.0],
x_opt_range=[-4.0, 4.0],
f_opt_range=[0.0, 0.0], # Force optimal fitness to 0.0
)
ES Comparison across many BBO problems¶
In [ ]:
Copied!
num_generations = 8_192
population_size = 1_024
num_tasks = 128
key = jax.random.key(0)
es_dict = {
"SimpleES": {},
"PGPE": {},
"Open_ES": {"optimizer": optax.adam(1e-3)},
"SNES": {},
"Sep_CMA_ES": {},
"CMA_ES": {},
}
# Dictionary to store results for each ES
results = {}
# Sample BBOB tasks (params and state shared across all ES runs)
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num_tasks)
bbob_params = jax.vmap(bbob.sample)(keys)
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num_tasks)
bbob_state = jax.vmap(bbob.init)(keys, bbob_params)
# Sample initial solutions for each task (to be used by all ES)
key, subkey = jax.random.split(key)
x = bbob.sample_x(subkey) # Dummy solution for ES init
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num_tasks)
solutions = jax.vmap(bbob.sample_x)(keys)
solutions = jnp.clip(solutions, -4.0, 4.0)
# Loop over the selected ES algorithms
for es_name in es_dict:
print(f"Running {es_name}...")
# Get the ES class from the algorithms dictionary
ES = algorithms[es_name]
# Initialize the ES
es = ES(
population_size=population_size,
solution=x,
**es_dict[es_name],
)
es_params = es.default_params
# Define the step function for the scan
def step(carry, key):
"""One step of the optimization loop."""
es_state, es_params, bbob_state, bbob_params = carry
key_ask, key_eval, key_tell = jax.random.split(key, 3)
# Ask
population, es_state = es.ask(key_ask, es_state, es_params)
population = jnp.clip(population, -5.0, 5.0)
# Eval (vectorized over population)
fitness_fn = jax.vmap(bbob.evaluate, in_axes=(0, 0, None, None))
keys_eval = jax.random.split(key_eval, population.shape[0])
bbob_state, bbob_eval = fitness_fn(
keys_eval, population, bbob_state, bbob_params
)
# bbob_state is updated (counter +1), same for all individuals in pop
bbob_state = jax.tree.map(lambda x: x[0], bbob_state)
fitness = bbob_eval.fitness
# Tell
es_state, metrics = es.tell(key_tell, population, fitness, es_state, es_params)
return (es_state, es_params, bbob_state, bbob_params), (es_state, metrics)
@jax.jit
def run_es_eval(key, solution, es_params, bbob_state, bbob_params):
"""Run ES on a single task."""
# Init ES state with the specific solution for this task
key, subkey = jax.random.split(key)
es_state = es.init(subkey, solution, es_params)
# Scan
keys = jax.random.split(subkey, num_generations)
(es_state, es_params, bbob_state, bbob_params), (es_states, metrics) = (
jax.lax.scan(
step,
(es_state, es_params, bbob_state, bbob_params),
keys,
length=num_generations,
)
)
return metrics, es_states.mean[-1]
# Run evaluation across all tasks
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num_tasks)
metrics_batch, final_means = jax.vmap(run_es_eval, in_axes=(0, 0, None, 0, 0))(
keys, solutions, es_params, bbob_state, bbob_params
)
# Average metrics across tasks
metrics = jax.tree.map(lambda x: jnp.mean(x, axis=0), metrics_batch)
# Store the results
results[es_name] = metrics
num_generations = 8_192
population_size = 1_024
num_tasks = 128
key = jax.random.key(0)
es_dict = {
"SimpleES": {},
"PGPE": {},
"Open_ES": {"optimizer": optax.adam(1e-3)},
"SNES": {},
"Sep_CMA_ES": {},
"CMA_ES": {},
}
# Dictionary to store results for each ES
results = {}
# Sample BBOB tasks (params and state shared across all ES runs)
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num_tasks)
bbob_params = jax.vmap(bbob.sample)(keys)
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num_tasks)
bbob_state = jax.vmap(bbob.init)(keys, bbob_params)
# Sample initial solutions for each task (to be used by all ES)
key, subkey = jax.random.split(key)
x = bbob.sample_x(subkey) # Dummy solution for ES init
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num_tasks)
solutions = jax.vmap(bbob.sample_x)(keys)
solutions = jnp.clip(solutions, -4.0, 4.0)
# Loop over the selected ES algorithms
for es_name in es_dict:
print(f"Running {es_name}...")
# Get the ES class from the algorithms dictionary
ES = algorithms[es_name]
# Initialize the ES
es = ES(
population_size=population_size,
solution=x,
**es_dict[es_name],
)
es_params = es.default_params
# Define the step function for the scan
def step(carry, key):
"""One step of the optimization loop."""
es_state, es_params, bbob_state, bbob_params = carry
key_ask, key_eval, key_tell = jax.random.split(key, 3)
# Ask
population, es_state = es.ask(key_ask, es_state, es_params)
population = jnp.clip(population, -5.0, 5.0)
# Eval (vectorized over population)
fitness_fn = jax.vmap(bbob.evaluate, in_axes=(0, 0, None, None))
keys_eval = jax.random.split(key_eval, population.shape[0])
bbob_state, bbob_eval = fitness_fn(
keys_eval, population, bbob_state, bbob_params
)
# bbob_state is updated (counter +1), same for all individuals in pop
bbob_state = jax.tree.map(lambda x: x[0], bbob_state)
fitness = bbob_eval.fitness
# Tell
es_state, metrics = es.tell(key_tell, population, fitness, es_state, es_params)
return (es_state, es_params, bbob_state, bbob_params), (es_state, metrics)
@jax.jit
def run_es_eval(key, solution, es_params, bbob_state, bbob_params):
"""Run ES on a single task."""
# Init ES state with the specific solution for this task
key, subkey = jax.random.split(key)
es_state = es.init(subkey, solution, es_params)
# Scan
keys = jax.random.split(subkey, num_generations)
(es_state, es_params, bbob_state, bbob_params), (es_states, metrics) = (
jax.lax.scan(
step,
(es_state, es_params, bbob_state, bbob_params),
keys,
length=num_generations,
)
)
return metrics, es_states.mean[-1]
# Run evaluation across all tasks
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num_tasks)
metrics_batch, final_means = jax.vmap(run_es_eval, in_axes=(0, 0, None, 0, 0))(
keys, solutions, es_params, bbob_state, bbob_params
)
# Average metrics across tasks
metrics = jax.tree.map(lambda x: jnp.mean(x, axis=0), metrics_batch)
# Store the results
results[es_name] = metrics
Visualize¶
In [ ]:
Copied!
# Plot the fitness over generations for each ES
plt.figure(figsize=(10, 6))
for es_name, metrics in results.items():
plt.plot(metrics["best_fitness"], label=es_name)
plt.title(f"Evolution Strategies Comparison on {num_tasks} Sampled BBOB Functions")
plt.xlabel("Generations")
plt.ylabel("Fitness")
plt.yscale("log")
plt.legend()
plt.grid(True)
plt.show()
# Plot the fitness over generations for each ES
plt.figure(figsize=(10, 6))
for es_name, metrics in results.items():
plt.plot(metrics["best_fitness"], label=es_name)
plt.title(f"Evolution Strategies Comparison on {num_tasks} Sampled BBOB Functions")
plt.xlabel("Generations")
plt.ylabel("Fitness")
plt.yscale("log")
plt.legend()
plt.grid(True)
plt.show()