[STAR] Synergistic Patch Pruning for Vision Transformer: Unifying Intra- & Inter-Layer Patch Importance
Yuyao Zhang, Lan Wei, Nikolaos M. Freris · University of Science and Technology of China (USTC)
한 줄 요약. 기존 patch pruning은 한 layer 안의 중요도([CLS] attention)만 보고 잘라, 얕은 layer에서 버린 패치가 깊은 layer에선 중요해지는 경우를 놓친다. STAR는 online intra-layer(각 layer [CLS] attention)와 offline inter-layer(LRP로 전 과정 기여도) 두 중요도를 KL 발산 최소화로 융합(닫힌 해 = 두 점수의 가중 거듭제곱 곱)하고, layer별 retention rate를 패치 유사도(ACS)로 자동 결정한다. DeiT-S를 손실 0.4%로 FLOPs 43.5%↓, throughput 4,080 img/s.
배경
patch(=토큰) pruning은 한 번 버린 패치를 되살릴 수 없다. 그런데 대부분의 방법은 현재 layer 안에서의 중요도(주로 [CLS] attention)만으로 자른다 — 이게 두 가지 문제를 낳는다.
- intra-layer 점수만으론 위험하다. [CLS] attention은 “이 layer에서 [CLS]에 얼마나 기여하나”만 잰다. 얕은 layer에서 [CLS]에 덜 기여하는 패치가 깊은 layer에선 핵심일 수 있는데, 얕은 곳에서 미리 잘리면 복구 불가. ablation에서 intra-layer만 쓰면 같은 FLOPs에서 Top-1이 2.4%p 더 떨어진다.
- retention rate를 손으로 정한다. layer마다 몇 %를 남길지 경험적으로 튜닝하거나(데이터 의존), sensitivity 분석·반복 탐색(연산 부담)에 의존했다.
한 layer만 보지 말고 전 과정에서의 기여도(inter-layer)를 같이 보면, 깊은 layer에 필요한 패치를 지킬 수 있지 않을까?
핵심 아이디어
두 종류의 중요도를 상호보완으로 융합한다.
- Intra-layer (online) — 각 layer의 [CLS] attention. 입력에 적응적이지만 단일 layer 시야.
- Inter-layer (offline) — LRP(Layer-wise Relevance Propagation) 로, 각 패치가 마지막 [CLS]까지 미치는 기여도를 학습 데이터에서 미리 계산. 전역 시야를 주지만 입력엔 비적응(통계적).
intra는 inter의 “입력 적응성”을, inter는 intra의 “전역 시야”를 메운다. 추가로 layer별 retention rate를 패치 유사도로 자동 정한다.
방법
Intra-layer 중요도
[CLS] attention을 value의 선형결합 계수로 보면, 그 계수가 곧 각 패치의 분류 기여도다. head 평균으로:
\[S^{\text{intra}}_{(l,i)} = \frac{1}{H}\sum_{h=1}^{H} \text{attn}^h_{(l,i)}\]추론 중 순차적으로 계산되고, 잘린 패치는 이후 모든 layer에서 빠져 연산이 바로 준다.
Inter-layer 중요도 (LRP)
[CLS] 전용으로 변형한 LRP로, 각 패치가 마지막 layer [CLS] 에 주는 “to-go” 기여도를 구한다(positive relevance만 사용, 픽셀이 아닌 패치 단위 집계, 다중 클래스 대응). 학습셋에서 offline(T=10 라운드, top-γ=6 클래스 평균)으로 미리 계산 → 추론 비용 0.
KL 융합 — 닫힌 해
두 점수 모두 양수라 정규화하면 확률분포(pmf) $p^1$(intra), $p^2$(inter)로 본다. 두 분포에 대한 KL 발산의 가중 평균을 최소화하는 분포를 찾는다:
\[\min_{p}\; \alpha\,\text{KL}(p \parallel p^1) + (1-\alpha)\,\text{KL}(p \parallel p^2), \quad \text{s.t. } p_i \ge 0,\ \textstyle\sum_i p_i = 1\]닫힌 해는 두 점수의 가중 거듭제곱 곱(가중 기하평균):
\[S^{\text{fused}}_{(l,i)} \propto \big(S^{\text{intra}}_{(l,i)}\big)^{\alpha} \times \big(S^{\text{inter}}_{(l,i)}\big)^{1-\alpha}\]이 융합 점수로 각 layer에서 top-k만 남긴다([CLS]는 제외). $\alpha$는 DeiT 0.15, LV-ViT 0.3이 최적이며 $\alpha\in[0.15,0.3]$에서 안정적 — $\alpha\in{0,1}$(한쪽만)이 가장 나쁘다.
Retention rate 자동 결정 (ACS)
깊은 layer일수록 패치 유사도가 커져 중복이 많다는 관찰에서, 각 layer의 평균 코사인 유사도 ACS를 offline(500장)으로 잰다. retention rate를 ACS에 선형으로 묶되, 두 경계조건(가장 민감한 첫 layer는 $p_1=1$, 마지막은 $p_L=\rho$)으로 풀면 단일 하이퍼파라미터 $\rho$ 하나로:
\[p_l = \frac{\rho-1}{\text{ACS}_L - \text{ACS}_1}\text{ACS}_l + \frac{\text{ACS}_L - \rho\,\text{ACS}_1}{\text{ACS}_L - \text{ACS}_1}\]단조성(ACS↑이면 $p_l$↓, $\rho$↑이면 $p_l$↑)이 있어 목표 FLOPs를 이분탐색으로 쉽게 맞춘다.
- a-STAR (adaptive) — 위 $p_l$ 를 고정 비율이 아니라 가중 백분위(percentile) 로 써서, 융합 점수 분포에 따라 입력별로 다른 개수를 남긴다(CDF가 $p_l$ 넘는 지점까지). STAR보다 압축·throughput이 항상 같거나 높고(증명), 정확도는 STAR가 약간 우위.
결과
ImageNet, DeiT-T/S/B·LV-ViT-S/M. 사전학습 모델에 적용 후 120 epoch hard distillation fine-tune.
- DeiT-Base — FLOPs 42.0%↓에 Top-1 −0.3%(throughput +69.5%). 47.2%↓에서도 Top-1 −0.8%.
- DeiT-Small — FLOPs 43.5%↓/50%↓에 Top-1 −0.4%/−0.6%, throughput +83.9%/+104.5%. 헤드라인: pruned DeiT-S가 4,080 img/s로 (압축 안 한) DeiT-Tiny보다 60%+ 빠르면서 Top-1 +7.0%p.
- online이라 throughput 우위 — SCOP처럼 FLOPs 감축이 같아도, STAR는 online pruning이라 실제 throughput이 더 높다.
- 요소 분해(fine-tune 전): intra-only는 급락, ACS 자동 스케줄이 수작업 대비 +0.3%p, KL 융합이 intra-only 대비 Top-1 +0.6%p. a-STAR가 추가 +0.2%p.
- DiffRate 대비 — 같은 정확도에서 STAR가 FLOPs 0.3↓·throughput +14.3%, a-STAR는 +24.4%(부록).
한 줄 정리 & 의의
- patch pruning의 약점인 “한 layer 시야”를, inter-layer 중요도(LRP) 를 더해 보강한 Pruning 계열. 두 점수를 KL 융합의 닫힌 해(가중 기하평균) 로 결합하는 게 깔끔한 핵심.
- retention rate 자동화 — 압축률을 미분가능하게 학습한 DiffRate와 달리, STAR는 패치 유사도(ACS) + 단일 $\rho$ 라는 무학습 규칙으로 layer별 비율을 정한다(재학습 없이 압축 목표 조정).
- 위치. “[CLS] attention으로 자른다”는 점은 EViT·Evo-ViT와 같지만, 거기에 모델 해석(LRP) 기반 전역 중요도를 처음으로 결합했다. Zero-TPrune이 attention graph로 전역 신호를 학습 없이 뽑았다면, STAR는 LRP라는 다른 해석 도구로 전역 기여도를 (offline) 뽑아 intra와 섞는다. → Token Reduction 개요