Attention is All You Need부터 FlashAttention-3까지: Transformer Attention 메커니즘의 진화와 O(1) 메모리 최적화 실전 구현

Updated Feb 6, 2026

Transformer Attention의 탄생과 메모리 병목

2017년 “Attention is All You Need” 논문은 자연어 처리의 패러다임을 완전히 바꿔놓았습니다. 하지만 Self-Attention 메커니즘은 시퀀스 길이 nn에 대해 O(n2)O(n^2) 메모리와 계산량을 요구하는 치명적인 한계를 가지고 있었죠.

Self-Attention의 시간 복잡도: O(n2d)O(n^2 \cdot d), 메모리 복잡도: O(n2)O(n^2) (n: 시퀀스 길이, d: 차원)

표준 Attention 연산 과정

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

  • QQ (Query): 현재 토큰이 “무엇을 찾고 있는가”
  • KK (Key): 각 토큰이 “어떤 정보를 가지고 있는가”
  • VV (Value): 실제 전달할 정보
  • dkd_k: Key 벡터의 차원 (스케일링 팩터)
  • QKTQK^T: n×nn \times n 크기의 Attention Score 행렬 (메모리 병목의 주범)

문제는 QKTQK^T를 계산할 때 전체 시퀀스에 대한 n×nn \times n 행렬을 메모리에 저장해야 한다는 점입니다. GPT-3(2048 토큰)의 경우 약 16MB의 Attention 행렬이 필요하고, 긴 문서(8K 토큰)를 처리하면 256MB 이상으로 폭발합니다.

FlashAttention: IO-Aware Attention의 혁명

핵심 아이디어: Tiling + Recomputation

FlashAttention(2022)은 GPU 메모리 계층 구조를 활용한 천재적인 알고리즘입니다. HBM(High Bandwidth Memory)과 SRAM의 속도 차이(20배)를 이용해, 작은 블록(tile) 단위로 Attention을 계산하고 중간 결과를 재계산하는 방식으로 메모리 접근을 최소화합니다.

방법 메모리 복잡도 HBM 접근 속도
표준 Attention O(n2)O(n^2) O(n2)O(n^2) 기준
FlashAttention O(n)O(n) O(n2/B)O(n^2/B) 2~4배
FlashAttention-2 O(n)O(n) O(n2/B)O(n^2/B) 5~9배

(B: 블록 크기)

실전 구현: PyTorch에서 FlashAttention 사용하기

import torch
from flash_attn import flash_attn_qkvpacked_func

# 입력: [배치, 시퀀스, 헤드, 차원]
batch, seqlen, num_heads, head_dim = 2, 2048, 12, 64
qkv = torch.randn(batch, seqlen, 3, num_heads, head_dim, 
                  device='cuda', dtype=torch.float16)

# FlashAttention 실행 (causal=True로 자동 마스킹)
output = flash_attn_qkvpacked_func(qkv, causal=True)
print(output.shape)  # [2, 2048, 12, 64]

주의사항:
– CUDA 11.6+ 필수
– A100/H100 GPU에서 최적화
– FP16/BF16만 지원 (FP32는 오히려 느림)

FlashAttention-2: 병렬화 극대화

FlashAttention-2(2023)는 워크로드 분산 최적화에 집중했습니다.

개선 포인트

  1. 시퀀스 차원 병렬화 (기존: 배치+헤드 차원만)
  2. Warp-level 연산 감소 (메모리 공유 최소화)
  3. Loop 순서 조정 (캐시 히트율 향상)
# 기존 방식: 헤드별 순차 처리
for head in range(num_heads):
    for block in range(num_blocks):
        compute_attention(block, head)

# FlashAttention-2: 시퀀스 블록별 병렬 처리
for block in range(num_blocks):  # 병렬화!
    for head in range(num_heads):
        compute_attention(block, head)

결과적으로 A100에서 기존 대비 2배, 표준 Attention 대비 9배 빠릅니다.

FlashAttention-3: H100 시대의 최적화

Tensor Core와 비동기 처리

FlashAttention-3(2024)는 H100의 4세대 Tensor Core를 활용합니다.

  • FP8 연산 지원: FP16 대비 2배 처리량
  • 비동기 WGMMA(Warp Group Matrix Multiply-Accumulate): 데이터 로딩과 연산 동시 실행
  • Incoherent Processing: 블록 단위 비순차 처리로 레이턴시 숨김
