Notice
Recent Posts
Recent Comments
Link
«   2024/05   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
Tags
more
Archives
Today
Total
관리 메뉴

올라프의 AI 공부

Taming Transformers for High-Resolution Image Synthesis 논문 리뷰 (VQ-GAN) 본문

논문 리뷰

Taming Transformers for High-Resolution Image Synthesis 논문 리뷰 (VQ-GAN)

jioniee 2024. 3. 17. 22:55
💡 정리
  1. Problem
    - Transformer : 시퀀셜 데이터에 long-range interaction 학습
    (+) expressive (context-rich)
    (-) computationally infeasible for long sequences (e.g., 고해상도 이미지)

    - CNN : local interaction을 더 우선시하는 inductive bias 가짐
    (+) efficient ($\because$ 이미지 내 강한 지역적인 상호작용에 대한 prior 가지기 때문)

    ⇒ Transformer + CNN 장점 결합, 단점 보완한 effective (from CNN) and expressive (from Transformer) model

  2. Idea
    - 목적 : 트랜스포머의 학습 능력을 최대 메가픽셀 범위의 고해상도 이미지 합성에 담아내는 것
    -> 트랜스포머에 입력되는 인풋 길이를 줄이면 됨
    - CNN + Transformer + GAN
    - VQ-VAE + Perceptual reconstruction loss, adversarial training + sliding window-based sampling

  3. Solution
    (1) 컨볼루션 접근 : context-rich visual parts의 코드북을 효율적으로 학습
    (2) 트랜스포머 접근 : global composition 시각적 분포를 학습하는 모델 학습
    (3) 적대적 접근 : local parts의 dictionary가 지각적으로 중요한 local 구조를 포착할 수 있도록 하여, 트랜스포머 구조가 low-level statistics 모델링할 필요가 없도록 만듦
    - conditional synthesis tasks(i.e., non-spatial information 또는 spatial information)에 쉽게 적용 가능

 

1. Introduction


  • Transformer
    • 적용 범위가 넓어지고 있음
    • 비전 분야의 CNN과 달리, 지역적인 상호작용에 대한 prior가 없기에 built-in inductive prior가 없어, 모든 관계를 다 학습
      • (+) 향상된 표현능력
      • (-) 모든 상호작용을 학습 (모든 pairwise interaction 고려) → quadratically increasing computational costs : pixel이 더 커지는 고해상도 이미지에 갈수록 문제가 됨
    • 트랜스포머는 컨볼루션 구조를 학습하는 경향[16]이 있음. 그렇다면, 비전 모델을 훈련할 때마다 이미지의 local 구조와 규칙성을 처음부터 학습해야 할까? 혹은, 트랜스포머 유연성을 유지하면서 inductive image biases를 가지도록 효율적으로 인코딩할 수 있을까?
      • 가정1. 저해상도 이미지 구조는 local connectivity, 컨볼루션 아키텍처에 의해 에서 잘 설명될 것이다. 하지만, 고해상도에서는 더 높은 semantic level을 가질 때는 유효하지 않을 것이다.
    • CNN은 강한 locality bias를 가질 뿐만 아니라, 모든 포지션에 대해 공유하는 가중치 때문에 spatial invariance에 대해 편향 또한 가지는데, 이것은 input에 대한 더 전반적인 이해가 필요한 경우, 비효율적이다.
  • effective and expressive model (Transformer + CNN)
    • (1) 컨볼루션 접근 : context-rich visual parts의 코드북을 효율적으로 학습
    • (2) 트랜스포머 접근 : global composition 시각적 분포를 학습하는 모델 학습
    • (3) 적대적 접근 : local parts의 dictionary가 지각적으로 중요한 local 구조를 포착할 수 있도록 하여, 트랜스포머 구조가 low-level statistics 모델링할 필요가 없도록 만듦
      • 즉, low-level 은 모두 CNN이 해결하도록 하고, 트랜스포머는 그것의 장점인 long-range relations만 모델링하면 되도록 만들어 고해상도 이미지를 만들어낼 수 있도록 함

 

3. Approach


  • 고해상도 이미지 합성
    • global composition of images 이해가 필요 (locally realistic + globally consistent patterns)
    • IDEA: 이미지를 픽셀로 표현하는 대신 코드북에서 지각적으로 풍부한 이미지 구성 요소의 composition으로 표현
    • composition 설명 길이 크게 줄이며, 트랜스포머 아키텍처로 글로벌 상관 관계를 효율적으로 모델링 가능

 

3.1 Learning an Effective Codebook of Image Constituents for Use in Transformers


(코드북 학습은 CNN으로)

트랜스포머의 입력은 시퀀스 형태

개별적인 픽셀을 만드는 것 대신, 학습된 표현의 discrete codebook을 사용하는 접근을 통해 복잡도 줄일 수 있음

$x \in \mathbb{R}^{H \times W \times 3}$ → $z_\mathbf{q} \in \mathbb{R}^{h \times w \times n_z}$ where $n_z$ =dimensionality of codes

discrete spatial codebook 효과적으로 학습하기 위해, CNN의 inductive bias 포함하도록 하고, neural discrete representation learning (VQ-VAE)의 아이디어 들고 옴

  • ENC $E$ and DEC $G$
    • codebook
      • $\mathcal{Z} = {z_k}_{k=1}^{K} \subset \mathbb{R}^{n_z}$
      • $\hat{z} = E(x) \in \mathbb{R}^{h \times w \times n_z}$
      • $\mathbf{q}(\cdot)$ is element-wise quantization (spatial code $\hat{z}_{ij} \in \mathbb{R}^{n_z}$를 가장 가까운 코드북 entry $z_k$로)
      • $$
        \begin{align}
        z_{\mathbf{q}}
        = \mathbf{q}(\hat{z})
        := \bigg( \underset{z_k \in \mathcal{Z}} \argmin \Vert \hat{z}_{ij} - z_k \Vert \bigg)
        \in \mathbb{R}^{h \times w \times n_z}
        \end{align}
        $$
    • reconstruction
      • 미분 불가능한 quantization operation으로 인해, 역전파는 디코더에서 인코더로 그레디언트를 복사해주는 straight-through gradient estimator 활용
    • $$
      \hat{x} = G(z_\mathbf{q})=G(\mathbf{q}(E(x)))
      $$

⇒ Loss를 통해 모델과 코드북은 end-to-end으로 (각각 reconstruction loss, codebook loss, commitment loss)

 

 

$$ \begin{align} \mathcal{L}_{VQ}(E,G,\mathcal{Z}) = \Vert x-\hat{x} \Vert^2 \Vert \text{sg}[E(x)]-z_\mathbf{q} \Vert _2^2 \Vert \text{sg}[z_\mathbf{q}]- E(x) \Vert _2^2 \end{align} $$

 

Learning a Perceptually Rich Codebook

  • VQ-GAN 제안 배경: 트랜스포머를 사용하여 이미지를 잠재 이미지 구성 요소에 대한 분포로 표현하려면, 압축의 한계를 뛰어넘고 풍부한 코드북을 학습해야 함
    • VQ-VAE 변형
    • discriminator $D$ & perceptual loss[40,30,39,17,47] 활용
      • 더 커진 압축 비율에도 좋은 perceptual quality 가져야 하기 때문
    • VQ-VAE loss
      • $$
        \begin{align}
        \mathcal{L}_{\text{VQ}}(E,G,\mathcal{Z})
        = \Vert x-\hat{x} \Vert _2^2
        + \Vert \text{sg}[E(x)]-z_\mathbf{q} \Vert_2^2 
        + \Vert \text{sg}[z_\mathbf{q}]- E(x)\Vert_2^2 
        \end{align}
        $$
  • VQ-GAN
    • VQ-VAE의 reconstruction loss에 쓰인 $L_2$ loss → perceptual loss [40, 30, 39, 17, 47]로 변형
      • Perceptual Loss 는 GAN에서 사용되는 loss중 하나로 MAE(L1), MSE를 보완하기 위해 만들어진 손실함수
      • perceptual loss는 image-to-image transfer와 고해상도 이미지 합성에서 자주 활용되는 loss로, feature reconstruction loss, style reconstrution loss를 정의해 이미지의 style과 context를 잘 보존하고자 하는 loss임
    • patch-based discriminator $D$로 적대적 훈련 방식 도입
    • $$
      \begin{align}
      \mathcal{L}_{\text{GAN}}({E,G,\mathcal{Z} }, D)
      = [\text{log}D(x)+log(1-D(\hat{x}))]
      \end{align}
      $$

 

  • 결과적으로 아래의 식과 같이 정리할 수 있다
    $$
    \begin{align}
    \mathcal{Q}^* 
    & = \{ E^*, G^*, \mathcal{Z}^* \} \\

    & = \argmin_{E,G,\mathcal{Z}} \max_D 
    \mathbb{E}_{x \sim p(x)} 
    \big[ 
    \mathcal{L}_{\text{VQ}}(E,G,\mathcal{Z}) +
    \lambda \mathcal{L}_{\text{GAN}}(\{E,G,\mathcal{Z} \}, D) 
    \big]
    \end{align}
    $$
    • where adaptive weight $\lambda$ according to

          $$
          \begin{align}
          \lambda = 
          \frac
          {\nabla_{G_L}[\mathcal{L}_{\text{rec}}]}
          {\nabla_{G_L}[\mathcal{L}_{\text{GAN}}]+\delta}
          \end{align}
          $$

          - $L_{rec}$ : the perceptual reconstruction loss (변형된 loss)
          - $\nabla_{G_L}[\cdot]$ : gradient of its input w.r.t. the last layer $L$ of the decoder
          - $\delta=10^{−6}$
      - 모든 곳의 context를 취합하기 위해, 가장 낮은 해상도에 하나의 어텐션 레이어를 적용
          - 잠재 코드를 unrolling할 때 시퀀스 길이를 상당히 줄여줌

 

