들어가며
Transformer 아키텍처는 자연어 처리, 컴퓨터 비전, 오디오 등 거의 모든 시퀀스 모델링 분야에서 압도적인 성능을 보여왔습니다. 하지만 Self-Attention의 시간복잡도는 긴 시퀀스를 다룰 때 심각한 병목이 됩니다. 시퀀스 길이가 두 배가 되면 연산량은 네 배로 증가하죠.
이 문제를 해결하기 위해 다양한 Efficient Transformer 변형들이 등장했지만, 대부분 성능과 효율성 사이에서 타협해야 했습니다. 2023년 12월, Albert Gu와 Tri Dao가 발표한 “Mamba: Linear-Time Sequence Modeling with Selective State Spaces” 논문은 완전히 다른 접근법을 제시합니다.
Mamba는 Attention 메커니즘 없이 선형 시간복잡도 로 동작하면서도 Transformer와 동등하거나 더 나은 성능을 달성한 State Space Model(SSM) 기반 아키텍처입니다.
이 글에서는 Mamba의 핵심 아이디어인 Selective State Space Model(S6)의 구조, 하드웨어 최적화 알고리즘, 실험 결과를 상세히 분석합니다.
배경: State Space Model(SSM)이란?
연속 시간 SSM
State Space Model은 제어 이론에서 유래한 시퀀스 모델링 프레임워크입니다. 연속 시간 SSM은 다음과 같은 미분 방정식으로 정의됩니다:
각 항의 의미는 다음과 같습니다:
- : 잠재 상태(hidden state) — 시퀀스의 정보를 압축 저장하는 벡터 (차원 )
- : 입력 신호 — 현재 시점의 입력값
- : 출력 신호 — 현재 시점의 출력값
- : 상태 전이 행렬 — 이전 상태가 다음 상태에 어떻게 전달되는지 결정
- : 입력 행렬 — 입력이 상태에 어떻게 반영되는지 결정
- : 출력 행렬 — 상태에서 출력을 어떻게 읽어내는지 결정
이산화(Discretization)
실제 디지털 시퀀스를 처리하려면 연속 시간 SSM을 이산화해야 합니다. 스텝 크기 를 사용하여 다음과 같이 변환합니다:
이산화된 SSM의 순환 관계식은 RNN과 유사한 형태가 됩니다:
SSM의 이중 연산 모드
기존 SSM(S4 등)의 가장 큰 장점은 두 가지 연산 모드를 지원한다는 것입니다:
| 모드 | 연산 방식 | 용도 | 시간복잡도 |
|---|---|---|---|
| 순환 모드(Recurrence) | 추론(Auto-regressive 생성) | 순차적 | |
| 합성곱 모드(Convolution) | 학습(병렬 처리) | FFT 활용 |
합성곱 커널 는 다음과 같이 미리 계산할 수 있습니다:
핵심 포인트: 기존 SSM은 학습 시 합성곱 모드로 병렬 처리하고, 추론 시 순환 모드로 효율적으로 생성할 수 있었습니다. 하지만 이 이중 모드는 파라미터가 입력에 독립적(시불변, LTI)이어야만 가능합니다.
기존 SSM의 근본적 한계: LTI 제약
시불변(Linear Time Invariant) 시스템의 문제
기존 S4, H3 등의 SSM은 LTI(Linear Time Invariant) 시스템입니다. 즉 , , , 파라미터가 입력과 무관하게 고정되어 있습니다. 어떤 토큰이 들어오든 동일한 방식으로 상태를 업데이트합니다.
이것이 왜 문제일까요? 논문에서 제시하는 두 가지 합성 태스크로 설명할 수 있습니다.
Selective Copying 태스크
일반 Copying 태스크에서는 입력 시퀀스를 그대로 출력하면 됩니다. 기존 SSM도 이를 잘 수행합니다. 하지만 Selective Copying에서는 특정 토큰만 선택적으로 기억하고 나머지는 무시해야 합니다.
입력: A _ _ B _ _ C _ _ → 출력: A B C
LTI 시스템은 모든 토큰을 동일하게 처리하므로, 어떤 토큰이 중요한지 구분할 수 없습니다. 결과적으로 이 태스크에서 완전히 실패합니다.
Induction Head 태스크
Induction Head는 “A B … A → B” 패턴을 학습하는 것으로, in-context learning의 핵심 메커니즘입니다. 이전에 A 다음에 B가 나왔으므로, 다시 A가 나오면 B를 예측해야 합니다.
LTI 시스템은 현재 토큰의 내용(content)에 기반한 조건부 처리를 할 수 없으므로, 이 패턴도 학습하지 못합니다.
핵심 통찰: LTI 제약은 SSM이 내용 기반 추론(content-based reasoning)을 하지 못하게 합니다. Transformer의 Attention이 강력한 이유가 바로 이 내용 기반 추론 능력에 있습니다.
Mamba의 핵심: Selective State Space Model (S6)
Selection Mechanism
Mamba의 핵심 기여는 SSM 파라미터를 입력의 함수로 만드는 Selection Mechanism입니다. 구체적으로, 시점 에서의 파라미터가 현재 입력 에 따라 동적으로 결정됩니다:
각 함수의 역할은 다음과 같습니다:
- : 선형 투영 — 입력을 상태에 어떻게 쓸지 결정
- : 선형 투영 — 상태에서 어떻게 읽을지 결정
- : 선형 투영 + softplus — 이산화 스텝 크기를 결정
는 입력 독립적으로 유지됩니다. 이는 를 통한 이산화 과정에서 가 간접적으로 입력 의존성을 갖기 때문입니다.
의 직관적 이해
는 Mamba의 Selection Mechanism에서 가장 중요한 역할을 합니다. 이산화된 상태 전이를 다시 살펴보면:
- 가 클 때: 이므로 이전 상태를 “잊고” 현재 입력 를 강하게 반영 → 새로운 정보 기록
- 가 작을 때: 이므로 이전 상태를 유지하고 현재 입력을 무시 → 기존 정보 보존
이는 LSTM의 Forget Gate와 깊은 유사성을 가집니다. LSTM에서 forget gate가 1이면 상태를 유지하고, 0이면 상태를 초기화하는 것과 같은 원리입니다. Mamba의 는 이를 연속적이고 미분 가능한 방식으로 구현한 것입니다.
정보 이론적 관점
논문은 Selection Mechanism의 필요성을 정보 이론적으로도 설명합니다:
- LTI 모델: 상태 크기 에 의해 제한된 고정 용량. 시퀀스 길이와 무관하게 최대 비트의 정보만 저장 가능
- Selective 모델: 어떤 정보를 저장하고 어떤 정보를 버릴지 선택할 수 있으므로, 같은 상태 크기로 더 효율적인 정보 저장 가능
이는 데이터의 정보 밀도가 불균일할 때 특히 중요합니다. 자연어에서 모든 토큰이 동등하게 중요하지는 않기 때문이죠.
Selection의 대가: 합성곱 모드의 상실
파라미터가 입력 의존적이 되면, 더 이상 시불변이 아니므로 합성곱 커널을 미리 계산할 수 없습니다. 즉, 학습 시 병렬화에 사용하던 합성곱 모드를 사용할 수 없게 됩니다.
이 문제를 해결하기 위해 Mamba는 Hardware-Aware Parallel Algorithm을 도입합니다.
Hardware-Aware 병렬 알고리즘
GPU 메모리 계층 구조의 이해
현대 GPU에는 두 가지 주요 메모리 계층이 있습니다:
| 메모리 | 용량 | 대역폭 | 용도 |
|---|---|---|---|
| HBM (High Bandwidth Memory) | 크다 (40-80GB) | 상대적으로 느림 | 메인 메모리 |
| SRAM (on-chip) | 작다 (수 MB) | 매우 빠름 | 연산용 캐시 |
성능의 핵심은 HBM 접근을 최소화하고 SRAM에서 최대한 많은 연산을 수행하는 것입니다.
Mamba의 최적화 전략
Mamba는 세 가지 핵심 전략으로 이 문제를 해결합니다:
1. 커널 퓨전(Kernel Fusion)
이산화된 파라미터 , 를 HBM에 저장하지 않고, SRAM에서 실시간으로 계산합니다. 즉, 로부터 를 계산하는 것과 순환 연산을 하나의 GPU 커널로 융합합니다.
2. 병렬 스캔(Parallel Scan)
순환 연산 는 본질적으로 순차적으로 보이지만, prefix sum(병렬 스캔) 알고리즘으로 병렬화할 수 있습니다:
- 총 작업량:
- 병렬 시간:
이는 고전적인 병렬 알고리즘 기법으로, 결합 가능한(associative) 연산에 적용할 수 있습니다.
3. 역전파 시 재계산(Recomputation)
일반적인 역전파에서는 중간 상태 를 모두 저장해야 합니다. 이 경우 메모리 요구량이 이 됩니다 (: 배치, : 시퀀스 길이, : 모델 차원, : 상태 차원).
Mamba는 FlashAttention에서 영감을 받아, 역전파 시 중간 상태를 저장하지 않고 입력으로부터 재계산합니다. 이로써 메모리를 으로 절감합니다.
# 의사 코드: Mamba의 순방향 패스 (개념적)
def mamba_forward(x, A, B_proj, C_proj, delta_proj):
# x: (B, L, D)
B_t = B_proj(x) # (B, L, N) - 입력 의존적
C_t = C_proj(x) # (B, L, N) - 입력 의존적
delta = softplus(delta_proj(x)) # (B, L, D) - 입력 의존적
# 이산화 (SRAM에서 실시간 계산, HBM 저장 없음)
A_bar = exp(delta.unsqueeze(-1) * A) # (B, L, D, N)
B_bar = delta.unsqueeze(-1) * B_t.unsqueeze(2) # (B, L, D, N)
# 병렬 스캔으로 순환 연산 수행
h = parallel_scan(A_bar, B_bar * x.unsqueeze(-1)) # (B, L, D, N)
# 출력 계산
y = (h * C_t.unsqueeze(2)).sum(-1) # (B, L, D)
return y
Mamba 블록 아키텍처
전체 구조
Mamba 블록은 Transformer의 Attention + MLP 두 블록을 하나로 통합한 구조입니다:
Input (D차원)
│
├───────────────────────────┐
│ │
▼ ▼
Linear (D → E) Linear (D → E)
│ │
Conv1D (k=4) SiLU 활성화
│ │
SiLU 활성화 │
│ │
Selective SSM (S6) │
│ │
└──── Element-wise × ──────┘
│
Linear (E → D)
│
Output
설계 핵심 요소
| 요소 | 설명 | 설계 이유 |
|---|---|---|
| 확장 비율 | Transformer의 SwiGLU MLP()와 파라미터 수 매칭 | |
| Conv1D (커널=4) | 짧은 1D 합성곱 | SSM 처리 전 로컬 컨텍스트 제공 |
| 게이팅(Gating) | 분기 후 element-wise 곱 | LSTM/GRU의 게이팅에서 영감. 정보 흐름 제어 |
| SiLU 활성화 | 비선형성 + 부드러운 게이팅 효과 | |
| RMSNorm | 각 블록 앞에 적용 (Pre-Norm) | 안정적 학습. Transformer와 동일한 패턴 |
| 잔차 연결 | 블록 출력에 입력을 더함 | 깊은 네트워크 학습 안정화 |
Transformer와의 구조적 차이
Transformer는 Attention 블록 + MLP 블록을 번갈아 쌓는 반면, Mamba는 단일 블록만 반복합니다. 이 단순화로 같은 파라미터 예산에서 더 많은 레이어를 쌓을 수 있습니다.
모델 구성
논문에서는 다양한 규모의 Mamba 모델을 실험했습니다:
| 모델 | 차원(D) | 상태 차원(N) | 레이어 수 | 파라미터 수 |
|---|---|---|---|---|
| Mamba-130M | 768 | 16 | 24 | ~130M |
| Mamba-370M | 1024 | 16 | 48 | ~370M |
| Mamba-790M | 1536 | 16 | 48 | ~790M |
| Mamba-1.4B | 2048 | 16 | 48 | ~1.4B |
| Mamba-2.8B | 2560 | 16 | 64 | ~2.8B |
모든 모델에서 상태 차원 을 사용한 것이 주목할 만합니다. 비교적 작은 상태 차원으로도 충분한 표현력을 달성할 수 있음을 보여줍니다.
학습 세부 사항
| 항목 | 설정 |
|---|---|
| 옵티마이저 | AdamW |
| 학습률 | 코사인 스케줄 + 웜업 (모델 크기에 따라 3e-4 ~ 8e-4) |
| Weight Decay | 0.1 |
| 시퀀스 길이 | 2048 (학습), 평가 시 더 긴 시퀀스 테스트 |
| 정밀도 | Mixed Precision (bf16/fp32) |
| 하드웨어 | A100 80GB GPU |
초기화 전략
- : S4의 HiPPO 행렬의 대각(diagonal) 버전으로 초기화. 장기 의존성 포착에 유리한 구조
- : softplus 역함수를 사용하여 초기값이 합리적인 범위에 오도록 설정 ( 균등 분포의 지수)
- , : 표준 초기화
실험 결과
합성 태스크(Synthetic Tasks)
합성 태스크에서 Selection Mechanism의 효과를 직접적으로 검증합니다:
| 모델 | Selective Copying | Induction Head |
|---|---|---|
| S4 (LTI SSM) | 실패 | 실패 |
| H3 | 부분 성공 | 부분 성공 |
| Transformer | 성공 | 성공 |
| Mamba (S6) | 성공 | 성공 |
Selection Mechanism 덕분에 Mamba는 Transformer처럼 내용 기반 추론(content-based reasoning)이 가능하며, 기존 SSM의 근본적 한계를 극복합니다.
언어 모델링 (The Pile)
300B 토큰 규모의 The Pile 데이터셋에서 학습한 결과입니다:
| 모델 | 파라미터 | Pile 성능 |
|---|---|---|
| GPT-Neo | 125M | 기준선 |
| Mamba | 130M | GPT-Neo 125M과 동등 또는 우세 |
| Pythia | 350M | 기준선 |
| Mamba | 370M | Pythia 350M보다 우세 |
| Pythia | 1.4B | 기준선 |
| Mamba | 1.4B | Pythia 1.4B보다 우세 |
| Transformer++ (SwiGLU, RoPE 등) | ~3.1B | 강화된 기준선 |
| Mamba | 2.8B | Transformer++ 3.1B과 동등 |
주목할 결과: Mamba-2.8B는 자신보다 약 2배 큰 Transformer 모델과 비슷한 성능을 달성합니다. 이는 파라미터 효율성 측면에서 매우 인상적입니다.
Zero-shot 다운스트림 평가
사전 학습된 모델을 다양한 벤치마크에서 평가한 결과:
| 벤치마크 | 평가 내용 | Mamba-2.8B 성능 |
|---|---|---|
| LAMBADA | 마지막 단어 예측 | Pythia-2.8B와 동등/우세 |
| HellaSwag | 상식 추론 | 경쟁적 |
| PIQA | 물리적 직관 | 경쟁적 |
| ARC-Easy/Challenge | 과학 질의응답 | 경쟁적 |
| WinoGrande | 상식 추론 | 경쟁적 |
Mamba-2.8B는 GPT-Neo 2.7B, Pythia-2.8B 등 동급 Transformer 모델들과 전반적으로 경쟁적이거나 우세한 성능을 보입니다.
DNA 시퀀스 모델링
HG38 인간 게놈 데이터셋에서의 결과:
| 모델 | 긴 시퀀스 처리 | 성능 |
|---|---|---|
| HyenaDNA (긴 합성곱 기반) | 최대 1M 토큰 | 이전 SOTA |
| Transformer | 메모리 한계로 긴 시퀀스 불가 | 제한적 |
| Mamba | 최대 1M 토큰 | 새로운 SOTA |
Mamba의 특징적인 결과:
– 시퀀스 길이가 길어질수록 성능 향상이 지속됨 (최대 100만 토큰까지)
– 다른 모델들은 특정 길이 이후 성능 향상이 정체
– 다운스트림 유전체 분류 태스크에서도 SOTA 달성
오디오 모델링
SC09 음성 생성 벤치마크에서 Mamba는 기존 SSM 기반 모델(SaShiMi/S4)을 FID와 IS 지표 모두에서 능가했습니다.
추론 속도
Mamba의 가장 극적인 장점은 추론 속도입니다:
| 시퀀스 길이 | Transformer 대비 Mamba 처리량 |
|---|---|
| 512 | ~2-3배 빠름 |
| 2048 | ~3-4배 빠름 |
| 8192+ | ~5배 이상 빠름 |
이 차이의 원인:
– Transformer: KV 캐시가 시퀀스 길이에 비례하여 증가 → 메모리 대역폭 병목
– Mamba: 고정 크기 상태 → 시퀀스 길이와 무관하게 일정한 메모리
Ablation Study 상세 분석
Selection Mechanism 구성 요소 분석
어떤 파라미터를 입력 의존적으로 만드는 것이 가장 효과적인지 분석합니다:
| 구성 | Selective Copying | Induction Head | 언어 모델링 |
|---|---|---|---|
| 선택 없음 (S4) | 실패 | 실패 | 기준선 |
| 만 선택적 | 부분 성공 | 성공 | 개선 |
| , 만 선택적 | 부분 성공 | 성공 | 개선 |
| + + 모두 선택적 | 성공 | 성공 | 최고 성능 |
를 입력 의존적으로 만드는 것이 단일 변경으로는 가장 큰 영향을 미칩니다. , 의 선택을 추가하면 추가 성능 향상이 있습니다.
상태 차원(N) 분석
| 상태 차원 N | 상대 성능 | 비고 |
|---|---|---|
| 1 | 상당히 나쁨 | 상태 용량 부족 |
| 4 | 보통 | |
| 8 | 양호 | |
| 16 | 최적 | 기본값으로 채택 |
| 32 | 미미한 개선 | 메모리 증가 대비 효과 작음 |
| 64 | 수확 체감 | 비용 대비 효과 미미 |
이 성능과 효율성의 최적 균형점입니다.
아키텍처 요소 분석
| 요소 | 제거 시 영향 |
|---|---|
| Conv1D | 성능 하락 (로컬 컨텍스트 상실) |
| 게이팅 메커니즘 | 상당한 성능 하락 |
| 확장 비율 E=2D | 최적 (E=D보다 우수, E=4D는 수확 체감) |
기존 방법론과의 종합 비교
| 비교 항목 | Transformer | S4 (기존 SSM) | H3 | RWKV | Mamba |
|---|---|---|---|---|---|
| 학습 복잡도 | |||||
| 추론 복잡도 (스텝당) | |||||
| 추론 메모리 | 증가 | 고정 | 고정 | 고정 | 고정 |
| 내용 기반 추론 | 가능 (Attention) | 불가 (LTI) | 제한적 | 제한적 | 가능 (Selection) |
| 긴 시퀀스 확장성 | 나쁨 | 좋음 | 좋음 | 좋음 | 좋음 |
| 언어 모델링 품질 | 최고 수준 | 낮음 | 중간 | 중간 | Transformer와 동등 |
| 구조 단순성 | Attention+MLP | SSM+MLP | H3+MLP | RWKV블록 | 단일 Mamba 블록 |
Mamba의 강점
1. 선형 시간복잡도로 Transformer 급 성능
Mamba는 복잡도로 동작하면서도 Transformer와 동등한 언어 모델링 성능을 달성합니다. 이는 기존 Efficient Transformer들이 성능 저하를 감수해야 했던 것과 대비됩니다.
2. 고정 크기 상태로 효율적 추론
추론 시 KV 캐시 없이 크기의 고정 상태만 유지하므로, 시퀀스 길이에 무관하게 일정한 추론 비용을 가집니다. 이는 실시간 서비스에서 큰 장점입니다.
3. 다중 도메인에서의 강건한 성능
언어뿐 아니라 DNA 시퀀스, 오디오에서도 SOTA를 달성하여, 범용 시퀀스 모델링 아키텍처로서의 가능성을 보여줍니다.
4. 구조적 단순함
Attention과 MLP를 별도로 가지는 Transformer와 달리, 단일 Mamba 블록만으로 전체 아키텍처를 구성합니다. 이 단순함은 구현과 최적화를 용이하게 합니다.
5. 이론적 기반의 견고함
Selection Mechanism의 설계가 정보 이론적 근거와 LSTM 게이팅과의 연결성에 기반하여, 단순한 경험적 트릭이 아닌 원리적 설계임을 보여줍니다.
Mamba의 한계점
1. 대규모 스케일링 검증 부족
논문 발표 시점 기준 최대 2.8B 파라미터까지만 실험했습니다. 100B+ 규모에서의 성능은 검증되지 않았으며, Transformer의 스케일링 법칙과 비교할 수 있는 데이터가 제한적입니다.
2. 정확한 정보 검색의 한계
Mamba의 고정 크기 상태는 시퀀스의 모든 정보를 압축해야 합니다. 긴 문서에서 정확한 정보를 조회(retrieval)하는 태스크에서는 Transformer의 Attention이 더 유리할 수 있습니다. Attention은 시퀀스의 모든 위치에 직접 접근할 수 있기 때문입니다.
3. 커스텀 CUDA 커널 의존성
Hardware-Aware 알고리즘의 성능은 커스텀 CUDA/Triton 커널에 크게 의존합니다. 이는 특정 GPU 하드웨어에서의 최적화를 전제로 하며, 새로운 하드웨어나 프레임워크에 대한 이식성이 제한될 수 있습니다.
4. 양방향 처리의 비자명성
Mamba는 기본적으로 단방향(causal) 모델입니다. BERT처럼 양방향 컨텍스트가 필요한 태스크에서는 추가적인 설계 수정이 필요합니다.
5. In-Context Learning의 깊이
합성 태스크에서 Induction Head를 학습할 수 있음을 보여주었지만, Transformer가 보여주는 다양하고 복잡한 in-context learning 패턴을 동일한 수준으로 구현할 수 있는지는 추가 연구가 필요합니다.
후속 연구 방향
Mamba-2
Mamba 이후 Gu와 Dao는 Mamba-2를 발표하여 SSM과 Attention의 이론적 연결성을 밝히고, State Space Duality(SSD) 프레임워크를 제시했습니다. 더 큰 상태 차원()을 효율적으로 사용할 수 있도록 알고리즘을 개선했습니다.
하이브리드 아키텍처
Mamba와 Attention을 결합한 하이브리드 아키텍처들이 활발히 연구되고 있습니다:
- Jamba (AI21 Labs): Mamba + Transformer 레이어 혼합
- Zamba: 선별적으로 Attention 레이어를 배치
이러한 하이브리드 접근은 Mamba의 효율성과 Attention의 정확한 정보 검색 능력을 결합합니다.
비전·멀티모달 확장
- Vision Mamba (Vim): 이미지 패치 시퀀스에 Mamba 적용
- VideoMamba: 비디오의 시공간 시퀀스 처리
- 멀티모달 Mamba: 텍스트 + 이미지 + 오디오 통합 처리
스케일링 연구
대규모 모델에서의 Mamba 스케일링 법칙 연구와 100B+ 규모의 학습이 중요한 미래 연구 방향입니다.
핵심 구현 코드 살펴보기
Mamba의 핵심 Selection Mechanism을 PyTorch로 간략하게 표현하면 다음과 같습니다:
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelectiveSSM(nn.Module):
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
super().__init__()
self.d_model = d_model
self.d_state = d_state
d_inner = int(expand * d_model)
self.d_inner = d_inner
# 입력 투영
self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
# Conv1D (로컬 컨텍스트)
self.conv1d = nn.Conv1d(
d_inner, d_inner,
kernel_size=d_conv,
padding=d_conv - 1,
groups=d_inner # depthwise
)
# Selection 파라미터 투영
self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False) # B, C, delta
self.dt_proj = nn.Linear(1, d_inner, bias=True) # delta 브로드캐스트
# 상태 행렬 A (학습 가능하지만 입력 독립적)
A = torch.arange(1, d_state + 1, dtype=torch.float32)
self.A_log = nn.Parameter(torch.log(A)) # log 공간에서 파라미터화
# 출력 투영
self.out_proj = nn.Linear(d_inner, d_model, bias=False)
def forward(self, x):
B, L, D = x.shape
# 입력 투영 + 게이팅 분기
xz = self.in_proj(x) # (B, L, 2*d_inner)
x_branch, z = xz.chunk(2, dim=-1) # 각각 (B, L, d_inner)
# Conv1D
x_branch = x_branch.transpose(1, 2) # (B, d_inner, L)
x_branch = self.conv1d(x_branch)[:, :, :L] # causal padding
x_branch = x_branch.transpose(1, 2) # (B, L, d_inner)
x_branch = F.silu(x_branch)
# Selection: 입력으로부터 B, C, delta 계산
x_dbl = self.x_proj(x_branch) # (B, L, 2*N + 1)
B_t = x_dbl[:, :, :self.d_state] # (B, L, N)
C_t = x_dbl[:, :, self.d_state:2*self.d_state] # (B, L, N)
delta = F.softplus(x_dbl[:, :, -1:]) # (B, L, 1)
# 이산화
A = -torch.exp(self.A_log) # (N,) 음수로 안정화
# 순환 연산 (실제로는 병렬 스캔 사용)
y = self.selective_scan(x_branch, delta, A, B_t, C_t)
# 게이팅
y = y * F.silu(z)
return self.out_proj(y)
def selective_scan(self, x, delta, A, B, C):
"""단순화된 순환 연산 (실제로는 CUDA 커널로 병렬 처리)"""
B_batch, L, D = x.shape
N = self.d_state
h = torch.zeros(B_batch, D, N, device=x.device)
outputs = []
for t in range(L):
# 이산화된 파라미터
A_bar = torch.exp(delta[:, t, :, None] * A[None, None, :]) # (B, D, N)
B_bar = delta[:, t, :, None] * B[:, t, None, :] # (B, D, N)
# 상태 업데이트
h = A_bar * h + B_bar * x[:, t, :, None] # (B, D, N)
# 출력
y_t = (h * C[:, t, None, :]).sum(-1) # (B, D)
outputs.append(y_t)
return torch.stack(outputs, dim=1) # (B, L, D)
위 코드는 개념 이해를 위한 단순화 버전입니다. 실제 Mamba 구현에서는
selective_scan함수가 커스텀 CUDA 커널과 병렬 스캔 알고리즘으로 대체되어 GPU에서 효율적으로 동작합니다.
마무리
Mamba는 시퀀스 모델링의 패러다임에 중요한 전환점을 제시한 연구입니다. 핵심 기여를 정리하면:
-
Selection Mechanism: SSM 파라미터(, , )를 입력 의존적으로 만들어 LTI 제약을 극복하고, 내용 기반 추론을 가능하게 함
-
Hardware-Aware Algorithm: 커널 퓨전, 병렬 스캔, 역전파 재계산을 결합하여 Selection의 연산 비용을 실용적 수준으로 관리
-
통합 아키텍처: Attention + MLP를 하나의 Mamba 블록으로 대체하는 단순하고 효과적인 설계
-
다중 도메인 SOTA: 언어, DNA, 오디오에서 Transformer와 동등하거나 우수한 성능을 선형 시간복잡도로 달성
-
5배 빠른 추론: 고정 크기 상태 덕분에 긴 시퀀스에서 Transformer 대비 최대 5배 빠른 생성 처리량
Mamba는 “Attention이 필요한 전부인가?”라는 근본적인 질문에 대해 설득력 있는 대안을 제시합니다. 물론 대규모 스케일링 검증과 정확한 정보 검색 태스크에서의 한계가 남아있지만, 후속 연구인 Mamba-2와 Jamba 같은 하이브리드 아키텍처들이 이러한 한계를 빠르게 보완해 나가고 있습니다.
에서 로의 전환은 단순한 효율성 개선이 아니라, 100만 토큰 이상의 초장문 시퀀스를 실용적으로 처리할 수 있게 하는 질적 변화입니다. Mamba가 열어놓은 이 방향이 앞으로 시퀀스 모델링의 미래를 어떻게 바꿔갈지 주목할 필요가 있습니다.
Did you find this helpful?
☕ Buy me a coffee
Leave a Reply