[Token Fusion / ToFu] Bridging the Gap between Token Pruning and Token Merging

Hybrid WACV 2024

Minchul Kim, Shangqian Gao, Yen-Chang Hsu, Yilin Shen, Hongxia Jin · Michigan State University / Samsung Research America

arXiv

한 줄 요약. "merging이 항상 pruning보다 나은가?"를 따져, layer의 functional linearity(입력 보간에 대한 출력의 선형성)에 따라 둘을 갈아 쓴다 — linearity가 낮은 초기 layer는 pruning, 높은 후기 layer는 merging(hybrid merging). 또 average merge가 feature norm을 깎아 분포가 틀어지는 문제를 MLERP(norm 보존 평균)로 고친다. 학습 없이 ToMe보다 더 정확하고 더 빠르다(분류·이미지 생성 모두).

배경

ToMe의 등장으로 merging(비슷한 토큰을 평균)이 pruning(그냥 버리기)을 대체하는 듯 보였다. 이 논문은 한 발 물러서 묻는다 — “average merging이 정말 항상 pruning보다 나은가?”

  • pruning vs merging의 본질 차이 — pruning은 기존 토큰 표현을 그대로 남긴다. merging은 두 토큰을 평균 내 입력을 보간(interpolate) 한다. 그런데 뒤따르는 MLP·ATTN은 비선형이라, 보간된 입력의 출력이 두 원본 출력의 중간에 안 놓일 수 있다 → 출력 공간에서 어긋나(misalignment) 정보 손실·분포 이동이 생긴다.
  • 그 정도가 layer마다 다르다 — 블록의 “선형성”이 낮은 곳에서 평균은 위험하고, 높은 곳에서 평균은 안전하며 오히려 여러 토큰 정보를 종합해 이득이다.
Figure 2. (왼쪽) ViT에 ToFu 삽입 — block마다 reduce 연산 R(MLP 앞)을 두고, BSM으로 가장 비슷한 r쌍을 고른다. layer 깊이에 따라 초기엔 pruned merge, 후기엔 average(MLERP) merge로 전환. (오른쪽) 두 토큰 결합 방식 비교 — pruned(하나 버림)·average(평균, norm 감소)·MLERP(평균 방향에 max norm 부여, norm 보존).

평균(merge)이 좋은 layer와 그냥 버리는(prune) 게 좋은 layer가 따로 있다면, layer별로 갈아 쓰면 되지 않을까?

핵심 아이디어

ToFu는 pruning과 merging을 하나의 알고리즘으로 일반화한다. 토대는 ToMe의 BSM(Bipartite Soft Matching) — 토큰을 SRC·DST로 나눠 가장 비슷한 r쌍을 찾는다. 같은 매칭 결과 위에서 세 가지 결합 방식을 고를 수 있게 한다.

  • Average merging — src를 dst 위치에 평균(scatter-reduce mean). ToMe와 동일.
  • Pruned merging — 비슷하니 src를 그냥 버린다(dst는 그대로). 평균 연산이 없어 average보다 빠르다.
  • MLERP merging — 평균이되 norm을 보존(아래 설명).

어느 layer에 무엇을 쓸지는 functional linearity로 정한다.

방법

Functional Linearity — 어디서 평균이 위험한가

함수 $f$ 와 두 토큰 $X_1, X_2$ 에 대해, 출력 사이의 직선 거리를 보간 경로를 따라간 실제 이동 거리로 나눈 값으로 선형성을 잰다:

\[\text{FL}(f, X_1, X_2) = \frac{\| f(X_1) - f(X_2) \|_2}{\sum_{i=1}^{N-1} \Delta f(t_i)}, \qquad X(t) = (1-t)X_1 + tX_2\]

완전히 선형이면 1. 사전학습 ViT-S의 각 block MLP를 ImageNet 5만 장으로 측정하니, 초기 layer는 FL이 낮고 깊어질수록 1.0에 수렴한다. 즉 초기엔 평균이 위험, 후기엔 안전.

Hybrid Merging — layer별 전환

이 관찰을 그대로 규칙으로 쓴다. 전환 깊이 $d$ 를 두고:

\[\text{method} = \text{PRUNE if } l < d \text{ else AVG(MLERP)}\]

초기 $d$개 layer는 pruned merge, 이후는 (M)LERP merge. ablation에서 $d=6$(ViT-B 12층 기준)이 최적. 순서가 결정적 — PRUNE→AVG가 80.17%인데 거꾸로 AVG→PRUNE은 77.73%로 급락한다. “초기 layer에서 입력 공간 보간을 피하라”는 원칙을 정량적으로 보여준다.

