馃Ж 隆Difusi贸n estable en JAX / Flax!

'馃Ж 隆Difusi贸n estable en JAX / Flax!' can be condensed to '馃Ж隆Difusi贸n estable en JAX/Flax!'

馃 Hugging Face Diffusers soporta Flax desde la versi贸n 0.5.1! Esto permite una inferencia s煤per r谩pida en las TPUs de Google, como las disponibles en Colab, Kaggle o Google Cloud Platform.

Esta publicaci贸n muestra c贸mo ejecutar inferencia usando JAX / Flax. Si desea m谩s detalles sobre c贸mo funciona Stable Diffusion o desea ejecutarlo en GPU, consulte este notebook de Colab.

Si desea seguir, haga clic en el bot贸n de arriba para abrir esta publicaci贸n como un notebook de Colab.

Primero, aseg煤rese de estar utilizando un backend de TPU. Si est谩 ejecutando este notebook en Colab, seleccione Entorno de ejecuci贸n en el men煤 de arriba, luego seleccione la opci贸n “Cambiar tipo de entorno de ejecuci贸n” y luego seleccione TPU en la configuraci贸n de Acelerador de hardware.

Tenga en cuenta que JAX no es exclusivo de las TPUs, pero brilla en ese hardware porque cada servidor TPU tiene 8 aceleradores TPU trabajando en paralelo.

Configuraci贸n

import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Se encontraron {num_devices} dispositivos JAX de tipo {device_type}.")
assert "TPU" in device_type, "El dispositivo disponible no es una TPU, por favor seleccione TPU desde Editar > Configuraci贸n del cuaderno > Acelerador de hardware"

Salida:

    Se encontraron 8 dispositivos JAX de tipo TPU v2.

Aseg煤rese de tener instalado diffusers.

!pip install diffusers==0.5.1

Luego importamos todas las dependencias.

import numpy as np
import jax
import jax.numpy as jnp

from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline

Carga del modelo

Antes de usar el modelo, debes aceptar la licencia del modelo para descargar y usar los pesos.

La licencia est谩 dise帽ada para mitigar los posibles efectos perjudiciales de un sistema de aprendizaje autom谩tico tan potente. Solicitamos a los usuarios que lean la licencia completa y cuidadosamente. Aqu铆 ofrecemos un resumen:

  1. No puedes utilizar el modelo para producir ni compartir deliberadamente salidas o contenido ilegal o perjudicial,
  2. No reclamamos derechos sobre las salidas que generes, eres libre de usarlas y eres responsable de su uso, el cual no debe ir en contra de las disposiciones establecidas en la licencia, y
  3. Puedes redistribuir los pesos y utilizar el modelo comercialmente y/o como servicio. Si lo haces, ten en cuenta que debes incluir las mismas restricciones de uso que las establecidas en la licencia y compartir una copia de CreativeML OpenRAIL-M con todos tus usuarios.

Los pesos de Flax est谩n disponibles en Hugging Face Hub como parte del repositorio Stable Diffusion. El modelo Stable Diffusion se distribuye bajo la licencia CreateML OpenRail-M. Es una licencia abierta que no reclama derechos sobre las salidas que generas y te proh铆be producir deliberadamente contenido ilegal o perjudicial. La tarjeta del modelo proporciona m谩s detalles, as铆 que t贸mate un momento para leerlos y considera cuidadosamente si aceptas la licencia. Si lo haces, debes ser un usuario registrado en el Hub y usar un token de acceso para que funcione el c贸digo. Tienes dos opciones para proporcionar tu token de acceso:

  • Usa la herramienta de l铆nea de comandos huggingface-cli login en tu terminal y pega tu token cuando se te solicite. Se guardar谩 en un archivo en tu computadora.
  • O usa notebook_login() en un notebook, que hace lo mismo.

La siguiente celda presentar谩 una interfaz de inicio de sesi贸n a menos que ya hayas autenticado antes en esta computadora. Deber谩s pegar tu token de acceso.

if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()

Los dispositivos TPU admiten bfloat16, un tipo de semi-float eficiente. Lo usaremos para nuestras pruebas, pero tambi茅n puedes usar float32 para usar precisi贸n completa en su lugar.

dtype = jnp.bfloat16

Flax es un marco de trabajo funcional, por lo que los modelos son sin estado y los par谩metros se almacenan fuera de ellos. Al cargar el pipeline pre-entrenado de Flax, se devolver谩 tanto el pipeline en s铆 como los pesos del modelo (o par谩metros). Estamos utilizando una versi贸n bf16 de los pesos, lo cual genera advertencias de tipo que se pueden ignorar de forma segura.

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="bf16",
    dtype=dtype,
)

Inferencia

Dado que las TPUs generalmente tienen 8 dispositivos trabajando en paralelo, replicaremos nuestra entrada tantas veces como dispositivos tengamos. Luego realizaremos la inferencia en los 8 dispositivos al mismo tiempo, cada uno responsable de generar una imagen. Por lo tanto, obtendremos 8 im谩genes en el mismo tiempo que tarda un chip en generar una sola.

