Switch Transformers 논문 리뷰: 1.6조 파라미터 MoE 모델의 Sparse Routing 기법 완전 분석

Updated Feb 6, 2026

들어가며: 왜 Switch Transformers인가?

딥러닝 모델의 성능은 파라미터 수에 비례하여 향상된다는 것이 여러 연구를 통해 입증되어 왔습니다. GPT-3가 1,750억 개의 파라미터로 놀라운 성능을 보여준 이후, “더 큰 모델 = 더 좋은 성능”이라는 스케일링 법칙(Scaling Law)은 AI 연구의 핵심 패러다임이 되었습니다.

그러나 여기에는 근본적인 문제가 있습니다. 파라미터 수를 늘리면 연산량(FLOPs)도 비례하여 증가합니다. Dense 모델에서는 모든 입력이 모든 파라미터를 거쳐야 하기 때문입니다. 이는 학습 시간과 비용의 기하급수적 증가를 의미합니다.

핵심 질문: 연산량은 고정한 채로 파라미터 수만 늘려서 성능을 향상시킬 수 있을까?

Google Brain 팀의 William Fedus, Barret Zoph, Noam Shazeer가 2021년 발표한 “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity”는 바로 이 질문에 대한 답을 제시합니다. 이 논문은 Mixture-of-Experts(MoE) 아키텍처를 극단적으로 단순화하여, 1.6조(trillion) 개의 파라미터를 가진 모델을 효율적으로 학습시키는 방법을 보여줍니다.

이 글에서는 Switch Transformers의 핵심 아이디어부터 아키텍처 설계, 학습 안정화 기법, 실험 결과까지 논문의 전체 내용을 깊이 있게 분석합니다.


배경 지식: Mixture-of-Experts (MoE)

MoE의 기본 개념

Mixture-of-Experts는 1991년 Jacobs 등이 처음 제안한 아키텍처로, 여러 전문가(Expert) 네트워크 중 일부만 선택적으로 활성화하는 조건부 연산(Conditional Computation) 방식입니다.

기존 Dense 모델에서는 입력 토큰이 네트워크의 모든 파라미터를 통과합니다. 반면 MoE 모델에서는 라우터(Router)가 각 입력에 대해 적합한 전문가만 선택하여 연산을 수행합니다.

Shazeer et al. (2017)의 MoE Layer

Transformer 이전에 LSTM 기반으로 제안된 MoE 레이어의 출력은 다음과 같이 정의됩니다:

y=i=1Ngi(x)Ei(x)y = \sum_{i=1}^{N} g_i(x) \cdot E_i(x)

여기서:
NN: 전문가의 총 수
Ei(x)E_i(x): ii번째 전문가 네트워크의 출력
gi(x)g_i(x): 게이팅 함수(gating function)가 ii번째 전문가에 부여하는 가중치

게이팅 함수는 일반적으로 Top-k 방식으로 구현됩니다:

g(x)=Softmax(TopK(xWg))g(x) = \text{Softmax}(\text{TopK}(x \cdot W_g))

여기서 WgW_g는 학습 가능한 게이팅 가중치 행렬이며, TopK는 상위 kk개의 값만 남기고 나머지를 -\infty로 마스킹합니다.

기존 연구들은 대부분 Top-2 라우팅(각 토큰을 2개의 전문가에게 보내는 방식)을 사용했습니다. Shazeer et al. (2017)은 Top-2에 노이즈를 추가하여 로드 밸런싱을 달성했고, Lepikhin et al. (2020)의 GShard는 이를 Transformer에 적용하여 6,000억 파라미터 MoE 모델을 학습시켰습니다.

기존 MoE의 한계

한계점 설명
학습 불안정성 MoE 모델은 Dense 모델보다 학습이 불안정하여 발산하기 쉬움
복잡한 라우팅 Top-2 이상의 라우팅은 통신 비용과 구현 복잡도를 높임
로드 불균형 특정 전문가에 토큰이 집중되는 현상 발생
미세 조정 어려움 거대 MoE 모델을 작은 Dense 모델로 증류하기 어려움

핵심 기여 (Key Contributions)

Switch Transformers는 다음과 같은 핵심 기여를 통해 기존 MoE의 한계를 극복합니다:

1. Top-1 라우팅으로의 단순화

