1 minute read

[SSL][CLS] CCSSL: Class-Aware Contrastive Semi-Supervised Learning

  • paper: 69회
  • github: https://github.com/TencentYoutuResearch/Classification-SemiCLS
  • CVPR 2022 accepted (인용수: 69회, ‘24-01-09 기준)
  • downstream task: SSL for (noisy) CLS

1. Motivation

  • unbalanced, out-of-distribution data가 공존하는 unlabeled data를 가정한 real-world data에서 SSL for classification 성능을 높여보자!

2. Contribution

  • In-distributionclass-wise contrastive (pull) & out-of-distributionimage-wise contrastive (push)로 학습하는 CCSSL (Class-aware Contrastive Semi-Supervised Learning)을 제안함

  • Target-reweighting을 통해 noisy label에 대한 implicit noise를 제거함

  • 다른 SOTA SSL과 비교하여 성능 향상을 보임

3. CCSSL

  • Framework

    • Encoder $r=F(Aug(x))$

    • Semi-supervised module : FixMatch 사용

    • Class-Aware contrastive module

      • In-distirbution : pusedo label의 confidence가 threshold $T_{thres}$보다 큰 경우. 같은 class image 간에 pull하도록 contrastive learning (class-wise)
      • Out-of-distribution : pseudo label의 confidence가 threshold $T_{thres}$보다 작은 경우. 자기 자신의 이미지만 pull하고, 나머진 push하도록 consrastive learning (image-wise) $\to$ 기존의 contrastive learning과 동일

      • $P(i)$ : i번째 sample과 positive pair를 이루는 집합의 갯수. mini-batch 단위로 계산
      • $W_{target}$: contrastive matrix $\in \mathbb{R}^{2N \times 2N}$

        • in distirbution 일 경우

        • out-of-distribution 일 경우

    • Foreground Re-weighting

      • Class-aware contrastive module을 통과하더라도, $p > T_{push}$인 경우에도 noise label이 있을 수 있다. 이를 완화 하기 위해 contrastive matrix에 probability score로 weight를 준다. (?)

  • 최종 식

4. Experiments

  • CIFAR-10/CIFAR-100/STL-10

    • 4-shot/25-shot/40-shot learning
  • Semi-iNat 2021 (noisy dataset)

    • 더 noisy한 데이터셋에서 CCSSL의 성능이 더욱 크게 향상된다.
  • Convergence graph

    • SubCon. Loss가 학습 속도를 빠르게 수렴시킨다.
  • Ablation Study

    • Grad-CAM

    • unlabel vs. label 비율 $\mu$에 따른 ablation. 클수록 unlabel의 비중이 증가

      • $T_{push}=0$: 동일 class면 모두 pull하다 보니 noisy label이 많아 unlabel data 비율이 많아지면 오히려 성능 하락 발생함
      • $T_{push}=0.9$: 동일 class 내 p>0.9만 pull하다 보니 많아 unlabel data 비율이 많아지면 성능이 증가함
    • $T_{push}$ 의 따른 ablation

      • Semi-iNat과 같이 noisy level이 높은 data는 threshold = 0.9에서 최적
      • CIFAR100처럼 easy level인 dataset은 threshold = 0.4에서 최적

Updated: