Mamba 논문 리뷰: 선형 시간복잡도로 Transformer에 도전하는 Selective State Space Model

Updated Feb 6, 2026

들어가며

Transformer 아키텍처는 자연어 처리, 컴퓨터 비전, 오디오 등 거의 모든 시퀀스 모델링 분야에서 압도적인 성능을 보여왔습니다. 하지만 Self-Attention의 O(L2)O(L^2) 시간복잡도는 긴 시퀀스를 다룰 때 심각한 병목이 됩니다. 시퀀스 길이가 두 배가 되면 연산량은 네 배로 증가하죠.

이 문제를 해결하기 위해 다양한 Efficient Transformer 변형들이 등장했지만, 대부분 성능과 효율성 사이에서 타협해야 했습니다. 2023년 12월, Albert Gu와 Tri Dao가 발표한 “Mamba: Linear-Time Sequence Modeling with Selective State Spaces” 논문은 완전히 다른 접근법을 제시합니다.

Mamba는 Attention 메커니즘 없이 선형 시간복잡도 O(L)O(L)로 동작하면서도 Transformer와 동등하거나 더 나은 성능을 달성한 State Space Model(SSM) 기반 아키텍처입니다.

이 글에서는 Mamba의 핵심 아이디어인 Selective State Space Model(S6)의 구조, 하드웨어 최적화 알고리즘, 실험 결과를 상세히 분석합니다.


배경: State Space Model(SSM)이란?

연속 시간 SSM

State Space Model은 제어 이론에서 유래한 시퀀스 모델링 프레임워크입니다. 연속 시간 SSM은 다음과 같은 미분 방정식으로 정의됩니다:

h(t)=Ah(t)+Bx(t)h'(t) = \mathbf{A}h(t) + \mathbf{B}x(t)

y(t)=Ch(t)y(t) = \mathbf{C}h(t)

각 항의 의미는 다음과 같습니다:

  • h(t)RNh(t) \in \mathbb{R}^N: 잠재 상태(hidden state) — 시퀀스의 정보를 압축 저장하는 벡터 (차원 NN)
  • x(t)Rx(t) \in \mathbb{R}: 입력 신호 — 현재 시점의 입력값
  • y(t)Ry(t) \in \mathbb{R}: 출력 신호 — 현재 시점의 출력값
  • ARN×N\mathbf{A} \in \mathbb{R}^{N \times N}: 상태 전이 행렬 — 이전 상태가 다음 상태에 어떻게 전달되는지 결정
  • BRN×1\mathbf{B} \in \mathbb{R}^{N \times 1}: 입력 행렬 — 입력이 상태에 어떻게 반영되는지 결정
  • CR1×N\mathbf{C} \in \mathbb{R}^{1 \times N}: 출력 행렬 — 상태에서 출력을 어떻게 읽어내는지 결정

이산화(Discretization)

실제 디지털 시퀀스를 처리하려면 연속 시간 SSM을 이산화해야 합니다. 스텝 크기 Δ\Delta를 사용하여 다음과 같이 변환합니다:

A=exp(ΔA)\overline{\mathbf{A}} = \exp(\Delta \mathbf{A})

B=(ΔA)1(exp(ΔA)I)ΔB\overline{\mathbf{B}} = (\Delta \mathbf{A})^{-1}(\exp(\Delta \mathbf{A}) – \mathbf{I}) \cdot \Delta \mathbf{B}

이산화된 SSM의 순환 관계식은 RNN과 유사한 형태가 됩니다:

ht=Aht1+Bxth_t = \overline{\mathbf{A}} h_{t-1} + \overline{\mathbf{B}} x_t

yt=Chty_t = \mathbf{C} h_t

SSM의 이중 연산 모드

기존 SSM(S4 등)의 가장 큰 장점은 두 가지 연산 모드를 지원한다는 것입니다:

모드 연산 방식 용도 시간복잡도
순환 모드(Recurrence) ht=Aht1+Bxth_t = \overline{\mathbf{A}}h_{t-1} + \overline{\mathbf{B}}x_t 추론(Auto-regressive 생성) O(L)O(L) 순차적
합성곱 모드(Convolution) y=xKy = x * \overline{\mathbf{K}} 학습(병렬 처리) O(LlogL)O(L \log L) FFT 활용

