Vectoriza y paraleliza entornos de RL con JAX Aprendizaje por refuerzo a la velocidad de la luz⚡

Optimiza y agiliza entornos de RL con JAX Aprendizaje por refuerzo a la velocidad de la luz⚡

En este artículo, aprendemos a vectorizar un entorno de RL y entrenar 30 agentes de Q-learning en paralelo en una CPU, a 1.8 millones de iteraciones por segundo.

Imagen de Google DeepMind en Unsplash

En la historia anterior, presentamos el Aprendizaje de Diferencia Temporal, particularmente Q-learning, en el contexto de un GridWorld.

Aprendizaje de Diferencia Temporal y la importancia de la exploración: Una guía ilustrada

Una comparación de los métodos de aprendizaje TD del modelo libre (Q-learning) y del modelo basado (Dyna-Q y Dyna-Q+) en un mundo de grid dinámico.

towardsdatascience.com

Aunque esta implementación sirvió para demostrar las diferencias en el rendimiento y los mecanismos de exploración de estos algoritmos, fue dolorosamente lenta.

De hecho, el entorno y los agentes fueron programados principalmente en Numpy, que de ninguna manera es estándar en RL, aunque hace que el código sea fácil de entender y depurar.

En este artículo, veremos cómo escalar los experimentos de RL mediante la vectorización de entornos y la paralelización fluida del entrenamiento de docenas de agentes utilizando JAX. En particular, este artículo cubre:

  • Conceptos básicos y características útiles de JAX para RL
  • Entornos vectorizados y por qué son tan rápidos
  • Implementación de un entorno, política y agente de Q-learning en JAX
  • Entrenamiento de un solo agente
  • Cómo paralelizar el entrenamiento de agentes ¡y qué tan fácil es!

Todo el código destacado en este artículo está disponible en GitHub:

GitHub – RPegoud/jax_rl: Implementación de algoritmos de ML y entornos vectorizados en JAX

Implementación de algoritmos de ML y entornos vectorizados en JAX – GitHub – RPegoud/jax_rl: Implementación de ML en JAX …

github.com

Conceptos Básicos de JAX

JAX es otro marco de aprendizaje profundo en Python desarrollado por Google y ampliamente utilizado por empresas como DeepMind.

“JAX es Autograd (diferenciación automática) y XLA (Álgebra Lineal Acelerada, un compilador de TensorFlow), combinados para cálculos numéricos de alto rendimiento.” — Documentación oficial

A diferencia de lo que la mayoría de los desarrolladores de Python están acostumbrados, JAX no adopta el paradigma de programación orientada a objetos (OOP), sino más bien la programación funcional (FP)[1].

En pocas palabras, se basa en funciones puras (determinísticas y sin efectos secundarios) y estructuras de datos inmutables (en lugar de cambiar los datos en su lugar, se crean nuevas estructuras de datos con las modificaciones deseadas) como bloques de construcción principales. Como resultado, FP fomenta un enfoque más funcional y matemático de la programación, lo que lo hace adecuado para tareas como computación numérica y aprendizaje automático.

Ilustraremos las diferencias entre esos dos paradigmas observando el pseudocódigo de una función de actualización Q:

  • El enfoque orientado a objetos se basa en una instancia de clase que contiene varias variables de estado (como los valores Q). La función de actualización se define como un método de clase que actualiza el estado interno de la instancia.
  • El enfoque de programación funcional se basa en una función pura. De hecho, esta actualización Q es determinista ya que los valores Q se pasan como argumento. Por lo tanto, cualquier llamada a esta función con las mismas entradas dará como resultado las mismas salidas, mientras que las salidas de un método de clase pueden depender del estado interno de la instancia. Además, las estructuras de datos como matrices se definen y se modifican en el ámbito global.
Implementación de una actualización Q en programación orientada a objetos y programación funcional (realizada por el autor)