Despu茅s de replicar la entrada, obtenemos los identificadores de texto tokenizados invocando la funci贸n prepare_inputs del pipeline. La longitud del texto tokenizado se establece en 77 tokens, como lo requiere la configuraci贸n del modelo de texto CLIP subyacente.

prompt = "Una imagen fija de una pel铆cula de Morgan Freeman interpretando a Jimi Hendrix, retrato, lente de 40mm, poca profundidad de campo, primer plano, iluminaci贸n dividida, cinematogr谩fica"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape

Salida :

(8, 77)

Replicaci贸n y paralelizaci贸n

Los par谩metros del modelo y las entradas deben replicarse en los 8 dispositivos paralelos que tenemos. El diccionario de par谩metros se replica utilizando flax.jax_utils.replicate, que recorre el diccionario y cambia la forma de los pesos para que se repitan 8 veces. Los arrays se replican utilizando shard.

p_params = replicate(params)

prompt_ids = shard(prompt_ids)
prompt_ids.shape

Salida :

(8, 1, 77)

Esa forma significa que cada uno de los 8 dispositivos recibir谩 como entrada un array jnp con una forma de (1, 77). Por lo tanto, 1 es el tama帽o del lote por dispositivo. En TPUs con suficiente memoria, podr铆a ser mayor que 1 si quisi茅ramos generar varias im谩genes (por chip) a la vez.

隆Ya casi estamos listos para generar im谩genes! Solo necesitamos crear un generador de n煤meros aleatorios para pasar a la funci贸n de generaci贸n. Este es el procedimiento est谩ndar en Flax, que es muy serio y tiene opiniones sobre los n煤meros aleatorios: se espera que todas las funciones que tratan con n煤meros aleatorios reciban un generador. Esto garantiza la reproducibilidad, incluso cuando estamos entrenando en m煤ltiples dispositivos distribuidos.

La funci贸n auxiliar a continuaci贸n utiliza una semilla para inicializar un generador de n煤meros aleatorios. Siempre que usemos la misma semilla, obtendremos los mismos resultados exactos. Si茅ntete libre de usar diferentes semillas al explorar los resultados m谩s adelante en el cuaderno.

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

Obtenemos un rng y luego lo “dividimos” 8 veces para que cada dispositivo reciba un generador diferente. Por lo tanto, cada dispositivo crear谩 una imagen diferente y todo el proceso es reproducible.

rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

El c贸digo de JAX se puede compilar en una representaci贸n eficiente que se ejecuta muy r谩pido. Sin embargo, debemos asegurarnos de que todas las entradas tengan la misma forma en llamadas posteriores; de lo contrario, JAX tendr谩 que volver a compilar el c贸digo y no podremos aprovechar la velocidad optimizada.

El pipeline de Flax puede compilar el c贸digo por nosotros si pasamos jit = True como argumento. Tambi茅n se asegurar谩 de que el modelo se ejecute en paralelo en los 8 dispositivos disponibles.

La primera vez que ejecutemos la siguiente celda, llevar谩 mucho tiempo compilar, pero las llamadas posteriores (incluso con entradas diferentes) ser谩n mucho m谩s r谩pidas. Por ejemplo, tom贸 m谩s de un minuto compilar en una TPU v2-8 cuando prob茅, pero luego tarda aproximadamente 7s para futuras ejecuciones de inferencia.

images = pipeline(prompt_ids, p_params, rng, jit=True)[0]

Salida:

    Tiempos de CPU: user 464 ms, sys: 105 ms, total: 569 ms
    Tiempo de ejecuci贸n: 7.07 s

El arreglo devuelto tiene forma (8, 1, 512, 512, 3). Lo reestructuramos para eliminar la segunda dimensi贸n y obtener 8 im谩genes de 512 脳 512 脳 3 y luego las convertimos a formato PIL.