GPU FlashAttention-2 FlashAttention-3 개선율
A100 (FP16) 100 TFLOPS
H100 (FP16) 180 TFLOPS 280 TFLOPS 1.55배
H100 (FP8) 560 TFLOPS 3.1배

실무 활용: Long Context LLM 학습

from transformers import AutoModelForCausalLM
import torch

# FlashAttention-3 활성화 (transformers 4.36+)
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B",
    attn_implementation="flash_attention_3",  # 명시적 지정
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# 32K 토큰 처리 (기존 방식은 OOM)
inputs = tokenizer("긴 문서...", return_tensors="pt", max_length=32768)
outputs = model.generate(**inputs, max_new_tokens=512)

실전 벤치마크 (Llama-3.1-70B, 32K 컨텍스트):
– 표준 Attention: OOM (메모리 부족)
– FlashAttention-2: 23 tok/s
– FlashAttention-3 (FP8): 41 tok/s + 메모리 40% 절감

O(1) 메모리는 가능한가? Linear Attention의 도전

Kernel Trick과 선형 복잡도

Attention(Q,K,V)=sim(Q,K)Vnorm(Q,K)\text{Attention}(Q, K, V) = \frac{\text{sim}(Q, K) V}{\text{norm}(Q, K)}

sim(qi,kj)\text{sim}(q_i, k_j)를 커널 함수 ϕ(qi)Tϕ(kj)\phi(q_i)^T \phi(k_j)로 근사하면:

LinearAttn(Q,K,V)=ϕ(Q)(ϕ(K)TV)\text{LinearAttn}(Q, K, V) = \phi(Q) \left( \phi(K)^T V \right)

  • ϕ(K)TV\phi(K)^T V를 먼저 계산 → O(nd2)O(n \cdot d^2) (메모리 O(d2)O(d^2))
  • 시퀀스 길이 nn에 독립적인 O(1) 메모리!

실전 예제: Performer vs FlashAttention

from performer_pytorch import SelfAttention as PerformerAttention
from flash_attn.modules.mha import MHA as FlashMHA

# Performer (선형 복잡도, 근사)
performer = PerformerAttention(dim=512, heads=8, kernel_fn='relu').cuda()

# FlashAttention (정확한 Attention)
flash = FlashMHA(embed_dim=512, num_heads=8, use_flash_attn=True).cuda()

x = torch.randn(1, 16384, 512, device='cuda')  # 16K 토큰

# 메모리 사용량 비교
import torch.cuda as cuda
cuda.reset_peak_memory_stats()
out_performer = performer(x)
print(f"Performer: {cuda.max_memory_allocated() / 1e9:.2f} GB")  # 0.8 GB

cuda.reset_peak_memory_stats()
out_flash = flash(x)
print(f"FlashAttention: {cuda.max_memory_allocated() / 1e9:.2f} GB")  # 1.2 GB

Trade-off:
– Performer: 메모리 효율↑, 정확도↓ (softmax 근사 오차)
– FlashAttention: 정확도 완벽, 메모리 O(n)O(n)

마무리

기술 메모리 정확도 추천 상황
표준 Attention O(n2)O(n^2) ★★★ 짧은 시퀀스(<512)
FlashAttention-2 O(n)O(n) ★★★ A100, 범용 학습
FlashAttention-3 O(n)O(n) ★★★ H100, 초장문 LLM
Linear Attention O(1)O(1) ★★ 메모리 극한 최적화

핵심 요약:
1. FlashAttention은 알고리즘 혁신으로 메모리를 O(n2)O(n)O(n^2) \to O(n)으로 개선
2. FlashAttention-3는 H100 하드웨어 특성을 극한 활용해 3배 이상 가속
3. Linear Attention은 O(1)O(1) 메모리를 달성하지만 정확도 손실 존재
4. 실무 권장: Hugging Face attn_implementation="flash_attention_3" 설정으로 무료 성능 향상

Attention 메커니즘은 단순한 수식이 아니라, 하드웨어·알고리즘·구현이 만나는 최적화의 예술입니다. 여러분의 프로젝트에도 적용해보세요!

Did you find this helpful?

☕ Buy me a coffee

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *

TODAY 390 | TOTAL 2,613