[TPS] Joint Token Pruning & Squeezing Towards More Aggressive Compression of Vision Transformers
Siyuan Wei, Tianzhu Ye, Shen Zhang, Yao Tang, Jiajun Liang · MEGVII Technology / Tsinghua University
한 줄 요약. Pruning은 토큰을 그냥 버려서 정보 손실이 생기고, 특히 aggressive하게 자를수록 정확도가 급락한다. TPS는 버릴 토큰을 살아남는 host 토큰에 짜 넣는(squeeze) 단계를 추가한다 — 추가 파라미터 없는 nearest-neighbor matching + similarity 기반 fusing. 어떤 pruning 기법에도 끼워 넣을 수 있고, DeiT-S를 35% FLOPs로 줄여도 baseline 대비 정확도 하락을 1~6% 막아낸다.
배경
Pruning 계열(DynamicViT, EViT 등)의 공통 한계는 분명하다 — 버린 토큰의 정보가 영영 사라진다. 점수가 낮다고 판단해 잘라낸 토큰이라도 약간의 정보는 담고 있는데, 점수 예측이 완벽하지 않으면 그 손실이 곧바로 오답으로 이어진다. 압축률을 키울수록(= 더 많이 자를수록) 이 문제가 심해진다.
- DynamicViT — 학습된 점수 예측 head로 토큰을 버린다. 버린 토큰은 그대로 폐기.
- EViT — class attention으로 점수를 매겨 자르되, 버릴 토큰들을 하나의 토큰으로 평균 내어 덧붙인다(token reorganization). 손실을 일부 줄이지만, 서로 다른 배경 토큰까지 한 토큰에 뭉뚱그려 정보가 흐려진다.
점수가 낮은 토큰도 정보를 가지고 있다 → 버리지 말고 비슷한 생존 토큰에 합쳐 넣으면 어떨까?
핵심 아이디어
TPS는 압축을 두 단계로 나눈다 — 먼저 토큰을 자르고(pruning), 잘린 토큰을 남는 토큰에 짜 넣는다(squeezing). 핵심은 squeezing이며, 두 개의 하위 단계 matching → fusing으로 이뤄진다. 이 과정에 추가 파라미터가 전혀 없고, 어떤 pruning 기법 위에도 올릴 수 있다.
토큰을 점수에 따라 두 집합으로 나눈다 — 살리는 reserved set $S^r$ 과 버리는 pruned set $S^p$. squeezing은 $S^p$의 정보를 $S^r$로 옮긴 뒤 $S^r$만 다음 블록으로 보낸다. 즉 출력 토큰 수 = reserved 토큰 수로 일정해, inference graph 최적화의 이점을 그대로 누린다.
방법
1) Token Pruning — 어떤 기법이든
TPS는 pruning 방식 자체를 새로 만들지 않는다. 공정한 비교를 위해 두 baseline의 pruning을 그대로 가져와 두 변형을 만든다.
- dTPS (inter-block) — DynamicViT의 학습된 점수 예측 head를 쓰고, 학습 시 Straight-Through Gumbel-Softmax로 binary mask를 미분 가능하게 샘플링한다.
- eTPS (intra-block) — EViT처럼 class token attention을 토큰 중요도로 쓴다.
추론 시에는 고정 keep ratio $\rho$로 Top-k를 적용해 토큰을 $S^r$(reserved)과 $S^p$(pruned)로 가른다. 두 변형 모두 토큰 수가 고정돼 계산 그래프가 일정하다.
2) Matching — 버릴 토큰을 host에 짝지운다
각 pruned 토큰 $x_i \in S^p$ 에 대해, reserved set에서 가장 비슷한 host 토큰을 찾는다:
\[x_*^{host} = \mathop{\arg\max}\limits_{x_j \in S^r} c_{i,j}\]- unidirectional, many-to-one — 매칭은 $S^p \to S^r$ 한 방향이라, 여러 pruned 토큰이 같은 host를 공유할 수 있고 모든 reserved 토큰이 host가 되는 것도 아니다.
- 매칭 결과는 mask $M \in \mathbb{R}^{N^p \times N^r}$ 로 기록 — host 쌍이면 $m_{i,j}=1$, 아니면 0. 덕분에 다음 fusing을 정규 행렬 연산으로 한 번에 처리하면서 매칭되지 않은 쌍의 영향은 자동으로 배제한다.
- 유사도는 attention map이 아니라 cosine similarity를 쓴다 (ablation에서 더 높은 성능):
유사도가 입력 feature에서 바로 계산되므로 matching에 추가 파라미터가 없다.
3) Fusing — 유사도 가중 합
단순 평균은 서로 다른 토큰들이 섞여 feature가 흐려진다. EViT는 중요도 점수로 가중했지만, TPS는 유사도 기반 가중을 써서 host에 더 가까운 토큰의 영향을 키우고, 불완전한 점수 예측의 약점도 피한다. host 토큰 $x_j$ 는 자기 feature와 매칭된 pruned 토큰들을 섞어 갱신된다:
\[y_j = w_j\, x_j + \sum_{x_i \in S^p} w_i\, x_i\] \[w_i = \frac{\exp(c_{i,j})\, m_{i,j}}{\sum_{x_i \in S^p}\exp(c_{i,j})\, m_{i,j} + \mathrm{e}}, \qquad w_j = \frac{\mathrm{e}}{\sum_{x_i \in S^p}\exp(c_{i,j})\, m_{i,j} + \mathrm{e}}\]- host 자신($x_j$)은 유사도가 1이므로 분자에 $\mathrm{e}$가 들어가 항상 가장 큰 가중치를 가진다 → 원본 정보가 지배적이고, 비슷한 pruned 토큰이 보조로 섞인다.
- host로 뽑히지 않은 reserved 토큰은 그대로 남는다. 결과적으로 처리 후 토큰 수 = reserved 수로 shape이 일정하다.
Hybrid ViT로의 확장
plain transformer block에는 TPS 모듈을 그대로 끼우면 된다. 다만 PVT처럼 conv·pooling으로 공간 구조가 필요한 layer에서는, spatial-reduction 단계에서 버린 토큰 자리를 0으로 채워 구조화된 입력을 유지한다. 이로써 계층형(hybrid) transformer에도 일반화됨을 보인다.
결과
평가는 ImageNet-1K와 fine-grained 데이터셋 iNaturalist 2019에서, baseline DynamicViT·EViT의 pruning 모듈을 dTPS·eTPS로 교체해 비교한다 (입력 224×224, 30 epoch fine-tune).
- 모든 설정에서 baseline 상회 — pruning이 aggressive해질수록 DynamicViT·EViT는 정확도 하락이 커지는데, DeiT를 35% FLOPs까지 줄여도 TPS는 baseline 대비 정확도 하락을 1~6% 막는다.
- 작은 모델을 능가 — TPS를 단 DeiT-S는 throughput 1745 images/s로 DeiT-tiny(1686 images/s)보다 빠르면서 정확도는 +4.78% 높다.
한 줄 정리 & 의의
- Pruning의 본질적 약점인 정보 손실을, 버릴 토큰을 살아남는 host 토큰에 짜 넣어(squeeze) 보완한 방법. matching(파라미터 0) + similarity-based fusing으로 추가 학습 파라미터 없이 어떤 pruning 위에도 얹을 수 있다.
- EViT의 token reorganization(버릴 토큰을 한 개로 평균)을 한 단계 정교화 — 여러 host에 유사도 가중으로 분배해 정보 혼탁을 줄인다. “버리지 말고 비슷한 곳에 합치자”는 점에서 merging과 사상이 맞닿지만, 어디까지나 pruning을 보강하는 비대칭(unidirectional) 합치기다.
- 한계 / 이후. keep ratio·pruning 위치는 여전히 손으로 정해야 하고(→ 압축률 자체를 학습한 DiffRate와 대비), squeezing이 약간의 redundant 연산을 더한다. 이후 흐름은 학습 없이 전역 신호로 점수를 매기거나(Zero-TPrune), 정보·크기까지 고려한 다기준 merging(MCTF)으로 이어진다. → Token Reduction 개요