images = images.reshape((images.shape[0],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

Visualizaci贸n

Creemos una funci贸n auxiliar para mostrar las im谩genes en una cuadr铆cula.

def image_grid(imgs, filas, columnas):
    w, h = imgs[0].size
    grid = Image.new('RGB', size=(columnas*w, filas*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%columnas*w, i//columnas*h))
    return grid

image_grid(images, 2, 4)

Usando diferentes indicaciones

No es necesario replicar la misma indicaci贸n en todos los dispositivos. Podemos hacer lo que queramos: generar 2 indicaciones 4 veces cada una, o incluso generar 8 indicaciones diferentes al mismo tiempo. 隆Hag谩moslo!

Primero, refactorizaremos el c贸digo de preparaci贸n de la entrada en una funci贸n pr谩ctica:

indicaciones = [
    "Labrador en el estilo de Hokusai",
    "Pintura de una ardilla patinando en Nueva York",
    "HAL-9000 en el estilo de Van Gogh",
    "Times Square bajo el agua, con peces y un delf铆n nadando",
    "Fresco romano antiguo que muestra a un hombre trabajando en su computadora port谩til",
    "Fotograf铆a en primer plano de una joven mujer negra contra un fondo urbano, alta calidad, bokeh",
    "Sill贸n en forma de aguacate",
    "Astronauta payaso en el espacio, con la Tierra en el fondo",
]

prompt_ids = pipeline.prepare_inputs(indicaciones)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0], ) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
image_grid(images, 2, 4)


驴C贸mo funciona la paralelizaci贸n?

Dijimos antes que el flujo de trabajo diffusers de Flax compila autom谩ticamente el modelo y lo ejecuta en paralelo en todos los dispositivos disponibles. Ahora veremos brevemente c贸mo funciona ese proceso.

La paralelizaci贸n de JAX se puede hacer de varias formas. La m谩s sencilla utiliza la funci贸n jax.pmap para lograr la paralelizaci贸n de un solo programa y m煤ltiples datos (SPMD, por sus siglas en ingl茅s). Esto significa que ejecutaremos varias copias del mismo c贸digo, cada una con datos diferentes. Tambi茅n es posible utilizar enfoques m谩s sofisticados, te invitamos a revisar la documentaci贸n de JAX y las p谩ginas de pjit para explorar este tema si est谩s interesado.

jax.pmap hace dos cosas por nosotros:

  • Compila (o realiza jit) el c贸digo, como si hubi茅ramos invocado jax.jit(). Esto no ocurre cuando llamamos a pmap, sino la primera vez que se invoca la funci贸n con pmap.
  • Asegura que el c贸digo compilado se ejecute en paralelo en todos los dispositivos disponibles.

Para mostrar c贸mo funciona, aplicamos pmap al m茅todo _generate del flujo de trabajo, que es el m茅todo privado que genera las im谩genes. Ten en cuenta que este m茅todo puede cambiar de nombre o eliminarse en futuras versiones de diffusers.

p_generate = pmap(pipeline._generate)

Despu茅s de utilizar pmap, la funci贸n preparada p_generate har谩 conceptualmente lo siguiente:

  • Invocar谩 una copia de la funci贸n subyacente pipeline._generate en cada dispositivo.
  • Enviar a cada dispositivo una porci贸n diferente de los argumentos de entrada. Para eso se utiliza el fragmentado. En nuestro caso, prompt_ids tiene forma (8, 1, 77, 768). Este arreglo se dividir谩 en 8 y cada copia de _generate recibir谩 una entrada con forma (1, 77, 768).

Podemos codificar _generate ignorando por completo el hecho de que se invocar谩 en paralelo. Solo nos importa nuestro tama帽o de lote (1 en este ejemplo) y las dimensiones que tienen sentido para nuestro c贸digo, y no tenemos que cambiar nada para hacer que funcione en paralelo.

De la misma manera que cuando usamos la llamada de canalizaci贸n, la primera vez que ejecutemos la siguiente celda llevar谩 un tiempo, pero luego ser谩 mucho m谩s r谩pido.

images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
images.shape

Salida :

    Tiempo de CPU: usuario 118 ms, sistema: 83.9 ms, total: 202 ms
    Tiempo de ejecuci贸n: 6.82 s

    (8, 1, 512, 512, 3)

Usamos block_until_ready() para medir correctamente el tiempo de inferencia, porque JAX utiliza un despacho as铆ncrono y devuelve el control al bucle de Python tan pronto como puede. No es necesario usar eso en tu c贸digo; el bloqueo ocurrir谩 autom谩ticamente cuando quieras usar el resultado de un c谩lculo que a煤n no se ha materializado.

We will continue to update Zepes; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

Inteligencia Artificial

Conoce a Rodin un nuevo marco de inteligencia artificial (IA) para generar avatares digitales en 3D a partir de diversas fuentes de entrada.

Los modelos generativos se est谩n convirtiendo en la soluci贸n por defecto para muchas tareas desafiantes en ciencias d...

Inteligencia Artificial

Investigadores de Inteligencia Artificial (IA) de la Universidad de Cornell proponen un nuevo marco de red neuronal para abordar el problema de la segmentaci贸n de video.

La edici贸n de im谩genes y videos son dos de las aplicaciones m谩s populares para los usuarios de computadoras. Con el a...

Inteligencia Artificial

La pantalla 3D podr铆a llevar el tacto al mundo digital

Los ingenieros dise帽aron una pantalla de transformaci贸n compuesta por una cuadr铆cula de m煤sculos rob贸ticos blandos qu...

Inteligencia Artificial

EE. UU. y la UE completan el tan esperado acuerdo sobre el intercambio de datos

El acuerdo pone fin a la incertidumbre legal para Meta, Google y decenas de empresas, al menos por ahora.