[ToMe] Token Merging: Your ViT But Faster
Daniel Bolya, Cheng-Yang Fu, Xiaoliang Dai, Peizhao Zhang, Christoph Feichtenhofer, Judy Hoffman · Georgia Tech / Meta AI
한 줄 요약. 토큰을 버리지 말고 비슷한 것끼리 합치자(merge). attention의 Key로 유사도를 재고, bipartite soft matching으로 매 block마다 r개씩 병합 — 정렬·반복 없이 pruning만큼 빠르면서 학습 없이도(off-the-shelf) ViT-L/H를 2× 가속, 정확도는 0.2~0.3%만 하락.
배경
이 노트부터 Merging 계열이다. 앞선 Pruning(DynamicViT·EViT·ATS…)들은 강력하지만 공통된 약점이 있다.
- 버리면 정보가 사라진다 — 토큰을 많이 줄일수록 손실이 커져 압축 한계가 생긴다.
- 대부분 재학습 필요 — 추가 파라미터·prediction module을 학습해야 효과가 난다.
- batched inference 곤란 — 입력마다 남는 토큰 수가 달라(dynamic) 배치로 묶기 어렵다. 그래서 학습 중엔 실제로 안 버리고 mask만 씌워, pruning의 속도 이점이 사라지기도 한다.
토큰을 합치면(merge) 정보를 덜 잃는다. 게다가 학습 없이 기존 모델에 바로 꽂을 수 있다면?
기존에 합치는 시도(Token Pooling 등)는 느린 k-means라 학습 없이는 정확도가 10~40% 떨어졌다. ToMe는 학습 없이도 쓸 만한 trade-off를 처음으로 보여준 merging 방법이다.
핵심 아이디어
매 transformer block에서 비슷한 토큰 r개를 합쳐 토큰 수를 줄인다. L개 layer를 지나면 총 rL개가 줄어든다. (r은 비율이 아니라 개수 — 입력과 무관하게 항상 같은 수를 줄이므로 batched inference가 가능)
세 가지 설계가 핵심이다: (1) 무엇이 ‘비슷한가’를 Key로 정의, (2) 빠른 매칭(bipartite soft matching), (3) 합친 토큰의 크기를 추적(proportional attention).
방법
1) Token similarity — Key로 잰다
토큰의 feature(X)는 overparameterized라 노이즈가 섞여 유사도 계산에 부적합하다. 대신 self-attention이 이미 만든 Key(K) 를 쓴다 — K는 각 토큰의 정보를 dot-product 비교용으로 요약한 값이기 때문. head별 K를 평균낸 뒤 cosine similarity로 잰다. (ablation: K > X > Q,V, cosine가 최선)
또 병합 모듈을 block 시작이 아니라 attention과 MLP 사이에 둔다 → attention을 거친 feature로 무엇을 합칠지 판단할 수 있고, 합쳐질 토큰의 정보가 미리 전파돼 정확도가 오른다.
2) Bipartite Soft Matching — 정렬 없는 빠른 병합
k-means·graph cut 같은 반복형 군집화는 너무 느리다(L번 × 수천 토큰). 또 군집화는 한 그룹에 토큰이 무한정 몰릴 수 있어 위험하다. 그래서 매칭으로 한다 — 대부분 토큰은 안 건드리고 일부만 점진적으로 합쳐지도록:
- 토큰을 두 집합 A, B로 나눈다 (번갈아 배정 — alternating).
- A의 각 토큰에서 가장 비슷한 B 토큰으로 edge 하나를 긋는다.
- 그중 유사도 상위 r개 edge만 남긴다.
- 연결된 토큰을 합친다 (크기로 가중 평균).
- 두 집합을 다시 합쳐 출력.
bipartite 그래프라 연결 컴포넌트 찾기가 자명하고, 모든 쌍을 비교할 필요도 없다. 반복 루프가 전혀 없어 병렬화되고, 랜덤 drop만큼 빠르면서 greedy 매칭만큼 정확하다 (PyTorch 몇 줄로 구현).
3) Proportional attention — 합친 토큰의 크기 추적
토큰을 합치면 더 이상 패치 1개가 아니다. 같은 key를 가진 두 토큰을 합치면 softmax에서 영향력이 부당하게 줄어든다. 크기 s(토큰이 대표하는 패치 수)를 attention에 더해 보정:
\[A = \mathrm{softmax}\!\left(\frac{QK^\top}{\sqrt{d}} + \log s\right)\]이는 key를 s번 복제한 것과 같은 효과. 병합·집계 때도 s로 가중한다. (지도학습 모델엔 필수지만, MAE pretraining 모델은 이미 토큰을 빼며 학습돼 불필요)
학습할 수도 있다 (선택)
ToMe는 학습 없이 붙이는 게 기본이지만, fine-tuning에 끼우면 정확도 손실을 더 줄이고 학습 자체를 2× 빠르게 한다. 이때 병합을 average pooling처럼 보고 그냥 backprop — Gumbel-softmax 같은 gradient trick이 필요 없다(pruning과의 큰 차이).
결과
- 학습 없이 2× throughput: ViT-L @512 (SWAG) −0.3%, ViT-H @518 −0.3%, ViT-L video −0.2%(2.2×). 큰 모델일수록 손실이 작다 — 깊어서 feature 변화가 점진적이라 병합 충격이 작기 때문.
- pruning보다 우수: 토큰 98% 제거 상황에서 랜덤·attention pruning은 79%대로 무너지지만, bipartite merging은 84.25% 유지. 합치기는 비슷한 토큰만 합칠 때 정보를 거의 안 잃는다.
- token pruning과 비교(DeiT-S): gradient trick·추가 파라미터 없이 DynamicViT·A-ViT·SP-ViT의 성능을 따라잡고 throughput은 능가. pruning은 학습 중 padding/mask가 필요해 학습이 안 빨라지지만 ToMe는 학습 1.5× 가속. off-the-shelf AugReg ViT-S에 같은 schedule만 적용해도 학습형 pruning과 동급.
- 이미지·비디오·오디오 전부에 코드 변경 없이 적용 — 비디오는 2.2× throughput에 학습시간 절반, 오디오는 2× throughput에 mAP −0.4%.
- merging schedule: 기본은 layer마다 일정한 r(constant)인데, 15,000개 랜덤 schedule과 비교해도 거의 최적. 더 공격적으로 줄일 땐 초반에 많이 줄이는 decreasing schedule이 유리.
한 줄 정리 & 의의
- Pruning → Merging의 전환점. “버려서 정보를 잃는다”는 pruning의 근본 한계를, 비슷한 토큰을 합쳐 정보를 보존하는 방식으로 우회. 게다가 학습 없이 동작하는 첫 실용적 merging 방법.
- Key 기반 유사도 + bipartite soft matching이라는 단순·병렬·학습불요 조합이 정체성. 정렬도 모듈 학습도 없이 pruning 속도를 낸다.
- 한계 / 이후. 합칠지 말지를 정해진 r·고정 schedule로만 제어한다. 이후 연구는 pruning과 merging을 함께·미분가능하게 자동 결정하는 Hybrid(DiffRate, ToFu)나, merging 전용 임베딩을 따로 학습하는 방향(DTEM)으로 확장. → Token Reduction 개요