Podado de redes neuronales con optimización combinatoria

Neural network pruning with combinatorial optimization

Publicado por Hussein Hazimeh, Científico Investigador, Equipo Athena, y Riade Benbaki, Estudiante de Posgrado en el MIT

Las redes neuronales modernas han logrado un rendimiento impresionante en una variedad de aplicaciones, como el lenguaje, el razonamiento matemático y la visión. Sin embargo, estas redes a menudo utilizan arquitecturas grandes que requieren muchos recursos computacionales. Esto puede hacer que sea impráctico servir dichos modelos a los usuarios, especialmente en entornos con recursos limitados como dispositivos portátiles y teléfonos inteligentes. Un enfoque ampliamente utilizado para mitigar los costos de inferencia de las redes pre-entrenadas es podarlas, es decir, eliminar algunos de sus pesos, de manera que no afecte significativamente su utilidad. En las redes neuronales estándar, cada peso define una conexión entre dos neuronas. Por lo tanto, después de podar los pesos, la entrada se propagará a través de un conjunto más pequeño de conexiones y, por lo tanto, requerirá menos recursos computacionales.

Red original vs. una red podada.

Los métodos de poda se pueden aplicar en diferentes etapas del proceso de entrenamiento de la red: después, durante o antes del entrenamiento (es decir, inmediatamente después de la inicialización de los pesos). En esta publicación, nos enfocamos en el escenario de post-entrenamiento: dado una red pre-entrenada, ¿cómo podemos determinar qué pesos deben ser podados? Uno de los métodos más populares es la poda por magnitud, que elimina los pesos con la magnitud más pequeña. Si bien es eficiente, este método no considera directamente el efecto de eliminar pesos en el rendimiento de la red. Otro paradigma popular es la poda basada en optimización, que elimina pesos en función de cuánto afecta su eliminación a la función de pérdida. Aunque es conceptualmente atractivo, la mayoría de los enfoques basados en optimización existentes parecen enfrentar un serio compromiso entre rendimiento y requisitos computacionales. Los métodos que hacen aproximaciones burdas (por ejemplo, asumiendo una matriz Hessiana diagonal) pueden escalar bien, pero tienen un rendimiento relativamente bajo. Por otro lado, aunque los métodos que hacen menos aproximaciones tienden a tener un mejor rendimiento, parecen ser mucho menos escalables.

En “Rápido como CHITA: Poda de Redes Neuronales con Optimización Combinatoria”, presentado en ICML 2023, describimos cómo desarrollamos un enfoque basado en optimización para podar redes neuronales pre-entrenadas a gran escala. CHITA (que significa “Combinatorial Hessian-free Iterative Thresholding Algorithm”) supera a los métodos de poda existentes en términos de escalabilidad y compensaciones de rendimiento, y lo hace aprovechando avances de varios campos, incluyendo estadísticas de alta dimensión, optimización combinatoria y poda de redes neuronales. Por ejemplo, CHITA puede ser de 20 a 1000 veces más rápido que los métodos de vanguardia para podar ResNet y mejora la precisión en más del 10% en muchos casos.

Resumen de las contribuciones

CHITA tiene dos mejoras técnicas destacadas en comparación con los métodos populares:

  • Uso eficiente de información de segundo orden: Los métodos de poda que utilizan información de segundo orden (es decir, relacionada con las segundas derivadas) logran el estado del arte en muchos casos. En la literatura, esta información se utiliza típicamente mediante el cálculo de la matriz Hessiana o su inversa, una operación que es muy difícil de escalar debido a que el tamaño de la Hessiana es cuadrático con respecto al número de pesos. A través de una reformulación cuidadosa, CHITA utiliza información de segundo orden sin tener que calcular o almacenar explícitamente la matriz Hessiana, lo que permite una mayor escalabilidad.
  • Optimización combinatoria: Los métodos populares basados en optimización utilizan una técnica de optimización simple que poda los pesos de forma aislada, es decir, al decidir podar un cierto peso, no tienen en cuenta si otros pesos han sido podados. Esto podría llevar a podar pesos importantes porque los pesos considerados no importantes de forma aislada pueden volverse importantes cuando se podan otros pesos. CHITA evita este problema utilizando un algoritmo de optimización combinatoria más avanzado que tiene en cuenta cómo la poda de un peso afecta a otros.

En las secciones siguientes, discutimos la formulación y los algoritmos de poda de CHITA.

Una formulación de poda amigable para la computación

Existen muchos posibles candidatos para la poda, que se obtienen reteniendo solo un subconjunto de los pesos de la red original. Sea k un parámetro especificado por el usuario que denota el número de pesos a retener. La poda se puede formular naturalmente como un problema de selección del mejor subconjunto (BSS, por sus siglas en inglés): entre todos los posibles candidatos para la poda (es decir, subconjuntos de pesos) con solo k pesos retenidos, se selecciona el candidato que tiene la menor pérdida.

La poda como un problema BSS: entre todos los posibles candidatos para la poda con el mismo número total de pesos, el mejor candidato se define como aquel con la menor pérdida. Esta ilustración muestra cuatro candidatos, pero este número generalmente es mucho mayor.

Resolver el problema de poda BSS en la función de pérdida original generalmente es computacionalmente intratable. Por lo tanto, al igual que en trabajos anteriores, como OBD y OBS, aproximamos la pérdida con una función cuadrática utilizando una serie de Taylor de segundo orden, donde la Hessiana se estima con la matriz de información empírica de Fisher. Si bien los gradientes se pueden calcular de manera eficiente, calcular y almacenar la matriz Hessiana es prohibitivamente costoso debido a su gran tamaño. En la literatura, es común enfrentar este desafío haciendo suposiciones restrictivas sobre la Hessiana (por ejemplo, matriz diagonal) y también sobre el algoritmo (por ejemplo, pesos de poda de forma aislada).

CHITA utiliza una reformulación eficiente del problema de poda (BSS utilizando la pérdida cuadrática) que evita calcular explícitamente la matriz Hessiana, al tiempo que utiliza toda la información de esta matriz. Esto es posible gracias a la estructura de bajo rango de la matriz de información empírica de Fisher. Esta reformulación se puede ver como un problema de regresión lineal dispersa, donde cada coeficiente de regresión corresponde a un cierto peso en la red neuronal. Después de obtener una solución a este problema de regresión, los coeficientes establecidos en cero corresponderán a los pesos que deben ser podados. Nuestra matriz de datos de regresión es (n x p), donde n es el tamaño del lote (submuestra) y p es el número de pesos en la red original. Normalmente, n << p, por lo que almacenar y operar con esta matriz de datos es mucho más escalable que los enfoques comunes de poda que operan con la matriz Hessiana (p x p).

CHITA reformula la aproximación de pérdida cuadrática, que requiere una matriz Hessiana costosa, como un problema de regresión lineal (LR). La matriz de datos de LR es lineal en p, lo que hace que la reformulación sea más escalable que la aproximación cuadrática original.

Algoritmos de optimización escalables

CHITA reduce la poda a un problema de regresión lineal bajo la siguiente restricción de dispersión: como máximo, k coeficientes de regresión pueden ser diferentes de cero. Para obtener una solución a este problema, consideramos una modificación del conocido algoritmo de umbral duro iterativo (IHT, por sus siglas en inglés). IHT realiza un descenso de gradiente donde, después de cada actualización, se realiza el siguiente paso de postprocesamiento: todos los coeficientes de regresión fuera de los k mejores (es decir, los k coeficientes con la magnitud más grande) se establecen en cero. IHT generalmente ofrece una buena solución al problema, y lo hace explorando iterativamente diferentes candidatos para la poda y optimizando conjuntamente los pesos.

Debido a la magnitud del problema, la IHT estándar con una tasa de aprendizaje constante puede sufrir de una convergencia muy lenta. Para una convergencia más rápida, desarrollamos un nuevo método de búsqueda de línea que explota la estructura del problema para encontrar una tasa de aprendizaje adecuada, es decir, una que conduzca a una disminución suficientemente grande en la pérdida. También utilizamos varios esquemas computacionales para mejorar la eficiencia de CHITA y la calidad de la aproximación de segundo orden, lo que nos lleva a una versión mejorada que llamamos CHITA++.

Experimentos

Comparamos el tiempo de ejecución y la precisión de CHITA con varios métodos de poda de última generación utilizando arquitecturas diferentes, incluyendo ResNet y MobileNet.

Tiempo de ejecución: CHITA es mucho más escalable que los métodos comparables que realizan una optimización conjunta (en lugar de podar pesos de forma aislada). Por ejemplo, la aceleración de CHITA puede superar las 1000 veces al podar ResNet.

Precisión después de la poda: A continuación, comparamos el rendimiento de CHITA y CHITA++ con la poda de magnitud (MP), Woodfisher (WF) y Combinatorial Brain Surgeon (CBS), para podar el 70% de los pesos del modelo. En general, observamos mejoras significativas con CHITA y CHITA++.

Precisión después de la poda de varios métodos en ResNet20. Se reportan los resultados para podar el 70% de los pesos del modelo.
Precisión después de la poda de varios métodos en MobileNet. Se reportan los resultados para podar el 70% de los pesos del modelo.

A continuación, presentamos los resultados para la poda de una red más grande: ResNet50 (en esta red, algunos de los métodos mencionados en la figura de ResNet20 no pudieron escalar). Aquí comparamos con la poda de magnitud y M-FAC. La figura a continuación muestra que CHITA logra una mejor precisión en las pruebas para una amplia gama de niveles de dispersión.

Precisión de la prueba de redes podadas, obtenidas utilizando diferentes métodos.

Conclusiones, limitaciones y trabajo futuro

Presentamos CHITA, un enfoque basado en optimización para la poda de redes neuronales pre-entrenadas. CHITA ofrece escalabilidad y rendimiento competitivo al utilizar eficientemente información de segundo orden y aprovechar ideas de la optimización combinatoria y la estadística de alta dimensión.

CHITA está diseñado para la poda no estructurada en la que se pueden eliminar cualquier peso. En teoría, la poda no estructurada puede reducir significativamente los requisitos computacionales. Sin embargo, para lograr estas reducciones en la práctica se requiere un software especial (y posiblemente hardware) que admita cálculos dispersos. En contraste, la poda estructurada, que elimina estructuras completas como neuronas, puede ofrecer mejoras más fáciles de lograr en software y hardware de propósito general. Sería interesante extender CHITA a la poda estructurada.

Agradecimientos

Este trabajo es parte de una colaboración de investigación entre Google y MIT. Agradecemos a Rahul Mazumder, Natalia Ponomareva, Wenyu Chen, Xiang Meng, Zhe Zhao y Sergei Vassilvitskii por su ayuda en la preparación de esta publicación y del artículo. También agradecemos a John Guilyard por crear los gráficos en esta publicación.

We will continue to update Zepes; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

Inteligencia Artificial

Un superordenador de inteligencia artificial cobra vida, impulsado por gigantes chips de computadora

La nueva supercomputadora, creada por la start-up de Silicon Valley Cerebras, fue presentada al mundo debido al auge ...

Inteligencia Artificial

Pagaste $1,000 por un iPhone, pero Apple todavía lo controla

La empresa codifica sus dispositivos con software que complican las reparaciones al activar advertencias de seguridad...

Inteligencia Artificial

El Ejército de los Estados Unidos pone a prueba la Inteligencia Artificial Generativa

El Departamento de Defensa de los Estados Unidos está probando cinco modelos de lenguaje grandes como parte de un esf...

Aprendizaje Automático

Conoce a DORSal Un modelo de difusión estructurada en 3D para la generación y edición a nivel de objeto de escenas en 3D.

La Inteligencia Artificial está evolucionando con la introducción de la IA Generativa y los Modelos de Lenguaje de Gr...

Inteligencia Artificial

Aumenta la productividad en Amazon SageMaker Studio Presentamos JupyterLab Spaces y herramientas de inteligencia artificial generativa

Amazon SageMaker Studio ofrece un conjunto amplio de entornos de desarrollo integrados completamente administrados (I...

Inteligencia Artificial

Keshav Pingali reconocido con el Premio ACM-IEEE CS Ken Kennedy

El premio se entregará formalmente a Pingali en noviembre en la Conferencia Internacional sobre Computación de Alto R...