[DTEM] Learning to Merge Tokens via Decoupled Embedding for Efficient Vision Transformers

Merging NeurIPS 2024

Dong Hoon Lee, Seunghoon Hong · KAIST

arXiv GitHub

한 줄 요약. ToMe 같은 merging은 합칠 토큰을 ViT의 중간 feature(key·token embedding)로 고른다 — 그런데 이 feature는 본래 contextual encoding용이라 merging에 딱 맞지 않고, 개선하려면 전체를 end-to-end 학습해야 한다. DTEM은 merging 전용 임베딩을 ViT forward pass와 분리(decoupled)해 따로 학습한다. 핵심 장치는 미분 가능하게 완화한(soft) grouping·merging — 이걸로 그 작은 임베딩 모듈만(파라미터 <1%) 학습하면 되고, 사전학습 모델을 얼린 채 모듈만(modular) 학습할 수 있다. DeiT-S에서 FLOPs 37.2%↓에 79.85%를 유지.

배경

이 노트는 ToMe 계열 merging의 한 가지 구조적 약점을 파고든다 — merging은 “어떤 토큰을 합칠지” 정하는 grouping이 성능을 좌우하는데, 기존 방법은 이 grouping을 ViT의 중간 feature(token embedding $X$ 또는 key $K$)의 유사도로만 결정한다.

  • feature가 merging 전용이 아니다. 같은 중간 feature가 contextual encoding(본업)과 merging(부업)을 겸한다. 역할이 다른데 feature를 공유하니, merging에 최적화된 별도 feature를 쓰는 것보다 손해다.
  • 개선하려면 전체를 학습해야 한다. grouping은 clustering·bipartite matching 같은 이산(discrete) 연산이라 gradient가 흐르지 않는다(Eq 1). merging을 나아지게 하는 유일한 길은 merging 연산(Eq 2)을 거쳐 중간 feature $X$를 역전파로 갱신하는 것 → 결국 네트워크 전체 end-to-end 학습이 필요해, off-the-shelf 활용이 어렵고 “merging에 좋은 feature”와 “task에 좋은 feature”가 충돌한다.
Figure 1. (a) 기존 token merging: grouping이 이산 연산이라 gradient가 끊겨(빨간 X), merging 연산을 통해 중간 feature로만 학습이 흐른다. (b) DTEM: ViT forward와 분리된 decoupled embedding module이 grouping을 담당하고, soft grouping·merging으로 연속 완화해 임베딩 모듈로 gradient가 직접 흐른다(초록). 추론 시엔 hard 연산으로 되돌려 ToMe와 동일하게 동작.

grouping에 쓸 feature를 ViT 본체에서 떼어내 merging 전용으로 따로 학습하면, 본체를 건드리지 않고도 merging을 개선할 수 있지 않을까?

핵심 아이디어

토큰별 decoupled embedding $Z = f(X; \phi)$ 를 만든다. ViT forward pass와 분리된 가벼운 모듈로, 오직 grouping의 유사도 $s_{ij} = \cos(z_i, z_j)$ 계산에만 쓰인다. grouping이 유사도에 전적으로 의존하므로, 이 임베딩만 학습해도 merging 정책을 직접 조절할 수 있고, ViT 파라미터는 건드리지 않는다.

문제는 grouping이 이산 연산이라 $Z$를 학습할 길이 없다는 것. DTEM은 grouping·merging을 연속 완화(continuous relaxation) 해 미분 가능하게 만든다 — 학습 때는 soft 버전으로 $Z$를 배우고, 추론 때는 hard 버전으로 되돌려 ToMe의 BSM과 동일하게 빠르게 동작한다.

방법

1) Decoupled Embedding Module

각 transformer block의 self-attention과 FFN 사이(ToMe와 같은 위치)에 토큰별 projection을 끼운다: $z_i = f(x_i; \phi_l)$, 출력 차원 $d’ \ll d$. self-attention과 병렬 계산돼 직렬화 오버헤드가 없다. ablation상 hidden layer 없는 단순 affine 변환(linear) 만으로 충분해, 추가 파라미터는 1% 미만이라 적은 데이터로도 학습된다(ViT-S/B는 출력 64차원).

2) Soft Grouping — 미분 가능한 top-r

target은 ToMe의 BSM(Bipartite Soft Matching): 토큰을 두 집합 $\mathbb{A}, \mathbb{B}$로 나누고 A의 각 노드가 B로 향하는 가장 비슷한 엣지를 골라, 가장 유사한 r쌍을 합친다. 이를 연속 인접행렬 $\tilde{E} \in [0,1]^{ \mathbb{A} \times \mathbb{B} }$ 로 근사하되 두 조건을 만족해야 한다 — ① 값이 큰 r개의 엣지를 흉내내고, ② A의 각 노드는 엣지 최대 1개($\sum_j \tilde e_{ij} \le 1$).

미분 가능 top-k(reparameterizable subset sampling)를 변형해, $S^1=S$에서 시작해 $t=1,\dots,r$ 반복:

\[A^t = \sigma(S^t/\tau), \qquad s^{t+1}_{ij} = s^t_{ij} + \log\!\Big(1 - \sum_{j\in\mathbb{B}} a^t_{ij}\Big)\]

$\sigma$는 global softmax(soft-argmax), $\tau$는 완화를 조절하는 온도. 매 스텝 가장 비슷한 한 쌍을 soft하게 뽑고(Eq 6), 이미 뽑힌 A 노드의 남은 엣지는 억제한다(Eq 7 — $\tau\to0$이면 $\log(1-\cdot)\to-\infty$로 노드당 1회 선택 보장). 누적 $A^* = \sum_t A^t$ 를 clipping(max + stop-gradient)해 유효한 $\tilde E$ 를 얻는다.

