[TPS] Joint Token Pruning & Squeezing Towards More Aggressive Compression of Vision Transformers

Pruning CVPR 2023

Siyuan Wei, Tianzhu Ye, Shen Zhang, Yao Tang, Jiajun Liang · MEGVII Technology / Tsinghua University

arXiv GitHub

한 줄 요약. Pruning은 토큰을 그냥 버려서 정보 손실이 생기고, 특히 aggressive하게 자를수록 정확도가 급락한다. TPS는 버릴 토큰을 살아남는 host 토큰에 짜 넣는(squeeze) 단계를 추가한다 — 추가 파라미터 없는 nearest-neighbor matching + similarity 기반 fusing. 어떤 pruning 기법에도 끼워 넣을 수 있고, DeiT-S를 35% FLOPs로 줄여도 baseline 대비 정확도 하락을 1~6% 막아낸다.

배경

Pruning 계열(DynamicViT, EViT 등)의 공통 한계는 분명하다 — 버린 토큰의 정보가 영영 사라진다. 점수가 낮다고 판단해 잘라낸 토큰이라도 약간의 정보는 담고 있는데, 점수 예측이 완벽하지 않으면 그 손실이 곧바로 오답으로 이어진다. 압축률을 키울수록(= 더 많이 자를수록) 이 문제가 심해진다.

  • DynamicViT — 학습된 점수 예측 head로 토큰을 버린다. 버린 토큰은 그대로 폐기.
  • EViT — class attention으로 점수를 매겨 자르되, 버릴 토큰들을 하나의 토큰으로 평균 내어 덧붙인다(token reorganization). 손실을 일부 줄이지만, 서로 다른 배경 토큰까지 한 토큰에 뭉뚱그려 정보가 흐려진다.
Figure 1. 같은 입력에서 token pruning(가운데)은 토큰을 통째로 버리지만, TPS(아래)는 버릴 토큰의 정보를 남는 토큰으로 흡수한다. aggressive하게 자를수록 정보 보존 여부가 정확도를 가른다.

점수가 낮은 토큰도 정보를 가지고 있다 → 버리지 말고 비슷한 생존 토큰에 합쳐 넣으면 어떨까?

핵심 아이디어

TPS는 압축을 두 단계로 나눈다 — 먼저 토큰을 자르고(pruning), 잘린 토큰을 남는 토큰에 짜 넣는다(squeezing). 핵심은 squeezing이며, 두 개의 하위 단계 matching → fusing으로 이뤄진다. 이 과정에 추가 파라미터가 전혀 없고, 어떤 pruning 기법 위에도 올릴 수 있다.

토큰을 점수에 따라 두 집합으로 나눈다 — 살리는 reserved set $S^r$ 과 버리는 pruned set $S^p$. squeezing은 $S^p$의 정보를 $S^r$로 옮긴 뒤 $S^r$만 다음 블록으로 보낸다. 즉 출력 토큰 수 = reserved 토큰 수로 일정해, inference graph 최적화의 이점을 그대로 누린다.

Figure 2. (a) Token Pruning: 버릴 토큰을 폐기 → 정보 손실. (b) Token Reorganization(EViT): 버릴 토큰을 한 토큰으로 평균. (c) TPS: 버릴 토큰을 가장 비슷한 host 토큰에 matching 후 fusing해 흡수. host로 뽑히지 않은 reserved 토큰은 그대로 유지된다.

방법

1) Token Pruning — 어떤 기법이든

TPS는 pruning 방식 자체를 새로 만들지 않는다. 공정한 비교를 위해 두 baseline의 pruning을 그대로 가져와 두 변형을 만든다.

  • dTPS (inter-block) — DynamicViT의 학습된 점수 예측 head를 쓰고, 학습 시 Straight-Through Gumbel-Softmax로 binary mask를 미분 가능하게 샘플링한다.
  • eTPS (intra-block) — EViT처럼 class token attention을 토큰 중요도로 쓴다.

추론 시에는 고정 keep ratio $\rho$로 Top-k를 적용해 토큰을 $S^r$(reserved)과 $S^p$(pruned)로 가른다. 두 변형 모두 토큰 수가 고정돼 계산 그래프가 일정하다.

2) Matching — 버릴 토큰을 host에 짝지운다

각 pruned 토큰 $x_i \in S^p$ 에 대해, reserved set에서 가장 비슷한 host 토큰을 찾는다:

