[MCTF] Multi-criteria Token Fusion with One-step-ahead Attention for Efficient Vision Transformers
Sanghyeok Lee, Joonmyung Choi, Hyunwoo J. Kim · Korea University
한 줄 요약. 기존 token fusion은 단일 기준(유사도 또는 중요도)만 봐서, 유사도만 쓰면 전경(foreground) 토큰을 과하게 합치고 중요도만 쓰면 안 닮은 토큰을 합쳐 표현이 붕괴한다. MCTF는 유사도 · 중요도(informativeness) · 크기(size) 세 기준을 곱으로 결합해 합칠 토큰을 고른다. 중요도는 one-step-ahead attention(다음 layer attention)으로 더 정확히 재고, token reduction consistency로 fine-tune한다. DeiT-T/S에서 FLOPs를 약 44% 줄이면서 정확도를 오히려 +0.5%/+0.3% 끌어올린다.
배경
ToMe·Token Pooling처럼 비슷한 토큰을 합치는(fusion) 흐름이 pruning(버리기)을 대체해 왔다. 그런데 대부분의 fusion 방법은 단 하나의 기준으로 합칠 토큰을 고른다 — 그게 문제다.
- 유사도만 보면 — 가장 닮은 토큰을 합치니, 비슷한 픽셀이 많은 전경 객체가 과도하게 뭉쳐(Figure 2b) 핵심 정보를 잃는다.
- 중요도만 보면 — 덜 중요한 토큰을 합치는데, 둘이 전혀 안 닮았어도 합쳐 표현이 붕괴(collapse)한다.
- 합쳐진 토큰의 크기를 무시 — 한 토큰에 너무 많은 토큰이 뭉치면(큰 size), 평균/풀링으로 정보를 보존하기 어려워 손실이 커진다.
토큰 관계는 한 가지 기준으로 다 설명되지 않는다. 여러 기준을 함께 보면 정보 손실을 줄일 수 있지 않을까?
핵심 아이디어
MCTF는 합칠 토큰을 고르는 attraction(끌림) 함수를 여러 기준의 곱으로 정의한다. 점수가 높을수록 합쳐질 확률이 크다.
\[W(x_i, x_j) = \prod_{k=1}^{M} \big(W^k(x_i, x_j)\big)^{\tau_k}\]여기서 $W^k$는 $k$번째 기준의 끌림 함수, $\tau_k$는 그 기준의 영향력을 조절하는 temperature다. MCTF가 쓰는 세 기준은 다음과 같다.
방법
세 가지 기준
- 유사도 (similarity) — 중복 토큰을 합치기 위한 기준. 토큰 간 코사인 유사도를 [0,1]로 정규화한다.
- 중요도 (informativeness) — 중요한 토큰의 합쳐짐을 막기 위한 기준. 각 토큰이 받는 평균 attention $a_j = \frac{1}{N}\sum_i A_{ij}$로 중요도를 재고, 그 역수의 곱을 끌림으로 쓴다. 즉 둘 다 안 중요할수록($a_i, a_j \to 0$) 끌림이 커져 잘 합쳐진다.
- 크기 (size) — 합쳐진 토큰 개수 $s$. 처음엔 모두 1로 두고 합칠 때마다 누적한다. 큰 토큰끼리의 결합을 억제하도록 역수의 곱을 쓴다.
Bidirectional bipartite soft matching
매칭은 ToMe의 BSM을 토대로 한다. 토큰을 두 그룹 $X^\alpha, X^\beta$로 나누고($N’ = \lfloor N/2 \rfloor$), 각 source 토큰마다 끌림이 최대인 edge 하나만 남긴 뒤 그중 top-r을 골라 합친다 — $O(N^2)$ 유사도 계산을 $O(N’^2)$로 줄인다. 합칠 때 풀링은 attention $a$와 size $s$를 가중한 weighted sum으로:
\[\delta(X) = \frac{\sum_i a_i s_i x_i}{\sum_{i'} a_{i'} s_{i'}}\]그런데 단방향 BSM은 target 그룹 $X^\beta$의 토큰 수를 못 줄인다. 그래서 MCTF는 반대 방향으로 한 번 더 매칭한다(Figure 3의 Step 3–4). 이때 양쪽 갱신된 토큰으로 끌림을 다시 계산하면 $O(N’(N’-r))$ 추가 비용이 드는데, $\tilde X^\alpha$가 $X^\alpha$의 부분집합이므로 합치기 전 점수를 재사용해 비용을 거의 0으로 만든다.
One-step-ahead attention
중요도를 잴 때, 기존 방법들(EViT·SPViT 등)은 이전 layer attention $A^l$을 썼다 — “연속 layer attention은 비슷하다”는 가정이다. 하지만 실제로는 연속 layer의 attention이 꽤 다르다(Figure 4). $A^l$로 $X^l$을 합치면, 다음 layer에서 중요해질 토큰을 잘못 합칠 수 있다.
MCTF는 다음 layer의 attention $A^{l+1}$로 중요도를 잰다(one-step-ahead). 합친 뒤에는 self-attention을 다시 계산하지 않고 $A^{l+1}$을 집계(aggregate) 해 fused 토큰의 attention $\hat A^{l+1}$을 근사한다 — query 방향은 합이 1이 되도록 단순 합으로 합친다. 이 근사로 정확도 손실 없이 FLOPs를 0.4G 더 줄인다.
Token reduction consistency
fine-tuning 기법. layer당 reduce 수 $r$이 달라지면 표현도 달라진다. 그래서 고정 $r$ 과 무작위 $r’ \sim \text{uniform}(0, r)$ 두 forward를 돌리고, 마지막 [CLS] 토큰이 서로 닮도록 MSE consistency를 건다:
\[L = L_{\text{CE}}(f_\theta(x; r), y) + L_{\text{CE}}(f_\theta(x; r'), y) + \lambda L_{\text{MSE}}(x^{\text{cls}}_r, x^{\text{cls}}_{r'})\]$r’ < r$이라 $x^{\text{cls}}_{r’}$은 “덜 줄인 = 원본에 가까운” 표현이고, 목표인 $x^{\text{cls}}_r$이 이를 모방하게 해 일종의 token-level augmentation + consistency 정규화로 작동한다.
결과
ImageNet-1K, DeiT-T/S·T2T-ViT·LV-ViT. DeiT는 사전학습 모델을 30 epoch fine-tune. 학습 없이도(ToMe처럼 무파라미터라) 적용 가능.
- base를 능가 — DeiT-T: 1.2→0.7G(−44%)에 72.2→72.7%(+0.5). DeiT-S: 4.6→2.6G(−44%)에 79.8→80.1%(+0.3). 표의 모든 경쟁 방법은 성능이 떨어지는데 MCTF만 올라간다.
- 다른 ViT에도 — T2T-ViT·LV-ViT에서 최소 31% 가속하면서 성능 저하 없음. 특히 LV-ViT에선 다른 모든 reduction이 성능을 깎는데 MCTF만 +0.1%(83.3→83.4).
- 학습 없이도 — pre-trained DeiT에 그대로 적용 시 ToMe를 일관되게 능가. 가장 공격적인 r=20에서 격차가 큼(DeiT-T +7.0%, DeiT-S +3.8%). 무학습 MCTF(r=16, DeiT-S) 79.2%는 학습이 필요한 A-ViT(78.6)·IA-RED²(79.1)보다도 높다.
- 요소 분해(DeiT-S, r=16) — 단일 기준 sim 79.7 / info 79.4 → 이중(sim+info) 79.8 → 삼중(sim+info+size) 80.1. r이 클수록 multi-criteria의 이득이 커진다. one-step-ahead attention과 consistency를 각각 빼면 전 구간에서 정확도가 떨어진다.
- 설계 선택(Table 4) — bidirectional > one-way(80.1 vs 80.0), 풀링은 weighted average가 최선(max·average보다↑), 근사 attention은 precise와 같은 정확도에 0.4G 더 효율적.
한 줄 정리 & 의의
- token fusion을 단일 기준에서 다기준(similarity × informativeness × size)으로 확장한 Merging 계열. 세 기준을 곱(거듭제곱 가중)으로 묶어 “닮았지만 안 중요하고 작은” 토큰을 우선 합치게 한 게 핵심.
- 두 개의 보조 장치 — (1) 중요도를 다음 layer attention으로 재는 one-step-ahead attention, (2) reduce 수에 대한 consistency fine-tuning. 둘 다 빼면 성능이 떨어지는 필수 요소.
- 위치. ToMe의 BSM·무학습 적용성을 그대로 물려받되, 매칭 기준을 다기준으로·양방향으로 키웠다. merging의 norm/정보 손실을 짚은 ToFu와 문제의식이 맞닿지만, ToFu가 layer별로 prune/merge를 갈아 쓴 반면 MCTF는 합칠 토큰 선택 자체를 다기준으로 정교화한다. → Token Reduction 개요