Skip to content

7. Optimizing Model Parameters


이제 모델과 데이터가 있으므로 데이터에 대한 매개변수를 최적화하여 모델을 학습, 검증 및 테스트할 차례입니다. 모델 학습은 반복적인 프로세스입니다. 각 반복(에포크라고 함)에서 모델은 출력에 대해 추측하고, 추측의 오류(손실)를 계산하고, 매개변수와 관련하여 오류의 도함수를 수집하고, 경사 하강법을 사용하여 이러한 매개변수를 최적화합니다.

Now that we have a model and data it’s time to train, validate and test our model by optimizing its parameters on our data. Training a model is an iterative process; in each iteration (called an epoch) the model makes a guess about the output, calculates the error in its guess (loss), collects the derivatives of the error with respect to its parameters, and optimizes these parameters using gradient descent.

Prerequisite Code

Datasets & DataLoadersBuild Model에 대한 이전 섹션의 코드를 로드합니다.

We load the code from the previous sections on Datasets & DataLoaders and Build Model.

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/26421880 [00:00<?, ?it/s]
  0%|          | 32768/26421880 [00:00<01:25, 309534.64it/s]
  0%|          | 65536/26421880 [00:00<01:26, 303391.10it/s]
  0%|          | 131072/26421880 [00:00<00:59, 438983.86it/s]
  1%|          | 229376/26421880 [00:00<00:42, 620624.28it/s]
  2%|1         | 491520/26421880 [00:00<00:20, 1260496.68it/s]
  4%|3         | 950272/26421880 [00:00<00:11, 2254055.51it/s]
  7%|7         | 1933312/26421880 [00:00<00:05, 4448960.69it/s]
 15%|#4        | 3833856/26421880 [00:00<00:02, 8552444.02it/s]
 26%|##6       | 6881280/26421880 [00:00<00:01, 14471774.06it/s]
 38%|###7      | 9994240/26421880 [00:01<00:00, 18764882.53it/s]
 50%|####9     | 13139968/26421880 [00:01<00:00, 21723599.14it/s]
 60%|######    | 15958016/26421880 [00:01<00:00, 22857354.38it/s]
 72%|#######2  | 19070976/26421880 [00:01<00:00, 24471353.99it/s]
 84%|########3 | 22151168/26421880 [00:01<00:00, 25595610.49it/s]
 96%|#########5| 25296896/26421880 [00:01<00:00, 26528844.94it/s]
100%|##########| 26421880/26421880 [00:01<00:00, 15991645.11it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/29515 [00:00<?, ?it/s]
100%|##########| 29515/29515 [00:00<00:00, 272639.92it/s]
100%|##########| 29515/29515 [00:00<00:00, 270974.31it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/4422102 [00:00<?, ?it/s]
  1%|          | 32768/4422102 [00:00<00:14, 298809.57it/s]
  1%|1         | 65536/4422102 [00:00<00:14, 297479.91it/s]
  3%|2         | 131072/4422102 [00:00<00:09, 432738.47it/s]
  5%|5         | 229376/4422102 [00:00<00:06, 613473.57it/s]
 11%|#1        | 491520/4422102 [00:00<00:03, 1247686.67it/s]
 21%|##1       | 950272/4422102 [00:00<00:01, 2235778.83it/s]
 44%|####3     | 1933312/4422102 [00:00<00:00, 4412540.76it/s]
 87%|########6 | 3833856/4422102 [00:00<00:00, 8488590.50it/s]
100%|##########| 4422102/4422102 [00:00<00:00, 4989037.54it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/5148 [00:00<?, ?it/s]
100%|##########| 5148/5148 [00:00<00:00, 19452501.79it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Hyperparameters

하이퍼파라미터는 모델 최적화 프로세스를 제어할 수 있는 조정 가능한 매개변수입니다. 서로 다른 하이퍼파라미터 값은 모델 학습 및 수렴 속도에 영향을 줄 수 있습니다.

