[DTEM] Learning to Merge Tokens via Decoupled Embedding for Efficient Vision Transformers
Dong Hoon Lee, Seunghoon Hong · KAIST
한 줄 요약. ToMe 같은 merging은 합칠 토큰을 ViT의 중간 feature(key·token embedding)로 고른다 — 그런데 이 feature는 본래 contextual encoding용이라 merging에 딱 맞지 않고, 개선하려면 전체를 end-to-end 학습해야 한다. DTEM은 merging 전용 임베딩을 ViT forward pass와 분리(decoupled)해 따로 학습한다. 핵심 장치는 미분 가능하게 완화한(soft) grouping·merging — 이걸로 그 작은 임베딩 모듈만(파라미터 <1%) 학습하면 되고, 사전학습 모델을 얼린 채 모듈만(modular) 학습할 수 있다. DeiT-S에서 FLOPs 37.2%↓에 79.85%를 유지.
배경
이 노트는 ToMe 계열 merging의 한 가지 구조적 약점을 파고든다 — merging은 “어떤 토큰을 합칠지” 정하는 grouping이 성능을 좌우하는데, 기존 방법은 이 grouping을 ViT의 중간 feature(token embedding $X$ 또는 key $K$)의 유사도로만 결정한다.
- feature가 merging 전용이 아니다. 같은 중간 feature가 contextual encoding(본업)과 merging(부업)을 겸한다. 역할이 다른데 feature를 공유하니, merging에 최적화된 별도 feature를 쓰는 것보다 손해다.
- 개선하려면 전체를 학습해야 한다. grouping은 clustering·bipartite matching 같은 이산(discrete) 연산이라 gradient가 흐르지 않는다(Eq 1). merging을 나아지게 하는 유일한 길은 merging 연산(Eq 2)을 거쳐 중간 feature $X$를 역전파로 갱신하는 것 → 결국 네트워크 전체 end-to-end 학습이 필요해, off-the-shelf 활용이 어렵고 “merging에 좋은 feature”와 “task에 좋은 feature”가 충돌한다.
grouping에 쓸 feature를 ViT 본체에서 떼어내 merging 전용으로 따로 학습하면, 본체를 건드리지 않고도 merging을 개선할 수 있지 않을까?
핵심 아이디어
토큰별 decoupled embedding $Z = f(X; \phi)$ 를 만든다. ViT forward pass와 분리된 가벼운 모듈로, 오직 grouping의 유사도 $s_{ij} = \cos(z_i, z_j)$ 계산에만 쓰인다. grouping이 유사도에 전적으로 의존하므로, 이 임베딩만 학습해도 merging 정책을 직접 조절할 수 있고, ViT 파라미터는 건드리지 않는다.
문제는 grouping이 이산 연산이라 $Z$를 학습할 길이 없다는 것. DTEM은 grouping·merging을 연속 완화(continuous relaxation) 해 미분 가능하게 만든다 — 학습 때는 soft 버전으로 $Z$를 배우고, 추론 때는 hard 버전으로 되돌려 ToMe의 BSM과 동일하게 빠르게 동작한다.
방법
1) Decoupled Embedding Module
각 transformer block의 self-attention과 FFN 사이(ToMe와 같은 위치)에 토큰별 projection을 끼운다: $z_i = f(x_i; \phi_l)$, 출력 차원 $d’ \ll d$. self-attention과 병렬 계산돼 직렬화 오버헤드가 없다. ablation상 hidden layer 없는 단순 affine 변환(linear) 만으로 충분해, 추가 파라미터는 1% 미만이라 적은 데이터로도 학습된다(ViT-S/B는 출력 64차원).
2) Soft Grouping — 미분 가능한 top-r
| target은 ToMe의 BSM(Bipartite Soft Matching): 토큰을 두 집합 $\mathbb{A}, \mathbb{B}$로 나누고 A의 각 노드가 B로 향하는 가장 비슷한 엣지를 골라, 가장 유사한 r쌍을 합친다. 이를 연속 인접행렬 $\tilde{E} \in [0,1]^{ | \mathbb{A} | \times | \mathbb{B} | }$ 로 근사하되 두 조건을 만족해야 한다 — ① 값이 큰 r개의 엣지를 흉내내고, ② A의 각 노드는 엣지 최대 1개($\sum_j \tilde e_{ij} \le 1$). |
미분 가능 top-k(reparameterizable subset sampling)를 변형해, $S^1=S$에서 시작해 $t=1,\dots,r$ 반복:
\[A^t = \sigma(S^t/\tau), \qquad s^{t+1}_{ij} = s^t_{ij} + \log\!\Big(1 - \sum_{j\in\mathbb{B}} a^t_{ij}\Big)\]$\sigma$는 global softmax(soft-argmax), $\tau$는 완화를 조절하는 온도. 매 스텝 가장 비슷한 한 쌍을 soft하게 뽑고(Eq 6), 이미 뽑힌 A 노드의 남은 엣지는 억제한다(Eq 7 — $\tau\to0$이면 $\log(1-\cdot)\to-\infty$로 노드당 1회 선택 보장). 누적 $A^* = \sum_t A^t$ 를 clipping(max + stop-gradient)해 유효한 $\tilde E$ 를 얻는다.
3) Soft Merging — 비대칭 갱신
$\tilde e_{ij}$ 만큼 A의 토큰 $i$가 B의 토큰 $j$로 부분적으로 합쳐지도록 설계한다. B 토큰은 feature와 유효 크기 $m$을 모두 갱신(ToMe의 proportional attention과 호환되는 size 추적):
\[\hat x_j \leftarrow \frac{m_j x_j + \sum_{i\in\mathbb{A}} \tilde e_{ij} m_i x_i}{m_j + \sum_{i\in\mathbb{A}} \tilde e_{ij} m_i}, \qquad \hat m_j \leftarrow m_j + \sum_{i\in\mathbb{A}} \tilde e_{ij} m_i\]A 토큰은 feature는 그대로 두고 유효 크기만 줄인다: $\hat m_i \leftarrow m_i(1 - \sum_j \tilde e_{ij})$. 이산 $E’$에서는 엣지가 있는 A 토큰의 크기가 0이 되어 이후 merging에서 자연히 빠진다 — 즉 학습 땐 토큰 수를 줄이지 않고 연속적으로 흡수만 시뮬레이션하고, 추론 땐 실제로 줄인다.
흥미롭게도 높은 reduction rate $r$로 학습한 임베딩이 더 낮은 rate $r’ \le r$에도 잘 일반화된다(top-k가 r쌍을 정렬하는 과정에 작은 rate의 정렬이 포함되므로). 덕분에 모델 하나로 여러 압축률을 커버한다.
학습/추론
- Modular — ViT는 얼리고 임베딩 모듈만 학습. off-the-shelf 모델을 그대로 쓰면서 merging 정책만 task에 맞춘다.
- End-to-end — ViT까지 함께 학습하되, soft 연산은 토큰 수를 안 줄여 비싸므로 번갈아 갱신(ViT 갱신 땐 discretized 연산으로 실제 토큰 감소, 임베딩 갱신 땐 soft 연산). 임베딩이 작아 빨리 수렴하므로 ViT 갱신을 훨씬 자주(약 9:1) 한다.
- 추론 — soft를 discretize → BSM으로 빠르게. loss는 task loss만 쓴다(dTPS의 distillation·KL 불필요).
결과
ImageNet 분류 + COCO captioning + ADE20K segmentation. 백본은 DeiT-S/B·MAE-B/L·LV-ViT·GIT·Segmenter.
- modular만으로 SOTA — 35% 감축에서 ToMe 대비 +0.15~0.47%p, 50% 감축에서 +0.47~1.64%p(추가 FLOPs는 1% 미만). ViT를 전혀 안 건드리고 임베딩 모듈만 학습한 결과다.
- end-to-end 비교 — 모델 하나로 여러 압축률을 내며 DeiT-S에서 baseline 대비 +0.12~0.2%p. 헤드라인: DeiT-S FLOPs 37.2%↓에 79.85% 유지(표 3, ToMe·ATS·eTPS 등 prior art 상회).
- 데이터·학습 효율 — modular 덕에 전체의 0.31%(4000장) 으로도 +0.44%p, 1 epoch 만에 빠르게 수렴.
- 다른 task로 확장 — captioning(GIT) CIDEr +2.3~6.0, segmentation(Segmenter) mIoU +0.32~1.3.
- decoupled가 핵심 — soft 연산을 ToMe의 key에 그냥 적용하면 오히려 성능이 떨어지고(표 6), decoupled embedding을 넣어야 개선된다. 학습 후 임베딩과 key의 유사도 순위 상관(Kendall)이 낮아져, merging 전용 feature가 실제로 분화됐음을 보인다(표 7).
한 줄 정리 & 의의
- merging의 “grouping을 무엇으로 하느냐”를 ViT 중간 feature에서 떼어내 전용 임베딩으로 학습한 Merging 계열. 이산 grouping을 연속 완화해 미분 가능하게 만든 덕에, 그 임베딩만(파라미터 <1%) 학습하면 된다.
- modular 학습이 최대 강점 — 사전학습 모델을 얼린 채 적은 데이터·짧은 epoch으로 merging을 개선한다. 점수 함수를 학습하던 [DynamicViT]·squeeze까지 더한 TPS가 무거운 end-to-end였다면, DTEM은 “merging 정책만 가볍게 갈아끼우는” 방향이다.
- 위치. 추론은 ToMe의 BSM 그대로(=빠름)이되 그 입력 feature를 학습으로 개선한다 → “BSM은 좋은데 feature가 아쉽다”를 정조준. block별 압축률을 학습한 DiffRate와는 직교적(개수 vs feature)이라 결합 여지가 있다고 저자도 짚는다. → Token Reduction 개요