합성곱 커널 K\overline{\mathbf{K}}는 다음과 같이 미리 계산할 수 있습니다:

K=(CB,  CAB,  CA2B,  ,  CAL1B)\overline{\mathbf{K}} = (\mathbf{C}\overline{\mathbf{B}}, \; \mathbf{C}\overline{\mathbf{A}}\overline{\mathbf{B}}, \; \mathbf{C}\overline{\mathbf{A}}^2\overline{\mathbf{B}}, \; \ldots, \; \mathbf{C}\overline{\mathbf{A}}^{L-1}\overline{\mathbf{B}})

핵심 포인트: 기존 SSM은 학습 시 합성곱 모드로 병렬 처리하고, 추론 시 순환 모드로 효율적으로 생성할 수 있었습니다. 하지만 이 이중 모드는 파라미터가 입력에 독립적(시불변, LTI)이어야만 가능합니다.


기존 SSM의 근본적 한계: LTI 제약

시불변(Linear Time Invariant) 시스템의 문제

기존 S4, H3 등의 SSM은 LTI(Linear Time Invariant) 시스템입니다. 즉 A\mathbf{A}, B\mathbf{B}, C\mathbf{C}, Δ\Delta 파라미터가 입력과 무관하게 고정되어 있습니다. 어떤 토큰이 들어오든 동일한 방식으로 상태를 업데이트합니다.

이것이 왜 문제일까요? 논문에서 제시하는 두 가지 합성 태스크로 설명할 수 있습니다.

Selective Copying 태스크

일반 Copying 태스크에서는 입력 시퀀스를 그대로 출력하면 됩니다. 기존 SSM도 이를 잘 수행합니다. 하지만 Selective Copying에서는 특정 토큰만 선택적으로 기억하고 나머지는 무시해야 합니다.

입력:  A _ _ B _ _ C _ _  출력: A B C

LTI 시스템은 모든 토큰을 동일하게 처리하므로, 어떤 토큰이 중요한지 구분할 수 없습니다. 결과적으로 이 태스크에서 완전히 실패합니다.

Induction Head 태스크

Induction Head는 “A B … A → B” 패턴을 학습하는 것으로, in-context learning의 핵심 메커니즘입니다. 이전에 A 다음에 B가 나왔으므로, 다시 A가 나오면 B를 예측해야 합니다.

LTI 시스템은 현재 토큰의 내용(content)에 기반한 조건부 처리를 할 수 없으므로, 이 패턴도 학습하지 못합니다.

핵심 통찰: LTI 제약은 SSM이 내용 기반 추론(content-based reasoning)을 하지 못하게 합니다. Transformer의 Attention이 강력한 이유가 바로 이 내용 기반 추론 능력에 있습니다.


Mamba의 핵심: Selective State Space Model (S6)

Selection Mechanism

Mamba의 핵심 기여는 SSM 파라미터를 입력의 함수로 만드는 Selection Mechanism입니다. 구체적으로, 시점 tt에서의 파라미터가 현재 입력 xtx_t에 따라 동적으로 결정됩니다:

Bt=sB(xt)\mathbf{B}_t = s_B(x_t)

Ct=sC(xt)\mathbf{C}_t = s_C(x_t)

Δt=softplus(sΔ(xt))\Delta_t = \text{softplus}(s_\Delta(x_t))

각 함수의 역할은 다음과 같습니다:

  • sBs_B: 선형 투영 RDRN\mathbb{R}^D \rightarrow \mathbb{R}^N — 입력을 상태에 어떻게 쓸지 결정
  • sCs_C: 선형 투영 RDRN\mathbb{R}^D \rightarrow \mathbb{R}^N — 상태에서 어떻게 읽을지 결정
  • sΔs_\Delta: 선형 투영 RDR1\mathbb{R}^D \rightarrow \mathbb{R}^1 + softplus — 이산화 스텝 크기를 결정

A\mathbf{A}는 입력 독립적으로 유지됩니다. 이는 Δ\Delta를 통한 이산화 과정에서 A=exp(ΔtA)\overline{\mathbf{A}} = \exp(\Delta_t \mathbf{A})가 간접적으로 입력 의존성을 갖기 때문입니다.