기존의 Top-2 라우팅을 Top-1 라우팅(Switch Routing)으로 단순화했습니다. 각 토큰은 단 하나의 전문가에게만 라우팅됩니다.

Switch Routing의 핵심: 라우팅을 극단적으로 단순화하면서도 성능은 유지하거나 오히려 향상된다.

2. 학습 안정화 기법

대규모 Sparse 모델의 학습 불안정성을 해결하기 위해 다음 기법들을 제안합니다:
선택적 정밀도 캐스팅 (Selective Precision Casting)
개선된 로드 밸런싱 손실 (Simplified Load Balancing Loss)
초기화 스케일링 (Smaller Initialization)

3. 1.6조 파라미터 스케일링

위 기법들을 조합하여 1.6조 개의 파라미터를 가진 Switch-C 모델을 성공적으로 학습시켰습니다.

4. 효율적 증류 및 미세 조정

거대 Sparse 모델의 지식을 작은 Dense 모델로 증류(Distillation)하는 방법과, 다운스트림 태스크에 대한 효율적 미세 조정 전략을 제시합니다.


아키텍처 상세 분석

Switch Layer의 구조

Switch Transformers는 표준 Transformer의 FFN(Feed-Forward Network) 레이어를 Switch Layer로 대체합니다. Self-Attention 레이어는 변경하지 않습니다.

표준 Transformer의 FFN:

FFN(x)=max(0,xW1+b1)W2+b2\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2

Switch Layer에서는 이 FFN이 NN개의 독립적인 전문가 E1,E2,,ENE_1, E_2, \dots, E_N으로 복제됩니다. 각 전문가는 동일한 구조의 FFN이지만, 서로 다른 가중치를 가집니다.

Switch Routing 메커니즘

입력 토큰 xx에 대해, 라우터는 다음과 같이 동작합니다:

Step 1: 라우터 확률 계산

pi(x)=eh(x)<em>i</em>j=1Neh(x)jp_i(x) = \frac{e^{h(x)<em>i}}{\sum</em>{j=1}^{N} e^{h(x)_j}}

여기서 h(x)=xWrh(x) = x \cdot W_r이며, WrRdmodel×NW_r \in \mathbb{R}^{d_{\text{model}} \times N}은 학습 가능한 라우터 가중치 행렬입니다.

Step 2: 최적 전문가 선택

i=argmaxipi(x)i^* = \arg\max_i \, p_i(x)

Step 3: 전문가 출력 계산

y=pi<em>(x)Ei</em>(x)y = p_{i^<em>}(x) \cdot E_{i^</em>}(x)

주목할 점은 출력에 라우터 확률 pi<em>(x)p_{i^<em>}(x)가 곱해진다는 것입니다. 이 게이팅 값*은 라우터의 신뢰도를 반영하며, 역전파 시 라우터 가중치 WrW_r의 그래디언트가 전파되도록 합니다.

Top-1 vs Top-2: 왜 하나의 전문가인가?

Top-1 라우팅이 Top-2보다 나은 이유는 다음과 같습니다:

비교 항목 Top-1 (Switch) Top-2 (기존 MoE)
라우팅 연산량 xWrx \cdot W_r 1회 xWrx \cdot W_r 1회 + Top-2 선택
전문가 호출 수 1개 2개
통신 비용 토큰당 1회 전송 토큰당 2회 전송
배치 크기 Expert Capacity 내 더 많은 토큰 처리 가능 같은 Capacity에서 절반의 토큰
구현 복잡도 단순 두 전문가 출력의 가중 합산 필요

논문은 동일한 연산량(FLOPs) 기준에서 Top-1 라우팅이 Top-2를 능가함을 실험적으로 보여줍니다. 그 핵심 이유는 같은 연산 예산 내에서 더 많은 전문가를 사용할 수 있기 때문입니다.

Expert Capacity와 토큰 드롭

분산 학습에서 각 전문가가 처리할 수 있는 토큰의 최대 수를 Expert Capacity라 합니다:

Expert Capacity=(tokens per batchnumber of experts)×capacity factor\text{Expert Capacity} = \left( \frac{\text{tokens per batch}}{\text{number of experts}} \right) \times \text{capacity factor}

Capacity Factor는 중요한 하이퍼파라미터입니다:

