[Flamingo] Flamingo: a Visual Language Model for Few-Shot Learning

VLM NeurIPS 2022

Jean-Baptiste Alayrac, Jeff Donahue, Pauline Luc, Antoine Miech, … Karen Simonyan · DeepMind

arXiv

한 줄 요약. 이미 학습된 비전 인코더(frozen)LLM(frozen)을 그대로 두고, 그 사이에 Perceiver Resampler(가변 길이 시각 특징 → 고정 64개 토큰)와 GATED XATTN-DENSE(LLM 층 사이에 끼우는 gated cross-attention)만 새로 학습한다. 텍스트·이미지가 임의로 섞인(interleaved) 시퀀스를 받아 텍스트를 생성하는 구조라, GPT-3처럼 few-shot in-context learning이 가능. 가중치를 전혀 바꾸지 않고 예시 몇 개만 주는 것으로 16개 멀티모달 task에서 few-shot SOTA, 그중 6개는 수천 배 많은 데이터로 fine-tune한 SOTA까지 능가.

배경

LLM(GPT-3 등)은 텍스트 인터페이스만으로 새 task에 적응한다 — 프롬프트에 예시 몇 개(input/output)와 질의를 넣으면, 모델이 그 뒤를 이어 생성하며 답을 낸다(in-context few-shot learning). 가중치 갱신이 없다.

Figure 1. Flamingo-80B의 few-shot 프롬프팅 예시 — 예시 몇 개만 주면 분류·캡셔닝·VQA에 가중치 변경 없이 즉석 적응(위), out-of-the-box 멀티이미지 시각 대화(아래).

이미지/비디오 이해(분류·캡셔닝·VQA)도 “시각 입력에 조건을 건 텍스트 예측” 으로 보면, 같은 few-shot 방식을 쓸 수 있지 않을까?

문제는 두 가지다.

  • 무거운 사전학습을 버리면 안 된다. 비전 모델과 LLM은 각자 막대한 연산으로 사전학습됐다. 처음부터 멀티모달로 다시 학습하는 건 비싸고, 기존 지식도 잃는다.
  • 입력이 멀티모달 프롬프트여야 한다. 텍스트만이 아니라 이미지·비디오가 텍스트와 임의로 섞인 시퀀스를 그대로 먹고, 그 길이(shot 수)에 유연해야 한다.

핵심 아이디어

Flamingo는 두 개의 frozen 모델을 다리로 잇는다. 비전 인코더와 LLM은 얼린 채, 사이에 새 모듈만 from-scratch로 학습한다 — 사전학습 지식을 보존하면서 시각 조건화를 더한다.

① Vision Encoder (frozen)

contrastive로 사전학습한 NFNet-F6. 픽셀 → 시공간 특징 그리드. 비디오는 1 FPS 프레임 + 학습된 temporal embedding.

② Connector (trainable)

Perceiver Resampler + GATED XATTN-DENSE. 가변 특징을 고정 토큰으로 줄이고, LLM에 시각 정보를 주입.

③ LLM (frozen)

Chinchilla 1.4B / 7B / 70B → Flamingo-3B / 9B / 80B. 시각 조건 하에 다음 토큰을 자기회귀 생성.

Figure 3. Flamingo 아키텍처. interleaved 시각+텍스트 입력을 받아 자유형 텍스트를 생성한다. frozen Vision Encoder → Perceiver Resampler가 고정 개수 시각 토큰을 만들고, frozen LM 층 사이에 끼운 GATED XATTN-DENSE가 그 토큰을 cross-attention으로 주입.

모델은 시각에 조건을 건 자기회귀 텍스트 생성기다. interleaved 이미지/비디오 $x$ 에 대해 텍스트 $y$ 의 likelihood를

\[p(y \mid x) = \prod_{\ell} p\big(y_\ell \mid y_{<\ell},\, x_{\le \ell}\big)\]

로 모델링한다($x_{\le\ell}$ = 토큰 $\ell$ 앞에 등장한 이미지/비디오들).

방법

1) Perceiver Resampler — 가변 특징 → 고정 64토큰

비전 인코더가 내는 특징 수는 이미지·비디오마다 다르다. Perceiver Resampler는 미리 정한 개수의 latent query를 두고, 이들이 시각 특징에 cross-attend 해 항상 고정된 64개 시각 토큰을 출력한다.

  • 덕분에 뒤따르는 vision-text cross-attention 비용이 입력 크기와 무관하게 일정해진다.
  • ablation상 단순 Transformer나 MLP보다 이 resampler가 낫다.

2) GATED XATTN-DENSE — frozen LLM에 시각 주입

얼린 LLM 블록 사이사이에 새 cross-attention + dense(FFW) 블록을 끼운다. query는 텍스트, key·value는 시각 토큰에서 온다.

핵심은 tanh 게이팅: 새 층의 출력에 $\tanh(\alpha)$ 를 곱해 residual에 더하는데, $\alpha$ 는 0으로 초기화된 학습 스칼라다. 따라서 학습 초기에는 $\tanh(0)=0$ 이라 모델이 원래 LLM과 정확히 같은 출력을 내고, 거기서부터 점진적으로 시각 정보를 섞어간다 → 안정성·최종 성능 ↑.

