Segmentación de Imágenes Eficiente utilizando PyTorch Parte 4

Efficient Image Segmentation with PyTorch Part 4.

Un modelo basado en Vision Transformer

En esta serie de 4 partes, implementaremos la segmentación de imágenes paso a paso desde cero utilizando técnicas de aprendizaje profundo en PyTorch. En esta parte nos centraremos en implementar un modelo basado en Vision Transformer para la segmentación de imágenes.

Co-autoría con Naresh Singh

Figura 1: Resultado de la ejecución de la segmentación de imágenes utilizando una arquitectura de modelo de Vision Transformer. De arriba abajo, imágenes de entrada, máscaras de segmentación de la verdad del suelo y máscaras de segmentación predichas. Fuente: Autor(es)

Esquema del artículo

En este artículo, visitaremos la arquitectura Transformer que ha revolucionado el mundo del aprendizaje profundo. El transformer es una arquitectura multimodal que puede modelar diferentes modalidades como el lenguaje, la visión y el audio.

En este artículo, vamos a:

  1. Aprender acerca de la arquitectura del transformer y los conceptos clave involucrados
  2. Comprender la arquitectura del Vision Transformer
  3. Introducir un modelo de Vision Transformer que está escrito desde cero para que pueda apreciar todos los bloques de construcción y partes móviles
  4. Seguir un tensor de entrada alimentado en este modelo e inspeccionar cómo cambia la forma
  5. Utilizar este modelo para realizar la segmentación de imágenes en el conjunto de datos Oxford IIIT Pet
  6. Observar los resultados de esta tarea de segmentación
  7. Introducir brevemente el SegFormer, un Vision Transformer de última generación para la segmentación semántica

A lo largo de este artículo, haremos referencia al código y los resultados de este cuaderno para el entrenamiento del modelo. Si desea reproducir los resultados, necesitará una GPU para asegurarse de que el primer cuaderno se complete en un tiempo razonable.

Artículos en esta serie

Esta serie es para lectores de todos los niveles de experiencia en aprendizaje profundo. Si desea aprender sobre la práctica del aprendizaje profundo y la inteligencia artificial de visión junto con algo de teoría sólida y experiencia práctica, ¡ha venido al lugar correcto! Se espera que esta serie conste de 4 partes con los siguientes artículos:

  1. Conceptos e ideas
  2. Un modelo basado en CNN
  3. Convoluciones separables en profundidad
  4. Un modelo basado en Vision Transformer (este artículo)

Comencemos nuestro viaje hacia los Vision Transformers con una introducción y comprensión intuitiva de la arquitectura del Transformer.

La arquitectura del Transformer

Podemos pensar en la arquitectura del Transformer como una composición de capas intercaladas de comunicación y computación. Esta idea se representa visualmente en la figura 2. El Transformer tiene N unidades de procesamiento (N es 3 en la Figura 2), cada una de las cuales es responsable de procesar una fracción 1/N de la entrada. Para que esas unidades de procesamiento produzcan resultados significativos, cada una de ellas necesita tener una vista global de la entrada. Por lo tanto, el sistema comunica repetidamente información sobre los datos en cada unidad de procesamiento a todas las demás unidades de procesamiento; se muestra mediante flechas rojas, verdes y azules que van desde cada unidad de procesamiento a todas las demás unidades de procesamiento. Esto se sigue de alguna computación basada en esta información. Después de suficientes repeticiones de este proceso, el modelo es capaz de producir los resultados deseados.

Figura 2: Comunicación e intercalación de computación en transformers. La imagen muestra solo 2 capas de comunicación y computación. En la práctica, hay muchas capas más. Fuente: Autor(es).

Vale la pena señalar que la mayoría de los recursos en línea discuten tanto el codificador como el decodificador del Transformer como se presenta en el artículo titulado “Attention is all you need”. Sin embargo, en este artículo, describiremos solo la parte del codificador del Transformer.

Echemos un vistazo más de cerca a lo que constituye la comunicación y la computación en los Transformers.

