endaaman.com

2021-10-12

Tips

PyTorchでEfficientDetを訓練して使う

effdetを自前データで訓練する手順を解説します

tl;dr

rwightman/efficientdet-pytorch: A PyTorch impl of EfficientDet faithful to the original Google impl w/ ported weights を使う最低限の訓練・評価コードの例を作った。

全体はendaaman/detector-example に掲載している。

はじめに

実はobject detectionタスクは、比較対象となりやすいsegmentaionに比べ、次元の調整やlossの解釈の面で煩雑になやすく、実装が難しいのはあまり知られていない。 本記事では「プロジェクトcloneしてきてtrain.pyにパラメーターを渡したらよくわからないけど訓練できた」というレベルを超えて、EfficientDetをライブラリとして導入して自前のプロジェクトに組み込む手順を解説する。

EfficientDetとは

2019年に発表され2020にCVPRに採択された、EfficientNetをバックボーンに持つ矩形検出モデルである。パラメーター量のレベルごとにd0からd7まで存在し、レベルが上がるに従って層の深さやパラメーターの量が増える。d7はMS COCOのsotaを主張している。最近ではYOLO v4やv5はこれに匹敵もしくは上回るとされているが、論文やネーミングライツの関係で近寄りがたい雰囲気があり、事足りるならこれで済ませたい。

PyTorchに関してはrwightman/efficientdet-pytorch: A PyTorch impl of EfficientDet faithful to the original Google impl w/ ported weightsというプロジェクトがあり、親切なことにPyPIにも登録されているので、これを使う手順など解説する。

ポイント

  • 訓練に入力するbboxの形式はyxyx(単位はピクセル)。下図のPASCAL VOCのxyxyのインデックスをyxyxに入れ替えたような形式で入力する
  • 訓練にはDetBenchTrain、推論にはDetBenchPredictというヘルパーモジュールがそれぞれあり、今回はこれらを使う
  • DetBenchTrainloss = bench(inputs, targets)という形式で、targets['bbox']にyxyx形式のbboxを(N, 4, )で入れ、targets['cls']に対応するラベルを(N, )で入れる
  • 画像ごとにbboxの数が違う場合は、余った要素にはダミーのbboxで埋める
  • 各bboxのclass idは0のとき訓練対象から無視される(有効なclass idはid >= 1

bbox_formats.jpg

Bounding boxes augmentation for object detection - Albumentations Documentationより

学習の概要

円の書かれた画像を自動で生成し、その位置を矩形で識別するモデルを構築する。

元画像 対応するラベルを描画したもの
example_x.png] example_xy.png

データセット

__get__()毎に画像を生成する。

import torch
import numpy as np
from PIL import Image, ImageDraw
from torch.utils.data import Dataset
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

np.random.seed(42)


class CircleDataset(Dataset):
    def __init__(self, use_yxyx=True, image_size=512, bg=(0, 0, 0), fg=(255, 0, 0), normalized=True):
        self.use_yxyx = use_yxyx
        self.bg = bg
        self.fg = fg
        self.image_size = image_size

        # 適当なaugmentaion
        self.albu = A.Compose([
            A.RandomResizedCrop(width=self.image_size, height=self.image_size, scale=[0.8, 1.0]),
            A.GaussNoise(p=0.2),
            A.OneOf([
                A.MotionBlur(p=.2),
                A.MedianBlur(blur_limit=3, p=0.1),
                A.Blur(blur_limit=3, p=0.1),
            ], p=0.2),
            A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=5, p=0.5),
            A.OneOf([
                A.CLAHE(clip_limit=2),
                A.Emboss(),
                A.RandomBrightnessContrast(),
            ], p=0.3),
            A.HueSaturationValue(p=0.3),
            # 可視化するとき正規化されるとnoisyなのでトグれるようにする
            A.Normalize(mean=[0.2, 0.1, 0.1], std=[0.2, 0.1, 0.1]) if normalized else None,
            ToTensorV2(),
        ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))

    def __len__(self):
        return 1000 # 1epochあたりの枚数。自動生成なので適当

    def __getitem__(self, idx):
        img = Image.new('RGB', (512, 512), self.bg)

        size = np.random.randint(0, 256)
        left = np.random.randint(0, 256)
        top = np.random.randint(0, 256)

        right = left + size
        bottom = top + size
        draw = ImageDraw.Draw(img)
        draw.ellipse((left, top, right, bottom), fill=self.fg)

        # shapeはbox_count x box_coords (N x 4)。円は常に一つなので、今回は画像一枚に対して(1 x 4)
        bboxes = np.array([
            # albumentationsにはASCAL VOC形式の[x0, y0, x1, y1]をピクセル単位で入力する
            [left, top, right, bottom,],
        ])

        labels = np.array([
            # 検出対象はid>=1である必要あり。0はラベルなしとして無視される。
            1,
        ])

        result = self.albu(
            image=np.array(img),
            bboxes=bboxes,
            labels=labels,
        )
        x = result['image']
        bboxes = torch.FloatTensor(result['bboxes'])
        labels = torch.FloatTensor(result['labels'])

        # albumentationsのrandom cropでbboxが範囲外に出るとラベルのサイズがなくなるのでゼロ埋めしておく
        # 複数のbboxを扱う場合は、足りない要素数分emptyなbboxとclsで補う処理が必要
        if bboxes.shape[0] == 0:
            bboxes = torch.zeros([1, 4], dtype=bboxes.dtype)

        # effdetはデフォルトではyxyxで受け取るので、インデックスを入れ替える
        if self.use_yxyx:
            bboxes = bboxes[:, [1, 0, 3, 2]]

        # effdetのtargetは以下の形式
        y = {
            'bbox': bboxes,
            'cls': labels,
        }
        return x, y

