PyTorchでのパラメータチューニングにおけるargparseの使用例

概要

PyTorchで学習パラメータを柔軟に設定するために argparse を使ってコマンドライン引数が処理されることがよくあります。私自身もよく argparse を使ってパラメータを扱うことがあるので、引数の設定方法や処理方法の例をまとめておきたいと思います。

argparse を使ったサンプルコード

以下で作成したCIFAR10の画像分類のコードに argparse による引数処理を追加します。

以下のようなコードになりました。

import argparse
import datetime
import pathlib
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def pred_per_epoch(model, device, dataloader, criterion, optimizer=None, scheduler=None):

    total_loss = 0.0
    preds_list = []
    labels_list = []

    with torch.set_grad_enabled(model.training):
        for data_batch, label_batch in dataloader:
            data_batch = data_batch.to(device)
            label_batch = label_batch.to(device)

            output = model(data_batch)
            loss = criterion(output, label_batch)

            if model.training:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            pred = output.max(dim=1)[1]
            total_loss += loss.item() * data_batch.shape[0]
            labels_list.extend(label_batch.cpu().tolist())
            preds_list.extend(pred.detach().cpu().tolist())

        if model.training and not scheduler is None:
            scheduler.step()

    preds_list = np.array(preds_list)
    labels_list = np.array(labels_list)
    scores = {
        "loss": total_loss / len(preds_list),
        "accuracy": np.mean(preds_list == labels_list),
    }

    return scores


def main(args=None):

    parser = argparse.ArgumentParser()
    parser.add_argument("--name", default="train")
    parser.add_argument("--cpu", action="store_true")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--num_workers", type=int, default=0)
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--pretrained_model", type=pathlib.Path, default="")
    args = parser.parse_args(args)

    fix_seed(args.seed)

    logdir = pathlib.Path("logs", f"{datetime.datetime.now():%Y%m%d-%H%M%S}_{args.name}")
    logdir.mkdir(parents=True)

    if torch.cuda.is_available() and not args.cpu:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True,
    )

    test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=False,
    )

    model = Net().to(device)

    if args.pretrained_model.is_file():
        model.load_state_dict(torch.load(args.pretrained_model))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)

    best_accuracy = 0.0

    for epoch in range(args.epochs):

        # 訓練用の予測処理
        model.train()
        train_scores = pred_per_epoch(
            model=model,
            device=device,
            dataloader=train_dataloader,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
        )
        print(f"Epoch {epoch:3}, Train, loss {train_scores['loss']:.5f}, accuracy {train_scores['accuracy']:.5f}")

        # 検証用の予測処理
        model.eval()
        test_scores = pred_per_epoch(
            model=model,
            device=device,
            dataloader=test_dataloader,
            criterion=criterion,
        )
        print(f"Epoch {epoch:3}, Test , loss {test_scores['loss']:.5f}, accuracy {test_scores['accuracy']:.5f}")

        if test_scores["accuracy"] >= best_accuracy:
            torch.save(model.state_dict(), logdir.joinpath("model.pth"))
            best_accuracy = test_scores["accuracy"]


if __name__ == "__main__":
    main()

このように実装することで、以下のようにコマンドラインから各引数を指定することができます。

python main.py --name seed01_batch32 --seed 1 --batch_size 32

また、このコードでは メイン関数 main(args=None) や引数解析 parser.parse_args(args) の引数に args を指定しています。このようにしておくことで、メイン関数にリストで渡した引数を解析することもできるようになります。これを利用して以下のように main() の呼び出しを複数記述することで、別途シェルスクリプト等を用意することなくバッチジョブを組むこともできます。

if __name__ == "__main__":
    main(["--name", "seed_00", "--seed", "0"])
    main(["--name", "seed_01", "--seed", "1"])
    main(["--name", "seed_02", "--seed", "2"])

設定した各引数の parser.add_argument() でのオプションの指定方法や引数の処理方法については以下にて解説します。

各パラメータの扱い方

学習ケース名 --name

学習ケースの名前を設定するため、 --name 引数を以下のように設定しています。

parser.add_argument("--name", default="train")

取得した引数は torch.save() で重みファイルを出力する際のフォルダ名に使用します。以下のように ./logs 配下に [日時]_[学習ケース名] のフォルダを作成し、学習ケース名が同じでも実行ごとに異なるフォルダに保存されるように設定しています。

logdir = pathlib.Path("logs", f"{datetime.datetime.now():%Y%m%d-%H%M%S}_{args.name}")
logdir.mkdir(parents=True)

このサンプルコードではファイル出力としては重みファイルしか行っていませんが、logging.FileHandler()torch.utils.tensorboard.writer.SummaryWriter() によるログのファイル出力を行う際にもここで作成したフォルダを使っています。

CPUによる実行指定 --cpu

CPUによる実行を指定するため、 --cpu 引数を以下のように設定しています。

parser.add_argument("--cpu", action="store_true")

action="store_true" を指定することで、デフォルトでは False となり、--cpu を指定した際に True が設定されます。

この引数と torch.cuda.is_available() に基づいて、各テンソルの転送先となるデバイスを以下のように指定しています。

if torch.cuda.is_available() and not args.cpu:
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

シード値の指定 --seed

実行時のシード値を指定するため、 --seed 引数を以下のように設定しています。

parser.add_argument("--seed", type=int, default=0)

ここで取得した引数は fix_seed() 関数に渡し、各種乱数のシード値を固定しています。

def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

学習パラメータの指定 --batch_size, --num_workers, --epochs, --lr

各種学習パラメータを指定するための引数を以下のように設定しています。

parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--lr", type=float, default=0.01)

これらの引数はそのままデータローダやオプティマイザのパラメータの指定に使用しています。

事前学習の重みファイルの指定 --pretrained_model

事前学習の重みを使用する際のファイル指定のため、--pretrained_model 引数を以下のように設定しています。

parser.add_argument("--pretrained_model", type=pathlib.Path, default="")

type=pathlib.Path を指定することで、pathlib のパスオブジェクトとして引数を解析します。こうしておくことで、取得した引数に対して is_file() メソッドを使ってスマートにファイルの存在チェックができます。

このサンプルコードでは以下のようにファイルの存在有無を確認して、load_state_dict() によりモデルの重みを初期化しています。

if args.pretrained_model.is_file():
    model.load_state_dict(torch.load(args.pretrained_model))

参考文献