Δ\Delta의 직관적 이해

Δ\Delta는 Mamba의 Selection Mechanism에서 가장 중요한 역할을 합니다. 이산화된 상태 전이를 다시 살펴보면:

ht=exp(ΔtA)ht1+ΔtBtxth_t = \exp(\Delta_t \mathbf{A}) \cdot h_{t-1} + \Delta_t \mathbf{B}_t \cdot x_t

  • Δt\Delta_t가 클 때: exp(ΔtA)0\exp(\Delta_t \mathbf{A}) \approx 0이므로 이전 상태를 “잊고” 현재 입력 xtx_t를 강하게 반영 → 새로운 정보 기록
  • Δt\Delta_t가 작을 때: exp(ΔtA)I\exp(\Delta_t \mathbf{A}) \approx \mathbf{I}이므로 이전 상태를 유지하고 현재 입력을 무시 → 기존 정보 보존

이는 LSTM의 Forget Gate와 깊은 유사성을 가집니다. LSTM에서 forget gate가 1이면 상태를 유지하고, 0이면 상태를 초기화하는 것과 같은 원리입니다. Mamba의 Δ\Delta는 이를 연속적이고 미분 가능한 방식으로 구현한 것입니다.

정보 이론적 관점

논문은 Selection Mechanism의 필요성을 정보 이론적으로도 설명합니다:

  • LTI 모델: 상태 크기 NN에 의해 제한된 고정 용량. 시퀀스 길이와 무관하게 최대 O(N)O(N) 비트의 정보만 저장 가능
  • Selective 모델: 어떤 정보를 저장하고 어떤 정보를 버릴지 선택할 수 있으므로, 같은 상태 크기로 더 효율적인 정보 저장 가능

이는 데이터의 정보 밀도가 불균일할 때 특히 중요합니다. 자연어에서 모든 토큰이 동등하게 중요하지는 않기 때문이죠.

Selection의 대가: 합성곱 모드의 상실

파라미터가 입력 의존적이 되면, 더 이상 시불변이 아니므로 합성곱 커널을 미리 계산할 수 없습니다. 즉, 학습 시 병렬화에 사용하던 합성곱 모드를 사용할 수 없게 됩니다.

LTI (시불변)합성곱 가능, 하지만 내용 기반 추론 불가\text{LTI (시불변)} \rightarrow \text{합성곱 가능, 하지만 내용 기반 추론 불가}

Selective (시변)내용 기반 추론 가능, 하지만 합성곱 불가\text{Selective (시변)} \rightarrow \text{내용 기반 추론 가능, 하지만 합성곱 불가}

이 문제를 해결하기 위해 Mamba는 Hardware-Aware Parallel Algorithm을 도입합니다.


Hardware-Aware 병렬 알고리즘

GPU 메모리 계층 구조의 이해

현대 GPU에는 두 가지 주요 메모리 계층이 있습니다:

메모리 용량 대역폭 용도
HBM (High Bandwidth Memory) 크다 (40-80GB) 상대적으로 느림 메인 메모리
SRAM (on-chip) 작다 (수 MB) 매우 빠름 연산용 캐시

성능의 핵심은 HBM 접근을 최소화하고 SRAM에서 최대한 많은 연산을 수행하는 것입니다.

Mamba의 최적화 전략

Mamba는 세 가지 핵심 전략으로 이 문제를 해결합니다:

1. 커널 퓨전(Kernel Fusion)

이산화된 파라미터 A\overline{\mathbf{A}}, B\overline{\mathbf{B}}를 HBM에 저장하지 않고, SRAM에서 실시간으로 계산합니다. 즉, (Δ,A,B)(\Delta, \mathbf{A}, \mathbf{B})로부터 (A,B)(\overline{\mathbf{A}}, \overline{\mathbf{B}})를 계산하는 것과 순환 연산을 하나의 GPU 커널로 융합합니다.

2. 병렬 스캔(Parallel Scan)

