[ToMe] Token Merging: Your ViT But Faster

Merging ICLR 2023

Daniel Bolya, Cheng-Yang Fu, Xiaoliang Dai, Peizhao Zhang, Christoph Feichtenhofer, Judy Hoffman · Georgia Tech / Meta AI

arXiv GitHub

한 줄 요약. 토큰을 버리지 말고 비슷한 것끼리 합치자(merge). attention의 Key로 유사도를 재고, bipartite soft matching으로 매 block마다 r개씩 병합 — 정렬·반복 없이 pruning만큼 빠르면서 학습 없이도(off-the-shelf) ViT-L/H를 2× 가속, 정확도는 0.2~0.3%만 하락.

배경

이 노트부터 Merging 계열이다. 앞선 Pruning(DynamicViT·EViT·ATS…)들은 강력하지만 공통된 약점이 있다.

  • 버리면 정보가 사라진다 — 토큰을 많이 줄일수록 손실이 커져 압축 한계가 생긴다.
  • 대부분 재학습 필요 — 추가 파라미터·prediction module을 학습해야 효과가 난다.
  • batched inference 곤란 — 입력마다 남는 토큰 수가 달라(dynamic) 배치로 묶기 어렵다. 그래서 학습 중엔 실제로 안 버리고 mask만 씌워, pruning의 속도 이점이 사라지기도 한다.

토큰을 합치면(merge) 정보를 덜 잃는다. 게다가 학습 없이 기존 모델에 바로 꽂을 수 있다면?

기존에 합치는 시도(Token Pooling 등)는 느린 k-means라 학습 없이는 정확도가 10~40% 떨어졌다. ToMe는 학습 없이도 쓸 만한 trade-off를 처음으로 보여준 merging 방법이다.

핵심 아이디어

매 transformer block에서 비슷한 토큰 r개를 합쳐 토큰 수를 줄인다. L개 layer를 지나면 총 rL개가 줄어든다. (r은 비율이 아니라 개수 — 입력과 무관하게 항상 같은 수를 줄이므로 batched inference가 가능)

Figure 1. (a) block마다 비슷한 패치가 점진적으로 병합(개 털이 한 토큰으로). (b) ToMe는 attention과 MLP 사이에 끼우는 간단한 모듈. (c) bipartite soft matching — A·B 분할 → 각 A가 가장 비슷한 B로 edge → 상위 r개만 병합.

세 가지 설계가 핵심이다: (1) 무엇이 ‘비슷한가’를 Key로 정의, (2) 빠른 매칭(bipartite soft matching), (3) 합친 토큰의 크기를 추적(proportional attention).

방법

1) Token similarity — Key로 잰다

토큰의 feature(X)는 overparameterized라 노이즈가 섞여 유사도 계산에 부적합하다. 대신 self-attention이 이미 만든 Key(K) 를 쓴다 — K는 각 토큰의 정보를 dot-product 비교용으로 요약한 값이기 때문. head별 K를 평균낸 뒤 cosine similarity로 잰다. (ablation: K > X > Q,V, cosine가 최선)

또 병합 모듈을 block 시작이 아니라 attention과 MLP 사이에 둔다 → attention을 거친 feature로 무엇을 합칠지 판단할 수 있고, 합쳐질 토큰의 정보가 미리 전파돼 정확도가 오른다.

2) Bipartite Soft Matching — 정렬 없는 빠른 병합

k-means·graph cut 같은 반복형 군집화는 너무 느리다(L번 × 수천 토큰). 또 군집화는 한 그룹에 토큰이 무한정 몰릴 수 있어 위험하다. 그래서 매칭으로 한다 — 대부분 토큰은 안 건드리고 일부만 점진적으로 합쳐지도록:

  1. 토큰을 두 집합 A, B로 나눈다 (번갈아 배정 — alternating).
  2. A의 각 토큰에서 가장 비슷한 B 토큰으로 edge 하나를 긋는다.
  3. 그중 유사도 상위 r개 edge만 남긴다.
  4. 연결된 토큰을 합친다 (크기로 가중 평균).
  5. 두 집합을 다시 합쳐 출력.

bipartite 그래프라 연결 컴포넌트 찾기가 자명하고, 모든 쌍을 비교할 필요도 없다. 반복 루프가 전혀 없어 병렬화되고, 랜덤 drop만큼 빠르면서 greedy 매칭만큼 정확하다 (PyTorch 몇 줄로 구현).