Capacity Factor 효과
1.0 완벽한 균등 분배 가정 시 딱 맞는 용량. 실제로는 토큰 드롭 발생
1.25 25% 여유 버퍼. 논문의 기본 설정
1.5 이상 드롭 거의 없지만, 메모리 낭비 증가

Capacity를 초과하는 토큰은 드롭(drop)되어 해당 레이어를 건너뛰고 잔차 연결(residual connection)을 통해 다음 레이어로 전달됩니다.


학습 안정화 기법

1. 로드 밸런싱 손실 (Auxiliary Load Balancing Loss)

MoE 모델의 가장 큰 문제 중 하나는 로드 불균형입니다. 특정 전문가에 토큰이 집중되면, 다른 전문가는 학습 기회를 잃고 모델 전체의 용량이 낭비됩니다.

Switch Transformers는 기존의 복잡한 로드 밸런싱 메커니즘을 단순화한 보조 손실(auxiliary loss)을 제안합니다:

L<em>aux=αN</em>i=1NfiPi\mathcal{L}<em>{\text{aux}} = \alpha \cdot N \cdot \sum</em>{i=1}^{N} f_i \cdot P_i

여기서:
α\alpha: 로드 밸런싱 손실의 가중치 (하이퍼파라미터, 논문에서는 α=102\alpha = 10^{-2} 사용)
NN: 전문가의 수
fif_i: 배치 내에서 전문가 ii실제로 라우팅된 토큰의 비율

fi=1TxB1argmaxjpj(x)=if_i = \frac{1}{T} \sum_{x \in \mathcal{B}} \mathbf{1}{\arg\max_j \, p_j(x) = i}

  • PiP_i: 배치 내 모든 토큰에 대한 전문가 ii평균 라우터 확률

Pi=1TxBpi(x)P_i = \frac{1}{T} \sum_{x \in \mathcal{B}} p_i(x)

  • TT: 배치 내 토큰의 총 수
  • B\mathcal{B}: 현재 배치

fiPif_i \cdot P_i인가? fif_i는 이산적(argmax에 의해 결정)이므로 미분 불가능합니다. 반면 PiP_i는 softmax 출력의 평균이므로 미분 가능합니다. 두 항의 곱을 최소화하면, 그래디언트는 PiP_i를 통해 전파되면서 각 전문가로의 라우팅이 균등해지도록 유도합니다.

완벽하게 균형 잡힌 경우, fi=Pi=1/Nf_i = P_i = 1/N이 되어 Laux=α\mathcal{L}_{\text{aux}} = \alpha가 됩니다.

2. 선택적 정밀도 캐스팅 (Selective Precision Casting)

bfloat16 학습은 메모리와 속도 면에서 유리하지만, MoE 모델에서는 라우터의 소프트맥스 연산이 수치적으로 불안정해질 수 있습니다.

Switch Transformers의 해결책:

  1. 라우터로 들어가는 입력을 float32로 업캐스팅
  2. 라우터 연산(행렬 곱, softmax)을 float32로 수행
  3. 결과를 다시 bfloat16으로 다운캐스팅하여 전문가에 전달
# 의사 코드: 선택적 정밀도 캐스팅
def switch_router(x_bf16, W_r):
    # Step 1: 라우터 입력만 float32로 업캐스팅
    x_f32 = x_bf16.to(torch.float32)

    # Step 2: 라우터 연산을 float32로 수행
    logits = x_f32 @ W_r.to(torch.float32)
    probs = torch.softmax(logits, dim=-1)

    # Step 3: 최적 전문가 선택
    expert_idx = torch.argmax(probs, dim=-1)
    gate_value = probs.gather(-1, expert_idx.unsqueeze(-1))

    # Step 4: 결과를 bfloat16으로 다운캐스팅
    gate_value = gate_value.to(torch.bfloat16)

    return expert_idx, gate_value

이 방법은 전체 모델을 float32로 학습시키는 것보다 훨씬 효율적이면서도, 라우터의 수치 안정성을 확보합니다.

3. 초기화 스케일링 (Smaller Initialization)

표준 Transformer는 초기화 스케일을 ss로 사용하지만, Switch Transformers는 이를 s/10s / \sqrt{10}으로 줄입니다. 이는 학습 초기 단계에서 라우터의 출력 분산을 줄여 더 균등한 초기 라우팅을 유도합니다.

