[DiffRate] Differentiable Compression Rate for Efficient Vision Transformers

Hybrid ICCV 2023

Mengzhao Chen, Wenqi Shao, Peng Xu, Mingbao Lin, Kaipeng Zhang, Fei Chao, Rongrong Ji, Yu Qiao, Ping Luo · Xiamen University / OpenGVLab (Shanghai AI Lab) / HKU / Tencent

arXiv GitHub

한 줄 요약. 지금까지 손으로 정하던 layer별 압축률(compression rate)을 미분 가능하게 만들어 gradient로 자동 학습한다. 게다가 pruning과 merging을 한 forward pass에 함께 수행(prune→merge) — off-the-shelf ViT-H를 FLOPs 40%↓, 정확도 −0.16%로, fine-tuning 없이도 EViT·ToMe를 능가.

배경

이 노트부터 Hybrid 계열이다 — pruning과 merging을 함께 쓴다. 앞선 방법들의 공통 약점은 압축률(layer마다 몇 개를 버릴지)을 사람이 정한다는 것이었다.

  • 고정 비율의 한계 — EViT·ToMe는 layer마다 keep ratio / r을 수작업으로 정한다. FLOPs 제약을 맞추려면 layer별 비율을 일일이 튜닝해야 해 번거롭고, 잘못 정하면 informative한 전경 토큰까지 잘려 정확도가 급락한다.
  • pruning과 merging이 분리돼 있었다 — pruning은 배경 토큰 제거에, merging은 비슷한 (전경 포함) 토큰 병합에 강한데, 기존 연구는 둘 중 하나만 썼다.
Figure 1. (a) 같은 2.3G FLOPs에서 DeiT-S를 fine-tuning 없이 압축 — EViT 73.8%, ToMe 78.0%, DiffRate 78.8%. (b) 기존: pruning '또는' merging을 hand-picked 비율로. (c) DiffRate: pruning '과' merging을 미분 가능한 압축률로 동시에.

압축률이 미분 불가능한 하이퍼파라미터라 일일이 정해야 했다 → 이걸 학습 대상으로 만들면?

핵심 아이디어

압축률 $\alpha$를 loss에 대해 미분 가능하게 만들어, layer마다 다른 최적 압축률을 gradient로 자동 학습한다. 그리고 pruning과 merging을 하나의 연산 $f_c = f_m \circ f_p$ 로 통합 — 먼저 덜 중요한 토큰을 자르고(prune), 남은 것 중 비슷한 토큰을 합친다(merge).

Figure 2. token compression은 transformer block의 Attention과 MLP 사이에 들어간다 (ToMe와 같은 위치).

방법 — Differentiable Discrete Proxy (DDP)

핵심 장치는 DDP다. 모든 이미지에서 top-K 중요 토큰을 보존하게 해 (1) batch 병렬 계산이 가능하고 (2) 원본 성능을 최대한 유지한다. 두 부분으로 구성된다.

Figure 3. (a) Token Sorting: class attention Aᶜ로 정렬 → 하위 Nαₚ개 prune → 남은 것 중 N(αₘ−αₚ)개를 비슷한 토큰과 merge. (b) Re-parameterization: 압축률 α를 discrete 후보율 C와 학습 확률 ρ의 조합으로 표현, π→0/1 mask→attention masking.

1) Token Sorting — 중요도로 정렬

EViT처럼 class attention $A_c = \mathrm{Softmax}(q_c K^\top/\sqrt{D})$ 을 중요도로 쓴다 (CLS가 각 토큰에 주는 attention). 추가 연산이 없어 기본값으로 채택. (ablation: ATS의 $A_c\cdot\lVert V\rVert$ 과 거의 동급)

  • 정렬 후 하위 $N\alpha_p$개를 prune, 그다음 남은 토큰 중 $N(\alpha_m-\alpha_p)$개를 cosine similarity로 비슷한 토큰과 merge(평균). 이 sorting→pruning→merging 파이프라인으로 두 기법을 자연스럽게 합친다.

2) Compression Rate Re-parameterization — α를 미분 가능하게