3) Proportional attention — 합친 토큰의 크기 추적

토큰을 합치면 더 이상 패치 1개가 아니다. 같은 key를 가진 두 토큰을 합치면 softmax에서 영향력이 부당하게 줄어든다. 크기 s(토큰이 대표하는 패치 수)를 attention에 더해 보정:

\[A = \mathrm{softmax}\!\left(\frac{QK^\top}{\sqrt{d}} + \log s\right)\]

이는 key를 s번 복제한 것과 같은 효과. 병합·집계 때도 s로 가중한다. (지도학습 모델엔 필수지만, MAE pretraining 모델은 이미 토큰을 빼며 학습돼 불필요)

학습할 수도 있다 (선택)

ToMe는 학습 없이 붙이는 게 기본이지만, fine-tuning에 끼우면 정확도 손실을 더 줄이고 학습 자체를 2× 빠르게 한다. 이때 병합을 average pooling처럼 보고 그냥 backprop — Gumbel-softmax 같은 gradient trick이 필요 없다(pruning과의 큰 차이).

결과

Table 1. ViT-L/16(MAE) off-the-shelf ablation — Key + cosine + 크기 가중 평균 + alternating 분할이 최선. r=8로 24 layer에서 토큰 98%를 줄여도 84.25%.
  • 학습 없이 2× throughput: ViT-L @512 (SWAG) −0.3%, ViT-H @518 −0.3%, ViT-L video −0.2%(2.2×). 큰 모델일수록 손실이 작다 — 깊어서 feature 변화가 점진적이라 병합 충격이 작기 때문.
  • pruning보다 우수: 토큰 98% 제거 상황에서 랜덤·attention pruning은 79%대로 무너지지만, bipartite merging은 84.25% 유지. 합치기는 비슷한 토큰만 합칠 때 정보를 거의 안 잃는다.
Table 3(좌)·4(우). 좌: ViT가 한 tier 위 모델(Swin·MViTv2)급 속도로 advancing a tier. 우: DeiT-S에서 DynamicViT·A-ViT·SP-ViT 등 pruning과 동급 정확도 + 더 높은 throughput, 게다가 학습도 1.5× 빠름.
  • token pruning과 비교(DeiT-S): gradient trick·추가 파라미터 없이 DynamicViT·A-ViT·SP-ViT의 성능을 따라잡고 throughput은 능가. pruning은 학습 중 padding/mask가 필요해 학습이 안 빨라지지만 ToMe는 학습 1.5× 가속. off-the-shelf AugReg ViT-S에 같은 schedule만 적용해도 학습형 pruning과 동급.
  • 이미지·비디오·오디오 전부에 코드 변경 없이 적용 — 비디오는 2.2× throughput에 학습시간 절반, 오디오는 2× throughput에 mAP −0.4%.
Figure 4. 병합 시각화 — 같은 색=한 토큰으로 병합. pruning과 달리 전경·배경 가리지 않고 비슷한 부분을 합쳐, 마치 part segmentation처럼 동작(개의 다리·몸·얼굴이 각각 한 토큰).
  • merging schedule: 기본은 layer마다 일정한 r(constant)인데, 15,000개 랜덤 schedule과 비교해도 거의 최적. 더 공격적으로 줄일 땐 초반에 많이 줄이는 decreasing schedule이 유리.

한 줄 정리 & 의의

  • Pruning → Merging의 전환점. “버려서 정보를 잃는다”는 pruning의 근본 한계를, 비슷한 토큰을 합쳐 정보를 보존하는 방식으로 우회. 게다가 학습 없이 동작하는 첫 실용적 merging 방법.
  • Key 기반 유사도 + bipartite soft matching이라는 단순·병렬·학습불요 조합이 정체성. 정렬도 모듈 학습도 없이 pruning 속도를 낸다.
  • 한계 / 이후. 합칠지 말지를 정해진 r·고정 schedule로만 제어한다. 이후 연구는 pruning과 merging을 함께·미분가능하게 자동 결정하는 Hybrid(DiffRate, ToFu)나, merging 전용 임베딩을 따로 학습하는 방향(DTEM)으로 확장. → Token Reduction 개요