[Flamingo] Flamingo: a Visual Language Model for Few-Shot Learning
Jean-Baptiste Alayrac, Jeff Donahue, Pauline Luc, Antoine Miech, … Karen Simonyan · DeepMind
한 줄 요약. 이미 학습된 비전 인코더(frozen)와 LLM(frozen)을 그대로 두고, 그 사이에 Perceiver Resampler(가변 길이 시각 특징 → 고정 64개 토큰)와 GATED XATTN-DENSE(LLM 층 사이에 끼우는 gated cross-attention)만 새로 학습한다. 텍스트·이미지가 임의로 섞인(interleaved) 시퀀스를 받아 텍스트를 생성하는 구조라, GPT-3처럼 few-shot in-context learning이 가능. 가중치를 전혀 바꾸지 않고 예시 몇 개만 주는 것으로 16개 멀티모달 task에서 few-shot SOTA, 그중 6개는 수천 배 많은 데이터로 fine-tune한 SOTA까지 능가.
배경
LLM(GPT-3 등)은 텍스트 인터페이스만으로 새 task에 적응한다 — 프롬프트에 예시 몇 개(input/output)와 질의를 넣으면, 모델이 그 뒤를 이어 생성하며 답을 낸다(in-context few-shot learning). 가중치 갱신이 없다.
이미지/비디오 이해(분류·캡셔닝·VQA)도 “시각 입력에 조건을 건 텍스트 예측” 으로 보면, 같은 few-shot 방식을 쓸 수 있지 않을까?
문제는 두 가지다.
- 무거운 사전학습을 버리면 안 된다. 비전 모델과 LLM은 각자 막대한 연산으로 사전학습됐다. 처음부터 멀티모달로 다시 학습하는 건 비싸고, 기존 지식도 잃는다.
- 입력이 멀티모달 프롬프트여야 한다. 텍스트만이 아니라 이미지·비디오가 텍스트와 임의로 섞인 시퀀스를 그대로 먹고, 그 길이(shot 수)에 유연해야 한다.
핵심 아이디어
Flamingo는 두 개의 frozen 모델을 다리로 잇는다. 비전 인코더와 LLM은 얼린 채, 사이에 새 모듈만 from-scratch로 학습한다 — 사전학습 지식을 보존하면서 시각 조건화를 더한다.
contrastive로 사전학습한 NFNet-F6. 픽셀 → 시공간 특징 그리드. 비디오는 1 FPS 프레임 + 학습된 temporal embedding.
Perceiver Resampler + GATED XATTN-DENSE. 가변 특징을 고정 토큰으로 줄이고, LLM에 시각 정보를 주입.
Chinchilla 1.4B / 7B / 70B → Flamingo-3B / 9B / 80B. 시각 조건 하에 다음 토큰을 자기회귀 생성.
모델은 시각에 조건을 건 자기회귀 텍스트 생성기다. interleaved 이미지/비디오 $x$ 에 대해 텍스트 $y$ 의 likelihood를
\[p(y \mid x) = \prod_{\ell} p\big(y_\ell \mid y_{<\ell},\, x_{\le \ell}\big)\]로 모델링한다($x_{\le\ell}$ = 토큰 $\ell$ 앞에 등장한 이미지/비디오들).
방법
1) Perceiver Resampler — 가변 특징 → 고정 64토큰
비전 인코더가 내는 특징 수는 이미지·비디오마다 다르다. Perceiver Resampler는 미리 정한 개수의 latent query를 두고, 이들이 시각 특징에 cross-attend 해 항상 고정된 64개 시각 토큰을 출력한다.
- 덕분에 뒤따르는 vision-text cross-attention 비용이 입력 크기와 무관하게 일정해진다.
- ablation상 단순 Transformer나 MLP보다 이 resampler가 낫다.
2) GATED XATTN-DENSE — frozen LLM에 시각 주입
얼린 LLM 블록 사이사이에 새 cross-attention + dense(FFW) 블록을 끼운다. query는 텍스트, key·value는 시각 토큰에서 온다.
핵심은 tanh 게이팅: 새 층의 출력에 $\tanh(\alpha)$ 를 곱해 residual에 더하는데, $\alpha$ 는 0으로 초기화된 학습 스칼라다. 따라서 학습 초기에는 $\tanh(0)=0$ 이라 모델이 원래 LLM과 정확히 같은 출력을 내고, 거기서부터 점진적으로 시각 정보를 섞어간다 → 안정성·최종 성능 ↑.
def gated_xattn_dense(y, x, alpha_xattn, alpha_dense):
# 1. Gated Cross-Attention (q=텍스트 y, kv=시각 x)
y = y + tanh(alpha_xattn) * attention(q=y, kv=x)
# 2. Gated Feed-Forward
y = y + tanh(alpha_dense) * ffw(y)
# 3. 기존 frozen self-attention + FFW (그대로)
y = y + frozen_attention(q=y, kv=y)
y = y + frozen_ffw(y)
return y # 시각 정보가 반영된 언어 표현
비전 인코더와 Perceiver Resampler 크기는 고정한 채 LLM만 키워 3B/9B/80B를 만든다(추가 학습 파라미터는 전체 대비 작다).
3) Interleaved 시퀀스 & single-image 마스킹
웹 문서의 이미지를 텍스트 위치에 <image> 태그로 끼우고, 이미지 직전과 문서 끝에 <EOC>(end-of-chunk) 토큰을 둔다. 그리고 cross-attention을 마스킹해, 각 텍스트 토큰은 바로 직전 이미지의 시각 토큰만 본다(이전 이미지들과의 의존성은 LLM의 self-attention으로 간접 전달).
- 한 번에 한 이미지만 직접 보지만, 이미지 개수에 무관하게 일반화된다.
- 학습은 시퀀스당 이미지 ≤ 5개로만 하고도, 평가 때는 최대 32-shot까지 활용. (모든 이미지를 직접 cross-attend 하는 방식보다 효과적 — ablation)
4) 학습 데이터 (전부 웹 스크랩, ML 라벨 0)
| 데이터셋 | 종류 | 규모 |
|---|---|---|
| M3W (MultiModal MassiveWeb) | interleaved 이미지+텍스트 (웹페이지 DOM에서 추출) | 약 4,300만 페이지 |
| ALIGN | 이미지–alt텍스트 쌍 | 18억 |
| LTIP (Long Text & Image Pairs) | 고품질·긴 설명 이미지–텍스트 | 3.12억 |
| VTP (Video & Text Pairs) | 비디오–텍스트 (평균 ~22초) | 2,700만 |
여러 데이터셋을 혼합해 학습한다. interleaved인 M3W가 few-shot 능력의 핵심(ablation에서 빼면 성능 급락).
5) 적응 — few-shot in-context learning
학습이 끝나면, 가중치를 그대로 둔 채 멀티모달 프롬프트만으로 새 task를 푼다 — (이미지, 텍스트) 지원 예시들을 이어 붙이고 질의 이미지를 넣어 프롬프트를 만든다. open-ended는 beam search, close-ended는 후보 랭킹으로 평가.
결과
- few-shot SOTA — 고려한 16개 이미지/비디오 task 전반에서 기존 zero-/few-shot 방법을 큰 폭으로 능가. 예시 4개만으로도 효과적 적응.
- fine-tuned SOTA까지 추월 — 단일 가중치 + 예시 32개만으로 6개 task에서, 수십만 장으로 fine-tune한 최고 모델(약 1,000배 많은 데이터) 보다 높다.
- fine-tune 시 추가 SOTA — 더 큰 라벨 예산을 쓰면 VQAv2·VATEX·VizWiz·MSRVTTQA·HatefulMemes 5개에서 새 SOTA.
- 스케일링 — 모델 크기와 shot 수가 커질수록 성능이 단조 증가.
- ablation 요지 — ① interleaved M3W가 필수, ② tanh 게이팅이 안정성·성능에 중요, ③ GATED XATTN-DENSE가 대안 구조보다 우수, ④ single-image 마스킹이 전부 cross-attend 하는 것보다 낫다.
한 줄 정리 & 의의
- frozen 비전 + frozen LLM을 가벼운 다리(Perceiver Resampler + gated cross-attention)로만 잇는다는, 이후 VLM의 지배적 설계(인코더·LLM은 얼리고 connector만 학습)를 연 대표 모델.
- GPT-3의 텍스트 few-shot ICL을 멀티모달로 옮겨, “예시 몇 개로 새 시각 task를 푸는” 패러다임을 처음 본격적으로 보였다. interleaved 학습이 그 능력의 근원.
- 위치(connector 계열). Flamingo는 Cross-Attention 방식 — LLM 안에 시각 정보를 cross-attention으로 주입한다. 이후 [BLIP-2]는 Q-Former, LLaVA는 단순 projection(MLP) 으로 갈라진다. → VLM 개요