[STAR] Synergistic Patch Pruning for Vision Transformer: Unifying Intra- & Inter-Layer Patch Importance

Pruning ICLR 2024

Yuyao Zhang, Lan Wei, Nikolaos M. Freris · University of Science and Technology of China (USTC)

OpenReview

한 줄 요약. 기존 patch pruning은 한 layer 안의 중요도([CLS] attention)만 보고 잘라, 얕은 layer에서 버린 패치가 깊은 layer에선 중요해지는 경우를 놓친다. STAR는 online intra-layer(각 layer [CLS] attention)와 offline inter-layer(LRP로 전 과정 기여도) 두 중요도를 KL 발산 최소화로 융합(닫힌 해 = 두 점수의 가중 거듭제곱 곱)하고, layer별 retention rate를 패치 유사도(ACS)로 자동 결정한다. DeiT-S를 손실 0.4%로 FLOPs 43.5%↓, throughput 4,080 img/s.

배경

patch(=토큰) pruning은 한 번 버린 패치를 되살릴 수 없다. 그런데 대부분의 방법은 현재 layer 안에서의 중요도(주로 [CLS] attention)만으로 자른다 — 이게 두 가지 문제를 낳는다.

  • intra-layer 점수만으론 위험하다. [CLS] attention은 “이 layer에서 [CLS]에 얼마나 기여하나”만 잰다. 얕은 layer에서 [CLS]에 덜 기여하는 패치가 깊은 layer에선 핵심일 수 있는데, 얕은 곳에서 미리 잘리면 복구 불가. ablation에서 intra-layer만 쓰면 같은 FLOPs에서 Top-1이 2.4%p 더 떨어진다.
  • retention rate를 손으로 정한다. layer마다 몇 %를 남길지 경험적으로 튜닝하거나(데이터 의존), sensitivity 분석·반복 탐색(연산 부담)에 의존했다.
Figure 1. DeiT-Base에서 LRP로 본 패치 중요도. (b) 3번째 layer에선 빨간 사각형 패치가 안 중요(파랑)하지만 (c) 12번째 layer에선 중요(빨강)해진다 — 패치 중요도는 layer마다 변하므로, 얕은 layer 평가만으로 자르면 깊은 layer에 필요한 패치를 잘못 버린다.

한 layer만 보지 말고 전 과정에서의 기여도(inter-layer)를 같이 보면, 깊은 layer에 필요한 패치를 지킬 수 있지 않을까?

핵심 아이디어

두 종류의 중요도를 상호보완으로 융합한다.

  • Intra-layer (online) — 각 layer의 [CLS] attention. 입력에 적응적이지만 단일 layer 시야.
  • Inter-layer (offline)LRP(Layer-wise Relevance Propagation) 로, 각 패치가 마지막 [CLS]까지 미치는 기여도를 학습 데이터에서 미리 계산. 전역 시야를 주지만 입력엔 비적응(통계적).

intra는 inter의 “입력 적응성”을, inter는 intra의 “전역 시야”를 메운다. 추가로 layer별 retention rate를 패치 유사도로 자동 정한다.

Figure 2. (d) STAR 개요. (b) inter-layer 중요도는 학습셋에서 LRP backprop으로 offline 계산, (c) intra-layer 중요도는 추론 중 각 layer [CLS] attention으로 online 계산. 둘을 S_intra^α × S_inter^(1−α)로 융합해 top-k 패치만 남긴다.

방법

Intra-layer 중요도

[CLS] attention을 value의 선형결합 계수로 보면, 그 계수가 곧 각 패치의 분류 기여도다. head 평균으로:

\[S^{\text{intra}}_{(l,i)} = \frac{1}{H}\sum_{h=1}^{H} \text{attn}^h_{(l,i)}\]

추론 중 순차적으로 계산되고, 잘린 패치는 이후 모든 layer에서 빠져 연산이 바로 준다.

Inter-layer 중요도 (LRP)

[CLS] 전용으로 변형한 LRP로, 각 패치가 마지막 layer [CLS] 에 주는 “to-go” 기여도를 구한다(positive relevance만 사용, 픽셀이 아닌 패치 단위 집계, 다중 클래스 대응). 학습셋에서 offline(T=10 라운드, top-γ=6 클래스 평균)으로 미리 계산 → 추론 비용 0.

KL 융합 — 닫힌 해

두 점수 모두 양수라 정규화하면 확률분포(pmf) $p^1$(intra), $p^2$(inter)로 본다. 두 분포에 대한 KL 발산의 가중 평균을 최소화하는 분포를 찾는다:

\[\min_{p}\; \alpha\,\text{KL}(p \parallel p^1) + (1-\alpha)\,\text{KL}(p \parallel p^2), \quad \text{s.t. } p_i \ge 0,\ \textstyle\sum_i p_i = 1\]

닫힌 해는 두 점수의 가중 거듭제곱 곱(가중 기하평균):

\[S^{\text{fused}}_{(l,i)} \propto \big(S^{\text{intra}}_{(l,i)}\big)^{\alpha} \times \big(S^{\text{inter}}_{(l,i)}\big)^{1-\alpha}\]

이 융합 점수로 각 layer에서 top-k만 남긴다([CLS]는 제외). $\alpha$는 DeiT 0.15, LV-ViT 0.3이 최적이며 $\alpha\in[0.15,0.3]$에서 안정적 — $\alpha\in{0,1}$(한쪽만)이 가장 나쁘다.

Retention rate 자동 결정 (ACS)

깊은 layer일수록 패치 유사도가 커져 중복이 많다는 관찰에서, 각 layer의 평균 코사인 유사도 ACS를 offline(500장)으로 잰다. retention rate를 ACS에 선형으로 묶되, 두 경계조건(가장 민감한 첫 layer는 $p_1=1$, 마지막은 $p_L=\rho$)으로 풀면 단일 하이퍼파라미터 $\rho$ 하나로:

\[p_l = \frac{\rho-1}{\text{ACS}_L - \text{ACS}_1}\text{ACS}_l + \frac{\text{ACS}_L - \rho\,\text{ACS}_1}{\text{ACS}_L - \text{ACS}_1}\]

단조성(ACS↑이면 $p_l$↓, $\rho$↑이면 $p_l$↑)이 있어 목표 FLOPs를 이분탐색으로 쉽게 맞춘다.

  • a-STAR (adaptive) — 위 $p_l$ 를 고정 비율이 아니라 가중 백분위(percentile) 로 써서, 융합 점수 분포에 따라 입력별로 다른 개수를 남긴다(CDF가 $p_l$ 넘는 지점까지). STAR보다 압축·throughput이 항상 같거나 높고(증명), 정확도는 STAR가 약간 우위.

결과

ImageNet, DeiT-T/S/B·LV-ViT-S/M. 사전학습 모델에 적용 후 120 epoch hard distillation fine-tune.

Table 1. STAR(static)·a-STAR(adaptive) 결과. DeiT·LV-ViT 전반에서 같은 FLOPs 대비 더 높은 정확도와 throughput. a-STAR는 압축·속도↑, STAR는 정확도↑.
  • DeiT-Base — FLOPs 42.0%↓에 Top-1 −0.3%(throughput +69.5%). 47.2%↓에서도 Top-1 −0.8%.
  • DeiT-Small — FLOPs 43.5%↓/50%↓에 Top-1 −0.4%/−0.6%, throughput +83.9%/+104.5%. 헤드라인: pruned DeiT-S가 4,080 img/s로 (압축 안 한) DeiT-Tiny보다 60%+ 빠르면서 Top-1 +7.0%p.
  • online이라 throughput 우위 — SCOP처럼 FLOPs 감축이 같아도, STAR는 online pruning이라 실제 throughput이 더 높다.
  • 요소 분해(fine-tune 전): intra-only는 급락, ACS 자동 스케줄이 수작업 대비 +0.3%p, KL 융합이 intra-only 대비 Top-1 +0.6%p. a-STAR가 추가 +0.2%p.
  • DiffRate 대비 — 같은 정확도에서 STAR가 FLOPs 0.3↓·throughput +14.3%, a-STAR는 +24.4%(부록).

한 줄 정리 & 의의

  • patch pruning의 약점인 “한 layer 시야”를, inter-layer 중요도(LRP) 를 더해 보강한 Pruning 계열. 두 점수를 KL 융합의 닫힌 해(가중 기하평균) 로 결합하는 게 깔끔한 핵심.
  • retention rate 자동화 — 압축률을 미분가능하게 학습한 DiffRate와 달리, STAR는 패치 유사도(ACS) + 단일 $\rho$ 라는 무학습 규칙으로 layer별 비율을 정한다(재학습 없이 압축 목표 조정).
  • 위치. “[CLS] attention으로 자른다”는 점은 EViT·Evo-ViT와 같지만, 거기에 모델 해석(LRP) 기반 전역 중요도를 처음으로 결합했다. Zero-TPrune이 attention graph로 전역 신호를 학습 없이 뽑았다면, STAR는 LRP라는 다른 해석 도구로 전역 기여도를 (offline) 뽑아 intra와 섞는다. → Token Reduction 개요