[MCTF] Multi-criteria Token Fusion with One-step-ahead Attention for Efficient Vision Transformers

Merging CVPR 2024

Sanghyeok Lee, Joonmyung Choi, Hyunwoo J. Kim · Korea University

arXiv GitHub

한 줄 요약. 기존 token fusion은 단일 기준(유사도 또는 중요도)만 봐서, 유사도만 쓰면 전경(foreground) 토큰을 과하게 합치고 중요도만 쓰면 안 닮은 토큰을 합쳐 표현이 붕괴한다. MCTF는 유사도 · 중요도(informativeness) · 크기(size) 세 기준을 으로 결합해 합칠 토큰을 고른다. 중요도는 one-step-ahead attention(다음 layer attention)으로 더 정확히 재고, token reduction consistency로 fine-tune한다. DeiT-T/S에서 FLOPs를 약 44% 줄이면서 정확도를 오히려 +0.5%/+0.3% 끌어올린다.

배경

ToMe·Token Pooling처럼 비슷한 토큰을 합치는(fusion) 흐름이 pruning(버리기)을 대체해 왔다. 그런데 대부분의 fusion 방법은 단 하나의 기준으로 합칠 토큰을 고른다 — 그게 문제다.

  • 유사도만 보면 — 가장 닮은 토큰을 합치니, 비슷한 픽셀이 많은 전경 객체가 과도하게 뭉쳐(Figure 2b) 핵심 정보를 잃는다.
  • 중요도만 보면 — 덜 중요한 토큰을 합치는데, 둘이 전혀 안 닮았어도 합쳐 표현이 붕괴(collapse)한다.
  • 합쳐진 토큰의 크기를 무시 — 한 토큰에 너무 많은 토큰이 뭉치면(큰 size), 평균/풀링으로 정보를 보존하기 어려워 손실이 커진다.
Figure 1. DeiT-T(왼쪽)·DeiT-S(오른쪽)에서 token reduction 비교. 파란 원이 base 모델. 기존 방법들은 정확도-FLOPs를 맞바꾸지만, MCTF(별)는 FLOPs를 약 44% 줄이고도 base보다 정확도가 높다. r=16(빨간 별)으로 한 번만 fine-tune한 뒤 reduce 수를 바꿔 다양한 FLOPs로 평가한 것.

토큰 관계는 한 가지 기준으로 다 설명되지 않는다. 여러 기준을 함께 보면 정보 손실을 줄일 수 있지 않을까?

핵심 아이디어

MCTF는 합칠 토큰을 고르는 attraction(끌림) 함수를 여러 기준의 으로 정의한다. 점수가 높을수록 합쳐질 확률이 크다.

\[W(x_i, x_j) = \prod_{k=1}^{M} \big(W^k(x_i, x_j)\big)^{\tau_k}\]

여기서 $W^k$는 $k$번째 기준의 끌림 함수, $\tau_k$는 그 기준의 영향력을 조절하는 temperature다. MCTF가 쓰는 세 기준은 다음과 같다.

Figure 2. 합쳐진 토큰 시각화. (b) 유사도(Wsim)만 쓰면 전경 객체가 과하게 뭉친다. (c) 유사도+중요도(Wsim&Winfo)를 더하면 전경은 덜 합쳐지고 배경이 크게 합쳐진다. (d) 크기(Wsize)까지 더한 MCTF는 큰 토큰을 억제해 각 구성요소의 정보를 보존한다.

방법

세 가지 기준

  • 유사도 (similarity) — 중복 토큰을 합치기 위한 기준. 토큰 간 코사인 유사도를 [0,1]로 정규화한다.
\[W^{\text{sim}}(x_i, x_j) = \frac{1}{2}\left(\frac{x_i \cdot x_j}{\lVert x_i\rVert \lVert x_j\rVert} + 1\right)\]
  • 중요도 (informativeness) — 중요한 토큰의 합쳐짐을 막기 위한 기준. 각 토큰이 받는 평균 attention $a_j = \frac{1}{N}\sum_i A_{ij}$로 중요도를 재고, 그 역수의 곱을 끌림으로 쓴다. 즉 둘 다 안 중요할수록($a_i, a_j \to 0$) 끌림이 커져 잘 합쳐진다.
\[W^{\text{info}}(x_i, x_j) = \frac{1}{a_i a_j}\]
  • 크기 (size) — 합쳐진 토큰 개수 $s$. 처음엔 모두 1로 두고 합칠 때마다 누적한다. 큰 토큰끼리의 결합을 억제하도록 역수의 곱을 쓴다.
\[W^{\text{size}}(x_i, x_j) = \frac{1}{s_i s_j}\]

Bidirectional bipartite soft matching

매칭은 ToMe의 BSM을 토대로 한다. 토큰을 두 그룹 $X^\alpha, X^\beta$로 나누고($N’ = \lfloor N/2 \rfloor$), 각 source 토큰마다 끌림이 최대인 edge 하나만 남긴 뒤 그중 top-r을 골라 합친다 — $O(N^2)$ 유사도 계산을 $O(N’^2)$로 줄인다. 합칠 때 풀링은 attention $a$와 size $s$를 가중한 weighted sum으로:

\[\delta(X) = \frac{\sum_i a_i s_i x_i}{\sum_{i'} a_{i'} s_{i'}}\]

그런데 단방향 BSM은 target 그룹 $X^\beta$의 토큰 수를 못 줄인다. 그래서 MCTF는 반대 방향으로 한 번 더 매칭한다(Figure 3의 Step 3–4). 이때 양쪽 갱신된 토큰으로 끌림을 다시 계산하면 $O(N’(N’-r))$ 추가 비용이 드는데, $\tilde X^\alpha$가 $X^\alpha$의 부분집합이므로 합치기 전 점수를 재사용해 비용을 거의 0으로 만든다.

Figure 3. Bidirectional bipartite soft matching. 토큰을 Xα, Xβ로 나눠 Step 1–2에서 한 방향, Step 3–4에서 반대 방향으로 매칭한다. 선의 진하기가 multi-criteria 끌림 Wt를 나타낸다. 양방향이라 두 그룹 모두 토큰 수가 준다.

One-step-ahead attention

중요도를 잴 때, 기존 방법들(EViT·SPViT 등)은 이전 layer attention $A^l$을 썼다 — “연속 layer attention은 비슷하다”는 가정이다. 하지만 실제로는 연속 layer의 attention이 꽤 다르다(Figure 4). $A^l$로 $X^l$을 합치면, 다음 layer에서 중요해질 토큰을 잘못 합칠 수 있다.

MCTF는 다음 layer의 attention $A^{l+1}$로 중요도를 잰다(one-step-ahead). 합친 뒤에는 self-attention을 다시 계산하지 않고 $A^{l+1}$을 집계(aggregate) 해 fused 토큰의 attention $\hat A^{l+1}$을 근사한다 — query 방향은 합이 1이 되도록 단순 합으로 합친다. 이 근사로 정확도 손실 없이 FLOPs를 0.4G 더 줄인다.

Token reduction consistency

fine-tuning 기법. layer당 reduce 수 $r$이 달라지면 표현도 달라진다. 그래서 고정 $r$무작위 $r’ \sim \text{uniform}(0, r)$ 두 forward를 돌리고, 마지막 [CLS] 토큰이 서로 닮도록 MSE consistency를 건다:

\[L = L_{\text{CE}}(f_\theta(x; r), y) + L_{\text{CE}}(f_\theta(x; r'), y) + \lambda L_{\text{MSE}}(x^{\text{cls}}_r, x^{\text{cls}}_{r'})\]

$r’ < r$이라 $x^{\text{cls}}_{r’}$은 “덜 줄인 = 원본에 가까운” 표현이고, 목표인 $x^{\text{cls}}_r$이 이를 모방하게 해 일종의 token-level augmentation + consistency 정규화로 작동한다.

결과

ImageNet-1K, DeiT-T/S·T2T-ViT·LV-ViT. DeiT는 사전학습 모델을 30 epoch fine-tune. 학습 없이도(ToMe처럼 무파라미터라) 적용 가능.

Table 1. ImageNet 분류. MCTF(r=16)는 DeiT-T/S 모두에서 가장 낮은 FLOPs로 최고 정확도. 유일하게 성능 저하 없이(오히려 +0.5%/+0.3%) 압축한다.
  • base를 능가 — DeiT-T: 1.2→0.7G(−44%)에 72.2→72.7%(+0.5). DeiT-S: 4.6→2.6G(−44%)에 79.8→80.1%(+0.3). 표의 모든 경쟁 방법은 성능이 떨어지는데 MCTF만 올라간다.
  • 다른 ViT에도 — T2T-ViT·LV-ViT에서 최소 31% 가속하면서 성능 저하 없음. 특히 LV-ViT에선 다른 모든 reduction이 성능을 깎는데 MCTF만 +0.1%(83.3→83.4).
  • 학습 없이도 — pre-trained DeiT에 그대로 적용 시 ToMe를 일관되게 능가. 가장 공격적인 r=20에서 격차가 큼(DeiT-T +7.0%, DeiT-S +3.8%). 무학습 MCTF(r=16, DeiT-S) 79.2%는 학습이 필요한 A-ViT(78.6)·IA-RED²(79.1)보다도 높다.
  • 요소 분해(DeiT-S, r=16) — 단일 기준 sim 79.7 / info 79.4 → 이중(sim+info) 79.8 → 삼중(sim+info+size) 80.1. r이 클수록 multi-criteria의 이득이 커진다. one-step-ahead attention과 consistency를 각각 빼면 전 구간에서 정확도가 떨어진다.
  • 설계 선택(Table 4) — bidirectional > one-way(80.1 vs 80.0), 풀링은 weighted average가 최선(max·average보다↑), 근사 attention은 precise와 같은 정확도에 0.4G 더 효율적.

한 줄 정리 & 의의

  • token fusion을 단일 기준에서 다기준(similarity × informativeness × size)으로 확장한 Merging 계열. 세 기준을 (거듭제곱 가중)으로 묶어 “닮았지만 안 중요하고 작은” 토큰을 우선 합치게 한 게 핵심.
  • 두 개의 보조 장치 — (1) 중요도를 다음 layer attention으로 재는 one-step-ahead attention, (2) reduce 수에 대한 consistency fine-tuning. 둘 다 빼면 성능이 떨어지는 필수 요소.
  • 위치. ToMe의 BSM·무학습 적용성을 그대로 물려받되, 매칭 기준을 다기준으로·양방향으로 키웠다. merging의 norm/정보 손실을 짚은 ToFu와 문제의식이 맞닿지만, ToFu가 layer별로 prune/merge를 갈아 쓴 반면 MCTF는 합칠 토큰 선택 자체를 다기준으로 정교화한다. → Token Reduction 개요