Comunicación en Transformers: Atención

En los Transformers, la comunicación se implementa mediante una capa conocida como capa de atención. En PyTorch, esto se llama MultiHeadAttention. Llegaremos al motivo de ese nombre en un momento.

La documentación dice:

“Permite al modelo asistir simultáneamente a la información de diferentes subespacios de representación, como se describe en el artículo: Attention is all you need”.

El mecanismo de atención consume un tensor de entrada x de forma (Batch, Longitud, Características) y produce un tensor de forma similar y tal que las características de cada entrada se actualizan en función de las otras entradas en la misma instancia a las que el tensor está prestando atención. Por lo tanto, las características de cada tensor de longitud “Características” en la instancia de tamaño “Longitud” se actualizan en función de todos los demás tensores. Aquí es donde entra en juego el costo cuadrático del mecanismo de atención.

En el contexto de un transformador de visión, la entrada al transformador es una imagen. Supongamos que esta es una imagen de 128 x 128 (ancho, alto). La dividimos en múltiples parches más pequeños de tamaño (16 x 16). Para una imagen de 128 x 128, obtenemos 64 parches (Longitud), 8 parches en cada fila y 8 filas de parches.

Cada uno de estos 64 parches de tamaño 16 x 16 píxeles se considera una entrada separada al modelo del transformador. Sin profundizar demasiado en los detalles, debería ser suficiente pensar en este proceso como impulsado por 64 unidades de procesamiento diferentes, cada una de las cuales procesa un solo parche de imagen de 16×16.

En cada ronda, el mecanismo de atención en cada unidad de procesamiento es responsable de mirar el parche de imagen del cual es responsable y consultar cada una de las otras 63 unidades de procesamiento restantes para solicitar información que pueda ser relevante y útil para ayudarlo a procesar efectivamente su propio parche de imagen.

El paso de comunicación a través de la atención es seguido por el cálculo, que veremos a continuación.

Cómputo en transformadores: Perceptrón de varias capas

El cómputo en los transformadores no es más que una unidad PerceptrónDeVariasCapas (MLP). Esta unidad está compuesta por 2 capas lineales, con una no linealidad GeLU en el medio. Se puede considerar el uso de otras no linealidades también. Esta unidad primero proyecta la entrada a 4 veces el tamaño y la vuelve a proyectar a 1x, que es el mismo que el tamaño de entrada.

En el código que veremos en nuestro cuaderno, esta clase se llama MultiLayerPerceptron. El código se muestra a continuación.

class MultiLayerPerceptron (nn.Sequential):    def __init__(self, tamaño_embebido, abandono):        super().__init__(            nn.Linear(tamaño_embebido, tamaño_embebido * 4),            nn.GELU(),            nn.Linear(tamaño_embebido * 4, tamaño_embebido),            nn.Dropout(p=abandono),        )    # end def# end class

Ahora que entendemos el funcionamiento de alto nivel de la arquitectura del transformador, centrémonos en el transformador de visión, ya que vamos a realizar la segmentación de imágenes.

El Transformador de Visión

El transformador de visión fue presentado por primera vez en el artículo titulado “Una imagen vale 16 x 16 palabras: Transformadores para el reconocimiento de imágenes a escala”. El artículo describe cómo los autores aplican la arquitectura del transformador de vainilla al problema de la clasificación de imágenes. Esto se hace dividiendo la imagen en parches de tamaño 16×16 y tratando cada parche como un token de entrada para el modelo. El modelo de codificador de transformador recibe estos tokens de entrada y se le pide que prediga una clase para la imagen de entrada.

Figura 4: Fuente: Transformadores para el reconocimiento de imágenes a escala.

En nuestro caso, estamos interesados en la segmentación de imágenes. Podemos considerarlo como una tarea de clasificación a nivel de píxel porque pretendemos predecir una clase objetivo por píxel.