Hyperparameters are adjustable parameters that let you control the model optimization process. Different hyperparameter values can impact model training and convergence rates.

학습을 위해 다음과 같은 하이퍼파라미터를 정의합니다:

  • 에포크 수 - 데이터 세트를 반복할 횟수

  • 배치 크기 - 매개변수가 업데이트되기 전에 네트워크를 통해 전파되는 데이터 샘플의 수

  • 학습률 - 각 배치/에포크에서 모델 매개변수를 업데이트하는 양. 값이 작을수록 학습 속도가 느려지고 값이 크면 학습 중에 예측할 수 없는 동작이 발생할 수 있습니다.

We define the following hyperparameters for training:

  • Number of Epochs - the number times to iterate over the dataset

  • Batch Size - the number of data samples propagated through the network before the parameters are updated

  • Learning Rate - how much to update models parameters at each batch/epoch. Smaller values yield slow learning speed, while large values may result in unpredictable behavior during training.

learning_rate = 1e-3
batch_size = 64
epochs = 5

Optimization Loop

하이퍼파라미터를 설정하면 최적화 루프로 모델을 학습하고 최적화할 수 있습니다. 최적화 루프의 각 반복을 에포크라고 합니다.

Once we set our hyperparameters, we can then train and optimize our model with an optimization loop. Each iteration of the optimization loop is called an epoch.

각 에포크는 두 가지 주요 부분으로 구성됩니다:

  • 학습 루프 - 학습 데이터 세트를 반복하고 최적의 매개변수로 수렴하려고 시도합니다.

  • 검증/테스트 루프 - 테스트 데이터 세트를 반복하여 모델 성능이 개선되고 있는지 확인합니다.

Each epoch consists of two main parts:

  • The Train Loop - iterate over the training dataset and try to converge to optimal parameters.

  • The Validation/Test Loop - iterate over the test dataset to check if model performance is improving.

학습 루프에서 사용되는 몇 가지 개념에 대해 간략하게 알아보겠습니다.

Let’s briefly familiarize ourselves with some of the concepts used in the training loop.

Loss Function

일부 학습 데이터가 제시될 때 학습되지 않은 네트워크는 정답을 제공하지 않을 가능성이 높습니다. 손실 함수는 획득한 결과가 목표값과 얼마나 다른지를 측정하는 것으로, 우리가 학습 중 최소화하고자 하는 것은 손실 함수입니다. 손실을 계산하기 위해 주어진 데이터 샘플의 입력을 사용하여 예측을 하고 실제 데이터 레이블 값과 비교합니다.

When presented with some training data, our untrained network is likely not to give the correct answer. Loss function measures the degree of dissimilarity of obtained result to the target value, and it is the loss function that we want to minimize during training. To calculate the loss we make a prediction using the inputs of our given data sample and compare it against the true data label value.

일반적인 손실 함수에는 회귀 작업을 위한 nn.MSELoss(Mean Square Error)와 분류를 위한 nn.NLLLoss(음수 로그 우도)가 포함됩니다. nn.CrossEntropyLossnn.LogSoftmaxnn.NLLLoss를 결합합니다.

Common loss functions include nn.MSELoss (Mean Square Error) for regression tasks, and nn.NLLLoss (Negative Log Likelihood) for classification. nn.CrossEntropyLoss combines nn.LogSoftmax and nn.NLLLoss.

모델의 출력 로짓을 nn.CrossEntropyLoss로 전달하면 로짓을 정규화하고 예측 오류를 계산합니다.

We pass our model’s output logits to nn.CrossEntropyLoss, which will normalize the logits and compute the prediction error.

# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()

Optimizer

최적화는 각 학습 단계에서 모델 오류를 줄이기 위해 모델 매개변수를 조정하는 프로세스입니다. 최적화 알고리즘은 이 프로세스가 수행되는 방식을 정의합니다(이 예에서는 확률적 경사 하강법을 사용함). 모든 최적화 논리는 optimizer 객체에 캡슐화됩니다. 여기서는 SGD 옵티마이저를 사용합니다. 또한 ADAM 및 RMSProp과 같은 PyTorch에는 다양한 종류의 모델 및 데이터에 대해 더 잘 작동하는 다양한 옵티마이저가 있습니다.

