Cifar-100とResNet18を用いてデータ拡張手法を比較してみた。

ResNet18 (PyTorch) を用いて、Augmentationなし、簡単なAugmentation、mixup、mixcutを比較してみました。

Koya Tango
Koya Tango
Cifar-100とResNet18を用いてデータ拡張手法を比較してみた。

はじめに

近年、Deep Learningより様々な計算機上のタスクの精度の向上を実現しました。そのタスクの一つに画像のクラス分類が挙げられます。画像のクラス分類というのは、与えられた画像に対して正しいクラスを推論するというタスクです。Deep Learningを用いたConvolutional Neural Network (CNN) を用いたモデルを用いると、画像と正しいラベルのペアを訓練するだけで高精度のクラス分類が可能になりました。

しかし、データセットの画像枚数が少なかったり、モデルのテスト時のデータが訓練時のデータと大きくかけ離れてたりしていた場合、あまり精度が出ない場合があります。そういった問題を解決するための手段の一つとしてData Augmentation(データ拡張・水増し)があげられます。Data Augmentationというのは、与えられたデータを前もって処理をすることにより訓練時のデータの多様性を上げて、モデルの汎化性能を上げる手法です。この投稿では、Cifar-100とResNet18を用いてData Augmentation手法であるMixup、Cutout、CutMixを実装し、従来手法と比較しました。

Cifar10/Cifar100データセットとは

cifer10

https://www.cs.toronto.edu/~kriz/cifar.html

Cifar-10/Cifar-100というのはトロント大学が公開している画像データセットです。

  • 解像度は32x32
  • 60000枚の画像 (Train:50000, Test:10000)
  • 10 or 100クラス分類

Cifar-10とCifar-100の両方とも60000枚の画像で出来ているので、簡単な問題には10クラス分類であるCifar-10、複雑な問題には100クラス分類であるCifar-100といった使い分けが出来るのが便利です。

State of The Art (SoTA)

SoTA (現時点の最高精度) はCifar-10ではAccuracyが99.370

Cifar-100ではAccuracyが93.510

本投稿では、より複雑な問題かつ1クラス当たりの画像枚数が少ないCifar-100を用いて実装・実験しました。

モデルの説明

ResNet

resnet18

画像はLearn Open CVより