3) Soft Merging — 비대칭 갱신

$\tilde e_{ij}$ 만큼 A의 토큰 $i$가 B의 토큰 $j$로 부분적으로 합쳐지도록 설계한다. B 토큰은 feature와 유효 크기 $m$을 모두 갱신(ToMe의 proportional attention과 호환되는 size 추적):

\[\hat x_j \leftarrow \frac{m_j x_j + \sum_{i\in\mathbb{A}} \tilde e_{ij} m_i x_i}{m_j + \sum_{i\in\mathbb{A}} \tilde e_{ij} m_i}, \qquad \hat m_j \leftarrow m_j + \sum_{i\in\mathbb{A}} \tilde e_{ij} m_i\]

A 토큰은 feature는 그대로 두고 유효 크기만 줄인다: $\hat m_i \leftarrow m_i(1 - \sum_j \tilde e_{ij})$. 이산 $E’$에서는 엣지가 있는 A 토큰의 크기가 0이 되어 이후 merging에서 자연히 빠진다 — 즉 학습 땐 토큰 수를 줄이지 않고 연속적으로 흡수만 시뮬레이션하고, 추론 땐 실제로 줄인다.

흥미롭게도 높은 reduction rate $r$로 학습한 임베딩이 더 낮은 rate $r’ \le r$에도 잘 일반화된다(top-k가 r쌍을 정렬하는 과정에 작은 rate의 정렬이 포함되므로). 덕분에 모델 하나로 여러 압축률을 커버한다.

학습/추론

  • Modular — ViT는 얼리고 임베딩 모듈만 학습. off-the-shelf 모델을 그대로 쓰면서 merging 정책만 task에 맞춘다.
  • End-to-end — ViT까지 함께 학습하되, soft 연산은 토큰 수를 안 줄여 비싸므로 번갈아 갱신(ViT 갱신 땐 discretized 연산으로 실제 토큰 감소, 임베딩 갱신 땐 soft 연산). 임베딩이 작아 빨리 수렴하므로 ViT 갱신을 훨씬 자주(약 9:1) 한다.
  • 추론 — soft를 discretize → BSM으로 빠르게. loss는 task loss만 쓴다(dTPS의 distillation·KL 불필요).

결과

ImageNet 분류 + COCO captioning + ADE20K segmentation. 백본은 DeiT-S/B·MAE-B/L·LV-ViT·GIT·Segmenter.

Table 1. 사전학습 모델을 얼린 modular 설정(ViT 파라미터 불변). 35%·50% FLOPs 감축 모두에서 DTEM이 ToMe·EViT를 일관되게 상회 — 특히 공격적인 50% 구간에서 격차가 커진다(MAE-B에서 ToMe 78.88 → DTEM 80.37).
  • modular만으로 SOTA — 35% 감축에서 ToMe 대비 +0.15~0.47%p, 50% 감축에서 +0.47~1.64%p(추가 FLOPs는 1% 미만). ViT를 전혀 안 건드리고 임베딩 모듈만 학습한 결과다.
  • end-to-end 비교 — 모델 하나로 여러 압축률을 내며 DeiT-S에서 baseline 대비 +0.12~0.2%p. 헤드라인: DeiT-S FLOPs 37.2%↓에 79.85% 유지(표 3, ToMe·ATS·eTPS 등 prior art 상회).
  • 데이터·학습 효율 — modular 덕에 전체의 0.31%(4000장) 으로도 +0.44%p, 1 epoch 만에 빠르게 수렴.
  • 다른 task로 확장 — captioning(GIT) CIDEr +2.3~6.0, segmentation(Segmenter) mIoU +0.32~1.3.
Figure 5. 같은 색=같은 그룹(r=16 → 최종 11토큰). DTEM은 배경을 우선 합쳐 전경 객체에 더 많은 토큰을 배분한다 — ToMe가 배경에 토큰을 더 쓰는 것과 대비된다. decoupled embedding이 merging 전용 feature를 학습했다는 정성적 증거.
  • decoupled가 핵심 — soft 연산을 ToMe의 key에 그냥 적용하면 오히려 성능이 떨어지고(표 6), decoupled embedding을 넣어야 개선된다. 학습 후 임베딩과 key의 유사도 순위 상관(Kendall)이 낮아져, merging 전용 feature가 실제로 분화됐음을 보인다(표 7).

한 줄 정리 & 의의

  • merging의 “grouping을 무엇으로 하느냐”를 ViT 중간 feature에서 떼어내 전용 임베딩으로 학습한 Merging 계열. 이산 grouping을 연속 완화해 미분 가능하게 만든 덕에, 그 임베딩만(파라미터 <1%) 학습하면 된다.
  • modular 학습이 최대 강점 — 사전학습 모델을 얼린 채 적은 데이터·짧은 epoch으로 merging을 개선한다. 점수 함수를 학습하던 [DynamicViT]·squeeze까지 더한 TPS가 무거운 end-to-end였다면, DTEM은 “merging 정책만 가볍게 갈아끼우는” 방향이다.
  • 위치. 추론은 ToMe의 BSM 그대로(=빠름)이되 그 입력 feature를 학습으로 개선한다 → “BSM은 좋은데 feature가 아쉽다”를 정조준. block별 압축률을 학습한 DiffRate와는 직교적(개수 vs feature)이라 결합 여지가 있다고 저자도 짚는다. → Token Reduction 개요