Como tal, JAX ofrece una variedad de decoradores de funciones que son particularmente útiles en el contexto del RL:

  • vmap (mapeo vectorizado): Permite aplicar una función que actúa sobre una sola muestra a un lote. Por ejemplo, si env.step() es una función que realiza un paso en un solo entorno, vmap(env.step)() es una función que realiza un paso en múltiples entornos. En otras palabras, vmap agrega una dimensión de lote a una función.
Ilustración de una función de paso vectorizada utilizando vmap (realizada por el autor)
  • jit (compilación justo a tiempo): Permite que JAX realice una “Compilación justo a tiempo de una función Python JAX” haciéndola compatible con XLA. Básicamente, al usar jit podemos compilar funciones y proporciona mejoras significativas de velocidad (a cambio de una sobrecarga adicional al compilar por primera vez la función).
  • pmap (mapeo paralelo): Similarmente a vmap, pmap permite la paralelización fácil. Sin embargo, en lugar de agregar una dimensión de lote a una función, replica la función y la ejecuta en varios dispositivos XLA. Nota: al aplicar pmap, jit también se aplica automáticamente.
Ilustración de una función de paso paralelizada utilizando pmap (realizada por el autor)

Ahora que hemos sentado las bases de JAX, veremos cómo obtener velocidades masivas al vectorizar entornos.

Entornos Vectorizados:

Primero, ¿qué es un entorno vectorizado y qué problemas resuelve la vectorización?

En la mayoría de los casos, los experimentos de RL se ralentizan debido a las transferencias de datos entre la CPU y la GPU. Los algoritmos de RL de aprendizaje profundo, como Proximal Policy Optimization (PPO), utilizan redes neuronales para aproximar la política.

Como siempre en Deep Learning, las Redes Neuronales utilizan GPUs durante el proceso de entrenamiento y inferencia. Sin embargo, en la mayoría de los casos, los entornos se ejecutan en la CPU (incluso en el caso de utilizar múltiples entornos en paralelo).

Esto significa que el ciclo usual del aprendizaje por refuerzo (RL), que consiste en seleccionar acciones a través de la política (Redes Neuronales) y recibir observaciones y recompensas del entorno, requiere de constantes intercambios entre la GPU y la CPU, lo cual afecta el rendimiento.

Además, utilizar frameworks como PyTorch sin “jitting” puede causar cierta sobrecarga, ya que la GPU podría tener que esperar a que Python envíe de vuelta las observaciones y recompensas desde la CPU.

Configuración usual de entrenamiento en batch para RL en PyTorch (hecha por el autor)

Por otro lado, JAX nos permite ejecutar fácilmente entornos en batch en la GPU, eliminando la fricción causada por la transferencia de datos entre la GPU y la CPU.

Además, al compilar nuestro código en JAX con jit a XLA, la ejecución ya no se ve (o al menos se ve menos) afectada por la ineficiencia de Python.

Configuración de entrenamiento en batch para RL en JAX (hecha por el autor)

Para más detalles y emocionantes aplicaciones en la investigación de meta-aprendizaje con RL, recomiendo encarecidamente esta publicación de blog de Chris Lu.

Implementaciones del Entorno, Agente y Política:

Echemos un vistazo a la implementación de las diferentes partes de nuestro experimento de RL. Aquí hay una visión general de alto nivel de las funciones básicas que necesitaremos:

Métodos de clase requeridos para una configuración simple de RL (hecha por el autor)

El entorno

Esta implementación sigue el esquema proporcionado por Nikolaj Goodger en su excelente artículo sobre cómo escribir entornos en JAX.

Escribiendo un entorno de RL en JAX

Cómo ejecutar CartPole a 1.25 Mil Millones de Pasos/Seg

VoAGI.com

Comencemos con una vista general de alto nivel del entorno y sus métodos. Este es un plan general para implementar un entorno en JAX:

Veamos más de cerca los métodos de clase (como recordatorio, las funciones que comienzan con “_” son privadas y no deben ser llamadas fuera del alcance de la clase):

  • _get_obs: Este método convierte el estado del entorno en una observación para el agente. En un entorno parcialmente observable o estocástico, aquí se aplicarían las funciones de procesamiento al estado.
  • _reset: Como ejecutaremos múltiples agentes en paralelo, necesitamos un método para reiniciar individualmente al completar un episodio.
  • _reset_if_done: Este método se llamará en cada paso y activará _reset si el indicador “done” está configurado en True.
  • reset: Este método se llama al comienzo del experimento para obtener el estado inicial de cada agente, así como las claves aleatorias asociadas.
  • step: Dado un estado y una acción, el entorno devuelve una observación (nuevo estado), una recompensa y el indicador de “done” actualizado.

En la práctica, una implementación genérica de un entorno de GridWorld se vería así:

Obsérvese que, como se mencionó anteriormente, todos los métodos de la clase siguen el paradigma de programación funcional. De hecho, nunca actualizamos el estado interno de la instancia de la clase. Además, los atributos de clase son todos constantes que no se modificarán después de la instanciación.

Echemos un vistazo más de cerca:

  • __init__: En el contexto de nuestro GridWorld, las acciones disponibles son [0, 1, 2, 3]. Estas acciones se traducen en una matriz bidimensional utilizando self.movements y se agregan al estado en la función step.
  • _get_obs: Nuestro entorno es determinista y totalmente observable, por lo tanto, el agente recibe el estado directamente en lugar de una observación procesada.
  • _reset_if_done: El argumento env_state corresponde a la tupla (state, key) donde key es un jax.random.PRNGKey. Esta función simplemente devuelve el estado inicial si el indicador done está configurado en True, sin embargo, no podemos usar el flujo de control convencional de Python dentro de las funciones jitted de JAX. Usando jax.lax.cond obtenemos esencialmente una expresión equivalente a:
def cond(condition, true_fun, false_fun, operand):  if condition: # si el indicador done == True    return true_fun(operand)  # return self._reset(key)  else:    return false_fun(operand) # return env_state
  • step: Convertimos la acción en un movimiento y lo agregamos al estado actual (jax.numpy.clip asegura que el agente permanezca dentro de la cuadrícula). Luego actualizamos la tupla env_state antes de verificar si es necesario reiniciar el entorno. Como la función step se utiliza con frecuencia durante el entrenamiento, jitearlo permite obtener mejoras significativas en el rendimiento. El decorador @partial(jit, static_argnums=(0, ) indica que el argumento “self” del método de la clase debe considerarse estático. En otras palabras, las propiedades de la clase son constantes y no cambiarán durante las llamadas sucesivas a la función step.

Agente de Q-Learning

El agente de Q-Learning está definido por la función update, así como una tasa de aprendizaje y un factor de descuento estáticos.

Nuevamente, al jitear la función de actualización, pasamos el argumento “self” como estático. Además, obsérvese que la matriz de valores Q se modifica in situ usando set() y su valor no se almacena como un atributo de clase.

Política Epsilon-Greedy

Por último, la política utilizada en este experimento es la política epsilon-greedy estándar. Un detalle importante es que utiliza desempates aleatorios, lo que significa que si el valor Q máximo no es único, la acción se seleccionará de manera uniforme entre los valores Q máximos (usar argmax siempre devolvería la primera acción con valor Q máximo). Esto es especialmente importante si los valores Q se inicializan como una matriz de ceros, ya que siempre se seleccionaría la acción 0 (moverse a la derecha).

De lo contrario, la política se puede resumir con este fragmento de código:

action = lax.cond(            explore, # si p < epsilon            _random_action_fn, # seleccionar una acción aleatoria según la clave            _greedy_action_fn, # seleccionar la acción codiciosa con respecto a los valores Q            operand=subkey, # utilizar subkey como argumento para las funciones anteriores        )return action, subkey

Observa que cuando usamos una clave en JAX (por ejemplo, aquí hemos elegido un número aleatorio y hemos usado random.choice) es práctica común dividir la clave posteriormente (es decir, “pasar a un nuevo estado aleatorio”, más detalles aquí).

Bucle de entrenamiento de un solo agente:

Ahora que tenemos todos los componentes necesarios, entrenemos a un solo agente.

Aquí tienes un bucle de entrenamiento pythonico, como puedes ver, esencialmente seleccionamos una acción utilizando la política, realizamos un paso en el entorno y actualizamos los valores Q hasta el final de un episodio. Luego repetimos el proceso para N episodios. Como veremos dentro de un minuto, esta forma de entrenar a un agente es bastante ineficiente, sin embargo, resume los pasos clave del algoritmo de una manera legible:

En una CPU única, completamos 10.000 episodios en 11 segundos, a una tasa de 881 episodios y 21 680 pasos por segundo.

100%|██████████| 10000/10000 [00:11<00:00, 881.86it/s]Número total de pasos: 238 488Número de pasos por segundo: 21 680

Ahora, replicaremos el mismo bucle de entrenamiento utilizando la sintaxis de JAX. Aquí tienes una descripción de alto nivel de la función rollout:

Training rollout function using JAX syntax (made by the author)

En resumen, la función de lanzamiento:

  1. Inicializa las observaciones, las <strong recompensas y las banderas finalizadas como matrices vacías con una dimensión igual al número de pasos de tiempo utilizando jax.numpy.zeros. Los Q-valores se inicializan como una matriz vacía con forma [pasos_de_tiempo+1, dimensión_x_del_grid, dimensión_y_del_grid, n_acciones]
  2. Llama a la función env.reset() para obtener el estado inicial.
  3. Utiliza la función jax.lax.fori_loop() para llamar a una función fori_body() N veces, donde N es el parámetro de paso_de_tiempo
  4. La función fori_body() se comporta de manera similar al bucle Python anterior. Después de seleccionar una acción, realizar un paso y calcular la actualización de Q, actualizamos las matrices de obs, recompensas, finalizadas y q_values en su lugar (la actualización de Q apunta al paso de tiempo t+1)

Esta complejidad adicional conduce a una velocidad de hasta 85 veces más rápida, ahora entrenamos nuestro agente a aproximadamente 1,83 millones de pasos por segundo. Ten en cuenta que aquí, el entrenamiento se realiza en una CPU única, ya que el entorno es sencillo.

Sin embargo, la vectorización de extremo a extremo escala aún mejor cuando se aplica a entornos complejos y a algoritmos que se benefician de múltiples GPU (un artículo de Chris Lu informa de una asombrosa mejora de 4000 veces más rápido entre una implementación de PPO de PyTorch de CleanRL y una reproducción de JAX).

100%|██████████| 1000000/1000000 [00:00<00:00, 1837563.94it/s]Número total de pasos: 1 000 000Número de pasos por segundo: 1 837 563

Después de entrenar a nuestro agente, trazamos el valor Q máximo para cada celda (es decir, estado) de GridWorld y observamos que ha aprendido eficazmente cómo ir desde el estado inicial (esquina inferior derecha) hasta el objetivo (esquina superior izquierda).

Representación en forma de mapa de calor del valor Q máximo para cada celda de GridWorld (hecha por el autor)

Bucle de entrenamiento de agentes en paralelo:

Como prometimos, ahora que hemos escrito las funciones necesarias para entrenar a un único agente, nos queda poco trabajo para entrenar a múltiples agentes en paralelo en entornos en lotes!

Gracias a vmap, podemos transformar rápidamente nuestras funciones anteriores para trabajar con lotes de datos. Solo tenemos que especificar las formas de entrada y salida esperadas, por ejemplo para env.step:

  • in_axes = ((0,0), 0) representa la forma de entrada, que está compuesta por la tupla env_state (dimensión (0, 0)) y una observación (dimensión 0).
  • out_axes = ((0, 0), 0, 0, 0) representa la forma de salida, con la salida siendo ((env_state), obs, recompensa, finalizada).
  • Ahora, podemos llamar a v_step en una matriz de env_states y acciones y recibir una matriz de env_states procesados, observaciones, recompensas y banderas finalizadas.
  • Ten en cuenta que también utilizamos jit para todas las funciones en lotes para mejorar el rendimiento (es discutible que jit en env.reset() no sea necesario dado que solo se llama una vez en nuestra función de entrenamiento).

El último ajuste que debemos hacer es agregar una dimensión de lote a nuestras matrices para tener en cuenta los datos de cada agente.

Al hacer esto, obtenemos una función que nos permite entrenar múltiples agentes en paralelo, con ajustes mínimos en comparación con la función de un solo agente:

Obtenemos resultados similares con esta versión de nuestra función de entrenamiento:

100%|██████████| 100000/100000 [00:02<00:00, 49036.11it/s]Número total de pasos: 100 000 * 30 = 3 000 000Número de pasos por segundo: 49 036 * 30 = 1 471 080

¡Y eso es todo! Gracias por leer hasta aquí, espero que este artículo haya sido una introducción útil para implementar entornos vectorizados en JAX.

Si disfrutaste la lectura, por favor considera compartir este artículo y darle una estrella a mi repositorio de GitHub, ¡gracias por tu apoyo! 🙏

GitHub – RPegoud/jax_rl: Implementación de algoritmos de RL y entornos vectorizados en JAX

Implementación de algoritmos de RL y entornos vectorizados en JAX – GitHub – RPegoud/jax_rl: Implementación de RL…

github.com

Finalmente, para aquellos interesados en profundizar un poco más, aquí hay una lista de recursos útiles que me ayudaron a comenzar con JAX y redactar este artículo:

Una lista seleccionada de artículos y recursos increíbles sobre JAX:

[1] Coderized, (programación funcional) El estilo de codificación más puro, donde los errores son casi imposibles, YouTube

[2] Aleksa Gordić, JAX From Zero to Hero YouTube Playlist (2022), The AI Epiphany

[3] Nikolaj Goodger, Escribir un entorno de RL en JAX (2021)

[4] Chris Lu, Lograr una aceleración de 4000 veces y descubrimientos meta-evolutivos con PureJaxRL (2023), Universidad de Oxford, Foerster Lab for AI Research

[5] Nicholas Vadivelu, Awesome-JAX (2020), una lista de bibliotecas, proyectos y recursos de JAX

[6] Documentación oficial de JAX, Entrenando una red neuronal simple con carga de datos de PyTorch

We will continue to update Zepes; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

Inteligencia Artificial

Empresa derivada de la Universidad de Glasgow recauda $43 millones para 'Digitalizar la Química

Chemify, que se separó de la Universidad de Glasgow del Reino Unido en 2022, ha recibido $43 millones de financiamien...

Inteligencia Artificial

Investigadores de Stanford y Microsoft presentan Inteligencia Artificial de Auto-Mejora Aprovechando GPT-4 para elevar el rendimiento del programa de andamiaje.

Casi todos los objetivos descritos en lenguaje natural pueden optimizarse mediante la consulta a un modelo de lenguaj...

Inteligencia Artificial

Top Herramientas/Startups de Datos Sintéticos para Modelos de Aprendizaje Automático en 2023

La información creada intencionalmente en lugar de ser el resultado de eventos reales se conoce como datos sintéticos...

Inteligencia Artificial

Apple entra en la competencia de la IA generativa con el chatbot 'AppleGPT

El gigante tecnológico Apple sigue adelante con su esperado chatbot impulsado por IA, tentativamente llamado “A...