순환 연산 ht=A<em>th</em>t1+Btxth_t = \overline{\mathbf{A}}<em>t h</em>{t-1} + \overline{\mathbf{B}}_t x_t는 본질적으로 순차적으로 보이지만, prefix sum(병렬 스캔) 알고리즘으로 병렬화할 수 있습니다:

  • 총 작업량: O(L)O(L)
  • 병렬 시간: O(logL)O(\log L)

이는 고전적인 병렬 알고리즘 기법으로, 결합 가능한(associative) 연산에 적용할 수 있습니다.

3. 역전파 시 재계산(Recomputation)

일반적인 역전파에서는 중간 상태 hth_t를 모두 저장해야 합니다. 이 경우 메모리 요구량이 O(BLDN)O(BLDN)이 됩니다 (BB: 배치, LL: 시퀀스 길이, DD: 모델 차원, NN: 상태 차원).

Mamba는 FlashAttention에서 영감을 받아, 역전파 시 중간 상태를 저장하지 않고 입력으로부터 재계산합니다. 이로써 메모리를 O(BLD+DN)O(BLD + DN)으로 절감합니다.

# 의사 코드: Mamba의 순방향 패스 (개념적)
def mamba_forward(x, A, B_proj, C_proj, delta_proj):
    # x: (B, L, D)
    B_t = B_proj(x)          # (B, L, N) - 입력 의존적
    C_t = C_proj(x)          # (B, L, N) - 입력 의존적
    delta = softplus(delta_proj(x))  # (B, L, D) - 입력 의존적

    # 이산화 (SRAM에서 실시간 계산, HBM 저장 없음)
    A_bar = exp(delta.unsqueeze(-1) * A)  # (B, L, D, N)
    B_bar = delta.unsqueeze(-1) * B_t.unsqueeze(2)  # (B, L, D, N)

    # 병렬 스캔으로 순환 연산 수행
    h = parallel_scan(A_bar, B_bar * x.unsqueeze(-1))  # (B, L, D, N)

    # 출력 계산
    y = (h * C_t.unsqueeze(2)).sum(-1)  # (B, L, D)
    return y

Mamba 블록 아키텍처

전체 구조

Mamba 블록은 Transformer의 Attention + MLP 두 블록을 하나로 통합한 구조입니다:

Input (D차원)
  │
  ├───────────────────────────┐
  │                           │
  ▼                           ▼
Linear (D → E)           Linear (D → E)
  │                           │
Conv1D (k=4)              SiLU 활성화
  │                           │
SiLU 활성화                   │
  │                           │
Selective SSM (S6)            │
  │                           │
  └──── Element-wise ×  ──────┘
              │
        Linear (E → D)
              │
           Output

설계 핵심 요소

요소 설명 설계 이유
확장 비율 E=2DE = 2D Transformer의 SwiGLU MLP(8D/3\approx 8D/3)와 파라미터 수 매칭
Conv1D (커널=4) 짧은 1D 합성곱 SSM 처리 전 로컬 컨텍스트 제공
게이팅(Gating) 분기 후 element-wise 곱 LSTM/GRU의 게이팅에서 영감. 정보 흐름 제어
SiLU 활성화 SiLU(x)=xσ(x)\text{SiLU}(x) = x \cdot \sigma(x) 비선형성 + 부드러운 게이팅 효과
RMSNorm 각 블록 앞에 적용 (Pre-Norm) 안정적 학습. Transformer와 동일한 패턴
잔차 연결 블록 출력에 입력을 더함 깊은 네트워크 학습 안정화

Transformer와의 구조적 차이

Transformer는 Attention 블록 + MLP 블록을 번갈아 쌓는 반면, Mamba는 단일 블록만 반복합니다. 이 단순화로 같은 파라미터 예산에서 더 많은 레이어를 쌓을 수 있습니다.


모델 구성

논문에서는 다양한 규모의 Mamba 모델을 실험했습니다:

모델 차원(D) 상태 차원(N) 레이어 수 파라미터 수
Mamba-130M 768 16 24 ~130M
Mamba-370M 1024 16 48 ~370M
Mamba-790M 1536 16 48 ~790M
Mamba-1.4B 2048 16 48 ~1.4B
Mamba-2.8B 2560 16 64 ~2.8B

