[CrossGET] Cross-Guided Ensemble of Tokens for Accelerating Vision-Language Transformers

Encoder ICML 2024

Dachuan Shi, Chaofan Tao, … Jiaqi Wang · Tsinghua University / Shanghai AI Lab 외

arXiv GitHub

한 줄 요약. Vision-Language Transformer를 토큰을 병합(ensemble)해 가속하는 범용 프레임워크. 인코더(비전·언어 분기)의 Self-Attention과 FFN 사이에 끼워 토큰을 줄인다. 핵심 둘 — ① Cross-Guided Matching & Ensemble: learnable cross token이 상대 모달리티 정보를 학습해, cross-modal 중요도로 어떤 토큰을 합칠지 가이드하고 가중합으로 병합. ② Complete-Graph Soft Matching: 비반복·병렬화 가능한 토큰 매칭 알고리즘. CLIP(modality-independent)·BLIP-2(modality-dependent)·멀티모달 LLM에 모두 적용.

배경

VLM 가속에 토큰 축소가 중요하지만, 기존 방법은 적용 범위가 좁았다.

  • 모달리티 의존성 — 모달리티가 독립인 모델(CLIP)과 의존인 모델(BLIP-2)은 구조가 달라, 한 방법이 양쪽에 두루 통하기 어렵다.
  • 계산 순서 의존 — 한 모달리티가 다른 모달리티 정보를 쓰려면 그 모달리티 추론이 끝나길 기다려야 한다(순서 제약).

두 모달리티 모두에 통하면서, 순서에 묶이지 않고 상대 모달리티의 단서로 토큰을 합칠 수 없을까?

Figure 1. CrossGET 개요. intra-modal complete-graph soft matching의 토큰 유사도 + cross-modal 가이드의 토큰 중요도를 함께 고려해 어떤 토큰을 합칠지 정하고, 그 중요도로 가중합해 ensemble을 출력한다. modality-independent·dependent 모델 모두에 적용.

핵심 아이디어

① Cross-Guided Matching & Ensemble

learnable cross token을 각 모달리티에 주입 — 서로의 cross token 거리를 좁혀 상대 모달리티 정보를 학습한다. 추론 때 cross token이 상대 모달리티의 대리(proxy)cross-modal 중요도를 제공 → 그 중요도로 토큰 매칭을 가이드하고 가중합으로 병합(ensemble). 순서 제약(B 추론을 기다림)을 끊는다.

② Complete-Graph Soft Matching

어떤 토큰끼리 합칠지 정하는 매칭을 완전 그래프에서 비반복·병렬화 가능하게 푼다. 신뢰성 있는 매칭 결과를 효율적으로 얻는다.

  • 병합 vs 가지치기 — CrossGET은 토큰을 버리는(prune) 게 아니라 합쳐서(ensemble) 수를 줄여 정보 손실을 줄인다.
  • 학습 — cross token이 learnable이라 가벼운 fine-tuning이 필요(training-free 아님).
  • 삽입 위치 — 비전·언어 분기의 Self-Attention과 FFN 사이.

적용·평가

항목 내용
적용 모델 CLIP(modality-independent) · BLIP · BLIP-2(BLIP2-OPT 6.7B) — BLIP-2식 MLLM 파이프라인(LLaVA·MiniGPT-4·mPLUG-Owl 계열)에 token ensemble을 처음 적용
데이터셋 Flickr30K · NLVR2 · COCO Caption · NoCaps · VQA2.0
Task Image-Text Retrieval(Flickr30K) · Visual Reasoning(NLVR2) · Image Captioning(COCO) · Novel Object Captioning(NoCaps) · VQA(VQA2.0)
대표 결과 BLIP·NLVR2 GFLOPs 57%↓(throughput +93%, −0.2%) · BLIP2-OPT 33~64% 가속(−0.97%)

결과

Table 1. 다양한 비전-언어 task(retrieval·reasoning·captioning·VQA)에서 CrossGET의 연산 절감 대비 성능. 고전 멀티모달 구조와 멀티모달 LLM 모두에서 큰 가속에도 성능 하락이 미미.
  • 효율 — BLIP·NLVR2에서 GFLOPs 57%↓(throughput +93%, −0.2%), BLIP2-OPT 33~64% 가속(−0.97%).
  • 범용성 — CLIP·BLIP·BLIP-2부터 멀티모달 LLM까지, retrieval·reasoning·captioning·VQA 전반에서 가속에도 성능 하락이 미미.

한 줄 정리 & 의의

  • cross-modal 가이드로 토큰을 “합치는(ensemble)” 가속 프레임워크. ① cross token으로 상대 모달리티 중요도를 받아 매칭·가중합 병합, ② complete-graph soft matching으로 신뢰성·병렬성 확보.
  • 차별점. 토큰을 버리는 가지치기와 달리 병합으로 손실을 줄이고, cross token으로 modality-independent(CLIP)·dependent(BLIP-2)·MLLM에 두루 적용 + 순서 제약 해소. (cf. ToMe의 단일 모달 병합을 cross-modal로 확장한 격)
  • 위치. Encoder — VLT 인코더(비전·언어 분기) 안에서 줄인다. → Efficient VLM 개요