4. 전문가 드롭아웃 (Expert Dropout)

미세 조정 시 과적합을 방지하기 위해, 전문가 내부의 FFN에만 높은 드롭아웃(예: 40%)을 적용하고 나머지 레이어(Attention 등)에는 낮은 드롭아웃을 유지합니다.


모델 스케일링 전략

Switch-Base: 같은 FLOPs에서의 스케일링

논문은 T5-Base 모델과 동일한 연산량을 유지하면서 전문가 수를 늘리는 실험을 수행합니다.

모델 전문가 수 파라미터 수 사전학습 속도 향상
T5-Base 1 (Dense) 223M 1.0x (기준)
Switch-Base 8 8 ~1.1B 3.2x
Switch-Base 16 16 ~2.0B 3.7x
Switch-Base 32 32 ~3.8B 4.0x
Switch-Base 64 64 ~7.4B 4.2x
Switch-Base 128 128 ~14.7B 4.7x
Switch-Base 256 256 ~29.3B 4.4x

128개 전문가까지는 속도 향상이 지속적으로 증가하지만, 256개에서는 약간 감소합니다. 이는 통신 비용과 로드 불균형의 영향입니다.

Switch-Large / XL / XXL: 대규모 모델

모델 파라미터 수 전문가 수 T5 대비 사전학습 속도 향상
T5-Large 739M 1.0x
Switch-Large 128 ~50B 128 4.0x
T5-XL 3B 1.0x
Switch-XL 128 ~159B 128 3.5x
T5-XXL 11B 1.0x
Switch-XXL 128 ~395B 128 2.8x

모델 크기가 커질수록 MoE의 상대적 이점이 줄어드는 경향이 있지만, 절대적으로는 여전히 유의미한 성능 향상을 달성합니다.

Switch-C: 1.6조 파라미터 모델

Switch-C는 전문가 수를 2,048개로 확장하면서, 각 전문가의 FFN 차원은 줄여 디바이스당 하나의 전문가만 배치하는 전략을 사용합니다.

설계 철학 Switch-XXL Switch-C
목표 품질 최적화 파라미터 수 최대화
전문가 수 128 2,048
전문가 FFN 차원 크게 유지 축소
총 파라미터 ~395B 1.6T
디바이스 수 512 TPU v3 2,048 TPU v3

분산 학습과 병렬화 전략

데이터 병렬 vs 모델 병렬 vs Expert 병렬

Switch Transformers는 세 가지 병렬화 전략을 조합합니다:

데이터 병렬 (Data Parallelism): 모든 코어에 동일한 모델을 복제하고, 데이터를 분할하여 학습합니다. 가장 단순하지만, 모델 전체가 각 디바이스 메모리에 들어가야 합니다.

모델 병렬 (Model Parallelism): 모델의 가중치를 여러 디바이스에 분할합니다. 큰 모델을 다룰 수 있지만, 디바이스 간 통신 오버헤드가 큽니다.

Expert 병렬 (Expert Parallelism): MoE의 핵심 병렬화 전략입니다. 각 전문가를 서로 다른 디바이스에 배치하고, 라우팅 결과에 따라 토큰을 해당 디바이스로 전송합니다.

[입력 토큰 배치]
       
   [라우터] → 전문가 할당 결정
       
  All-to-All 통신
   ↓    ↓    ↓    ↓
[E1]  [E2]  [E3]  [E4]  (각 디바이스)
   ↓    ↓    ↓    ↓
  All-to-All 통신
       
   [출력 수집]

통신 비용 분석

Expert 병렬화에서 가장 큰 비용은 All-to-All 통신입니다. Switch Routing(Top-1)이 Top-2보다 유리한 핵심 이유가 여기에 있습니다:

  • Top-1: 토큰당 All-to-All 1회 (전송 → 전문가 연산 → 수신)
  • Top-2: 토큰당 All-to-All 2회 (2개 전문가로 각각 전송)

실험 결과 분석

사전학습 성능

논문은 C4(Colossal Clean Crawled Corpus) 데이터셋에서 Masked Language Modeling으로 사전학습을 수행합니다.