모든 모델에서 상태 차원 N=16N=16을 사용한 것이 주목할 만합니다. 비교적 작은 상태 차원으로도 충분한 표현력을 달성할 수 있음을 보여줍니다.


학습 세부 사항

항목 설정
옵티마이저 AdamW
학습률 코사인 스케줄 + 웜업 (모델 크기에 따라 3e-4 ~ 8e-4)
Weight Decay 0.1
시퀀스 길이 2048 (학습), 평가 시 더 긴 시퀀스 테스트
정밀도 Mixed Precision (bf16/fp32)
하드웨어 A100 80GB GPU

초기화 전략

  • A\mathbf{A}: S4의 HiPPO 행렬의 대각(diagonal) 버전으로 초기화. 장기 의존성 포착에 유리한 구조
  • Δ\Delta: softplus 역함수를 사용하여 초기값이 합리적인 범위에 오도록 설정 ([4,4][-4, 4] 균등 분포의 지수)
  • B\mathbf{B}, C\mathbf{C}: 표준 초기화

실험 결과

합성 태스크(Synthetic Tasks)

합성 태스크에서 Selection Mechanism의 효과를 직접적으로 검증합니다:

모델 Selective Copying Induction Head
S4 (LTI SSM) 실패 실패
H3 부분 성공 부분 성공
Transformer 성공 성공
Mamba (S6) 성공 성공

Selection Mechanism 덕분에 Mamba는 Transformer처럼 내용 기반 추론(content-based reasoning)이 가능하며, 기존 SSM의 근본적 한계를 극복합니다.

언어 모델링 (The Pile)

300B 토큰 규모의 The Pile 데이터셋에서 학습한 결과입니다:

모델 파라미터 Pile 성능
GPT-Neo 125M 기준선
Mamba 130M GPT-Neo 125M과 동등 또는 우세
Pythia 350M 기준선
Mamba 370M Pythia 350M보다 우세
Pythia 1.4B 기준선
Mamba 1.4B Pythia 1.4B보다 우세
Transformer++ (SwiGLU, RoPE 등) ~3.1B 강화된 기준선
Mamba 2.8B Transformer++ 3.1B과 동등

주목할 결과: Mamba-2.8B는 자신보다 약 2배 큰 Transformer 모델과 비슷한 성능을 달성합니다. 이는 파라미터 효율성 측면에서 매우 인상적입니다.

Zero-shot 다운스트림 평가

사전 학습된 모델을 다양한 벤치마크에서 평가한 결과:

벤치마크 평가 내용 Mamba-2.8B 성능
LAMBADA 마지막 단어 예측 Pythia-2.8B와 동등/우세
HellaSwag 상식 추론 경쟁적
PIQA 물리적 직관 경쟁적
ARC-Easy/Challenge 과학 질의응답 경쟁적
WinoGrande 상식 추론 경쟁적

Mamba-2.8B는 GPT-Neo 2.7B, Pythia-2.8B 등 동급 Transformer 모델들과 전반적으로 경쟁적이거나 우세한 성능을 보입니다.

DNA 시퀀스 모델링

HG38 인간 게놈 데이터셋에서의 결과:

모델 긴 시퀀스 처리 성능
HyenaDNA (긴 합성곱 기반) 최대 1M 토큰 이전 SOTA
Transformer 메모리 한계로 긴 시퀀스 불가 제한적
Mamba 최대 1M 토큰 새로운 SOTA

Mamba의 특징적인 결과:
시퀀스 길이가 길어질수록 성능 향상이 지속됨 (최대 100만 토큰까지)
– 다른 모델들은 특정 길이 이후 성능 향상이 정체
– 다운스트림 유전체 분류 태스크에서도 SOTA 달성

오디오 모델링

SC09 음성 생성 벤치마크에서 Mamba는 기존 SSM 기반 모델(SaShiMi/S4)을 FID와 IS 지표 모두에서 능가했습니다.

추론 속도

Mamba의 가장 극적인 장점은 추론 속도입니다:

시퀀스 길이 Transformer 대비 Mamba 처리량
512 ~2-3배 빠름
2048 ~3-4배 빠름
8192+ ~5배 이상 빠름