MLERP — norm을 보존하는 평균

average merge의 또 다른 약점: $\lVert \frac{x_1+x_2}{2}\rVert \le \max(\lVert x_1\rVert, \lVert x_2\rVert)$ — 평균이 feature norm을 깎는다. norm은 도메인 적응·style transfer에서 보듯 분포의 핵심 통계라, 줄어들면 학습 분포에서 벗어난다(pruning은 남은 토큰 norm을 그대로 둠). SLERP(구면 보간)는 2개 토큰 전용이라 다중 토큰엔 못 쓴다. 그래서 MLERP(Maximum-norm Linear intERPolation):

\[\bar x = \frac{1}{K}\sum_{k=1}^{K} x_k, \qquad \text{MLERP}(\{x_1\dots x_K\}) = \frac{\bar x}{\lVert \bar x\rVert} \times \max_{k} \lVert x_k\rVert\]

평균의 방향은 쓰되 크기는 개별 토큰 중 최대 norm으로 되돌린다. 학습 없이 average를 MLERP로 바꾸기만 해도 성능이 오른다.

추가 파라미터·학습이 전혀 없다 — BSM 매칭 위에 “어느 layer에 prune/merge를 쓸지”와 “merge를 norm 보존형으로” 두 가지만 얹은 것.

결과

ImageNet 분류(MAE·AugReg ViT) + Stable Diffusion 이미지 생성. throughput은 V100 FP32.

Table 1. ViT-B, r별 정확도·속도·FLOPs. ToFu AVG가 hybrid merging으로 ToMe를 정확도·속도 모두 상회, ToFu MLERP는 약간 느리지만 더 정확. 공격적인 r=20에서 격차가 극적 — ToMe 67.54% → ToFu MLERP 74.06%.
  • ToMe를 정확도·속도 모두 추월 — ViT-B r=8: ToMe 82.86% → ToFu AVG 83.19% → MLERP 83.22%(full 83.74%). 압축이 셀수록 격차 확대 — r=20에서 ToMe 67.54% → MLERP 74.06%(+6.5%p). ViT-L처럼 깊은(24층) 모델에서 효과가 더 크다.
  • 두 기여의 분해 — ToMe→ToFu AVG는 hybrid(초기 prune) 효과, ToFu AVG→MLERP는 norm 보존 효과. 둘 다 학습 없이 동작.
Table 4. SoTA 비교(ViT-S). ToFu는 ToMe와 함께 학습 불필요·batch 호환이면서, 같은 FLOPs(2.7G)에서 DynamicViT·SP-ViT·ToMe보다 높은 79.6%.
  • 학습 없이 SOTA — AugReg r=13에서 79.6%, 같은 2.7 GFLOPs의 DynamicViT(79.3, 학습 필요)·ToMe(79.3~79.4)보다 높다. A-ViT·DynamicViT·SP-ViT가 auxiliary 학습이 필요하고 batch 비호환인 반면, ToFu는 추가 파라미터 0·batch 호환.
  • 이미지 생성(Stable Diffusion) — ToMeSD처럼 ATTN 앞에 끼워 50% 감축. ToMe 대비 FID 15.74→14.72, LPIPS 0.313→0.271, MS-SSIM 0.730→0.762로 전부 개선. 초기 layer ATTN이 이미지의 전체 구조를 잡는데, 거기서 pruned merge가 구조를 더 잘 보존하기 때문(보간을 피함).

한 줄 정리 & 의의

  • pruning과 merging을 대립이 아니라 layer별로 갈아 쓰는 Hybrid 계열. “merging은 입력 보간이라 비선형 layer에서 위험”이라는 functional linearity 분석으로, 초기=prune·후기=merge라는 규칙을 정당화한다.
  • 두 개의 가벼운 개선 — (1) hybrid(prune→merge 순서), (2) MLERP(norm 보존 평균). 둘 다 BSM 위에 얹는 무학습 장치라 ToMe를 그대로 대체 가능.
  • 위치. DiffRate가 prune·merge를 한 layer 안에서 비율로 섞고 그 비율을 학습했다면, ToFu는 layer 단위로 둘 중 하나를 통째로 고르고 그 기준(linearity)을 학습 없이 정한다. merging의 norm 손실을 짚은 점은 이후 다기준 fusion([MCTF]) 흐름과도 맞닿는다. → Token Reduction 개요