def gated_xattn_dense(y, x, alpha_xattn, alpha_dense):
    # 1. Gated Cross-Attention  (q=텍스트 y, kv=시각 x)
    y = y + tanh(alpha_xattn) * attention(q=y, kv=x)
    # 2. Gated Feed-Forward
    y = y + tanh(alpha_dense) * ffw(y)
    # 3. 기존 frozen self-attention + FFW (그대로)
    y = y + frozen_attention(q=y, kv=y)
    y = y + frozen_ffw(y)
    return y   # 시각 정보가 반영된 언어 표현
Figure 4. GATED XATTN-DENSE. frozen LM 층 사이에 새 cross-attention(+dense)을 삽입한다 — query=언어, key·value=시각 토큰. tanh 게이팅($\alpha$ init 0)으로 학습 초기엔 원래 LM과 동일하게 동작해 안정적.

비전 인코더와 Perceiver Resampler 크기는 고정한 채 LLM만 키워 3B/9B/80B를 만든다(추가 학습 파라미터는 전체 대비 작다).

3) Interleaved 시퀀스 & single-image 마스킹

웹 문서의 이미지를 텍스트 위치에 <image> 태그로 끼우고, 이미지 직전과 문서 끝에 <EOC>(end-of-chunk) 토큰을 둔다. 그리고 cross-attention을 마스킹해, 각 텍스트 토큰은 바로 직전 이미지의 시각 토큰만 본다(이전 이미지들과의 의존성은 LLM의 self-attention으로 간접 전달).

  • 한 번에 한 이미지만 직접 보지만, 이미지 개수에 무관하게 일반화된다.
  • 학습은 시퀀스당 이미지 ≤ 5개로만 하고도, 평가 때는 최대 32-shot까지 활용. (모든 이미지를 직접 cross-attend 하는 방식보다 효과적 — ablation)

4) 학습 데이터 (전부 웹 스크랩, ML 라벨 0)

데이터셋 종류 규모
M3W (MultiModal MassiveWeb) interleaved 이미지+텍스트 (웹페이지 DOM에서 추출) 약 4,300만 페이지
ALIGN 이미지–alt텍스트 쌍 18억
LTIP (Long Text & Image Pairs) 고품질·긴 설명 이미지–텍스트 3.12억
VTP (Video & Text Pairs) 비디오–텍스트 (평균 ~22초) 2,700만

여러 데이터셋을 혼합해 학습한다. interleaved인 M3W가 few-shot 능력의 핵심(ablation에서 빼면 성능 급락).

5) 적응 — few-shot in-context learning

학습이 끝나면, 가중치를 그대로 둔 채 멀티모달 프롬프트만으로 새 task를 푼다 — (이미지, 텍스트) 지원 예시들을 이어 붙이고 질의 이미지를 넣어 프롬프트를 만든다. open-ended는 beam search, close-ended는 후보 랭킹으로 평가.

결과

Table 1. SOTA 비교. 단일 Flamingo가 적은 예시(as few as 4)로 기존 zero-/few-shot을 큰 폭 상회. 예시 32개·가중치 미수정만으로도 수천 배 많은 데이터로 fine-tune한 SOTA를 여러 task에서 능가(볼드=최고 few-shot, 밑줄=전체 최고).
  • few-shot SOTA — 고려한 16개 이미지/비디오 task 전반에서 기존 zero-/few-shot 방법을 큰 폭으로 능가. 예시 4개만으로도 효과적 적응.
  • fine-tuned SOTA까지 추월 — 단일 가중치 + 예시 32개만으로 6개 task에서, 수십만 장으로 fine-tune한 최고 모델(약 1,000배 많은 데이터) 보다 높다.
  • fine-tune 시 추가 SOTA — 더 큰 라벨 예산을 쓰면 VQAv2·VATEX·VizWiz·MSRVTTQA·HatefulMemes 5개에서 새 SOTA.
  • 스케일링 — 모델 크기와 shot 수가 커질수록 성능이 단조 증가.
  • ablation 요지 — ① interleaved M3W가 필수, ② tanh 게이팅이 안정성·성능에 중요, ③ GATED XATTN-DENSE가 대안 구조보다 우수, ④ single-image 마스킹이 전부 cross-attend 하는 것보다 낫다.

한 줄 정리 & 의의

  • frozen 비전 + frozen LLM을 가벼운 다리(Perceiver Resampler + gated cross-attention)로만 잇는다는, 이후 VLM의 지배적 설계(인코더·LLM은 얼리고 connector만 학습)를 연 대표 모델.
  • GPT-3의 텍스트 few-shot ICL을 멀티모달로 옮겨, “예시 몇 개로 새 시각 task를 푸는” 패러다임을 처음 본격적으로 보였다. interleaved 학습이 그 능력의 근원.
  • 위치(connector 계열). Flamingo는 Cross-Attention 방식 — LLM 안에 시각 정보를 cross-attention으로 주입한다. 이후 [BLIP-2]는 Q-Former, LLaVA는 단순 projection(MLP) 으로 갈라진다. → VLM 개요