이 차이의 원인:
Transformer: KV 캐시가 시퀀스 길이에 비례하여 증가 → 메모리 대역폭 병목
Mamba: 고정 크기 상태 O(D×N)O(D \times N) → 시퀀스 길이와 무관하게 일정한 메모리


Ablation Study 상세 분석

Selection Mechanism 구성 요소 분석

어떤 파라미터를 입력 의존적으로 만드는 것이 가장 효과적인지 분석합니다:

구성 Selective Copying Induction Head 언어 모델링
선택 없음 (S4) 실패 실패 기준선
Δ\Delta만 선택적 부분 성공 성공 개선
B\mathbf{B}, C\mathbf{C}만 선택적 부분 성공 성공 개선
Δ\Delta + B\mathbf{B} + C\mathbf{C} 모두 선택적 성공 성공 최고 성능

Δ\Delta를 입력 의존적으로 만드는 것이 단일 변경으로는 가장 큰 영향을 미칩니다. B\mathbf{B}, C\mathbf{C}의 선택을 추가하면 추가 성능 향상이 있습니다.

상태 차원(N) 분석

상태 차원 N 상대 성능 비고
1 상당히 나쁨 상태 용량 부족
4 보통
8 양호
16 최적 기본값으로 채택
32 미미한 개선 메모리 증가 대비 효과 작음
64 수확 체감 비용 대비 효과 미미

N=16N=16이 성능과 효율성의 최적 균형점입니다.

아키텍처 요소 분석

요소 제거 시 영향
Conv1D 성능 하락 (로컬 컨텍스트 상실)
게이팅 메커니즘 상당한 성능 하락
확장 비율 E=2D 최적 (E=D보다 우수, E=4D는 수확 체감)

기존 방법론과의 종합 비교

비교 항목 Transformer S4 (기존 SSM) H3 RWKV Mamba
학습 복잡도 O(L2D)O(L^2D) O(LlogLD)O(L \log L \cdot D) O(LlogLD)O(L \log L \cdot D) O(LD)O(LD) O(LDN)O(LDN)
추론 복잡도 (스텝당) O(LD)O(LD) O(DN)O(DN) O(DN)O(DN) O(D)O(D) O(DN)O(DN)
추론 메모리 O(LD)O(LD) 증가 O(DN)O(DN) 고정 O(DN)O(DN) 고정 O(D)O(D) 고정 O(DN)O(DN) 고정
내용 기반 추론 가능 (Attention) 불가 (LTI) 제한적 제한적 가능 (Selection)
긴 시퀀스 확장성 나쁨 좋음 좋음 좋음 좋음
언어 모델링 품질 최고 수준 낮음 중간 중간 Transformer와 동등
구조 단순성 Attention+MLP SSM+MLP H3+MLP RWKV블록 단일 Mamba 블록

Mamba의 강점

1. 선형 시간복잡도로 Transformer 급 성능

Mamba는 O(L)O(L) 복잡도로 동작하면서도 Transformer와 동등한 언어 모델링 성능을 달성합니다. 이는 기존 Efficient Transformer들이 성능 저하를 감수해야 했던 것과 대비됩니다.

2. 고정 크기 상태로 효율적 추론

추론 시 KV 캐시 없이 O(DN)O(DN) 크기의 고정 상태만 유지하므로, 시퀀스 길이에 무관하게 일정한 추론 비용을 가집니다. 이는 실시간 서비스에서 큰 장점입니다.

3. 다중 도메인에서의 강건한 성능

언어뿐 아니라 DNA 시퀀스, 오디오에서도 SOTA를 달성하여, 범용 시퀀스 모델링 아키텍처로서의 가능성을 보여줍니다.

4. 구조적 단순함

Attention과 MLP를 별도로 가지는 Transformer와 달리, 단일 Mamba 블록만으로 전체 아키텍처를 구성합니다. 이 단순함은 구현과 최적화를 용이하게 합니다.

5. 이론적 기반의 견고함

Selection Mechanism의 설계가 정보 이론적 근거와 LSTM 게이팅과의 연결성에 기반하여, 단순한 경험적 트릭이 아닌 원리적 설계임을 보여줍니다.


