Resolución de problemas de Brax en EvoX

Resolución de problemas de Brax en EvoX

EvoX se adentra en la neuroevolución con Brax. Aquí mostraremos un ejemplo de cómo resolver un problema de Brax en EvoX.

# install EvoX and Brax, skip it if you have already installed EvoX or Brax
from importlib.util import find_spec
from IPython.display import HTML

if find_spec("evox") is None:
    %pip install evox
if find_spec("brax") is None:
    %pip install brax
# The dependent packages or functions in this example
import torch
import torch.nn as nn

from evox.algorithms import PSO
from evox.problems.neuroevolution.brax import BraxProblem
from evox.utils import ParamsAndVector
from evox.workflows import EvalMonitor, StdWorkflow

¿Qué es Brax?

Brax es un motor de física rápido y totalmente diferenciable utilizado para la investigación y el desarrollo de robótica, percepción humana, ciencia de materiales, reinforcement learning y otras aplicaciones con gran carga de simulación.

Aquí demostraremos un entorno “swimmer” de Brax.

Para más información, puedes consultar el Github de Brax.

Diseñar una clase de red neuronal

Para empezar, debemos decidir qué red neuronal vamos a construir.

Aquí presentaremos una clase sencilla de Perceptrón Multicapa (MLP).

# Construct an MLP using PyTorch.
# This MLP has 3 layers.


class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.features = nn.Sequential(nn.Linear(17, 8), nn.Tanh(), nn.Linear(8, 6))

    def forward(self, x):
        x = self.features(x)
        return torch.tanh(x)

Iniciar un modelo

A través de la clase SimpleMLP, podemos iniciar un modelo MLP.

# Make sure that the model is on the same device, better to be on the GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Reset the random seed
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Initialize the MLP model
model = SimpleMLP().to(device)

Iniciar un adaptador

Un adaptador puede ayudarnos a convertir los datos de un formato a otro.

adapter = ParamsAndVector(dummy_model=model)

Con un adaptador, podemos disponernos a realizar esta tarea de neuroevolución.

Configurar el proceso de ejecución

Iniciar un algoritmo y un problema

Iniciamos un algoritmo PSO, y el problema es un problema de Brax en el entorno “swimmer”.

# Set the population size
POP_SIZE = 1024

# Get the bound of the PSO algorithm
model_params = dict(model.named_parameters())
pop_center = adapter.to_vector(model_params)
lower_bound = torch.full_like(pop_center, -5)
upper_bound = torch.full_like(pop_center, 5)

# Initialize the PSO, and you can also use any other algorithms
algorithm = PSO(
    pop_size=POP_SIZE,
    lb=lower_bound,
    ub=upper_bound,
    device=device,
)

# Initialize the Brax problem
problem = BraxProblem(
    policy=model,
    env_name="halfcheetah",
    max_episode_length=1000,
    num_episodes=3,
    pop_size=POP_SIZE,
    device=device,
)

En este caso, utilizaremos 1000 pasos para cada episodio, y la recompensa media de 3 episodios se devolverá como el valor de aptitud (fitness).

Configurar un monitor

# set an monitor, and it can record the top 3 best fitnesses
monitor = EvalMonitor(
    topk=3,
    device=device,
)

Iniciar un flujo de trabajo (workflow)

# Initiate an workflow
workflow = StdWorkflow(
    algorithm=algorithm,
    problem=problem,
    monitor=monitor,
    opt_direction="max",
    solution_transform=adapter,
    device=device,
)

Ejecutar el flujo de trabajo

¡Ejecuta el flujo de trabajo y observa la magia!

Nota: El siguiente bloque tardará unos 20 minutos en ejecutarse. El tiempo puede variar dependiendo de tu hardware.

# Set the maximum number of generations
max_generation = 50

# Run the workflow
workflow.init_step()
compiled_step = torch.compile(workflow.step)
for i in range(max_generation):
    if i % 10 == 0:
        print(f"Generation {i}")
    compiled_step()

print(f"Top fitness: {monitor.get_best_fitness()}")
best_params = adapter.to_params(monitor.get_best_solution())
print(f"Best params: {best_params}")
monitor.get_best_fitness()
monitor.plot()
html_string = problem.visualize(best_params)
escaped_string = html_string.replace('"', """)
HTML(f'<iframe srcdoc="{escaped_string}" width="100%" height="480" frameborder="0"></iframe>')

Importante:

  • Normalmente, solo necesitas HTML(problem.visualize(best_params)) para renderizar. El código anterior es una solución alternativa para asegurar que el resultado se muestre correctamente en nuestro sitio web.
  • El algoritmo PSO no está optimizado específicamente para este tipo de tareas, por lo que se esperan limitaciones de rendimiento. Este ejemplo es para fines demostrativos.

¡Esperamos que disfrutes resolviendo problemas de Brax con EvoX y que te diviertas!