2 minute read

[LG] Spot the Error: Non-autoregressive Graphic Layout Generation with Wireframe Locator

  • paper: http://arxiv.org/pdf/2401.16375
  • github: https://github.com/ffffatgoose/SpotError
  • AAAI 2024 accpeted (인용수: 1회, 2024-05-22 기준)
    • MicroSoft
  • downstream task: Layout Generation

1. Motivation

  • AR (Auto Regressitve) model (ex. LayoutTransformer)와 NAR (Non-AutoRegressitve) model (ex. BLT)를 비교하다 보니, NAR 기반 모델이 element (x, y, w, h, c)간의 order 변화에 더 robust하다는 특성을 발견함

  • 또한 NAR은 AR에 비해 causal attention으로 이전 step의 prediction이 다음 step에 영향을 끼치지 않고, 한번에 병렬적으로 모든 masked token을 예측하기에 한 gt에 겹치는 prediction 생성하는 경향이 있음 $\to$ overlap problem

    • 이를 해결하기 위해 BLT는 heuristic하게 prediction score기준으로 masking하고자 하는 token을 찾았음

      • iterative하게 refine을 수행하면 NAR성능이 좋아지는 것을 확인

      • 하지만 이는 decoder의 학습된 distribution에 종속되어, error propagation을 유발할 수 있음

      $\to$ iterative decoding을 learnable locator를 두어 end-to-end로 피드백하면 어떨까?

2. Contribution

  • AR과 NAR의 장단점을 분석함. NAR이 AR에 비해 input element order에 robust한 장점이 있는 반면, repetitive token generation하는 단점이 있음

  • 이를 해결하고자 iterative-based NAR을 기반으로 learnable locator를 두는 방법을 제안함

  • locator의 input으로 pixel-level (Image)와 object-level(entity) 비교 실험 결과, image level input의 locator가 더 noisy label을 잘 찾는 특성이 있음을 발견함 $\to$ image-level 기반으로 locator를 설계

3. Spot the Error

  • preliminaries

    • Layout Generation : Sequence generation task

    • AR : 이전 step의 sequence를 condition으로 현재 step의 sequence를 예측

    • NAR: bi-directional attention을 통해 자기 자신 외의 다른 sequence를 condition으로 자기 sequence를 예측

  • Finding1. NAR은 AR보다 input element order에 robust함 $\to$ 앞서 설명

    • overlap 기준으로 비교

    • D: 전체 layout의 갯수
    • N: layout의 element 갯수?
  • FInding2. NAR은 repetitive token generation 이슈가 있음

    • full bi-directional attention 기반으로 token prediction을 하다 보니 positional encoding만으로 모델이 서로 다른 token을 분류하는게 어려 $\to$ overlap이 커짐
    • iterative-based NAR은 이를 완화시켜줌 (앞서 설명)
  • Layout Representation : Object vs. Pixel

    • Refine하는 주체인 Locator의 입력으로 object, pixel중 비교

      • Object: Transformer 입력으로 object entity을 주고, binary classification head를 두어 real layout, noisy layout 분류하는 task 수행하여 noisy layout일 경우 masking

      • Pixel: Faster-RCNN 입력으로 rendering된 이미지를 주고, RoI를 통과시켜 FC layer를 붙여 마찬가지로 noisy layout일 경우 masking

        $\to$앞서 본 표대로 Pixel기반이 성능이 좋아, learnable Locator로 채택!

  • overall architecture

    • Decoder: BLT을 기반으로 사용

      • Positional Encoder만 3개의 entity 추가

        • $\gamma_1$: token sentence index
        • $\gamma_2$: element-level index
        • $\gamma_3$: element의 갯수
      • Training : BLT와 동일하게 masking된 token을 예측

      • Inference: unknown token은 모두 masking후 예측

    • Wireframe decoder

      • 앞선 BLT decoder 출력과 Faster-RCNN의 image feature 출력간에 Cross attention 수행

        • $\phi(s)$: query matrix로 BLT decoder 출력
        • $\phi(I)$: key matrix로 Faster-RCNN의 출력
    • Locator

      • 기존 BLT의 low predicted token masking은 model의 error propagation 문제가 있음

      • learnable locator의 경우 Faster-RCNN을 사용

        • W$_c$, b$_c$: mask classifier의 weight & bias
      • BLT decoder의 generated (noisy) layout과 GT layout간의 동일 class entity간의 Bipartie matching 후, distance를 산출하여 threshold를 넘는 token을 masking

4. Experiments

  • $C \to SP$

  • $CS \to P$

  • Qualitative result

  • Cross Attention Mask

  • Overlap performance vs. BLT

  • Ablation study

    • Locator

Updated: