Стандартное внимание имеет квадратичную сложность с точки зрения длины последовательности (количества токенов). Чтобы уменьшить сложность, в методах эффективного внимания были предложены разреженные и / или низкоранговые приближения. Эти приближения уменьшают сложность до линейной или почти линейной по отношению к длине последовательности. Тем не менее, эти методы либо отстают по производительности, либо не обеспечивают ускорения настенных часов по сравнению со стандартным вниманием.

В этой статье [1] предлагается формула внимания с учетом ввода-вывода, которая учитывает операции чтения и записи на разных уровнях быстрой и медленной памяти графического процессора. Современные графические процессоры имеют несколько уровней памяти с разными размерами и скоростями, как показано на рис. 1. В этой статье мы сосредоточимся на памяти с высокой пропускной способностью (HBM) и статической памяти с произвольным доступом (SRAM). Как и другие иерархии памяти, HBM одновременно большая и дешевая, но медленная по сравнению с SRAM, которая одновременно маленькая и дорогая, но быстрая.

HBM используется для хранения тензоров (например, карт функций/активаций), а SRAM используется для выполнения вычислительных операций с этими тензорами. Например, при применении операции RELU к тензору x мы (1) перемещаем x из HBM (операция чтения) в SRAM; (2) применить операцию RELU к x (операция вычисления) и (3) переместить x обратно из SRAM в HBM (операция записи).

В этой статье [1] предлагается алгоритм с учетом операций ввода-вывода, который вычисляет точное внимание при одновременном сокращении количества операций чтения и записи памяти (операций чтения и записи). Для достижения этой цели в статье делается два вклада:

  1. Реализуйте ядро ​​CUDA, чтобы объединить все операции внимания (matmul, mask, softmax и т. д.) в одном ядре графического процессора.
  2. Вычислите операцию softmax, не вычисляя и не сохраняя матрицу внимания NxN, где N — количество токенов.

1-й вклад: новое ядро ​​CUDA

По умолчанию каждая тензорная операция реализуется следующим образом: (1) операция чтения (операция чтения), (2) операция вычисления и (3) операция записи. Этот дизайн упрощает применение нескольких тензорных операций друг над другом. Конкретно, можно применить (1) умножение матриц, затем (2) маскирование, затем (3) softmax, не делая никаких предположений о предыдущей или последующей операции. Три вышеупомянутые операции будут выполнены, как показано на рис. 2.

Каждая операция требует операции чтения и записи, что выделено зеленым цветом. Эти операции доступа к памяти становятся узким местом, особенно для простых/быстрых вычислительных операций (например, RELU). Dao et al. [1] отмечают, что самовнимание стало стандартной операцией: (1) матмул, (2) маска, (3) софтмакс, (4) отсев, (5) матмул. Соответственно, в документе реализовано плавное ядро, единое ядро, объединяющее все эти пять операций. В объединенном ядре будет одна операция чтения и записи, что значительно снижает стоимость операций с памятью. На рис. 3 сравнивается стандартное внимание (слева) и мгновенное внимание (справа).

Имея слитое ядро, предлагаемое мгновенное внимание уменьшает количество операций памяти, что приводит к значительному ускорению во время обучения, как показано на рис. 4.

2nd Contribution: вычисление softmax без реализации матрицы внимания

Помимо реализации объединенного ядра, статья [1] вносит еще один вклад. Мгновенное внимание вычисляет точное внимание без учета матрицы NxN внимания A! Было ошибочно предположено, что точная операция softmax требует как вычисления, так и сохранения матрицы внимания A. Это предположение вытекает из знаменателя softmax, который работает со всей строкой A, как показано на рис. 5.

Мгновенное внимание опровергает это предположение с помощью двух уловок: (1) матричная мозаика, как показано на рис. 6; (2) сводная статистика, как показано на рисунках 7 и 8.

Посредством мозаичного отображения flash внимание разбивает входные данные Q, K и V на блоки, а затем загружает их из медленного HBM в быстрое SRAM. затем вычисляет вывод внимания по отношению к этим блокам, как показано на рис. 6. Конечно, вычисление softmax для отдельных блоков является неточным, поскольку softmax нормализует всю строку. Чтобы решить эту проблему, Flash Attention отслеживает сводную статистику по мере перехода от одного блока к другому. Когда мгновенное внимание достигает последнего блока, сводная статистика будет содержать точный знаменатель softmax.

На рис. 7 (слева) изображены игрушки Q, K и V, чтобы проиллюстрировать, как работает сводная статистика. На рис. 7 (справа) представлен псевдокод того, как операция softmax может быть вычислена по блокам с использованием сводной статистики {D и O}, т. е. без сохранения матрица внимания A.

На рис. 8 используется игрушка Q, K, V с рис. 7, и показан псевдокод в действии. На любой данной итерации мгновенное внимание обращается только к текущему блоку/элементу, а не ко всей строке. Чтобы вычислить точное внимание, флэш-внимание отслеживает сводную статистику {D и O}, которая обновляется после каждой итерации/блока.