\[x_*^{host} = \mathop{\arg\max}\limits_{x_j \in S^r} c_{i,j}\]
  • unidirectional, many-to-one — 매칭은 $S^p \to S^r$ 한 방향이라, 여러 pruned 토큰이 같은 host를 공유할 수 있고 모든 reserved 토큰이 host가 되는 것도 아니다.
  • 매칭 결과는 mask $M \in \mathbb{R}^{N^p \times N^r}$ 로 기록 — host 쌍이면 $m_{i,j}=1$, 아니면 0. 덕분에 다음 fusing을 정규 행렬 연산으로 한 번에 처리하면서 매칭되지 않은 쌍의 영향은 자동으로 배제한다.
  • 유사도는 attention map이 아니라 cosine similarity를 쓴다 (ablation에서 더 높은 성능):
\[c_{i,j} = \frac{x_i^\top x_j}{\lVert x_i \rVert \lVert x_j \rVert}, \quad i \in I^p,\ j \in I^r\]

유사도가 입력 feature에서 바로 계산되므로 matching에 추가 파라미터가 없다.

3) Fusing — 유사도 가중 합

단순 평균은 서로 다른 토큰들이 섞여 feature가 흐려진다. EViT는 중요도 점수로 가중했지만, TPS는 유사도 기반 가중을 써서 host에 더 가까운 토큰의 영향을 키우고, 불완전한 점수 예측의 약점도 피한다. host 토큰 $x_j$ 는 자기 feature와 매칭된 pruned 토큰들을 섞어 갱신된다:

\[y_j = w_j\, x_j + \sum_{x_i \in S^p} w_i\, x_i\] \[w_i = \frac{\exp(c_{i,j})\, m_{i,j}}{\sum_{x_i \in S^p}\exp(c_{i,j})\, m_{i,j} + \mathrm{e}}, \qquad w_j = \frac{\mathrm{e}}{\sum_{x_i \in S^p}\exp(c_{i,j})\, m_{i,j} + \mathrm{e}}\]
  • host 자신($x_j$)은 유사도가 1이므로 분자에 $\mathrm{e}$가 들어가 항상 가장 큰 가중치를 가진다 → 원본 정보가 지배적이고, 비슷한 pruned 토큰이 보조로 섞인다.
  • host로 뽑히지 않은 reserved 토큰은 그대로 남는다. 결과적으로 처리 후 토큰 수 = reserved 수로 shape이 일정하다.

Hybrid ViT로의 확장

plain transformer block에는 TPS 모듈을 그대로 끼우면 된다. 다만 PVT처럼 conv·pooling으로 공간 구조가 필요한 layer에서는, spatial-reduction 단계에서 버린 토큰 자리를 0으로 채워 구조화된 입력을 유지한다. 이로써 계층형(hybrid) transformer에도 일반화됨을 보인다.

결과

평가는 ImageNet-1K와 fine-grained 데이터셋 iNaturalist 2019에서, baseline DynamicViT·EViT의 pruning 모듈을 dTPS·eTPS로 교체해 비교한다 (입력 224×224, 30 epoch fine-tune).

  • 모든 설정에서 baseline 상회 — pruning이 aggressive해질수록 DynamicViT·EViT는 정확도 하락이 커지는데, DeiT를 35% FLOPs까지 줄여도 TPS는 baseline 대비 정확도 하락을 1~6% 막는다.
  • 작은 모델을 능가 — TPS를 단 DeiT-S는 throughput 1745 images/s로 DeiT-tiny(1686 images/s)보다 빠르면서 정확도는 +4.78% 높다.
Figure 7. DeiT는 맞히지만 token pruning을 적용하면 틀리는 ImageNet 사례들. 불완전한 pruning 정책이 정보를 버리는 바람에 DynamicViT는 오답을, dTPS는 정보를 흡수해 정답을 낸다.

한 줄 정리 & 의의

  • Pruning의 본질적 약점인 정보 손실을, 버릴 토큰을 살아남는 host 토큰에 짜 넣어(squeeze) 보완한 방법. matching(파라미터 0) + similarity-based fusing으로 추가 학습 파라미터 없이 어떤 pruning 위에도 얹을 수 있다.
  • EViT의 token reorganization(버릴 토큰을 한 개로 평균)을 한 단계 정교화 — 여러 host에 유사도 가중으로 분배해 정보 혼탁을 줄인다. “버리지 말고 비슷한 곳에 합치자”는 점에서 merging과 사상이 맞닿지만, 어디까지나 pruning을 보강하는 비대칭(unidirectional) 합치기다.
  • 한계 / 이후. keep ratio·pruning 위치는 여전히 손으로 정해야 하고(→ 압축률 자체를 학습한 DiffRate와 대비), squeezing이 약간의 redundant 연산을 더한다. 이후 흐름은 학습 없이 전역 신호로 점수를 매기거나(Zero-TPrune), 정보·크기까지 고려한 다기준 merging(MCTF)으로 이어진다. → Token Reduction 개요