3.2 Learning the Composition of Images with Transformers


Latent Transformer

$\text{sequence } s \in {0,…,|\mathcal{Z}|-1}^{h \times w}$

s에서 코드북을 매핑시킨 후, Autoregressive하게 데이터 분포 최대화하도록 훈련

 

  • $$ \begin{align} s_{ij}=k \text{ such that } (z_\mathbf{q})_{ij}=z_k. \end{align} $$
  • $$ \begin{align}
    \mathcal{L}_{\text{Transformer}} = \mathbb{E}_{x \sim p(x)} [\textrm{log }p(s)].
    \end{align} $$

 

Conditioned Synthesis

  • 전반적인 이미지 클래스에 관한 single label이 될 수도, 다른 이미지 자체가 될 수도 있음
  • $c$가 spatial extent 가지면, 다른 VQGAN를 학습시켜 $c$에 대해 새로운 코드북 $\mathcal{Z}_c$를 가지고 인덱스 기반 표현 $r \in { 0,…,|\mathcal{Z}_c|-1 }^{h_c \times w_c}$ 가지도록 훈련
    • 코드북에 대한 vector가 이미지에 대해, condition에 대해 하나씩 생기는 것
  • r || s prepend하고, 음수 로그 가능성을 $p(s_i|s_{<i},r)$로 제한함
  • “decoder-only” 전략은 텍스트 요약 작업에서도 성공적이었음

$$
\begin{align}
p(s|c)= \prod_i
p(s_i|s_{<i}, c).
\end{align}
$$

 

Generating High-Resolution Images

  • 트랜스포머의 어텐션 매커니즘은 인풋 s에 대해 시퀀스 길이를 $h \times w$로 제한함
    • VQGAN을 활용해서 다운샘플링 블록 수 m을 조정하여 $H \times W$를 $h=H/2^m \times w=W/2^m$ 으로 줄일 수 있으나, 임계값 m을 초과하면 재구성 품질이 저하됨
    • 메가픽셀 영역에서 이미지를 생성하기 위해서는, 훈련 중에 patch-wise 작업하고, s의 길이를 최대한 가능한 크기로 제한해야 함 (이미지 크기를 최대한 유지해야함)
  • 그런 다음, 이미지 샘플링하기 위해 트랜스포머를 슬라이딩 윈도우 방식으로 활용
  • VQGAN은 활용 가능한 컨텍스트가 다음 둘 중 하나의 상황일 경우, 이미지를 모델링하기에 충분함
    • (1) 데이터셋 통계가 대략 거의 공간적으로 invariant
    • (2) 공간 정보를 활용할 수 있는 경우
  • 실제로는, 필수 조건은 아닌데, 왜냐하면 unconditional image synthesis인 경우, 이미지 좌표를 단순히 조건으로 설정하면 되기 때문임

 

4. Experiments


|Z|= 1024

predict sequences of length 16 · 16, as this is the maximum feasible length to train a GPT2-medium architecture (307 M parameters)

4.1 Attention is All You Need in the Latent Space (CNN에 비해 트랜스포머 장점 유지)


  • Q1) 저해상도에서만 성능 좋음 → This raises the question if our approach retains the advantages of transformers over convolutional approaches.
  • A1) 트랜스포머, CNN 기반 접근 간 비교 실험 (conditional, unconditional tasks에 대해)
    • VQGAN with m=4 blocks
    • Unconditional data : IN, RIN, LSUN-CT
    • Conditional data : D-RIN, S-FLCKR
    • CNN 기반 PixelSNAIL이 2배 더 빨라서, 훈련 속도 같을 때와 훈련 단계가 같을 때를 나누어 결과 보고
      • for a fair comparison, report the negative log-likelihood both for the same amount of training time (P-SNAIL time) and for the same amount of training steps (P-SNAIL steps)
      • 모든 경우에 트랜스포머가 더 좋았음
       

4.2 A Unified Model for Image Synthesis Tasks (컨볼루션 아키텍처 효과까지 통합)

image size 256 × 256, latent size 16 × 16

4.3 Building Context-Rich Vocabularies (코드북 품질)

트랜스포머 아키텍처 고정, 표현에 인코딩되는 컨텍스트 양을 VQGAN의 다운샘플링 블록 수에 따라 변화시키는 실험

4.4 Benchmarking Image Synthesis Results (정량적 비교평가)

 

Class-Conditional Synthesis on ImageNet

...

 

How good is the VQGAN?

...

 

5. Conclusion


트랜스포머가 저해상도 이미지에만 국한되어있던 트랜스포머의 근본적인 문제 해결

이미지를 지각적으로 풍부한 이미지의 구성 요소로 표현하고, 픽셀 공간에서 이미지를 직접 모델링할 때 발생하는 quadratic complexity를 극복

CNN 아키텍처로 구성 요소 모델링 → 그 구성 요소 트랜스포머 아키텍처로 모델링