Skip to content

QDBBOB

bbobax.QDBBOB

Bases: BBOB

Quality-Diversity Black-box Optimization Benchmarking Task class.

Source code in src/bbobax/bbob.py
class QDBBOB(BBOB):
    """Quality-Diversity Black-box Optimization Benchmarking Task class."""

    def __init__(
        self,
        descriptor_fns: list[Callable[[jax.Array, BBOBState, QDBBOBParams], jax.Array]]
        | dict[str, Callable[[jax.Array, BBOBState, QDBBOBParams], jax.Array]],
        fitness_fns: list[Callable[[jax.Array, BBOBState, BBOBParams], jax.Array]]
        | dict[str, Callable[[jax.Array, BBOBState, BBOBParams], jax.Array]],
        descriptor_size: int = 2,
        **kwargs,
    ):
        """Initialize the QD-BBOB task.

        Args:
            descriptor_fns: List or dictionary of descriptor functions.
            fitness_fns: List or dictionary of fitness functions.
            descriptor_size: Size of the descriptor vector.
            **kwargs: Additional arguments for BBOB.

        """
        super().__init__(fitness_fns=fitness_fns, **kwargs)

        if isinstance(descriptor_fns, dict):
            self.descriptor_fns = list(descriptor_fns.values())
        else:
            self.descriptor_fns = descriptor_fns

        self.descriptor_size = descriptor_size

        # Vectorize descriptors
        self._vmapped_descriptor_fns = [
            jax.vmap(fn, in_axes=(0, None, None)) for fn in self.descriptor_fns
        ]

        self.num_descriptors = len(self.descriptor_fns)

    def sample(self, key: jax.Array) -> QDBBOBParams:
        """Sample BBOB task parameters including descriptor params."""
        key_base, key_desc_id, key_desc_params = jax.random.split(key, 3)

        base_params = super().sample(key_base)

        desc_id = jax.random.randint(
            key_desc_id, (), minval=0, maxval=self.num_descriptors
        )

        # Descriptor params
        descriptor_params = self.gaussian_random_projection(
            key_desc_params, base_params.num_dims
        )

        return QDBBOBParams(
            fn_id=base_params.fn_id,
            num_dims=base_params.num_dims,
            x_opt=base_params.x_opt,
            f_opt=base_params.f_opt,
            noise_params=base_params.noise_params,
            descriptor_params=descriptor_params,
            descriptor_id=desc_id,
        )

    def evaluate(
        self,
        key: jax.Array,
        x: jax.Array,
        state: BBOBState,
        params: QDBBOBParams,
    ) -> tuple[BBOBState, QDBBOBEval]:
        """Evaluate the fitness and descriptor of a solution.

        Args:
            key: JAX random key.
            x: Input solution.
            state: Current task state.
            params: Task parameters.

        Returns:
            Updated state and evaluation results.

        """
        state, bbob_eval = super().evaluate(key, x, state, params)

        descriptor = jax.lax.switch(
            params.descriptor_id, self.descriptor_fns, x, state, params
        )

        bbob_eval = QDBBOBEval(fitness=bbob_eval.fitness, descriptor=descriptor)
        return state, bbob_eval

    def gaussian_random_projection(self, key: jax.Array, num_dims: int) -> jax.Array:
        """Generate a random Gaussian projection matrix.

        Args:
            key: JAX random key.
            num_dims: Number of dimensions.

        Returns:
            Random projection matrix.

        """
        descriptor_params = jax.random.normal(
            key,
            shape=(self.descriptor_size, self.max_num_dims),
        ) / jnp.sqrt(self.descriptor_size)
        mask = jnp.arange(self.max_num_dims) < num_dims
        descriptor_params = jnp.where(mask, descriptor_params, 0)
        return descriptor_params

__init__(descriptor_fns, fitness_fns, descriptor_size=2, **kwargs)

Initialize the QD-BBOB task.

Parameters:

Name Type Description Default
descriptor_fns list[Callable[[Array, BBOBState, QDBBOBParams], Array]] | dict[str, Callable[[Array, BBOBState, QDBBOBParams], Array]]

List or dictionary of descriptor functions.

required
fitness_fns list[Callable[[Array, BBOBState, BBOBParams], Array]] | dict[str, Callable[[Array, BBOBState, BBOBParams], Array]]

List or dictionary of fitness functions.

required
descriptor_size int