Hacemos un cambio pequeño pero importante en el transformador de visión vainilla y reemplazamos la cabeza MLP para clasificación por una cabeza MLP para clasificación a nivel de píxel. Tenemos una sola capa lineal en la salida que es compartida por cada parche cuya máscara de segmentación es predicha por el transformador de visión. Esta capa lineal compartida predice una máscara de segmentación para cada parche que se envió como entrada al modelo.

En el caso del transformador de visión, un parche de tamaño 16×16 se considera equivalente a un único token de entrada en un momento específico.

Figura 5: Funcionamiento de extremo a extremo del transformador de visión para segmentación de imágenes. Imagen generada usando este cuaderno. Fuente: Autor(es).

Construyendo una intuición para las dimensiones de tensores en transformadores de visión

Cuando se trabaja con CNNs profundos, las dimensiones de los tensores que usamos en su mayor parte son (N, C, H, W), donde las letras significan lo siguiente:

  • N: Tamaño del lote
  • C: Número de canales
  • H: Altura
  • W: Ancho

Puede ver que este formato está hecho para el procesamiento de imágenes 2D, ya que presenta características muy específicas para las imágenes.

Por otro lado, con los transformadores, las cosas se vuelven mucho más genéricas y agnósticas al dominio. Lo que veremos a continuación se aplica a vision, texto, NLP, audio u otros problemas donde los datos de entrada se pueden representar como una secuencia. Vale la pena señalar que hay poco sesgo específico de visión en la representación de tensores a medida que fluyen a través de nuestro transformador de visión.

Cuando se trabaja con transformadores y atención en general, esperamos que los tensores tengan la siguiente forma: (B, T, C), donde las letras significan lo siguiente:

  • B: Tamaño del lote (igual que en CNNs)
  • T: Dimensión temporal o longitud de la secuencia. Esta dimensión también se llama L a veces. En el caso de los transformadores de visión, cada parche de imagen corresponde a esta dimensión. Si tenemos 16 parches de imagen, entonces el valor de la dimensión T será 16
  • C: La dimensión del canal o tamaño de incrustación. Esta dimensión también se llama E a veces. Al procesar imágenes, cada parche de tamaño 3x16x16 (Canal, Ancho, Altura) se asigna a través de una capa de incrustación de parche a una incrustación de tamaño C. Veremos cómo se hace esto más adelante.

Sumergámonos en cómo el tensor de imagen de entrada cambia y se procesa en su camino hacia la predicción de la máscara de segmentación.

El viaje de un tensor en un transformador de visión

En las CNNs profundas, el viaje de un tensor se ve algo así (en una UNet, SegNet u otra arquitectura basada en CNN).

El tensor de entrada es típicamente de forma (1, 3, 128, 128). Este tensor pasa por una serie de operaciones de convolución y max-pooling donde sus dimensiones espaciales se reducen y las dimensiones del canal aumentan, típicamente por un factor de 2 cada uno. Esto se llama el codificador de características. Después de esto, hacemos la operación inversa donde aumentamos las dimensiones espaciales y reducimos las dimensiones del canal. Esto se llama el decodificador de características. Después del proceso de decodificación, obtenemos un tensor de forma (1, 64, 128, 128). Luego se proyecta en el número de canales de salida C que deseamos como (1, C, 128, 128) usando una convolución de punto 1×1 sin sesgo.

Figura 6: Progresión típica de formas de tensor a través de una CNN profunda utilizada para la segmentación de imágenes. Fuente: Autor(es).

Con los transformadores de visión, el flujo es mucho más complejo. Veamos una imagen a continuación e intentemos comprender cómo el tensor transforma las formas en cada paso del camino.

Figura 7: Progresión típica de formas de tensor a través de un transformador de visión para la segmentación de imágenes. Fuente: Autor(es).

