Hello Computer Vision

파이토치 모델 구조 살펴보기(summary) 본문

딥러닝/파이토치

파이토치 모델 구조 살펴보기(summary)

지웅쓰 2022. 11. 15. 10:44

이번에 kaggle 대회 참여를 위해 CycleGAN공부를 하던 도중 완성된 모델에 대해

한눈에 알아보려고 summary기능을 사용해보려고 했는데 (기존에는 keras를 사용해서 model.summary() 사용)

파이토치에서는 동일하게 적용되지 않더라고요

 

그래서 찾아본 결과 3가지 방법이 있었는데요.

 

1. 모델 입력하기

mod = Discriminator(3)
mod
Discriminator(
  (initial): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (model): Sequential(
    (0): Block(
      (conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
        (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (1): Block(
      (conv): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (2): Block(
      (conv): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
        (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (3): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
  )
)

이렇게모델만 입력하더라도 구조를 알 수 있습니다.

 

2. torchinfo 사용

torchinfo 라이브러리를 pip install 을 이용해 다운받고 summary 메소드를 사용해주시면 됩니다.

from torchinfo import summary
summary(mod)
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Discriminator                            --
├─Sequential: 1-1                        --
│    └─Conv2d: 2-1                       3,136
│    └─LeakyReLU: 2-2                    --
├─Sequential: 1-2                        --
│    └─Block: 2-3                        --
│    │    └─Sequential: 3-1              131,200
│    └─Block: 2-4                        --
│    │    └─Sequential: 3-2              524,544
│    └─Block: 2-5                        --
│    │    └─Sequential: 3-3              2,097,664
│    └─Conv2d: 2-6                       8,193
=================================================================
Total params: 2,764,737
Trainable params: 2,764,737
Non-trainable params: 0
=================================================================

파라미터 및 그냥 모델을 입력한 것 보다는 편안하게, 더 쉽게 볼 수 있네요.

 

3. torchsummary 이용

마지막은으 torchsummary를 사용하는 것입니다.

from torchsummary import summary as summary_
mod = mod.to(device)
summary_(mod, (3, 256, 256), batch_size = 5)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1          [5, 64, 128, 128]           3,136
         LeakyReLU-2          [5, 64, 128, 128]               0
            Conv2d-3           [5, 128, 64, 64]         131,200
    InstanceNorm2d-4           [5, 128, 64, 64]               0
         LeakyReLU-5           [5, 128, 64, 64]               0
             Block-6           [5, 128, 64, 64]               0
            Conv2d-7           [5, 256, 32, 32]         524,544
    InstanceNorm2d-8           [5, 256, 32, 32]               0
         LeakyReLU-9           [5, 256, 32, 32]               0
            Block-10           [5, 256, 32, 32]               0
           Conv2d-11           [5, 512, 31, 31]       2,097,664
   InstanceNorm2d-12           [5, 512, 31, 31]               0
        LeakyReLU-13           [5, 512, 31, 31]               0
            Block-14           [5, 512, 31, 31]               0
           Conv2d-15             [5, 1, 30, 30]           8,193
================================================================
Total params: 2,764,737
Trainable params: 2,764,737
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.75
Forward/backward pass size (MB): 275.11
Params size (MB): 10.55
Estimated Total Size (MB): 289.41
----------------------------------------------------------------

input값을 따로 설정해서 넣어주는 만큼 레이어를 거쳐 나오는 크기까지 확인할 수 있네요.

 


모델의 구조를 알고 싶을 때 귀찮으면 model을 그냥 입력해도 좋지만 크기까지 정확히 알고 싶다면

torchsummary를 사용하면 될 거 같습니다.