[DiffRate] Differentiable Compression Rate for Efficient Vision Transformers
Mengzhao Chen, Wenqi Shao, Peng Xu, Mingbao Lin, Kaipeng Zhang, Fei Chao, Rongrong Ji, Yu Qiao, Ping Luo · Xiamen University / OpenGVLab (Shanghai AI Lab) / HKU / Tencent
한 줄 요약. 지금까지 손으로 정하던 layer별 압축률(compression rate)을 미분 가능하게 만들어 gradient로 자동 학습한다. 게다가 pruning과 merging을 한 forward pass에 함께 수행(prune→merge) — off-the-shelf ViT-H를 FLOPs 40%↓, 정확도 −0.16%로, fine-tuning 없이도 EViT·ToMe를 능가.
배경
이 노트부터 Hybrid 계열이다 — pruning과 merging을 함께 쓴다. 앞선 방법들의 공통 약점은 압축률(layer마다 몇 개를 버릴지)을 사람이 정한다는 것이었다.
- 고정 비율의 한계 — EViT·ToMe는 layer마다 keep ratio / r을 수작업으로 정한다. FLOPs 제약을 맞추려면 layer별 비율을 일일이 튜닝해야 해 번거롭고, 잘못 정하면 informative한 전경 토큰까지 잘려 정확도가 급락한다.
- pruning과 merging이 분리돼 있었다 — pruning은 배경 토큰 제거에, merging은 비슷한 (전경 포함) 토큰 병합에 강한데, 기존 연구는 둘 중 하나만 썼다.
압축률이 미분 불가능한 하이퍼파라미터라 일일이 정해야 했다 → 이걸 학습 대상으로 만들면?
핵심 아이디어
압축률 $\alpha$를 loss에 대해 미분 가능하게 만들어, layer마다 다른 최적 압축률을 gradient로 자동 학습한다. 그리고 pruning과 merging을 하나의 연산 $f_c = f_m \circ f_p$ 로 통합 — 먼저 덜 중요한 토큰을 자르고(prune), 남은 것 중 비슷한 토큰을 합친다(merge).
방법 — Differentiable Discrete Proxy (DDP)
핵심 장치는 DDP다. 모든 이미지에서 top-K 중요 토큰을 보존하게 해 (1) batch 병렬 계산이 가능하고 (2) 원본 성능을 최대한 유지한다. 두 부분으로 구성된다.
1) Token Sorting — 중요도로 정렬
EViT처럼 class attention $A_c = \mathrm{Softmax}(q_c K^\top/\sqrt{D})$ 을 중요도로 쓴다 (CLS가 각 토큰에 주는 attention). 추가 연산이 없어 기본값으로 채택. (ablation: ATS의 $A_c\cdot\lVert V\rVert$ 과 거의 동급)
- 정렬 후 하위 $N\alpha_p$개를 prune, 그다음 남은 토큰 중 $N(\alpha_m-\alpha_p)$개를 cosine similarity로 비슷한 토큰과 merge(평균). 이 sorting→pruning→merging 파이프라인으로 두 기법을 자연스럽게 합친다.
2) Compression Rate Re-parameterization — α를 미분 가능하게
압축률을 직접 학습하면 이미지마다 버리는 개수가 달라져 병렬화가 안 된다. 그래서 이산 후보율 집합 $C={C_1,…,C_N},\ C_k=\frac{k-1}{N}$ 에 학습 가능한 확률 $\rho_k$ 를 부여하고:
\[\alpha = \sum_{k=1}^{N} C_k \rho_k\]- 토큰별 압축 확률 $\pi_k$(누적합, $\pi_1=0$으로 최중요 토큰은 항상 보존)을 만들고, $\pi_k$와 $\alpha$를 비교해 0–1 mask로 변환. 덜 중요한 토큰일수록 압축 확률이 크도록 정렬과 일관됨.
- block마다 pruning용·merging용 mask 2개를 독립적으로 학습. gradient가 끊기지 않도록 토큰 drop을 attention masking(DynamicViT 방식)으로 대체하고 STE로 역전파 → loss의 gradient가 $\rho_k$ 로 흘러 $\alpha$ 가 학습된다.
3) Budget(FLOPs) 제약 학습
\[\mathcal{L} = \mathcal{L}_{cls} + \lambda_f\,(\,\mathcal{F}(\alpha_p,\alpha_m)-T\,)^2\]- 목표 FLOPs $T$ 하나만 주면 끝. FLOPs는 압축률의 미분 가능한 함수로 표현. ($\lambda_f=5$)
- 매우 효율적: pretrained는 고정하고 $\rho$만 학습 → 3 epoch(ViT-B 2.7 GPU-hour) 면 수렴, 1,000장이면 충분, 추가 파라미터·FLOPs는 무시할 수준. latency·power 같은 하드웨어 지표로도 확장 가능.
결과
- 학습 없이 SOTA: 학습된 압축률을 pretrained 모델에 바로 적용해도, 학습이 필요한 기존 방법과 동급 이상. 예) off-the-shelf DeiT-S 79.58% > 학습한 EViT 79.50%·ToMe 79.49%.
- ViT-H (MAE): FLOPs 40%↓, throughput 50%↑, 정확도 −0.16%(fine-tuning 없이).
- 선택적 fine-tune(†) 30 epoch이면 추가 향상 (예: DeiT-S 79.58→79.83).
- 압축 스케줄도 거의 최적: 10,000개 랜덤 스케줄과 비교해 pruning-only·merging-only·둘 다 모두에서 최상위. 특히 FLOPs를 강하게 줄일수록 격차가 커짐.
Ablation. 순서가 중요하다 — prune→merge(81.50) > merge→prune(81.18) > merge-only(81.14) > prune-only(80.83). 배경은 pruning으로 지우고, 전경의 덜 중요한 토큰은 merging으로 합치는 조합이 최선임을 보여준다.
한 줄 정리 & 의의
- Pruning + Merging을 통합한 첫 Hybrid이자, 압축률 자체를 학습 대상으로 끌어올린 전환점. “버릴 건 버리고(prune), 남길 것 중 비슷한 건 합친다(merge)”를 미분 가능한 비율로 자동 결정.
- ToMe(고정 r)·EViT(고정 keep ratio)의 hand-crafted schedule을 gradient로 대체 — Adaptive Sparse ViT가 threshold를 학습했다면, DiffRate는 layer별 압축률 자체를 학습한다.
- 한계 / 이후. 표준 ViT 분류 중심(계층형엔 uncompression 모듈 필요). 이후 흐름은 더 정교한 merging(정보·크기까지 고려한 다기준 fusion → MCTF)이나, 학습 없이 전역 신호로 점수를 매기는 방향(Zero-TPrune)으로 갈라진다. → Token Reduction 개요