if __name__ == '__main__':
    # draw_bounding_boxesはxyxy形式
    ds = CircleDataset(use_yxyx=False, normalized=False)
    for (x, y) in ds:
        to_pil_image(x).save(f'example_x.png')
        t = draw_bounding_boxes(image=x, boxes=y['bbox'], labels=[str(v.item()) for v in y['cls']])
        img = to_pil_image(t)
        img.save(f'example_xy.png')
        break

ポイントは2つ合って、yxyxのままではalbumentationsに入力できないので、内部では一貫してPASCAL VOC形式で扱って、モデルに食わせるときだけyxyxに変換しているところ。もう一つはbboxがalbumentationsのリサイズ/移動の過程で外にでてしまったとき、bboxの数が0になってしまい、tensorをstackできなくなってしまうので、ダミーのbboxで埋めているところ。

augmentaionでランダムな処理が入ると次元ずれもランダムに発生するのでバグ潰しが少し面倒くさい。レアケースであるほど学習が進行したタイミングで発生するので、より箇所が掴みづらい。ぶっちゃけ機械学習で一番むずかしい部分だと思う。

訓練

import os
import re
import argparse

from tqdm import tqdm
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
from effdet import EfficientDet, DetBenchTrain, get_efficientdet_config

from datasets import CircleDataset

parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--cpu', action='store_true')
parser.add_argument('-b', '--batch-size', type=int, default=24)
parser.add_argument('--workers', type=int, default=os.cpu_count()//2)
parser.add_argument('-n', '--network', default='d0', type=str, choices=[f'd{i}' for i in range(8)])
parser.add_argument('-e', '--epoch', type=int, default=50)
parser.add_argument('--lr', type=float, default=0.01)
args = parser.parse_args()
use_gpu = not args.cpu and torch.cuda.is_available()
device = torch.device('cuda' if use_gpu else 'cpu')

# effdetにはyxyxでわたす
dataset = CircleDataset(use_yxyx=True)
loader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    num_workers=args.workers,
)

cfg = get_efficientdet_config(f'tf_efficientdet_{args.network}')
# 識別する対象は一種類
cfg.num_classes = 1
model = EfficientDet(cfg)
bench = DetBenchTrain(model).to(device) # lossも組み込んである便利モジュール

optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, verbose=True)

