Do a tutorial - Data Augmentation with pytorch

Ulf Hamster 6 min.
python image classification image processing pytorch torchvision transformation feature engineering

Load Packages

%%capture 
!pip install torchvision==0.4.2
# load packages
import torch
import numpy as np
import torchvision as tv

# check version
print(f"torch version: {torch.__version__}")

# set GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"device type: {device}")

# reproducibility
np.random.seed(42)  # numpy seed
torch.manual_seed(42)  # pytorch seed
if torch.backends.cudnn.enabled:  # CuDNN deterministic mode
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
torch version: 1.3.1
device type: cpu
import os

# image processing
import PIL  # Pillow
import requests
from io import BytesIO

# text processing
import json

# visualization
import matplotlib.pyplot as plt
%matplotlib inline

Load Demo Image

# create image folder
DATA_DIR = "tmp/"
CLASS_DIR = "chair/"
os.makedirs(DATA_DIR + CLASS_DIR, exist_ok=True)

# download image
FILE_URL = "https://upload.wikimedia.org/wikipedia/commons/9/9a/Fat_spiderman_in_Madrid.jpg"
response = requests.get(FILE_URL)  # download image
open(DATA_DIR + CLASS_DIR + "test123.jpg", "wb").write(response.content)

!ls {DATA_DIR + CLASS_DIR}
test123.jpg
# eyeballing the image reveals two objects
img = PIL.Image.open(DATA_DIR + CLASS_DIR + "test123.jpg", "r")
print(img.size)
plt.figure(figsize=(8, 8)); plt.imshow(img);
(1944, 2592)

png

Before Cropping - Black Pixels

Do this before resizing (e.g. RandomResizedCrop). The following transformation will lead to empty space or black pixels.

Rotate (RandomRotation)

The default in RandomRotation

from torchvision.transforms import (Compose, RandomRotation)

trans = Compose([
    RandomRotation(degrees=(-5, 5), expand=False, 
                   resample=PIL.Image.BILINEAR)
])

n_augm = 4
dataset = tv.datasets.ImageFolder("tmp/", transform=trans)
%time augm_images = np.hstack([np.asarray(dataset[0][0]) for _ in range(n_augm)])
plt.figure(figsize=(2*n_augm, 2)); plt.imshow(augm_images);
CPU times: user 1.69 s, sys: 95.1 ms, total: 1.79 s
Wall time: 1.82 s

png

RandomRotation(degrees=(-5, 5), expand=False) is equivalent to RandomAffine(degrees=(-5, 5))

# Randomly crop non-fixed sized boxes that doesn't include black pixels

Perspective (RandomPerspective)

from torchvision.transforms import (Compose, RandomPerspective)

trans = Compose([
    RandomPerspective(p=0.85, distortion_scale=0.2, 
                      interpolation=PIL.Image.BICUBIC)
])

n_augm = 8
dataset = tv.datasets.ImageFolder("tmp/", transform=trans)
%time augm_images = np.hstack([np.asarray(dataset[0][0]) for _ in range(n_augm)])
plt.figure(figsize=(2*n_augm, 2)); plt.imshow(augm_images);
CPU times: user 4.63 s, sys: 38 ms, total: 4.67 s
Wall time: 4.68 s

png

Double doubled (useless) transformations

Shear (RandomAffine)

Shear is a special case of Perspective.

from torchvision.transforms import (Compose, RandomAffine)

trans = Compose([
    RandomAffine(degrees=(0, 0), shear=(-20, 20),
                 resample=False, fillcolor=0)
])

n_augm = 5
dataset = tv.datasets.ImageFolder("tmp/", transform=trans)
%time augm_images = np.hstack([np.asarray(dataset[0][0]) for _ in range(n_augm)])
plt.figure(figsize=(2*n_augm, 2)); plt.imshow(augm_images);
CPU times: user 647 ms, sys: 9.91 ms, total: 656 ms
Wall time: 674 ms

png

Translation (RandomAffine)

Applying RandomResizeCrop lateron, has the same effect as RandomAffine(translation=...) but without black pixels.