Switch Transformers는 동일한 연산량 대비 T5보다 일관되게 낮은 perplexity를 달성합니다. 특히:

  • Switch-Base 128은 T5-Base 대비 4.7배 빠른 사전학습 속도
  • Switch-Large 128은 T5-Large 대비 4.0배 빠른 사전학습 속도
  • Switch-XXL 128은 T5-XXL 대비 2.8배 빠른 사전학습 속도

다운스트림 미세 조정 성능

SuperGLUE 벤치마크

모델 SuperGLUE 평균
T5-Base 74.6
Switch-Base 128 80.0
T5-Large 82.7
Switch-Large 128 84.7

Switch-Base 128(~14.7B 파라미터, T5-Base FLOPs)이 T5-Base를 5.4포인트 상회하며, 심지어 3배 큰 T5-Large에 근접하는 성능을 보여줍니다.

다국어 번역 (mC4)

101개 언어에 대한 다국어 사전학습 후, 번역 태스크에서의 성능:

언어쌍 T5-Base Switch-Base 128 향상
En→De 26.2 27.8 +1.6
En→Fr 34.1 35.6 +1.5
En→Ro 26.0 27.2 +1.2
91개 언어 평균 +2.1 BLEU

특히 저자원 언어에서의 성능 향상이 두드러집니다. 이는 각 전문가가 특정 언어 그룹에 특화되면서 발생하는 효과로 해석됩니다.

지식 집약적 태스크

대규모 Sparse 모델이 더 많은 세계 지식을 저장할 수 있는지 검증하기 위해, Closed-Book Question Answering 태스크에서의 성능을 평가합니다:

모델 파라미터 수 TriviaQA Natural Questions WebQuestions
T5-Base 223M 27.0 20.9 19.1
Switch-Base 128 ~14.7B 30.7 23.6 21.5
T5-Large 739M 32.3 25.1 22.3
Switch-Large 128 ~50B 35.2 27.6 24.8
T5-XXL 11B 42.9 34.8 30.6
Switch-XXL 128 ~395B 44.1 35.9 31.8

Sparse 모델의 추가 파라미터가 지식 저장소 역할을 하여, 검색 없이도 더 많은 사실 지식을 기억할 수 있습니다.


Ablation Study

Capacity Factor의 영향

Capacity Factor 학습 품질 속도 메모리
0.5 낮음 (토큰 드롭 많음) 빠름 적음
1.0 중간 중간 중간
1.25 높음 (기본값) 중간 중간
1.5 매우 높음 느림 많음
2.0 최고 가장 느림 가장 많음

Capacity Factor 1.25가 품질과 효율의 최적 균형점으로, 논문 전체에서 기본값으로 사용됩니다.

라우팅 전략 비교

논문은 여러 대안적 라우팅 전략과 비교 실험을 수행합니다:

라우팅 전략 Neg. Log Perplexity ↑ 비고
Switch (Top-1) -1.554 가장 우수
Top-2 (tokens split) -1.550 Top-1보다 약간 낮음
Top-2 (tokens routed) -1.548 통신 비용 2배
Expert Choice -1.543 토큰이 드롭될 수 있음

동일 FLOPs 기준에서 Switch(Top-1)가 가장 우수한 성능을 보여줍니다.

로드 밸런싱 손실 가중치 α\alpha

α\alpha 효과
00 밸런싱 없음 → 특정 전문가에 집중, 성능 저하
10310^{-3} 약한 밸런싱 → 불균형 잔존
10210^{-2} 적절한 밸런싱 → 최적 성능
10110^{-1} 과도한 밸런싱 → 라우팅 품질 저하
11 밸런싱 손실이 지배적 → 학습 실패

α\alpha가 너무 크면 라우터가 “모든 전문가에 균등하게 보내기”에만 집중하여 전문화(specialization)가 일어나지 않습니다. 반대로 너무 작으면 일부 전문가만 과도하게 사용됩니다.


증류 (Distillation)

거대 Sparse 모델을 실제 서비스에 배포하려면, 작은 Dense 모델로 증류하는 것이 현실적입니다.

증류 방법

Switch Transformers의 증류 전략:

  1. 교사 모델: Switch-Base 128 (~14.7B 파라미터)
  2. 학생 모델: T5-Base (223M 파라미터)
  3. 손실 함수:

L<em>distill=(1λ)L</em>MLM+λLKD\mathcal{L}<em>{\text{distill}} = (1 – \lambda) \cdot \mathcal{L}</em>{\text{MLM}} + \lambda \cdot \mathcal{L}_{\text{KD}}

여기서:
L<em>MLM\mathcal{L}<em>{\text{MLM}}: 원래의 Masked Language Modeling 손실
L</em>KD=KL(pteacherpstudent)\mathcal{L}</em>{\text{KD}} = \text{KL}(p_{\text{teacher}} | p_{\text{student}}): 교사-학생 간 KL 발산
λ\lambda: 혼합 비율 (논문에서는 λ=0.25\lambda = 0.25 사용)

증류 결과

모델 파라미터 수 SuperGLUE
T5-Base (기준) 223M 74.6
T5-Base (Switch에서 증류) 223M 77.3
Switch-Base 128 (교사) ~14.7B 80.0

증류를 통해 동일한 크기의 Dense 모델에서 2.7포인트 향상을 달성합니다. 이는 Sparse 교사 모델이 학습한 지식의 약 37%를 성공적으로 전이한 것입니다.


기존 방법론과의 종합 비교

특성 Dense Transformer (T5) GShard MoE Switch Transformer
라우팅 방식 없음 (전체 연산) Top-2 Top-1
전문가 수 수백~수천 수백~수천
FLOPs 대비 파라미터 1:1 비례 ~2x 파라미터/FLOPs 최대 50x+ 파라미터/FLOPs
통신 비용 없음 높음 (Top-2) 낮음 (Top-1)
학습 안정성 높음 보통 높음 (안정화 기법 적용)
로드 밸런싱 불필요 복잡한 손실 함수 단순한 보조 손실
최대 검증 규모 11B (T5-XXL) 600B 1.6T
증류 가능성 해당 없음 미검증 검증됨 (37% 지식 전이)
구현 복잡도 낮음 높음 중간

전문가 특화 분석

논문은 학습된 라우터가 실제로 의미 있는 전문화(specialization)를 보이는지 분석합니다.

토큰 수준 분석

Encoder의 초기 레이어에서, 라우터는 주로 토큰의 해시 값이나 형태적 특성에 기반하여 라우팅합니다. 예를 들어:

  • 전문가 A: 구두점과 특수 문자
  • 전문가 B: 동사와 동사구
  • 전문가 C: 고유명사
  • 전문가 D: 숫자와 날짜

상위 레이어로 갈수록, 라우팅은 의미론적 특성에 더 기반하게 됩니다.

다국어 모델에서의 전문가 특화

다국어 모델에서는 더 명확한 패턴이 관찰됩니다:

  • 특정 전문가가 특정 언어 또는 언어 계통에 특화
  • 형태론적으로 유사한 언어들이 같은 전문가를 공유하는 경향
  • 이는 MoE가 암묵적으로 언어별 파라미터 분할을 학습함을 시사

강점과 한계점

강점

  1. 극단적 단순화의 성공: Top-1 라우팅이라는 가장 단순한 형태가 가장 효과적임을 입증
  2. 실용적 스케일링: 1.6조 파라미터까지 안정적으로 학습 가능함을 시연
  3. 포괄적 분석: 스케일링, 안정화, 증류, 미세 조정까지 전 파이프라인을 다룸
  4. 재현 가능한 레시피: bfloat16 학습, 초기화 스케일링 등 구체적인 학습 레시피 제공
  5. T5와의 직접 비교: 공정한 비교를 위해 동일한 FLOPs 예산에서 실험

한계점

  1. 추론 비용: 파라미터 수가 많아 추론 시 메모리 요구량이 크고, 모든 전문가를 메모리에 올려야 함
  2. 학습 불안정성의 근본 해결 아님: 선택적 정밀도 캐스팅 등은 증상 완화에 가까우며, bfloat16 없이는 여전히 불안정
  3. Fine-tuning 시 과적합: Sparse 모델은 Dense 모델보다 다운스트림 태스크에서 과적합 경향이 강함
  4. 전문가 활용 불균형: 로드 밸런싱 손실에도 불구하고 완벽한 균형은 달성하기 어려움
  5. 통신 오버헤드: 대규모 클러스터에서의 All-to-All 통신은 여전히 병목
  6. 하드웨어 제약: TPU v3 2,048개를 사용한 실험은 대부분의 연구 그룹이 재현하기 어려움