Optimization is the process of adjusting model parameters to reduce model error in each training step. Optimization algorithms define how this process is performed (in this example we use Stochastic Gradient Descent). All optimization logic is encapsulated in the optimizer object. Here, we use the SGD optimizer; additionally, there are many different optimizers available in PyTorch such as ADAM and RMSProp, that work better for different kinds of models and data.

학습이 필요한 모델의 매개변수를 등록하고 학습률 하이퍼파라미터를 전달하여 옵티마이저를 초기화합니다.

We initialize the optimizer by registering the model’s parameters that need to be trained, and passing in the learning rate hyperparameter.

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

학습 루프 내에서 최적화는 세 단계로 이루어집니다:

  • 모델 매개변수의 그래디언트를 재설정하려면 optimizer.zero_grad()를 호출합니다. 기본적으로 그래디언트는 합산됩니다. 중복 계산을 방지하기 위해 반복할 때마다 명시적으로 0으로 지정합니다.

  • loss.backward()를 호출하여 예측 손실을 역전파합니다. PyTorch는 각 매개변수에 대한 손실의 기울기를 저장합니다.

  • 그래디언트가 있으면 optimizer.step()을 호출하여 역방향 패스에서 수집된 그래디언트로 매개변수를 조정합니다.

Inside the training loop, optimization happens in three steps:

  • Call optimizer.zero_grad() to reset the gradients of model parameters. Gradients by default add up; to prevent double-counting, we explicitly zero them at each iteration.

  • Backpropagate the prediction loss with a call to loss.backward(). PyTorch deposits the gradients of the loss with respect to each parameter.

  • Once we have our gradients, we call optimizer.step() to adjust the parameters by the gradients collected in the backward pass.

Full Implementation

최적화 코드를 반복하는 train_loop와 테스트 데이터에 대해 모델의 성능을 평가하는 test_loop를 정의합니다.

We define train_loop that loops over our optimization code, and test_loop that evaluates the model’s performance against our test data.

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

손실 함수와 옵티마이저를 초기화하고 train_looptest_loop에 전달합니다. 모델의 성능 향상을 추적하려면 에포크 수를 자유롭게 늘리세요.

We initialize the loss function and optimizer, and pass it to train_loop and test_loop. Feel free to increase the number of epochs to track the model’s improving performance.

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")
Epoch 1
-------------------------------
loss: 2.309345  [    0/60000]
loss: 2.296001  [ 6400/60000]
loss: 2.281641  [12800/60000]
loss: 2.278391  [19200/60000]
loss: 2.245569  [25600/60000]
loss: 2.219331  [32000/60000]
loss: 2.224006  [38400/60000]
loss: 2.191759  [44800/60000]
loss: 2.198319  [51200/60000]
loss: 2.151630  [57600/60000]
Test Error:
 Accuracy: 45.6%, Avg loss: 2.155069

Epoch 2
-------------------------------
loss: 2.169956  [    0/60000]
loss: 2.158091  [ 6400/60000]
loss: 2.102398  [12800/60000]
loss: 2.115292  [19200/60000]
loss: 2.062334  [25600/60000]
loss: 1.998322  [32000/60000]
loss: 2.025763  [38400/60000]
loss: 1.949629  [44800/60000]
loss: 1.963059  [51200/60000]
loss: 1.864815  [57600/60000]
Test Error:
 Accuracy: 58.1%, Avg loss: 1.879126

Epoch 3
-------------------------------
loss: 1.919734  [    0/60000]
loss: 1.885475  [ 6400/60000]
loss: 1.768473  [12800/60000]
loss: 1.802812  [19200/60000]
loss: 1.692852  [25600/60000]
loss: 1.642522  [32000/60000]
loss: 1.661180  [38400/60000]
loss: 1.569238  [44800/60000]
loss: 1.596826  [51200/60000]
loss: 1.474900  [57600/60000]
Test Error:
 Accuracy: 63.3%, Avg loss: 1.501845