from torchvision.transforms import (Compose, RandomAffine)

trans = Compose([
    RandomAffine(degrees=(0, 0), translate=(.1, .1),
                 resample=False, fillcolor=0)
])

n_augm = 5
dataset = tv.datasets.ImageFolder("tmp/", transform=trans)
%time augm_images = np.hstack([np.asarray(dataset[0][0]) for _ in range(n_augm)])
plt.figure(figsize=(2*n_augm, 2)); plt.imshow(augm_images);
CPU times: user 616 ms, sys: 7.03 ms, total: 623 ms
Wall time: 640 ms

png

Scaling (RandomAffine)

Applying RandomResizeCrop(scale=...) lateron, has the same effect as RandomAffine(scale=..., ratio=...) but without black pixels.

from torchvision.transforms import (Compose, RandomAffine)

trans = Compose([
    RandomAffine(degrees=(0, 0), scale=(0.8, 1.0),
                 resample=False, fillcolor=0)
])

n_augm = 5
dataset = tv.datasets.ImageFolder("tmp/", transform=trans)
%time augm_images = np.hstack([np.asarray(dataset[0][0]) for _ in range(n_augm)])
plt.figure(figsize=(2*n_augm, 2)); plt.imshow(augm_images);
CPU times: user 611 ms, sys: 12.9 ms, total: 624 ms
Wall time: 633 ms

png

Cropping

Crop fixed boxes from image (RandomResizedCrop)

from torchvision.transforms import (Compose, RandomResizedCrop)

trans = Compose([
    RandomResizedCrop((224, 224), scale=(.1, 1.0), ratio=(3/4, 4/3))
])

n_augm = 5
dataset = tv.datasets.ImageFolder("tmp/", transform=trans)
%time augm_images = np.hstack([np.asarray(dataset[0][0]) for _ in range(n_augm)])
plt.figure(figsize=(2*n_augm, 2)); plt.imshow(augm_images);
CPU times: user 548 ms, sys: 9.06 ms, total: 558 ms
Wall time: 558 ms

png

Independent from Cropping

The following transformations can be applied after the final crop.

Flip left-right (RandomHorizontalFlip)

from torchvision.transforms import (Compose, RandomHorizontalFlip)

trans = Compose([
    RandomHorizontalFlip(p=0.5)
])

n_augm = 4
dataset = tv.datasets.ImageFolder("tmp/", transform=trans)
%time augm_images = np.hstack([np.asarray(dataset[0][0]) for _ in range(n_augm)])
plt.figure(figsize=(2*n_augm, 2)); plt.imshow(augm_images);
CPU times: user 462 ms, sys: 6.83 ms, total: 469 ms
Wall time: 469 ms

png

Flip upside-down (RandomVerticalFlip)

from torchvision.transforms import (Compose, RandomVerticalFlip)

trans = Compose([
    RandomVerticalFlip(p=0.5)
])

n_augm = 4
dataset = tv.datasets.ImageFolder("tmp/", transform=trans)
%time augm_images = np.hstack([np.asarray(dataset[0][0]) for _ in range(n_augm)])
plt.figure(figsize=(2*n_augm, 2)); plt.imshow(augm_images);
CPU times: user 440 ms, sys: 11.1 ms, total: 451 ms
Wall time: 451 ms

png

Brightness (ColorJitter)

The effect is that bright pixels are burned out, and dark pixels are blackenend out. Better use small values less than 10%

from torchvision.transforms import (Compose, ColorJitter)

trans = Compose([
    ColorJitter(brightness=0.25)  # better use small proba<0.05
])

n_augm = 6
dataset = tv.datasets.ImageFolder("tmp/", transform=trans)
%time augm_images = np.hstack([np.asarray(dataset[0][0]) for _ in range(n_augm)])
plt.figure(figsize=(2*n_augm, 2)); plt.imshow(augm_images);
CPU times: user 1.1 s, sys: 11.7 ms, total: 1.11 s
Wall time: 1.12 s