print('Starting training')
for epoch in range(1, args.epoch + 1):
    header = f'[{epoch}/{args.epoch}] '

    lr = optimizer.param_groups[0]['lr']
    print(f'{header}Starting lr={lr:.7f}')

    metrics = {
        'loss': [],
    }
    t = tqdm(loader, leave=False)
    for (inputs, targets) in t:
        inputs = inputs.to(device)
        targets['bbox'] = targets['bbox'].to(device)
        targets['cls'] = targets['cls'].to(device)
        optimizer.zero_grad()
        losses = bench(inputs, targets)
        loss = losses['loss']
        loss.backward()
        optimizer.step()
        iter_metrics = {
            'loss': float(loss.item()),
        }
        message = ' '.join([f'{k}:{v:.4f}' for k, v in iter_metrics.items()])
        t.set_description(f'{header}{message}')
        t.refresh()
        for k, v in iter_metrics.items():
            metrics[k].append(v)
    train_metrics = {k: np.mean(v) for k, v in metrics.items()}
    train_message = ' '.join([f'{k}:{v:.4f}' for k, v in train_metrics.items()])
    print(f'{header}Train: {train_message}')

    # 10エポックごとに保存
    if epoch % 10 == 0:
        state = {
            'epoch': epoch,
            'args': args,
             # multi GPUは考慮しない
            'state_dict': model.state_dict(),
        }
        checkpoint_dir = f'weights/{self.args.network}'
        os.makedirs(checkpoint_dir, exist_ok=True)
        # weights/d1/20.pth みたいな形式で保存
        checkpoint_path = os.path.join(checkpoint_dir, f'{epoch}.pth')
        torch.save(state, checkpoint_path)
        print(f'{header}Saved "{checkpoint_path}"')

    scheduler.step(train_metrics['loss'])
    print()

訓練には python train.py -b 4 -n d1 という感じ。レベルなどはお好みで。

weights/d1/20.pth という形式でモデルの重みを保存する。

評価

import os
import re
import argparse

from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
import numpy as np
import torch
from torch import optim
from torchvision import transforms
from torch.utils.data import DataLoader
from effdet import EfficientDet, DetBenchPredict, get_efficientdet_config

from datasets import CircleDataset

parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('-c', '--checkpoint', type=str)
parser.add_argument('-s', '--src', type=str, required=True)
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if args.checkpoint:
    checkpoint = torch.load(args.checkpoint)
    network = checkpoint['args'].network
else:
    print('using default weights')
    network = 'd0'

# モデル準備
cfg = get_efficientdet_config(f'tf_efficientdet_{network}')
cfg.num_classes = 1
# cfg.soft_nms = True
model = EfficientDet(cfg).eval()
bench = DetBenchPredict(model).to(device)

# 入力データ準備
img = Image.open(args.src)
original_size = (img.width, img.height)
transform = transforms.Compose([
    transforms.ToTensor(),
])
input_tensor = transform(img.resize(cfg.image_size))
# 先頭にバッチのインデックスをつける
input_tensor = input_tensor[None, :] # CHW -> BCHW

# モデルに入力
with torch.no_grad():
    output_tensor = bench(input_tensor.to(device))

# 出力形式は [[x0, y0, x1, y1, confidence, cls]] となっている
# DetBenchPredictが内部でnmsなのよしなに済ませてくれてくれる
output_tensor = output_tensor.detach().cpu().type(torch.long)

# 一枚だけ入力しているので最初だけ取得
bboxes = output_tensor[0]


# フォントは各自適当なものを使うこと
font = ImageFont.truetype('/usr/share/fonts/ubuntu/Ubuntu-R.ttf', size=16)

draw = ImageDraw.Draw(img)
# 
scale = np.array(original_size) / np.array(cfg.image_size) # [w, h]
# [x0, y0, x1, y1] に掛けやすい形に変形
scale = np.tile(scale, 2) # [w, h, w, h]
print(bboxes.shape)
for bbox in bboxes:
    label = bbox[5].item()
    # 元の画像サイズにスケールし直して四捨五入
    bbox = np.rint(bbox[:4].numpy() * scale).astype(np.int64)
    draw.text((bbox[0], bbox[1]), f'{label}', font=font, fill='yellow')
    draw.rectangle(((bbox[0], bbox[1]), (bbox[2], bbox[3])), outline='yellow', width=1)

img.save('out.png')
print('wrote out.png')

実行は python predict.py -s example.png -c checkpoints/d1/20.pth という感じ。

バッチとしてstackすれば複数枚一気に処理できるが、コードが長くなりexampleとして見通しが悪くなるので端折った。

まとめ

ダミーデータで訓練・評価まで行う最低限のコードを示した。pipで導入できてなおかつ精度も良く、さらにパラメータ量をある程度調節できてとても便利なのでみなさんも使ってみましょう。


©2024 endaaman.com