Veamos cada paso con más detalle y veamos cómo actualiza la forma del tensor que fluye a través del transformador de visión. Para entender esto mejor, tomemos valores concretos para las dimensiones de nuestro tensor.

  1. Normalización por lotes: Los tensores de entrada y salida tienen forma (1, 3, 128, 128). La forma no cambia, pero los valores se normalizan a una media cero y una varianza unitaria.
  2. Imagen a parches: El tensor de entrada de forma (1, 3, 128, 128) se convierte en un parche apilado de imágenes de 16×16. El tensor de salida tiene forma (1, 64, 768).
  3. Incrustación de parche: La capa de incrustación de parches mapea los 768 canales de entrada a 512 canales de incrustación (para este ejemplo). El tensor de salida tiene forma (1, 64, 512). La capa de incrustación de parches es básicamente solo una capa nn.Linear en PyTorch.
  4. Incrustación de posición: La capa de incrustación de posición no tiene un tensor de entrada, pero contribuye efectivamente a un parámetro aprendible (tensor entrenable en PyTorch) de la misma forma que la incrustación de parche. Esto tiene una forma de (1, 64, 512).
  5. Suma: Las incrustaciones de parche y posición se suman pieza por pieza para producir la entrada a nuestro codificador de transformador de visión. Este tensor tiene forma (1, 64, 512). Observarás que el principal trabajo del transformador de visión, es decir, el codificador básicamente deja sin cambios la forma de este tensor.
  6. Codificador de transformador: El tensor de entrada de forma (1, 64, 512) fluye a través de múltiples bloques de codificador de transformador, cada uno de los cuales tiene múltiples cabezales de atención (comunicación) seguidos de una capa MLP (cálculo). La forma del tensor permanece sin cambios como (1, 64, 512).
  7. Proyección lineal de salida: Si asumimos que queremos segmentar cada imagen en 10 clases, entonces necesitamos que cada parche de tamaño 16×16 tenga 10 canales. La capa nn.Linear para la proyección de salida ahora convertirá los 512 canales de incrustación en 16x16x10 = 2560 canales de salida, y este tensor se verá como (1, 64, 2560). En el diagrama anterior, C’ = 10. Idealmente, esto sería un perceptrón multicapa, ya que “los MLP son aproximadores universales de funciones”, pero usamos una sola capa lineal ya que esto es un ejercicio educativo.
  8. Parche a imagen: Esta capa convierte los 64 parches codificados como un tensor (1, 64, 2560) de nuevo en algo que parece una máscara de segmentación. Esto pueden ser 10 imágenes de un solo canal, o en este caso una sola imagen de 10 canales, donde cada canal es la máscara de segmentación para una de las 10 clases. El tensor de salida tiene forma (1, 10, 128, 128).

¡Eso es todo! ¡Hemos segmentado con éxito una imagen de entrada usando un transformador de visión! A continuación, veamos un experimento junto con algunos resultados.

Transformadores de visión en acción

Este cuaderno contiene todo el código para esta sección.

En cuanto al código y la estructura de clases se refiere, imita de cerca el diagrama de bloques anterior. La mayoría de los conceptos mencionados anteriormente tienen una correspondencia 1:1 con los nombres de clase en este cuaderno.

Hay algunos conceptos relacionados con las capas de atención que son hiperparámetros críticos para nuestro modelo. No mencionamos nada sobre los detalles de la atención multi-cabeza anteriormente ya que está fuera del alcance de este artículo. Recomendamos encarecidamente leer el material de referencia mencionado anteriormente antes de continuar si no tiene una comprensión básica del mecanismo de atención en los transformadores.

Utilizamos los siguientes parámetros del modelo para el transformador de visión para la segmentación.

  1. 768 dimensiones de incrustación para la capa PatchEmbedding
  2. 12 bloques de codificador de transformador
  3. 8 cabezales de atención en cada bloque de codificador de transformador
  4. 20% de eliminación en la atención multi-cabeza y MLP

Esta configuración se puede ver en la clase de datos Python VisionTransformerArgs.

@dataclassclass VisionTransformerArgs:    """Argumentos para VisionTransformerForSegmentation."""    image_size: int = 128    patch_size: int = 16    in_channels: int = 3    out_channels: int = 3    embed_size: int = 768    num_blocks: int = 12    num_heads: int = 8    dropout: float = 0.2# end class