png

Constrast (ColorJitter)

This remove the photo editor's “look”, e.g. hard contrast and the “anabolic bodybuilder look”, soft contrast and the “cute toddler look”.

from torchvision.transforms import (Compose, ColorJitter)

trans = Compose([
    ColorJitter(contrast=0.25)
])

n_augm = 4
dataset = tv.datasets.ImageFolder("tmp/", transform=trans)
%time augm_images = np.hstack([np.asarray(dataset[0][0]) for _ in range(n_augm)])
plt.figure(figsize=(3*n_augm, 3)); plt.imshow(augm_images);
CPU times: user 801 ms, sys: 9.92 ms, total: 811 ms
Wall time: 811 ms

png

Hue (ColorJitter)

Funky stuff to emulate a broken camera ;)

from torchvision.transforms import (Compose, ColorJitter)

trans = Compose([
    ColorJitter(hue=.25)  # [-0.5, 0.5] is the maximum (better pick something small like 0.01 to 0.1)
])

n_augm = 4
dataset = tv.datasets.ImageFolder("tmp/", transform=trans)
%time augm_images = np.hstack([np.asarray(dataset[0][0]) for _ in range(n_augm)])
plt.figure(figsize=(3*n_augm, 3)); plt.imshow(augm_images);
CPU times: user 2.18 s, sys: 13 ms, total: 2.19 s
Wall time: 2.19 s

png

Saturation (ColorJitter)

The color intensity.

from torchvision.transforms import (Compose, ColorJitter)

trans = Compose([
    ColorJitter(saturation=[0.5, 1.5])  # 0.0=grayscale
])

n_augm = 4
dataset = tv.datasets.ImageFolder("tmp/", transform=trans)
%time augm_images = np.hstack([np.asarray(dataset[0][0]) for _ in range(n_augm)])
plt.figure(figsize=(3*n_augm, 3)); plt.imshow(augm_images);
CPU times: user 811 ms, sys: 7.98 ms, total: 819 ms
Wall time: 819 ms

png

Grayscale (RandomGrayscale)

Remove color information (convert to gray scale) with a certain probability. This transformation might be double doubled if ColorJitter(saturation=[0.0, ...]) is used.

from torchvision.transforms import (Compose, RandomGrayscale)

trans = Compose([
    RandomGrayscale(p=.25)
])

n_augm = 5
dataset = tv.datasets.ImageFolder("tmp/", transform=trans)
%time augm_images = np.hstack([np.asarray(dataset[0][0]) for _ in range(n_augm)])
plt.figure(figsize=(2*n_augm, 2)); plt.imshow(augm_images);
CPU times: user 570 ms, sys: 17 ms, total: 587 ms
Wall time: 587 ms

png

Everything Together

from torchvision.transforms import (
    Compose, RandomRotation, RandomPerspective, RandomResizedCrop,
    RandomHorizontalFlip, RandomVerticalFlip, ColorJitter, RandomGrayscale)

trans = Compose([
    # Before Crop (Black Pixels)
    RandomRotation(degrees=(-12, 12), expand=True, resample=PIL.Image.BILINEAR),
    RandomPerspective(p=0.6, distortion_scale=0.1, interpolation=PIL.Image.BICUBIC),
    # Crop (incl. translate, scaling)
    RandomResizedCrop((224, 224), scale=(.2, .9), ratio=(3/4, 4/3)),
    # Other edits
    RandomHorizontalFlip(p=0.5),
    RandomVerticalFlip(p=0.5),
    ColorJitter(brightness=.025, contrast=.075, hue=0.010, saturation=0.5),
    RandomGrayscale(p=.025)
])
n_augm = 6
dataset = tv.datasets.ImageFolder("tmp/", transform=trans)
for _ in range(10):
    %time augm_images = np.hstack([np.asarray(dataset[0][0]) for _ in range(n_augm)])
    plt.figure(figsize=(2*n_augm, 2)); plt.axis('off'); plt.imshow(augm_images);
Output hidden; open in https://colab.research.google.com to view.