PyTorchの学習処理の関数化によるリファクタリング

概要

PyTorchで学習コードを書く際、エポックごとに実行する訓練と検証で類似の処理を行うことになります。 それぞれの処理を愚直に記述していると、変更を加えにくくなったり、変数名の管理が煩雑になってしまったりすることがあります。

類似の処理を関数化することはプログラミング共通のテクニックですが、PyTorchの学習コードで関数化によるリファクタリングを行う例を紹介します。

使用環境

OS Windows 11
Python 3.9
PyTorch 1.13
pytorch-cuda 11.7
torchvision 0.14

公式のインストール手順に従い、以下の通りcondaを用いてインストールしています。

conda create -n py39 python=3.9
conda activate py39
conda install pytorch torchvision pytorch-cuda=11.7 -c pytorch -c nvidia

リファクタリング前のコード

公式チュートリアルのCIFAR10データセットに対する画像分類タスクを例とします。

上記のチュートリアルを参考に、リファクタリング用のサンプルコードを以下のように作成しました。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# リファクタリング前後で同一の結果が得られていることを確認するためシード値を固定
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def main():
    batch_size = 64
    epochs = 20
    device = torch.device("cuda")

    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    model = Net().to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

    best_accuracy = 0.0

    for epoch in range(epochs):

        # 訓練用の予測処理
        model.train()

        total_loss = 0.0
        preds_list = []
        labels_list = []
        for data_batch, label_batch in train_dataloader:
            data_batch = data_batch.to(device)
            label_batch = label_batch.to(device)

            output = model(data_batch)
            loss = criterion(output, label_batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pred = output.max(dim=1)[1]
            total_loss += loss.item() * data_batch.shape[0]
            labels_list.extend(label_batch.cpu().tolist())
            preds_list.extend(pred.detach().cpu().tolist())

        scheduler.step()

        train_loss = total_loss / len(preds_list)
        preds_list = np.array(preds_list)
        labels_list = np.array(labels_list)
        train_accuracy = np.mean(preds_list == labels_list)
        print(f"Epoch {epoch:3}, Train, loss {train_loss:.5f}, accuracy {train_accuracy:.5f}")

        # 検証用の予測処理
        model.eval()

        total_loss = 0.0
        preds_list = []
        labels_list = []
        with torch.no_grad():
            for data_batch, label_batch in test_dataloader:
                data_batch = data_batch.to(device)
                label_batch = label_batch.to(device)

                output = model(data_batch)
                loss = criterion(output, label_batch)

                pred = output.max(dim=1)[1]
                total_loss += loss.item() * data_batch.shape[0]
                labels_list.extend(label_batch.cpu().tolist())
                preds_list.extend(pred.cpu().tolist())

        test_loss = total_loss / len(preds_list)
        preds_list = np.array(preds_list)
        labels_list = np.array(labels_list)
        test_accuracy = np.mean(preds_list == labels_list)
        print(f"Epoch {epoch:3}, Test , loss {test_loss:.5f}, accuracy {test_accuracy:.5f}")

        if test_accuracy >= best_accuracy:
            torch.save(model.state_dict(), "model.pth")
            best_accuracy = test_accuracy


if __name__ == "__main__":
    main()

標準出力で得られる学習結果は以下の通りです。

Epoch   0, Train, loss 1.86282, accuracy 0.30380
Epoch   0, Test , loss 1.54407, accuracy 0.43110
Epoch   1, Train, loss 1.43182, accuracy 0.48214
Epoch   1, Test , loss 1.33742, accuracy 0.52070

~~~ 中略 ~~~

Epoch  18, Train, loss 0.41104, accuracy 0.86144
Epoch  18, Test , loss 1.22032, accuracy 0.65330
Epoch  19, Train, loss 0.40051, accuracy 0.86732
Epoch  19, Test , loss 1.22197, accuracy 0.65470

学習は問題なくできているのですが、このコードの欠点としては以下の通りです。

  • 訓練と検証で類似の処理が記述されており、別の評価指標を追加するなど変更を加える際に変更箇所が多くなる。
  • 似た名前の変数が多数宣言されており、各変数の意味の把握や管理が煩雑になっている。
  • インデントが深くなっており、どのループ内での処理かが直感的に把握しにくい。

これらを改善するには、繰り返しになっている部分を関数化することで上手く解決できそうです。

リファクタリング後のコード

関数により類似の処理を統一したコードは以下の通りです。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# リファクタリング前後で同一の結果が得られていることを確認するためシード値を固定
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def pred_per_epoch(model, device, dataloader, criterion, optimizer=None, scheduler=None):

    total_loss = 0.0
    preds_list = []
    labels_list = []

    with torch.set_grad_enabled(model.training):
        for data_batch, label_batch in dataloader:
            data_batch = data_batch.to(device)
            label_batch = label_batch.to(device)

            output = model(data_batch)
            loss = criterion(output, label_batch)

            if model.training:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            pred = output.max(dim=1)[1]
            total_loss += loss.item() * data_batch.shape[0]
            labels_list.extend(label_batch.cpu().tolist())
            preds_list.extend(pred.detach().cpu().tolist())

        if model.training and not scheduler is None:
            scheduler.step()

    preds_list = np.array(preds_list)
    labels_list = np.array(labels_list)
    scores = {
        "loss": total_loss / len(preds_list),
        "accuracy": np.mean(preds_list == labels_list),
    }

    return scores


def main():
    batch_size = 64
    epochs = 20
    device = torch.device("cuda")

    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    model = Net().to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

    best_accuracy = 0.0

    for epoch in range(epochs):

        # 訓練用の予測処理
        model.train()
        train_scores = pred_per_epoch(
            model=model,
            device=device,
            dataloader=train_dataloader,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
        )
        print(f"Epoch {epoch:3}, Train, loss {train_scores['loss']:.5f}, accuracy {train_scores['accuracy']:.5f}")

        # 検証用の予測処理
        model.eval()
        test_scores = pred_per_epoch(
            model=model,
            device=device,
            dataloader=test_dataloader,
            criterion=criterion,
        )
        print(f"Epoch {epoch:3}, Test , loss {test_scores['loss']:.5f}, accuracy {test_scores['accuracy']:.5f}")

        if test_scores["accuracy"] >= best_accuracy:
            torch.save(model.state_dict(), "model.pth")
            best_accuracy = test_scores["accuracy"]


if __name__ == "__main__":
    main()

このコードで学習を行ったところ、以下の通りリファクタリング前と同じ結果が得られました。

Epoch   0, Train, loss 1.86282, accuracy 0.30380
Epoch   0, Test , loss 1.54407, accuracy 0.43110
Epoch   1, Train, loss 1.43182, accuracy 0.48214
Epoch   1, Test , loss 1.33742, accuracy 0.52070

~~~ 中略 ~~~

Epoch  18, Train, loss 0.41104, accuracy 0.86144
Epoch  18, Test , loss 1.22032, accuracy 0.65330
Epoch  19, Train, loss 0.40051, accuracy 0.86732
Epoch  19, Test , loss 1.22197, accuracy 0.65470

リファクタリング後のコードでは、pred_per_epoch() 関数に訓練と検証において共通的に利用できるよう統一した処理を記述しています。これにより、main() 関数の見通しが良くなっていたり、宣言される変数が少なくなっており、保守性や拡張性の面で改善されています。

関数化におけるポイントとしては以下の2点です。

model.training プロパティによる訓練時・検証時処理の分岐

model.training プロパティで、model.train()model.eval() で指定されている現在のモードが判定でき、訓練モードで True 、検証モードで False をとります。このプロパティとif文を以下のように組み合わせることで、訓練時と検証時の処理を動的に分岐することができます。

if model.training:
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

これを利用し、 main() 関数内で pred_per_epoch() 関数を呼び出す前に model.train() もしくは model.eval() でモードを切り替え、pred_per_epoch() 関数に model オブジェクトを渡すことでプロパティを介して処理を分岐させています。

# 訓練用の予測処理
model.train()
train_scores = pred_per_epoch(
    model=model,
    device=device,
    dataloader=train_dataloader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
)

torch.set_grad_enabled による自動微分有無の分岐

リファクタリング前のコードでは torch.no_grad() を利用して自動微分を無効化している箇所があります。関数化する場合には、自動微分の有無についても動的に切り替える必要があるため、 torch.set_grad_enabled を利用します。

引数に True を指定した場合に自動微分が有効、False を指定した場合に自動微分が無効となります。コンテキストマネージャを利用することで、ブロック単位で切り替えることができます。ここでも model.training プロパティを利用して以下のように記述することで、自動微分の動的な切り替えができます。

with torch.set_grad_enabled(model.training):
    for data_batch, label_batch in dataloader:
        data_batch = data_batch.to(device)
        label_batch = label_batch.to(device)

        output = model(data_batch)
        loss = criterion(output, label_batch)

        if model.training:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        pred = output.max(dim=1)[1]
        total_loss += loss.item() * data_batch.shape[0]
        labels_list.extend(label_batch.cpu().tolist())
        preds_list.extend(pred.detach().cpu().tolist())

    if model.training and not scheduler is None:
        scheduler.step()

あとがき

リファクタリングによってコードの見通しが良くなり、バグが少なくなったり、チューニングがしやすくなることが期待できます。また、GitHubで公開されている実装を利用する際にも、リファクタリングを行う過程でその実装の処理内容がよく理解できます。こういった際のリファクタリングのパターンは身に着けておきたいところです。