Mamba의 한계점

1. 대규모 스케일링 검증 부족

논문 발표 시점 기준 최대 2.8B 파라미터까지만 실험했습니다. 100B+ 규모에서의 성능은 검증되지 않았으며, Transformer의 스케일링 법칙과 비교할 수 있는 데이터가 제한적입니다.

2. 정확한 정보 검색의 한계

Mamba의 고정 크기 상태는 시퀀스의 모든 정보를 압축해야 합니다. 긴 문서에서 정확한 정보를 조회(retrieval)하는 태스크에서는 Transformer의 Attention이 더 유리할 수 있습니다. Attention은 시퀀스의 모든 위치에 직접 접근할 수 있기 때문입니다.

3. 커스텀 CUDA 커널 의존성

Hardware-Aware 알고리즘의 성능은 커스텀 CUDA/Triton 커널에 크게 의존합니다. 이는 특정 GPU 하드웨어에서의 최적화를 전제로 하며, 새로운 하드웨어나 프레임워크에 대한 이식성이 제한될 수 있습니다.

4. 양방향 처리의 비자명성

Mamba는 기본적으로 단방향(causal) 모델입니다. BERT처럼 양방향 컨텍스트가 필요한 태스크에서는 추가적인 설계 수정이 필요합니다.

5. In-Context Learning의 깊이

합성 태스크에서 Induction Head를 학습할 수 있음을 보여주었지만, Transformer가 보여주는 다양하고 복잡한 in-context learning 패턴을 동일한 수준으로 구현할 수 있는지는 추가 연구가 필요합니다.


후속 연구 방향

Mamba-2

Mamba 이후 Gu와 Dao는 Mamba-2를 발표하여 SSM과 Attention의 이론적 연결성을 밝히고, State Space Duality(SSD) 프레임워크를 제시했습니다. 더 큰 상태 차원(NN)을 효율적으로 사용할 수 있도록 알고리즘을 개선했습니다.

하이브리드 아키텍처

Mamba와 Attention을 결합한 하이브리드 아키텍처들이 활발히 연구되고 있습니다:

  • Jamba (AI21 Labs): Mamba + Transformer 레이어 혼합
  • Zamba: 선별적으로 Attention 레이어를 배치

이러한 하이브리드 접근은 Mamba의 효율성과 Attention의 정확한 정보 검색 능력을 결합합니다.

비전·멀티모달 확장

  • Vision Mamba (Vim): 이미지 패치 시퀀스에 Mamba 적용
  • VideoMamba: 비디오의 시공간 시퀀스 처리
  • 멀티모달 Mamba: 텍스트 + 이미지 + 오디오 통합 처리

스케일링 연구

대규모 모델에서의 Mamba 스케일링 법칙 연구와 100B+ 규모의 학습이 중요한 미래 연구 방향입니다.


핵심 구현 코드 살펴보기