Se utilizó una configuración similar a la anterior durante el entrenamiento y la validación del modelo. La configuración se especifica a continuación.

  1. Se aplican las aumentaciones de datos de volteo horizontal aleatorio y color jittering al conjunto de entrenamiento para evitar el sobreajuste
  2. Las imágenes se redimensionan a 128×128 píxeles en una operación de redimensionamiento no conservadora del aspecto
  3. No se aplica normalización de entrada a las imágenes, en su lugar se utiliza una capa de normalización de lote como primera capa del modelo
  4. El modelo se entrena durante 50 épocas utilizando el optimizador Adam con una tasa de aprendizaje de 0,0004 y un programador StepLR que reduce la tasa de aprendizaje en 0,8x cada 12 épocas
  5. La función de pérdida de entropía cruzada se utiliza para clasificar un píxel como perteneciente a una mascota, el fondo o un borde de mascota

El modelo tiene 86,28 millones de parámetros y logró una precisión de validación del 85,89% después de 50 épocas de entrenamiento. Esto es menos que la precisión del 88,28% lograda por el modelo CNN profundo después de 20 épocas de entrenamiento. Esto podría deberse a algunos factores que necesitan ser validados experimentalmente.

  1. La última capa de proyección de salida es una sola nn.Linear y no un perceptrón multicapa
  2. El tamaño de parche de 16×16 es demasiado grande para capturar detalles más finos
  3. No hay suficientes épocas de entrenamiento
  4. No hay suficientes datos de entrenamiento, se sabe que los modelos transformadores necesitan muchos más datos para entrenar de manera efectiva en comparación con los modelos CNN profundos
  5. La tasa de aprendizaje es demasiado baja

Hemos trazado un gif que muestra cómo el modelo está aprendiendo a predecir las máscaras de segmentación para 21 imágenes en el conjunto de validación.

Figura 8: Un gif que muestra la progresión de las máscaras de segmentación predichas por el transformador de visión para el modelo de segmentación de imágenes. Fuente: Autor(es).

Observamos algo interesante en las primeras épocas de entrenamiento. Las máscaras de segmentación predichas tienen algunos artefactos de bloqueo extraños. La única razón que se nos ocurre para esto es que estamos descomponiendo la imagen en parches de tamaño 16×16 y después de muy pocas épocas de entrenamiento, el modelo no ha aprendido nada útil más allá de información muy gruesa sobre si este parche de 16×16 está cubierto generalmente por una mascota o por píxeles de fondo.

Figura 9: Los artefactos de bloqueo vistos en las máscaras de segmentación predichas al usar el transformador de visión para la segmentación de imágenes. Fuente: Autor(es).

Ahora que hemos visto en acción un transformador de visión básico, centrémonos en un transformador de visión de vanguardia para tareas de segmentación.

SegFormer: Segmentación semántica con transformadores

La arquitectura SegFormer fue propuesta en este documento en 2021. El transformador que vimos anteriormente es una versión más simple de la arquitectura SegFormer.

Figura 10: La arquitectura SegFormer. Fuente: SegFormer paper (2021).

Lo más notable es que el SegFormer:

  1. Genera 4 conjuntos de imágenes con parches de tamaño 4×4, 8×8, 16×16 y 32×32 en lugar de una sola imagen parcheada con parches de tamaño 16×16
  2. Utiliza 4 bloques codificadores de transformadores en lugar de solo 1. Esto se siente como un conjunto de modelos
  3. Utiliza convoluciones en las fases previas y posteriores de la autoatención
  4. No utiliza incrustaciones posicionales
  5. Cada bloque de transformador procesa imágenes a una resolución espacial de H/4 x W/4, H/8 x W/8, H/16 x W/16 y H/32, W/32
  6. De manera similar, los canales aumentan cuando las dimensiones espaciales se reducen. Esto se siente similar a los modelos CNN profundos
  7. Las predicciones en múltiples dimensiones espaciales se interpolan y luego se fusionan en el decodificador
  8. Un MLP combina todas estas predicciones para proporcionar una predicción final
  9. La predicción final es en la dimensión espacial H/4, W/4 y no en H, W.