압축률을 직접 학습하면 이미지마다 버리는 개수가 달라져 병렬화가 안 된다. 그래서 이산 후보율 집합 $C={C_1,…,C_N},\ C_k=\frac{k-1}{N}$ 에 학습 가능한 확률 $\rho_k$ 를 부여하고:

\[\alpha = \sum_{k=1}^{N} C_k \rho_k\]
  • 토큰별 압축 확률 $\pi_k$(누적합, $\pi_1=0$으로 최중요 토큰은 항상 보존)을 만들고, $\pi_k$와 $\alpha$를 비교해 0–1 mask로 변환. 덜 중요한 토큰일수록 압축 확률이 크도록 정렬과 일관됨.
  • block마다 pruning용·merging용 mask 2개를 독립적으로 학습. gradient가 끊기지 않도록 토큰 drop을 attention masking(DynamicViT 방식)으로 대체하고 STE로 역전파 → loss의 gradient가 $\rho_k$ 로 흘러 $\alpha$ 가 학습된다.

3) Budget(FLOPs) 제약 학습

\[\mathcal{L} = \mathcal{L}_{cls} + \lambda_f\,(\,\mathcal{F}(\alpha_p,\alpha_m)-T\,)^2\]
  • 목표 FLOPs $T$ 하나만 주면 끝. FLOPs는 압축률의 미분 가능한 함수로 표현. ($\lambda_f=5$)
  • 매우 효율적: pretrained는 고정하고 $\rho$만 학습 → 3 epoch(ViT-B 2.7 GPU-hour) 면 수렴, 1,000장이면 충분, 추가 파라미터·FLOPs는 무시할 수준. latency·power 같은 하드웨어 지표로도 확장 가능.

결과

Table 1. off-the-shelf(파라미터 미수정) 비교 — 같은 FLOPs에서 DiffRate가 EViT·ToMe를 일관되게 능가. ToMe 대비 +0.14~1.14%, EViT 대비 +0.39~4.93%.
  • 학습 없이 SOTA: 학습된 압축률을 pretrained 모델에 바로 적용해도, 학습이 필요한 기존 방법과 동급 이상. 예) off-the-shelf DeiT-S 79.58% > 학습한 EViT 79.50%·ToMe 79.49%.
  • ViT-H (MAE): FLOPs 40%↓, throughput 50%↑, 정확도 −0.16%(fine-tuning 없이).
  • 선택적 fine-tune(†) 30 epoch이면 추가 향상 (예: DeiT-S 79.58→79.83).
  • 압축 스케줄도 거의 최적: 10,000개 랜덤 스케줄과 비교해 pruning-only·merging-only·둘 다 모두에서 최상위. 특히 FLOPs를 강하게 줄일수록 격차가 커짐.

Ablation. 순서가 중요하다 — prune→merge(81.50) > merge→prune(81.18) > merge-only(81.14) > prune-only(80.83). 배경은 pruning으로 지우고, 전경의 덜 중요한 토큰은 merging으로 합치는 조합이 최선임을 보여준다.

Figure 5. 시각화 — 검정=pruned(주로 배경), 같은 색 패치=merged(주로 전경의 덜 변별적인 영역). block이 깊어질수록 배경을 점진적으로 제거하고 전경을 병합 (block 12엔 34개만 남음).

한 줄 정리 & 의의

  • Pruning + Merging을 통합한 첫 Hybrid이자, 압축률 자체를 학습 대상으로 끌어올린 전환점. “버릴 건 버리고(prune), 남길 것 중 비슷한 건 합친다(merge)”를 미분 가능한 비율로 자동 결정.
  • ToMe(고정 r)·EViT(고정 keep ratio)의 hand-crafted schedule을 gradient로 대체 — Adaptive Sparse ViT가 threshold를 학습했다면, DiffRate는 layer별 압축률 자체를 학습한다.
  • 한계 / 이후. 표준 ViT 분류 중심(계층형엔 uncompression 모듈 필요). 이후 흐름은 더 정교한 merging(정보·크기까지 고려한 다기준 fusion → MCTF)이나, 학습 없이 전역 신호로 점수를 매기는 방향(Zero-TPrune)으로 갈라진다. → Token Reduction 개요