Size of the descriptor vector.

2
**kwargs

Additional arguments for BBOB.

{}
Source code in src/bbobax/bbob.py
def __init__(
    self,
    descriptor_fns: list[Callable[[jax.Array, BBOBState, QDBBOBParams], jax.Array]]
    | dict[str, Callable[[jax.Array, BBOBState, QDBBOBParams], jax.Array]],
    fitness_fns: list[Callable[[jax.Array, BBOBState, BBOBParams], jax.Array]]
    | dict[str, Callable[[jax.Array, BBOBState, BBOBParams], jax.Array]],
    descriptor_size: int = 2,
    **kwargs,
):
    """Initialize the QD-BBOB task.

    Args:
        descriptor_fns: List or dictionary of descriptor functions.
        fitness_fns: List or dictionary of fitness functions.
        descriptor_size: Size of the descriptor vector.
        **kwargs: Additional arguments for BBOB.

    """
    super().__init__(fitness_fns=fitness_fns, **kwargs)

    if isinstance(descriptor_fns, dict):
        self.descriptor_fns = list(descriptor_fns.values())
    else:
        self.descriptor_fns = descriptor_fns

    self.descriptor_size = descriptor_size

    # Vectorize descriptors
    self._vmapped_descriptor_fns = [
        jax.vmap(fn, in_axes=(0, None, None)) for fn in self.descriptor_fns
    ]

    self.num_descriptors = len(self.descriptor_fns)

evaluate(key, x, state, params)

Evaluate the fitness and descriptor of a solution.

Parameters:

Name Type Description Default
key Array

JAX random key.

required
x Array

Input solution.

required
state BBOBState

Current task state.

required
params QDBBOBParams

Task parameters.

required

Returns:

Type Description
tuple[BBOBState, QDBBOBEval]

Updated state and evaluation results.

Source code in src/bbobax/bbob.py
def evaluate(
    self,
    key: jax.Array,
    x: jax.Array,
    state: BBOBState,
    params: QDBBOBParams,
) -> tuple[BBOBState, QDBBOBEval]:
    """Evaluate the fitness and descriptor of a solution.

    Args:
        key: JAX random key.
        x: Input solution.
        state: Current task state.
        params: Task parameters.

    Returns:
        Updated state and evaluation results.

    """
    state, bbob_eval = super().evaluate(key, x, state, params)

    descriptor = jax.lax.switch(
        params.descriptor_id, self.descriptor_fns, x, state, params
    )

    bbob_eval = QDBBOBEval(fitness=bbob_eval.fitness, descriptor=descriptor)
    return state, bbob_eval

gaussian_random_projection(key, num_dims)

Generate a random Gaussian projection matrix.

Parameters:

Name Type Description Default
key Array

JAX random key.

required
num_dims int

Number of dimensions.

required

Returns:

Type Description
Array

Random projection matrix.

Source code in src/bbobax/bbob.py
def gaussian_random_projection(self, key: jax.Array, num_dims: int) -> jax.Array:
    """Generate a random Gaussian projection matrix.

    Args:
        key: JAX random key.
        num_dims: Number of dimensions.

    Returns:
        Random projection matrix.

    """
    descriptor_params = jax.random.normal(
        key,
        shape=(self.descriptor_size, self.max_num_dims),
    ) / jnp.sqrt(self.descriptor_size)
    mask = jnp.arange(self.max_num_dims) < num_dims
    descriptor_params = jnp.where(mask, descriptor_params, 0)
    return descriptor_params

sample(key)

Sample BBOB task parameters including descriptor params.

Source code in src/bbobax/bbob.py
def sample(self, key: jax.Array) -> QDBBOBParams:
    """Sample BBOB task parameters including descriptor params."""
    key_base, key_desc_id, key_desc_params = jax.random.split(key, 3)

    base_params = super().sample(key_base)

    desc_id = jax.random.randint(
        key_desc_id, (), minval=0, maxval=self.num_descriptors
    )

    # Descriptor params
    descriptor_params = self.gaussian_random_projection(
        key_desc_params, base_params.num_dims
    )

    return QDBBOBParams(
        fn_id=base_params.fn_id,
        num_dims=base_params.num_dims,
        x_opt=base_params.x_opt,
        f_opt=base_params.f_opt,
        noise_params=base_params.noise_params,
        descriptor_params=descriptor_params,
        descriptor_id=desc_id,
    )