HPO eficiente con EvoX
En este capítulo, exploraremos cómo utilizar EvoX para la optimización de hiperparámetros (HPO).
La HPO desempeña un papel crucial en muchas tareas de machine learning, pero a menudo se pasa por alto debido a su alto coste computacional, que a veces puede tardar días en procesarse, así como a los desafíos que conlleva su despliegue.
Con EvoX, podemos simplificar el despliegue de HPO utilizando el HPOProblemWrapper y lograr un cómputo eficiente aprovechando el método vmap y la aceleración por GPU.
Transformando el Workflow en un Problema

La clave para desplegar HPO con EvoX es transformar los workflows en problems utilizando el HPOProblemWrapper. Una vez transformados, podemos tratar los workflows como problems estándar. La entrada al ‘problema de HPO’ consiste en los hiperparámetros, y la salida son las métricas de evaluación.
El componente clave — HPOProblemWrapper
Para asegurar que el HPOProblemWrapper reconozca los hiperparámetros, necesitamos envolverlos usando Parameter. Con este sencillo paso, los hiperparámetros se identificarán automáticamente.
class ExampleAlgorithm(Algorithm):
def __init__(self,...):
self.omega = Parameter([1.0, 2.0]) # wrap the hyper-parameters with `Parameter`
self.beta = Parameter(0.1)
pass
def step(self):
# run algorithm step depending on the value of self.omega and self.beta
pass
Utilizando el HPOFitnessMonitor
Proporcionamos un HPOFitnessMonitor que admite el cálculo de las métricas ‘IGD’ y ‘HV’ para problemas multiobjetivo, así como el valor mínimo para problemas de un solo objetivo.
Es importante tener en cuenta que el HPOFitnessMonitor es un monitor básico diseñado para problemas de HPO. También puedes crear tu propio monitor personalizado de forma flexible utilizando el enfoque descrito en Deploy HPO with Custom Algorithms.
Un ejemplo sencillo
Aquí, mostraremos un ejemplo sencillo del uso de EvoX para HPO. Específicamente, utilizaremos el algoritmo PSO para optimizar los hiperparámetros del algoritmo PSO para resolver el problema de la esfera (sphere).
Ten en cuenta que este capítulo solo ofrece una breve visión general del despliegue de HPO. Para una guía más detallada, consulta Deploy HPO with Custom Algorithms.
Para empezar, importemos los módulos necesarios.
import torch
from evox.algorithms.pso_variants.pso import PSO
from evox.core import Problem
from evox.problems.hpo_wrapper import HPOFitnessMonitor, HPOProblemWrapper
from evox.workflows import EvalMonitor, StdWorkflow
A continuación, definimos un problema Sphere sencillo.
class Sphere(Problem):
def __init__(self):
super().__init__()
def evaluate(self, x: torch.Tensor):
return (x * x).sum(-1)
A continuación, podemos usar el StdWorkflow para envolver el problem, el algorithm y el monitor. Luego, usamos el HPOProblemWrapper para transformar el StdWorkflow en un problema de HPO.
# the inner loop is a PSO algorithm with a population size of 50
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")
inner_algo = PSO(50, -10 * torch.ones(10), 10 * torch.ones(10))
inner_prob = Sphere()
inner_monitor = HPOFitnessMonitor()
inner_workflow = StdWorkflow(inner_algo, inner_prob, monitor=inner_monitor)
# Transform the inner workflow to an HPO problem
hpo_prob = HPOProblemWrapper(iterations=30, num_instances=128, workflow=inner_workflow, copy_init_state=True)
El HPOProblemWrapper toma 4 argumentos:
iterations: El número de iteraciones que se ejecutarán en el proceso de optimización.num_instances: El número de instancias que se ejecutarán en paralelo en el proceso de optimización.workflow: El workflow que se utilizará en el proceso de optimización.copy_init_state: Si se debe copiar el estado inicial del workflow para cada evaluación. Por defecto esTrue. Si tu workflow contiene operaciones que modifican IN-PLACE el tensor o tensores en el estado inicial, esto debe establecerse enTrue. De lo contrario, puedes establecerlo enFalsepara ahorrar memoria.
Podemos verificar si el HPOProblemWrapper reconoce correctamente los hiperparámetros que definimos. Dado que no se realizan modificaciones en los hiperparámetros en las 5 instancias, deberían permanecer idénticos para todas las instancias.
params = hpo_prob.get_init_params()
print("init params:\n", params)
También podemos definir un conjunto personalizado de valores de hiperparámetros. Es importante asegurarse de que el número de conjuntos de hiperparámetros coincida con el número de instancias en el HPOProblemWrapper. Además, los hiperparámetros personalizados deben proporcionarse como un diccionario cuyos valores estén envueltos mediante Parameter.
params = hpo_prob.get_init_params()
# since we have 128 instances, we need to pass 128 sets of hyperparameters
params["algorithm.w"] = torch.nn.Parameter(torch.rand(128, 1), requires_grad=False)
params["algorithm.phi_p"] = torch.nn.Parameter(torch.rand(128, 1), requires_grad=False)
params["algorithm.phi_g"] = torch.nn.Parameter(torch.rand(128, 1), requires_grad=False)
result = hpo_prob.evaluate(params)
print("The result of the first 3 parameter sets:\n", result[:3])
Ahora, utilizamos el algoritmo PSO para optimizar los hiperparámetros del algoritmo PSO.
Es importante asegurarse de que el tamaño de la población del PSO coincida con el número de instancias; de lo contrario, pueden ocurrir errores inesperados.
Además, la solución debe transformarse en el workflow externo, ya que el HPOProblemWrapper requiere que la entrada tenga forma de diccionario.
class solution_transform(torch.nn.Module):
def forward(self, x: torch.Tensor):
return {
"algorithm.w": x[:, 0],
"algorithm.phi_p": x[:, 1],
"algorithm.phi_g": x[:, 2],
}
outer_algo = PSO(128, 0 * torch.ones(3), 3 * torch.ones(3)) # search each hyperparameter in the range [0, 3]
monitor = EvalMonitor(full_sol_history=False)
outer_workflow = StdWorkflow(outer_algo, hpo_prob, monitor=monitor, solution_transform=solution_transform())
outer_workflow.init_step()
compiled_step = torch.compile(outer_workflow.step)
for _ in range(100):
compiled_step()
monitor = outer_workflow.get_submodule("monitor")
print("params:\n", monitor.topk_solutions, "\n")
print("result:\n", monitor.topk_fitness)
monitor.plot()