Прежде чем представить количественные оценки, необходимо рассмотреть еще одну техническую деталь. До сих пор мы объясняли, как работает внезапное внимание во время прохода с прямой связью. Тем не менее, для обратного прохода обычно требуются градиенты по отношению к матрице внимания A, чтобы распространить градиент на более ранние уровни. Поскольку матрица вниманияA никогда не реализуется, у мгновенного внимания нет этих градиентов, по крайней мере, без пересчета. Используя вычисленные выходные данные и сводную статистику из прямого прохода, flash-attention повторно вычисляет элементы матрицы внимания и их градиент во время обратного прохода, т. е. снова без сохранения всей матрицы. Это означает, что мгновенное внимание вызывает больше провалов по сравнению со стандартным вниманием. Тем не менее, даже при большем количестве FLOP мгновенное внимание ускоряет обратный проход из-за уменьшения количества обращений к HBM. Рис. 9 подчеркивает этот технический аспект, сравнивая стандартное внимание со мгновенным вниманием через GFLOP, доступ к памяти и время выполнения.

Несмотря на то, что флэш-память вызывает больше FLOP, она значительно быстрее с точки зрения времени выполнения из-за сокращения доступа к памяти HBM.

Мгновенное внимание достигает двух целей: (1) ускорение обучения, (2) поддержка более длинных последовательностей (контекст). Мгновенное внимание сокращает время обучения GPT-2. Флэш-внимание демонстрирует сквозное ускорение в 3 и 1,7 раза по сравнению с Huggingface и Megatron-LM соответственно, как показано на рис. 10. Это ускорение достигается без потери точности, поскольку мгновенное внимание вычисляет точное внимание.

Эффективность времени выполнения и памяти Flash-Attention позволяет увеличить длину контекста в 4 раза по сравнению с базовым уровнем GPT-2. Это по-прежнему работает быстрее, чем оптимизированная реализация от Megatron-LM. На рис. 11 показано, что GPT-2 с вниманием Flash и длиной контекста 4K по-прежнему на 30% быстрее, чем GPT-2 от Megatron с длиной контекста 1K. Конечно, этот большой контекст повышает производительность (на 0,7 больше недоумения).

В статье [1] представлены другие количественные оценки для тех, кто интересуется быстрым вниманием. Эта статья завершается тестами Path-X и Path-256. Это сложный бенчмарк, задача которого состоит в том, чтобы классифицировать, имеют ли две точки на черно-белом изображении 128×128 (или 256×256) путь, соединяющий их. В этом тесте изображения подаются на преобразователь по одному пикселю за раз, что приводит к очень большой длине последовательности. На рис. 12 показана пара примеров из тестов Path-X.

В предыдущих работах у моделей на основе трансформаторов либо заканчивалась память, либо достигалась только случайная производительность. Flash Attention — первый трансформер, который показал превосходную производительность в этих тестах. На рис. 13 показано, как Flash-внимание достигает точности 61,4 % в задачах Path-X (длина последовательности 16 КБ) и точности 63,1 % в задачах Path-256 (длина последовательности 64 КБ).

Заключительные мысли:

  1. [S/W] Flash Attention интегрирован в PyTorch 2.0. Таким образом, его легко использовать, но он имеет несколько ограничений (например, поддерживает только определенные графические процессоры и требует CUDA 11). На рис. 14 перечислены эти ограничения.
  2. [W] Flash-внимание объединяет все операции стандартного внимания (например, matmul, mask, softmax и т. д.) в единое объединенное ядро ​​CUDA. Соответственно, любое изменение этих операций требует соответствующего изменения ядра CUDA. Конкретно, каждый раз, когда для повышения стандартного внимания вводится новая операция, ядро ​​flash-attention необходимо соответствующим образом обновлять.
  3. [W] Ускорения Flash-внимания предполагают идеальную процедуру загрузки данных. В статье сообщалось о впечатляющем ускорении работы при мгновенном внимании во время тренировки. Тем не менее, если процесс обучения ограничен загрузкой данных, эти ускорения не будут реализованы. Аналогичное наблюдение можно сделать и в отношении экономии памяти. Хотя мгновенное внимание всегда будет использовать меньше памяти, эта экономия значительна для больших последовательностей только.
  4. [S] Мне нравится простота мгновенного внимания и то, как оно снижает затраты на стандартное внимание благодаря разумной инженерии. Мгновенное внимание — это всего лишь одна из плиток в большом аспекте лаборатории профессора Кристофера Ре в Стэнфорде [2].
  5. Я написал игрушечный скрипт [3] для сопоставления мгновенного внимания с обычным вниманием. Скрипт доступен здесь. На рис. 15 показано время работы (в секундах) по мере увеличения длины последовательности. Стандартное внимание выдает ошибку нехватки памяти для последовательностей размером ≥ 4 КБ, в то время как мгновенное внимание поддерживает seq_len=16 КБ.

[1] Дао, Т., Фу, Д., Эрмон, С., Рудра, А. и Ре, К., 2022. FlashAttention: быстрое и эффективное для памяти точное внимание с io-осведомленностью. НейриПС

[2] https://github.com/HazyResearch

[3] https://discuss.pytorch.org/t/flash-attention/174955/14?u=ahmdtaha