Skip to content

8. Save and Load the Model


이 섹션에서는 모델 예측을 저장, 로드 및 실행하면서 모델 상태를 유지하는 방법을 살펴보겠습니다.

In this section we will look at how to persist model state with saving, loading and running model predictions.

import torch
import torchvision.models as models

Saving and Loading Model Weights

PyTorch 모델은 state_dict라는 내부 상태 사전에 학습된 매개변수를 저장합니다. 이는 torch.save 메서드를 통해 유지될 수 있습니다:

PyTorch models store the learned parameters in an internal state dictionary, called state_dict. These can be persisted via the torch.save method:

model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')
/opt/conda/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning:

The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.

/opt/conda/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning:

Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /var/lib/jenkins/.cache/torch/hub/checkpoints/vgg16-397923af.pth

  0%|          | 0.00/528M [00:00<?, ?B/s]
  2%|1         | 10.1M/528M [00:00<00:05, 106MB/s]
  6%|5         | 31.0M/528M [00:00<00:03, 173MB/s]
 10%|9         | 52.3M/528M [00:00<00:02, 196MB/s]
 14%|#3        | 73.8M/528M [00:00<00:02, 207MB/s]
 18%|#8        | 95.3M/528M [00:00<00:02, 214MB/s]
 22%|##1       | 116M/528M [00:00<00:02, 214MB/s]
 26%|##5       | 137M/528M [00:00<00:01, 217MB/s]
 30%|##9       | 158M/528M [00:00<00:01, 218MB/s]
 34%|###4      | 180M/528M [00:00<00:01, 220MB/s]
 38%|###8      | 201M/528M [00:01<00:01, 223MB/s]
 42%|####2     | 223M/528M [00:01<00:01, 225MB/s]
 46%|####6     | 245M/528M [00:01<00:01, 226MB/s]
 51%|#####     | 267M/528M [00:01<00:01, 226MB/s]
 55%|#####4    | 288M/528M [00:01<00:01, 225MB/s]
 59%|#####8    | 311M/528M [00:01<00:01, 227MB/s]
 63%|######2   | 332M/528M [00:01<00:00, 227MB/s]
 67%|######7   | 355M/528M [00:01<00:00, 229MB/s]
 71%|#######1  | 377M/528M [00:01<00:00, 230MB/s]
 76%|#######5  | 399M/528M [00:01<00:00, 230MB/s]
 80%|#######9  | 421M/528M [00:02<00:00, 233MB/s]
 84%|########4 | 444M/528M [00:02<00:00, 232MB/s]
 88%|########8 | 466M/528M [00:02<00:00, 229MB/s]
 92%|#########2| 488M/528M [00:02<00:00, 230MB/s]
 97%|#########6| 510M/528M [00:02<00:00, 228MB/s]
100%|##########| 528M/528M [00:02<00:00, 222MB/s]

모델 가중치를 불러오려면 먼저 동일한 모델의 인스턴스를 생성한 다음 load_state_dict() 메서드를 사용하여 매개변수를 불러와야 합니다.

To load model weights, you need to create an instance of the same model first, and then load the parameters using load_state_dict() method.

model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

Note

드롭아웃 및 배치 정규화 레이어를 평가 모드로 설정하기 위해 추론하기 전에 model.eval() 메서드를 호출해야 합니다. 이렇게 하지 않으면 일관되지 않은 추론 결과가 생성됩니다.

be sure to call model.eval() method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results.

Saving and Loading Models with Shapes

모델 가중치를 로드할 때 클래스가 네트워크의 구조를 정의하기 때문에 먼저 모델 클래스를 인스턴스화해야 했습니다. 이 클래스의 구조를 모델과 함께 저장하고 싶을 수 있습니다. 이 경우 model(model.state_dict()가 아님)을 저장 함수에 전달할 수 있습니다:

When loading model weights, we needed to instantiate the model class first, because the class defines the structure of a network. We might want to save the structure of this class together with the model, in which case we can pass model (and not model.state_dict()) to the saving function:

torch.save(model, 'model.pth')

그런 다음 다음과 같이 모델을 로드할 수 있습니다:

We can then load the model like this:

model = torch.load('model.pth')

Note

이 접근 방식은 모델을 직렬화할 때 Python pickle 모듈을 사용하므로 모델을 로드할 때 사용할 수 있는 실제 클래스 정의에 의존합니다.

This approach uses Python pickle module when serializing the model, thus it relies on the actual class definition to be available when loading the model.


References