Epoch 4
-------------------------------
loss: 1.574462  [    0/60000]
loss: 1.535350  [ 6400/60000]
loss: 1.382391  [12800/60000]
loss: 1.450995  [19200/60000]
loss: 1.329817  [25600/60000]
loss: 1.333179  [32000/60000]
loss: 1.345748  [38400/60000]
loss: 1.272695  [44800/60000]
loss: 1.308814  [51200/60000]
loss: 1.208299  [57600/60000]
Test Error:
 Accuracy: 64.2%, Avg loss: 1.230113

Epoch 5
-------------------------------
loss: 1.308189  [    0/60000]
loss: 1.289927  [ 6400/60000]
loss: 1.119291  [12800/60000]
loss: 1.227006  [19200/60000]
loss: 1.098123  [25600/60000]
loss: 1.129893  [32000/60000]
loss: 1.154455  [38400/60000]
loss: 1.089008  [44800/60000]
loss: 1.131896  [51200/60000]
loss: 1.051155  [57600/60000]
Test Error:
 Accuracy: 65.2%, Avg loss: 1.065313

Epoch 6
-------------------------------
loss: 1.131394  [    0/60000]
loss: 1.138076  [ 6400/60000]
loss: 0.952318  [12800/60000]
loss: 1.091227  [19200/60000]
loss: 0.960416  [25600/60000]
loss: 0.996659  [32000/60000]
loss: 1.036613  [38400/60000]
loss: 0.975237  [44800/60000]
loss: 1.019064  [51200/60000]
loss: 0.953284  [57600/60000]
Test Error:
 Accuracy: 66.4%, Avg loss: 0.961228

Epoch 7
-------------------------------
loss: 1.010152  [    0/60000]
loss: 1.041147  [ 6400/60000]
loss: 0.840924  [12800/60000]
loss: 1.002530  [19200/60000]
loss: 0.875920  [25600/60000]
loss: 0.904912  [32000/60000]
loss: 0.959307  [38400/60000]
loss: 0.902858  [44800/60000]
loss: 0.942745  [51200/60000]
loss: 0.887940  [57600/60000]
Test Error:
 Accuracy: 67.6%, Avg loss: 0.891471

Epoch 8
-------------------------------
loss: 0.922746  [    0/60000]
loss: 0.975288  [ 6400/60000]
loss: 0.762463  [12800/60000]
loss: 0.940795  [19200/60000]
loss: 0.820391  [25600/60000]
loss: 0.839055  [32000/60000]
loss: 0.904622  [38400/60000]
loss: 0.854785  [44800/60000]
loss: 0.888364  [51200/60000]
loss: 0.840878  [57600/60000]
Test Error:
 Accuracy: 69.0%, Avg loss: 0.841620

Epoch 9
-------------------------------
loss: 0.856505  [    0/60000]
loss: 0.926620  [ 6400/60000]
loss: 0.703993  [12800/60000]
loss: 0.895273  [19200/60000]
loss: 0.780757  [25600/60000]
loss: 0.789876  [32000/60000]
loss: 0.863077  [38400/60000]
loss: 0.821058  [44800/60000]
loss: 0.847345  [51200/60000]
loss: 0.804765  [57600/60000]
Test Error:
 Accuracy: 70.3%, Avg loss: 0.803810

Epoch 10
-------------------------------
loss: 0.803816  [    0/60000]
loss: 0.887677  [ 6400/60000]
loss: 0.658199  [12800/60000]
loss: 0.860149  [19200/60000]
loss: 0.750218  [25600/60000]
loss: 0.751755  [32000/60000]
loss: 0.829348  [38400/60000]
loss: 0.795724  [44800/60000]
loss: 0.814821  [51200/60000]
loss: 0.775579  [57600/60000]
Test Error:
 Accuracy: 71.7%, Avg loss: 0.773497

Done!

References