[ATS] Adaptive Token Sampling for Efficient Vision Transformers
Pruning ECCV 2022
Mohsen Fayyaz, Soroush Abbasi Koohpayegani, Farnoush Rezaei Jafari, Sunando Sengupta, Hamid Reza Vaezi Joze, Eric Sommerlade, Hamed Pirsiavash, Jürgen Gall
한 줄 요약. 이미지마다 토큰 수를 다르게 — 별도 파라미터·학습 모듈 없이, 이미 계산된 attention 점수의 분포(CDF)에서 inverse transform sampling으로 중요한 토큰을 뽑는다. top-k와 달리 흩어진 중요 영역을 덜 놓침.
배경
DynamicViT·EViT 같은 방법은 보통 stage마다 고정 keep ratio를 쓴다. 하지만 이미지마다 복잡도가 다르다 — 객체가 단순하고 배경이 넓으면 토큰을 적게 써도 되고, 복잡하면 많이 써야 한다.
모든 이미지에 같은 비율을 강제하는 건 비효율적이다. 입력에 따라 토큰 수가 달라져야 한다.
또 하나의 불편함: 기존 방법들은 대개 학습 가능한 추가 모듈(prediction module 등)이나 별도 학습을 필요로 한다.
핵심 아이디어
Parameter-free + adaptive. 새 파라미터를 전혀 추가하지 않고, self-attention 안에서 이미 계산되는 정보를 재활용해 입력마다 다른 개수의 토큰을 뽑는다.
- top-k: 개별 점수 상위 k개를 자른다 → 점수가 여러 영역에 분산되면 일부 중요 영역을 통째로 놓칠 수 있음.
- ATS: 점수 분포 전체를 보고 샘플링 → 대표 토큰을 고르게 뽑아 분산된 중요 영역을 보존하고, 고정 k가 아니라 개수 자체가 입력에 따라 적응.
방법
1) Significance score — 어떤 토큰이 중요한가
CLS token이 각 토큰에 주는 attention에 value 벡터의 크기를 곱해 중요도를 잰다:
\[S_i \;\propto\; A_{1,i}\,\lVert V_i \rVert\]- A_{1,i} = CLS → i번째 토큰 attention (출력에 i가 기여하는 정도).
- ‖V_i‖로 가중 → value가 큰(실제 출력에 영향 큰) 토큰을 반영. 별도 학습 없이 attention·value만으로 계산.
2) Inverse Transform Sampling — 적응적 선택
- significance score를 정규화해 누적분포(CDF) 를 만들고, 그 위에서 샘플링해 토큰을 고른다.
- 샘플은 점수가 높은 영역에서 더 자주 뽑히되, 중복을 제거하면 남는 토큰 수가 자연히 입력마다 달라진다 → adaptive token count.
- 선택된 토큰 행만 모아 A’를 만들고, O = A’·V 로 입력 토큰을 부드럽게 downsample. (hard drop이 아니라 soft)
3) Plug-in & 학습
- 새 파라미터가 없어 어떤 ViT에도 꽂을 수 있고, pretrained 모델에 붙여 fine-tune 가능. 여러 stage에 넣어 점진적으로 토큰을 줄인다.
결과
- DeiT-S: 79.8 → ATS+DeiT-S 79.7, GFLOPs −37% (정확도 −0.1%p).
- DeiT-Ti: −30% GFLOPs로 72.2 → 72.1.
- 같은 GFLOPs에서 단순히 embedding 차원을 줄인 DeiT 변형(emb258/emb318)보다 정확도가 더 높음 → 토큰을 적응적으로 줄이는 게 모델 폭 축소보다 효율적.
한 줄 정리 & 의의
- “고정 keep ratio → 입력 적응적 토큰 수”, “추가 파라미터 없이 attention 재활용”, “top-k → 분포 기반 sampling” 세 가지가 ATS의 정체성.
- DynamicViT(별도 prediction module + distillation)·EViT(CLS top-k + fusion)와 달리, ATS는 모듈을 학습하지 않고 샘플링만으로 adaptive pruning을 달성한다.
- 한계 / 이후. 분류 중심. 이후 연구는 학습형 threshold(Adaptive Sparse ViT)나 merging 계열로 확장. → Token Reduction 개요