ResNetは従来のモデルに対して残差ブロックを導入したモデルです。 残差ブロック(上図右)を導入することにより、従来の単純に何層も重ねたモデルより効率よく学習でき高性能なモデルが作成できるといった手法です。(参考: https://deepage.net/deep_learning/2016/11/30/resnet.html)

本投稿では、学習時間と計算機リソースの関係上、上図左に示すような18層からなるResNet18を用いて実験をしました。

データ拡張手法

augment

上図がCutMixに掲載されている、Mixup、Cutout、CutMixのデータオーグメンテーション手法の比較です。この3つについて解説していきます。

Mixup (ICLR2018)

https://arxiv.org/abs/1710.09412

mix

MixupはICLR2018で提案されたデータオーグメンテーション手法です。一言で言うならば、 2つの画像を重ね合わせる オーグメンテーション手法です。

具体的には上図のように犬とネコの画像を重ね合わせて、それぞれの画像の透過度によってラベルを決定するという手法です。論文には以下のような図を交えて、判別面が滑らかになることによるという利点を説明されていました。

Mixup

左のEmpirical Risk Minimization (ERM) は経験的リスク最小化という通常の学習のことで、右のmixupが提案手法です。 青色のシェーディングである判別面(p(y=1|x)のこと)が従来のERMでは急に変化しているのに対して、mixupでは滑らかに変化しており、汎化性能が向上しそうなイメージが掴めると思います。

なぜ採択されたか

2つの画像を混ぜ合わせるようなAugmentationの走りであり、画像を透過させて重ねるという奇抜な手法で精度の向上と判別面を滑らかにするという理論的な展開もわかりやすい点が採択された点だと感じた。

筆者はGANsでもこのテクニックが使えるとも仰っており、現に確かWGANかWGAN-GPではこれに似たテクニックが使われており、応用が期待される。

Cutout (Arxiv論文)

https://arxiv.org/abs/1708.04552

cutout

Cutoutは、画像を一部黒塗りにするData Augmentation手法です。

具体的には上図のように犬の画像の一部を黒塗りにして、より難しい問題をモデルに解かせるという手法です。こうすることにより、上図のような場合は、モデルは犬のお腹やしっぽ等の判別が難しい箇所での判別を解くことになります。その結果モデルの汎化性能が上がるというイメージです。

Cutoutとほぼ同一である提案で、RandomErasing(https://arxiv.org/abs/1708.04896)も提案されている。

CutMix (ICCV2019)

https://arxiv.org/abs/1905.04899

cutmix

CutMixはICCV2019で提案された、2つの画像をツギハギにするData Augmentation手法です。

具体的には上図のように犬とネコの画像をツギハギにして、面積の応じてラベルを設定するという手法です。

論文では下図のようなCAMというCNNモデルがどの部分を見て分類しているのか可視化するという手法を用いて説明しています。

comparison

CutMixはバーナード犬もプードルの時も対応する正しい位置を見ていることが分かります。これは、2つの画像を効率よく学習できることを示しています。さらに、CutMixを用いるとCutoutでも述べたようにバーナード犬のお腹の特徴等の難しい特徴を捉えるような学習になり、より汎化性能が上がることがイメージできると思います。

なぜ採択されたか

関連手法を掛け合わせたような(Mixup+Cutout)シンプルな手法で精度が格段に向上し、それに対する考察もしっかりと書けていることが採択された要因であると感じました。

CutMixのようなアイデアをDiscriminatorに組み込ませたGANs (https://arxiv.org/abs/2002.12655) も登場しており、応用にも期待される。

実験

実装にはPyTorchを用いて実装しました。実行環境には、Google Colaboratoryを採用し、200epoch実行しました。

Optimizerには、SGDを用いて学習率は0.01、momentumは0.9に設定しました。SchedulerにはStepLRを用いて50epochごとに学習率が0.1倍になるように設定しました。

実装したコード(ipynb)はGitHubに公開しています。

Data-Augmentation-Method-Comparison-of-Cifar-100

下記に私が実際に実装したAugmentationのキーとなるコードと実際にColab上で動かせることができるOpen in Colabボタンを示します。

Data Augmentation無し

Open In Colab

シンプルなAugmentation (回転、反転、色変化)

Open In Colab

transform = transforms.Compose(\[
     transforms.RandomRotation(degrees=15),
     transforms.RandomHorizontalFlip(p=0.5),
     transforms.RandomVerticalFlip(p=0.5),
     transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# RandomRotationを15度と縦と横のRandomFlip、色変更のAugmentationを入れています。

Mixup

Open In Colab

def mixup(batch, alpha=1.0):
    input, target = batch

    ### Shuffle Minibatch ###
    indices = torch.randperm(input.size(0))
    input_s, target_s = input[indices], target[indices]

    lamb = np.random.beta(alpha, alpha)

    input = lamb * input + (1-lamb) * input_s
    target = (target, target_s, lamb)

    return input, target

# lossはこのようにinterpolateなロスにする
target, target_s, lamb = labels[0].cuda(), labels[1].cuda(), labels[2]
loss = lamb * criterion(outputs, target) + (1 - lamb) * criterion(outputs, target_s)

Cutout

Open In Colab

def cutout(batch):
    input, target = batch

    input_s = torch.zeros_like(input)
    lamb = np.random.uniform(0.0, 1.0)

    H, W = input.shape[2:]
    r_x = np.random.uniform(0, W)
    r_y = np.random.uniform(0, H)
    r_w = W * np.sqrt(1 - lamb)
    r_h = H * np.sqrt(1 - lamb)
    x1 = int(np.round(max(r_x - r_w / 2, 0)))
    x2 = int(np.round(min(r_x + r_w / 2, W)))
    y1 = int(np.round(max(r_y - r_h / 2, 0)))
    y2 = int(np.round(min(r_y + r_h / 2, H)))

    input[:, :, x1:x2, y1:y2] = input_s[:, :, x1:x2, y1:y2]

    return input, target

# lossは通常のロス
loss = criterion(outputs, target)

CutMix

Open In Colab

def cutmix(batch):
    input, target = batch

    ### Shuffle Minibatch ###
    indices = torch.randperm(input.size(0))
    input_s, target_s = input[indices], target[indices]

    lamb = np.random.uniform(0.0, 1.0)

    H, W = input.shape[2:]
    r_x = np.random.uniform(0, W)
    r_y = np.random.uniform(0, H)
    r_w = W * np.sqrt(1 - lamb)
    r_h = H * np.sqrt(1 - lamb)
    x1 = int(np.round(max(r_x - r_w / 2, 0)))
    x2 = int(np.round(min(r_x + r_w / 2, W)))
    y1 = int(np.round(max(r_y - r_h / 2, 0)))
    y2 = int(np.round(min(r_y + r_h / 2, H)))

    input[:, :, x1:x2, y1:y2] = input_s[:, :, x1:x2, y1:y2]
    target = (target, target_s, lamb)

    return input, target

# lossはこのようにinterpolateなロスにする
target, target_s, lamb = labels[0].cuda(), labels[1].cuda(), labels[2]
loss = lamb * criterion(outputs, target) + (1 - lamb) * criterion(outputs, target_s)

結果

テストのデータセットでのAccuracyの比較を以下の表に示します。

Augmentation無し シンプルなAugmentation Mixup Cutout CutMix
29.37% 41.88%  (+12.51) 41.68% (+12.31) 45.3% (+15.93) 47.69% (+18.32)

シンプルなAugmentationでも12.51%も向上しており、Data Augmentationのインパクトが理解できる。CutoutとCutMixはシンプルなAugmentationより精度が向上しており、それぞれ15.93%と18.32%も向上した。しかし、Mixupは精度が0.2%低下している。これはCifar-100の解像度が32x32と低解像度の画像であるのに対して、透過画像を重ね合わせるというMixupのアプローチでは分類タスクが難しくなり精度があまり向上しなかったと考察できる。

考察

テストのデータセットでのAccuracyの推移をプロットしたものを以下に示す。縦軸がAccuracyで横軸がEpoch数を示す。

Data Augmentation無し

60epochになるまでAccuracyが乱れており、データオーグメンテーション無しでは学習が不安定なことが分かる。

シンプルなAugmentation

Augmentation無しと比較すると安定して学習が進んでいることが分かる。これだけでもデータオーグメンテーションの必要性が分かる。

Mixup

学習自体は問題なく進んではいるが、透過画像というタスクでは問題設定が難しいのかシンプルなAugmentationと比較するとやや不安定にも見える。

Cutout

精度の上がり方はなだらかではあるものの、精度が順調に上がっている印象を受ける。これはCutoutという一部を塗りつぶすという複雑な問題設定と、物体の様々な形状を学習できるという汎化性能の向上が要因の一つであると考えられる。

CutMix

Cutoutより精度の向上はよりなだらかではあるが、最も精度が向上している。2つの画像のツギハギ画像を分類するという複雑ではあるが、Mixupの透過画像のような判断が難しいタスクではないので、学習が上手く行えたと考察できる。

まとめ

本投稿では、Mixup、Cutout、CutMixというAugmentation手法を実装・比較を行った。Cifar-100を用いた実験の結果、CutMixが最も良い精度となりAugmentation無しより18.32%も向上し、シンプルなAugmentationより精度が5.81%も向上した。Cutoutも同様にシンプルなAugmentationよりも精度は向上したが、Mixupに関してはシンプルなAugmentationと精度では大きな違いが見られなかった。

これらの手法は単純な分類タスクだけではなくGANs等の生成タスクにも応用できる余地があり今後に期待したい。

謝辞

本投稿は、Visual Media (映像メディア学)の課題の一貫として行われました。この機会を与えてくださった山崎先生に、この場で心より感謝申し上げます。