[ATS] Adaptive Token Sampling for Efficient Vision Transformers

Pruning ECCV 2022

Mohsen Fayyaz, Soroush Abbasi Koohpayegani, Farnoush Rezaei Jafari, Sunando Sengupta, Hamid Reza Vaezi Joze, Eric Sommerlade, Hamed Pirsiavash, Jürgen Gall

arXiv GitHub

한 줄 요약. 이미지마다 토큰 수를 다르게 — 별도 파라미터·학습 모듈 없이, 이미 계산된 attention 점수의 분포(CDF)에서 inverse transform sampling으로 중요한 토큰을 뽑는다. top-k와 달리 흩어진 중요 영역을 덜 놓침.

배경

DynamicViT·EViT 같은 방법은 보통 stage마다 고정 keep ratio를 쓴다. 하지만 이미지마다 복잡도가 다르다 — 객체가 단순하고 배경이 넓으면 토큰을 적게 써도 되고, 복잡하면 많이 써야 한다.

모든 이미지에 같은 비율을 강제하는 건 비효율적이다. 입력에 따라 토큰 수가 달라져야 한다.

또 하나의 불편함: 기존 방법들은 대개 학습 가능한 추가 모듈(prediction module 등)이나 별도 학습을 필요로 한다.

핵심 아이디어

Parameter-free + adaptive. 새 파라미터를 전혀 추가하지 않고, self-attention 안에서 이미 계산되는 정보를 재활용해 입력마다 다른 개수의 토큰을 뽑는다.

  • top-k: 개별 점수 상위 k개를 자른다 → 점수가 여러 영역에 분산되면 일부 중요 영역을 통째로 놓칠 수 있음.
  • ATS: 점수 분포 전체를 보고 샘플링 → 대표 토큰을 고르게 뽑아 분산된 중요 영역을 보존하고, 고정 k가 아니라 개수 자체가 입력에 따라 적응.
Figure 1. ATS 모듈 — CLS의 attention(A의 첫 행)을 value 크기로 가중해 significance score를 만들고, 그 CDF에서 inverse transform sampling으로 중요한 토큰을 선택, A'·V로 출력 토큰을 soft downsample.

방법

1) Significance score — 어떤 토큰이 중요한가

CLS token이 각 토큰에 주는 attention에 value 벡터의 크기를 곱해 중요도를 잰다:

\[S_i \;\propto\; A_{1,i}\,\lVert V_i \rVert\]
  • A_{1,i} = CLS → i번째 토큰 attention (출력에 i가 기여하는 정도).
  • ‖V_i‖로 가중 → value가 큰(실제 출력에 영향 큰) 토큰을 반영. 별도 학습 없이 attention·value만으로 계산.

2) Inverse Transform Sampling — 적응적 선택

  • significance score를 정규화해 누적분포(CDF) 를 만들고, 그 위에서 샘플링해 토큰을 고른다.
  • 샘플은 점수가 높은 영역에서 더 자주 뽑히되, 중복을 제거하면 남는 토큰 수가 자연히 입력마다 달라진다 → adaptive token count.
  • 선택된 토큰 행만 모아 A’를 만들고, O = A’·V 로 입력 토큰을 부드럽게 downsample. (hard drop이 아니라 soft)

3) Plug-in & 학습

  • 새 파라미터가 없어 어떤 ViT에도 꽂을 수 있고, pretrained 모델에 붙여 fine-tune 가능. 여러 stage에 넣어 점진적으로 토큰을 줄인다.
Figure 2. 다단계 DeiT-S+ATS의 stage별 토큰 샘플링 — 덜 중요한 토큰은 마스킹되고 예측에 기여한 토큰이 샘플됨. 같은 그림에서 Top-K 선택과 비교(ATS가 분산된 중요 영역을 더 잘 보존).

결과

Figure 4. ImageNet 정확도 vs GFLOPs — ATS가 정확도-연산 trade-off에서 SOTA. DeiT-S의 GFLOPs를 37% 줄이면서 정확도 유지.
  • DeiT-S: 79.8 → ATS+DeiT-S 79.7, GFLOPs −37% (정확도 −0.1%p).
  • DeiT-Ti: −30% GFLOPs로 72.2 → 72.1.
  • 같은 GFLOPs에서 단순히 embedding 차원을 줄인 DeiT 변형(emb258/emb318)보다 정확도가 더 높음 → 토큰을 적응적으로 줄이는 게 모델 폭 축소보다 효율적.

한 줄 정리 & 의의

  • 고정 keep ratio → 입력 적응적 토큰 수”, “추가 파라미터 없이 attention 재활용”, “top-k → 분포 기반 sampling” 세 가지가 ATS의 정체성.
  • DynamicViT(별도 prediction module + distillation)·EViT(CLS top-k + fusion)와 달리, ATS는 모듈을 학습하지 않고 샘플링만으로 adaptive pruning을 달성한다.
  • 한계 / 이후. 분류 중심. 이후 연구는 학습형 threshold(Adaptive Sparse ViT)나 merging 계열로 확장. → Token Reduction 개요