후속 연구 방향

Switch Transformers 이후 MoE 연구는 다음 방향으로 발전하고 있습니다:

1. 라우팅 알고리즘 개선

  • Expert Choice Routing (Zhou et al., 2022): 토큰이 전문가를 선택하는 대신, 전문가가 토큰을 선택하여 완벽한 로드 밸런싱 달성
  • Soft MoE (Puigcerver et al., 2023): 이산적 라우팅을 연속적 가중 합산으로 대체하여 미분 가능한 라우팅 실현
  • Hash Layer (Roller et al., 2021): 학습 가능한 라우터 대신 해시 함수로 결정론적 라우팅

2. 효율적 추론

  • 전문가 오프로딩: 사용하지 않는 전문가를 CPU/SSD로 이동
  • 전문가 가지치기(Pruning): 미세 조정 후 불필요한 전문가 제거
  • 동적 전문가 활성화: 입력 복잡도에 따라 활성 전문가 수 조절

3. 실제 대규모 모델 적용

  • Mixtral 8x7B (Mistral AI, 2024): 8개 전문가 중 Top-2 선택, 오픈소스 MoE 모델
  • DeepSeek-MoE (2024): Fine-grained Expert Segmentation으로 전문가 세분화
  • Grok-1 (xAI, 2024): 314B 파라미터 MoE 모델

4. 이론적 이해 심화

  • MoE의 스케일링 법칙 정립
  • 전문가 특화(specialization)의 이론적 분석
  • 최적 전문가 수와 모델 크기의 관계 규명

핵심 수식 요약

논문의 핵심 수식들을 한눈에 정리합니다:

수식 의미
pi(x)=Softmax(xWr)ip_i(x) = \text{Softmax}(x \cdot W_r)_i 라우터 확률
i=argmaxipi(x)i^* = \arg\max_i \, p_i(x) 최적 전문가 선택
y=pi<em>(x)Ei</em>(x)y = p_{i^<em>}(x) \cdot E_{i^</em>}(x) Switch Layer 출력
Laux=αNfiPi\mathcal{L}_{\text{aux}} = \alpha \cdot N \cdot \sum f_i \cdot P_i 로드 밸런싱 손실
Capacity=(T/N)×CF\text{Capacity} = (T/N) \times \text{CF} 전문가 용량

마무리

Switch Transformers는 “단순함이 최고의 전략”임을 설득력 있게 보여준 논문입니다. 핵심 내용을 요약하면:

  1. Top-1 라우팅의 효과성: 기존의 Top-2 라우팅보다 단순한 Top-1이 동일 연산량에서 더 우수한 성능을 달성합니다. 이는 적은 통신 비용과 더 큰 전문가 수 활용이라는 이중 이점에서 비롯됩니다.

  2. 실용적 스케일링 레시피: 선택적 정밀도 캐스팅, 단순화된 로드 밸런싱 손실, 초기화 스케일링이라는 세 가지 기법의 조합으로 1.6조 파라미터까지 안정적 학습이 가능합니다.

  3. FLOPs 대비 성능: Switch-Base 128은 T5-Base와 동일한 FLOPs를 사용하면서도 4.7배 빠른 사전학습과 SuperGLUE에서 5.4포인트 향상을 달성합니다.

  4. 지식 저장 능력: Sparse 모델의 추가 파라미터는 효과적인 지식 저장소로 기능하여, Closed-Book QA에서 일관된 성능 향상을 보여줍니다.

  5. 증류 가능성: 거대 Sparse 교사 모델의 지식을 작은 Dense 학생 모델로 전이할 수 있으며, 이는 실제 배포 시나리오에서 중요한 의미를 가집니다.

Switch Transformers가 제시한 Sparse MoE 패러다임은 이후 Mixtral, DeepSeek-MoE 등 최신 모델들의 기반이 되었으며, 연산 효율적 스케일링이라는 방향이 LLM 발전의 핵심 축으로 자리잡는 데 결정적인 역할을 했습니다. 모델을 키우되, 연산량은 유지하면서 성능을 높이는 이 접근법은 AI 연구와 산업 모두에서 중요한 설계 원칙으로 남아 있습니다.

Did you find this helpful?

☕ Buy me a coffee

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *

TODAY 432 | TOTAL 2,655