Transformer Attention의 탄생과 메모리 병목
2017년 “Attention is All You Need” 논문은 자연어 처리의 패러다임을 완전히 바꿔놓았습니다. 하지만 Self-Attention 메커니즘은 시퀀스 길이 에 대해 메모리와 계산량을 요구하는 치명적인 한계를 가지고 있었죠.
Self-Attention의 시간 복잡도: , 메모리 복잡도: (n: 시퀀스 길이, d: 차원)
표준 Attention 연산 과정
- (Query): 현재 토큰이 “무엇을 찾고 있는가”
- (Key): 각 토큰이 “어떤 정보를 가지고 있는가”
- (Value): 실제 전달할 정보
- : Key 벡터의 차원 (스케일링 팩터)
- : 크기의 Attention Score 행렬 (메모리 병목의 주범)
문제는 를 계산할 때 전체 시퀀스에 대한 행렬을 메모리에 저장해야 한다는 점입니다. GPT-3(2048 토큰)의 경우 약 16MB의 Attention 행렬이 필요하고, 긴 문서(8K 토큰)를 처리하면 256MB 이상으로 폭발합니다.
FlashAttention: IO-Aware Attention의 혁명
핵심 아이디어: Tiling + Recomputation
FlashAttention(2022)은 GPU 메모리 계층 구조를 활용한 천재적인 알고리즘입니다. HBM(High Bandwidth Memory)과 SRAM의 속도 차이(20배)를 이용해, 작은 블록(tile) 단위로 Attention을 계산하고 중간 결과를 재계산하는 방식으로 메모리 접근을 최소화합니다.
| 방법 | 메모리 복잡도 | HBM 접근 | 속도 |
|---|---|---|---|
| 표준 Attention | 기준 | ||
| FlashAttention | 2~4배 | ||
| FlashAttention-2 | 5~9배 |
(B: 블록 크기)
실전 구현: PyTorch에서 FlashAttention 사용하기
import torch
from flash_attn import flash_attn_qkvpacked_func
# 입력: [배치, 시퀀스, 헤드, 차원]
batch, seqlen, num_heads, head_dim = 2, 2048, 12, 64
qkv = torch.randn(batch, seqlen, 3, num_heads, head_dim,
device='cuda', dtype=torch.float16)
# FlashAttention 실행 (causal=True로 자동 마스킹)
output = flash_attn_qkvpacked_func(qkv, causal=True)
print(output.shape) # [2, 2048, 12, 64]
주의사항:
– CUDA 11.6+ 필수
– A100/H100 GPU에서 최적화
– FP16/BF16만 지원 (FP32는 오히려 느림)
FlashAttention-2: 병렬화 극대화
FlashAttention-2(2023)는 워크로드 분산 최적화에 집중했습니다.
개선 포인트
- 시퀀스 차원 병렬화 (기존: 배치+헤드 차원만)
- Warp-level 연산 감소 (메모리 공유 최소화)
- Loop 순서 조정 (캐시 히트율 향상)
# 기존 방식: 헤드별 순차 처리
for head in range(num_heads):
for block in range(num_blocks):
compute_attention(block, head)
# FlashAttention-2: 시퀀스 블록별 병렬 처리
for block in range(num_blocks): # 병렬화!
for head in range(num_heads):
compute_attention(block, head)
결과적으로 A100에서 기존 대비 2배, 표준 Attention 대비 9배 빠릅니다.
FlashAttention-3: H100 시대의 최적화
Tensor Core와 비동기 처리
FlashAttention-3(2024)는 H100의 4세대 Tensor Core를 활용합니다.
- FP8 연산 지원: FP16 대비 2배 처리량
- 비동기 WGMMA(Warp Group Matrix Multiply-Accumulate): 데이터 로딩과 연산 동시 실행
- Incoherent Processing: 블록 단위 비순차 처리로 레이턴시 숨김
| GPU | FlashAttention-2 | FlashAttention-3 | 개선율 |
|---|---|---|---|
| A100 (FP16) | 100 TFLOPS | – | – |
| H100 (FP16) | 180 TFLOPS | 280 TFLOPS | 1.55배 |
| H100 (FP8) | – | 560 TFLOPS | 3.1배 |
실무 활용: Long Context LLM 학습
from transformers import AutoModelForCausalLM
import torch
# FlashAttention-3 활성화 (transformers 4.36+)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
attn_implementation="flash_attention_3", # 명시적 지정
torch_dtype=torch.bfloat16,
device_map="auto"
)
# 32K 토큰 처리 (기존 방식은 OOM)
inputs = tokenizer("긴 문서...", return_tensors="pt", max_length=32768)
outputs = model.generate(**inputs, max_new_tokens=512)
실전 벤치마크 (Llama-3.1-70B, 32K 컨텍스트):
– 표준 Attention: OOM (메모리 부족)
– FlashAttention-2: 23 tok/s
– FlashAttention-3 (FP8): 41 tok/s + 메모리 40% 절감
O(1) 메모리는 가능한가? Linear Attention의 도전
Kernel Trick과 선형 복잡도
를 커널 함수 로 근사하면:
- 를 먼저 계산 → (메모리 )
- 시퀀스 길이 에 독립적인 O(1) 메모리!
실전 예제: Performer vs FlashAttention
from performer_pytorch import SelfAttention as PerformerAttention
from flash_attn.modules.mha import MHA as FlashMHA
# Performer (선형 복잡도, 근사)
performer = PerformerAttention(dim=512, heads=8, kernel_fn='relu').cuda()
# FlashAttention (정확한 Attention)
flash = FlashMHA(embed_dim=512, num_heads=8, use_flash_attn=True).cuda()
x = torch.randn(1, 16384, 512, device='cuda') # 16K 토큰
# 메모리 사용량 비교
import torch.cuda as cuda
cuda.reset_peak_memory_stats()
out_performer = performer(x)
print(f"Performer: {cuda.max_memory_allocated() / 1e9:.2f} GB") # 0.8 GB
cuda.reset_peak_memory_stats()
out_flash = flash(x)
print(f"FlashAttention: {cuda.max_memory_allocated() / 1e9:.2f} GB") # 1.2 GB
Trade-off:
– Performer: 메모리 효율↑, 정확도↓ (softmax 근사 오차)
– FlashAttention: 정확도 완벽, 메모리
마무리
| 기술 | 메모리 | 정확도 | 추천 상황 |
|---|---|---|---|
| 표준 Attention | ★★★ | 짧은 시퀀스(<512) | |
| FlashAttention-2 | ★★★ | A100, 범용 학습 | |
| FlashAttention-3 | ★★★ | H100, 초장문 LLM | |
| Linear Attention | ★★ | 메모리 극한 최적화 |
핵심 요약:
1. FlashAttention은 알고리즘 혁신으로 메모리를 으로 개선
2. FlashAttention-3는 H100 하드웨어 특성을 극한 활용해 3배 이상 가속
3. Linear Attention은 메모리를 달성하지만 정확도 손실 존재
4. 실무 권장: Hugging Face attn_implementation="flash_attention_3" 설정으로 무료 성능 향상
Attention 메커니즘은 단순한 수식이 아니라, 하드웨어·알고리즘·구현이 만나는 최적화의 예술입니다. 여러분의 프로젝트에도 적용해보세요!
Did you find this helpful?
☕ Buy me a coffee
Leave a Reply