Mamba의 핵심 Selection Mechanism을 PyTorch로 간략하게 표현하면 다음과 같습니다:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SelectiveSSM(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        d_inner = int(expand * d_model)
        self.d_inner = d_inner

        # 입력 투영
        self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)

        # Conv1D (로컬 컨텍스트)
        self.conv1d = nn.Conv1d(
            d_inner, d_inner, 
            kernel_size=d_conv, 
            padding=d_conv - 1,
            groups=d_inner  # depthwise
        )

        # Selection 파라미터 투영
        self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False)  # B, C, delta
        self.dt_proj = nn.Linear(1, d_inner, bias=True)  # delta 브로드캐스트

        # 상태 행렬 A (학습 가능하지만 입력 독립적)
        A = torch.arange(1, d_state + 1, dtype=torch.float32)
        self.A_log = nn.Parameter(torch.log(A))  # log 공간에서 파라미터화

        # 출력 투영
        self.out_proj = nn.Linear(d_inner, d_model, bias=False)

    def forward(self, x):
        B, L, D = x.shape

        # 입력 투영 + 게이팅 분기
        xz = self.in_proj(x)  # (B, L, 2*d_inner)
        x_branch, z = xz.chunk(2, dim=-1)  # 각각 (B, L, d_inner)

        # Conv1D
        x_branch = x_branch.transpose(1, 2)  # (B, d_inner, L)
        x_branch = self.conv1d(x_branch)[:, :, :L]  # causal padding
        x_branch = x_branch.transpose(1, 2)  # (B, L, d_inner)
        x_branch = F.silu(x_branch)

        # Selection: 입력으로부터 B, C, delta 계산
        x_dbl = self.x_proj(x_branch)  # (B, L, 2*N + 1)
        B_t = x_dbl[:, :, :self.d_state]           # (B, L, N)
        C_t = x_dbl[:, :, self.d_state:2*self.d_state]  # (B, L, N)
        delta = F.softplus(x_dbl[:, :, -1:])        # (B, L, 1)

        # 이산화
        A = -torch.exp(self.A_log)  # (N,) 음수로 안정화

        # 순환 연산 (실제로는 병렬 스캔 사용)
        y = self.selective_scan(x_branch, delta, A, B_t, C_t)

        # 게이팅
        y = y * F.silu(z)

        return self.out_proj(y)

    def selective_scan(self, x, delta, A, B, C):
        """단순화된 순환 연산 (실제로는 CUDA 커널로 병렬 처리)"""
        B_batch, L, D = x.shape
        N = self.d_state

        h = torch.zeros(B_batch, D, N, device=x.device)
        outputs = []

        for t in range(L):
            # 이산화된 파라미터
            A_bar = torch.exp(delta[:, t, :, None] * A[None, None, :])  # (B, D, N)
            B_bar = delta[:, t, :, None] * B[:, t, None, :]  # (B, D, N)

            # 상태 업데이트
            h = A_bar * h + B_bar * x[:, t, :, None]  # (B, D, N)

            # 출력
            y_t = (h * C[:, t, None, :]).sum(-1)  # (B, D)
            outputs.append(y_t)

        return torch.stack(outputs, dim=1)  # (B, L, D)

위 코드는 개념 이해를 위한 단순화 버전입니다. 실제 Mamba 구현에서는 selective_scan 함수가 커스텀 CUDA 커널과 병렬 스캔 알고리즘으로 대체되어 GPU에서 효율적으로 동작합니다.


마무리

Mamba는 시퀀스 모델링의 패러다임에 중요한 전환점을 제시한 연구입니다. 핵심 기여를 정리하면:

  1. Selection Mechanism: SSM 파라미터(B\mathbf{B}, C\mathbf{C}, Δ\Delta)를 입력 의존적으로 만들어 LTI 제약을 극복하고, 내용 기반 추론을 가능하게 함

  2. Hardware-Aware Algorithm: 커널 퓨전, 병렬 스캔, 역전파 재계산을 결합하여 Selection의 연산 비용을 실용적 수준으로 관리

  3. 통합 아키텍처: Attention + MLP를 하나의 Mamba 블록으로 대체하는 단순하고 효과적인 설계

  4. 다중 도메인 SOTA: 언어, DNA, 오디오에서 Transformer와 동등하거나 우수한 성능을 선형 시간복잡도로 달성

  5. 5배 빠른 추론: 고정 크기 상태 덕분에 긴 시퀀스에서 Transformer 대비 최대 5배 빠른 생성 처리량

Mamba는 “Attention이 필요한 전부인가?”라는 근본적인 질문에 대해 설득력 있는 대안을 제시합니다. 물론 대규모 스케일링 검증과 정확한 정보 검색 태스크에서의 한계가 남아있지만, 후속 연구인 Mamba-2와 Jamba 같은 하이브리드 아키텍처들이 이러한 한계를 빠르게 보완해 나가고 있습니다.

O(L2)O(L^2)에서 O(L)O(L)로의 전환은 단순한 효율성 개선이 아니라, 100만 토큰 이상의 초장문 시퀀스를 실용적으로 처리할 수 있게 하는 질적 변화입니다. Mamba가 열어놓은 이 방향이 앞으로 시퀀스 모델링의 미래를 어떻게 바꿔갈지 주목할 필요가 있습니다.

Did you find this helpful?

☕ Buy me a coffee

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *

TODAY 432 | TOTAL 2,655