Resolver Problemas Brax en EvoX
EvoX profundiza en la neuroevolucion con Brax. Aqui mostraremos un ejemplo de resolucion de un problema Brax en EvoX.
# install EvoX and Brax, skip it if you have already installed EvoX or Brax
from importlib.util import find_spec
from IPython.display import HTML
if find_spec("evox") is None:
%pip install evox
if find_spec("brax") is None:
%pip install brax
# The dependent packages or functions in this example
import torch
import torch.nn as nn
from evox.algorithms import PSO
from evox.problems.neuroevolution.brax import BraxProblem
from evox.utils import ParamsAndVector
from evox.workflows import EvalMonitor, StdWorkflow
Que es Brax
Brax es un motor de fisica rapido y completamente diferenciable utilizado para investigacion y desarrollo de robotica, percepcion humana, ciencia de materiales, aprendizaje por refuerzo y otras aplicaciones intensivas en simulacion.
Aqui demostraremos un entorno “swimmer” de Brax.
Para mas informacion, puede visitar el Github de Brax.
Disenar una clase de red neuronal
Para comenzar, necesitamos decidir que red neuronal vamos a construir.
Aqui daremos una clase simple de Perceptron Multicapa (MLP).
# Construct an MLP using PyTorch.
# This MLP has 3 layers.
class SimpleMLP(nn.Module):
def __init__(self):
super(SimpleMLP, self).__init__()
self.features = nn.Sequential(nn.Linear(17, 8), nn.Tanh(), nn.Linear(8, 6))
def forward(self, x):
x = self.features(x)
return torch.tanh(x)
Iniciar un modelo
A traves de la clase SimpleMLP, podemos iniciar un modelo MLP.
# Make sure that the model is on the same device, better to be on the GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Reset the random seed
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Initialize the MLP model
model = SimpleMLP().to(device)
Iniciar un adaptador
Un adaptador puede ayudarnos a convertir los datos de ida y vuelta.
adapter = ParamsAndVector(dummy_model=model)
Con un adaptador, podemos comenzar a realizar esta tarea de Neuroevolucion.
Configurar el proceso de ejecucion
Iniciar un algoritmo y un problema
Iniciamos un algoritmo PSO, y el problema es un problema Brax en el entorno “swimmer”.
# Set the population size
POP_SIZE = 1024
# Get the bound of the PSO algorithm
model_params = dict(model.named_parameters())
pop_center = adapter.to_vector(model_params)
lower_bound = torch.full_like(pop_center, -5)
upper_bound = torch.full_like(pop_center, 5)
# Initialize the PSO, and you can also use any other algorithms
algorithm = PSO(
pop_size=POP_SIZE,
lb=lower_bound,
ub=upper_bound,
device=device,
)
# Initialize the Brax problem
problem = BraxProblem(
policy=model,
env_name="halfcheetah",
max_episode_length=1000,
num_episodes=3,
pop_size=POP_SIZE,
device=device,
)
En este caso, usaremos 1000 pasos para cada episodio, y la recompensa promedio de 3 episodios se devolvera como el valor de aptitud.
Establecer un monitor
# set an monitor, and it can record the top 3 best fitnesses
monitor = EvalMonitor(
topk=3,
device=device,
)
Iniciar un flujo de trabajo
# Initiate an workflow
workflow = StdWorkflow(
algorithm=algorithm,
problem=problem,
monitor=monitor,
opt_direction="max",
solution_transform=adapter,
device=device,
)
Ejecutar el flujo de trabajo
Ejecute el flujo de trabajo y vea la magia!
Nota: El siguiente bloque tardara alrededor de 20 minutos en ejecutarse. El tiempo puede variar dependiendo de su hardware.
# Set the maximum number of generations
max_generation = 50
# Run the workflow
workflow.init_step()
compiled_step = torch.compile(workflow.step)
for i in range(max_generation):
if i % 10 == 0:
print(f"Generation {i}")
compiled_step()
print(f"Top fitness: {monitor.get_best_fitness()}")
best_params = adapter.to_params(monitor.get_best_solution())
print(f"Best params: {best_params}")
monitor.get_best_fitness()
monitor.plot()
html_string = problem.visualize(best_params)
escaped_string = html_string.replace('"', """)
HTML(f'<iframe srcdoc="{escaped_string}" width="100%" height="480" frameborder="0"></iframe>')
Importante:
- Normalmente, solo necesita
HTML(problem.visualize(best_params))para renderizar. El codigo anterior es una solucion alternativa para asegurar que el resultado se muestre correctamente en nuestro sitio web.- El algoritmo PSO no esta especificamente optimizado para este tipo de tarea, por lo que se esperan limitaciones de rendimiento. Este ejemplo es con fines de demostracion.
Esperamos que disfrute resolviendo problemas Brax con EvoX y se divierta!