Conclusion

En la parte 4 de esta serie, se nos presentó la arquitectura de transformadores y en particular los transformadores de visión. Desarrollamos una comprensión intuitiva de cómo funcionan los transformadores de visión y el bloque de construcción básico involucrado en las fases de comunicación y computación de los mismos. Vimos el enfoque único basado en parches adoptado por los transformadores de visión para predecir las máscaras de segmentación y luego combinar las predicciones juntas.

Revisamos un experimento que muestra los transformadores de visión en acción y pudimos comparar los resultados con enfoques de CNN profundos. Si bien nuestro transformador de visión no es de última generación, logró resultados bastante decentes. Proporcionamos una visión de enfoques de última generación como SegFormer.

Debería quedar claro en este momento que los transformadores tienen muchas más partes móviles y son más complejos en comparación con los enfoques basados en CNN profundos. Desde el punto de vista de FLOPs brutos, los transformadores tienen la promesa de ser más eficientes. En los transformadores, la única capa real que es computacionalmente pesada es nn.Linear. Esto se implementa mediante multiplicación de matrices optimizada en la mayoría de las arquitecturas. Debido a esta simplicidad arquitectónica, los transformadores tienen la promesa de ser más fáciles de optimizar y acelerar en comparación con los enfoques basados en CNN profundos.

¡Felicitaciones por llegar tan lejos! Nos complace que haya disfrutado leyendo esta serie sobre segmentación de imágenes eficiente en PyTorch. Si tiene preguntas o comentarios, no dude en dejarlos en la sección de comentarios.

Lecturas adicionales

Los detalles del mecanismo de atención están fuera del alcance de este artículo. Además, hay numerosos recursos de alta calidad a los que puede referirse para comprender el mecanismo de atención en detalle. Aquí hay algunos que recomendamos encarecidamente.

  1. El Transformer ilustrado
  2. NanoGPT desde cero usando PyTorch

A continuación, proporcionaremos enlaces a artículos que brindan más detalles sobre los transformadores de visión.

  1. Implementación de Vision Transformer (ViT) en PyTorch: este artículo detalla la implementación de un transformador de visión para la clasificación de imágenes en PyTorch. Es notable que su implementación utiliza einops, que evitamos, ya que este es un ejercicio enfocado en la educación (aunque recomendamos aprender y usar einops para la legibilidad del código). En su lugar, utilizamos operadores nativos de PyTorch para permutar y reorganizar las dimensiones del tensor. Además, hay algunos lugares donde el autor usa Conv2d en lugar de capas Lineales. Queríamos construir una implementación de transformadores de visión sin el uso de capas convolucionales en absoluto.
  2. Vision Transformer: AI Summer
  3. Implementación de SegFormer en PyTorch

We will continue to update Zepes; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

Inteligencia Artificial

Destilando lo que sabemos

Los investigadores buscan reducir el tamaño de los modelos GPT grandes.

Aprendizaje Automático

Inmersión teórica profunda en la Regresión Lineal

La mayoría de los aspirantes a bloggers de ciencia de datos lo hacen escriben un artículo introductorio sobre regresi...

Inteligencia Artificial

Sistemas de IA Sesgos desenterrados y la apasionante búsqueda de la verdadera equidad

La Inteligencia Artificial (IA) ya no es un concepto futurista, se ha convertido en una parte intrínseca de nuestras ...

Investigación

Investigadores de MIT CSAIL discuten las fronteras del AI generativo.

Expertos se reúnen para examinar el código, lenguaje e imágenes generados por la inteligencia artificial, así como su...

Inteligencia Artificial

Construye una solución centralizada de monitoreo e informes para Amazon SageMaker utilizando Amazon CloudWatch

En esta publicación, presentamos un panel de observabilidad intercuentas que proporciona una vista centralizada para ...