일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
Tags
- 컴퓨터구조
- Pseudo Label
- Entropy Minimization
- tent paper
- conjugate pseudo label paper
- shrinkmatch paper
- adamatch paper
- CGAN
- WGAN
- semi supervised learnin 가정
- CoMatch
- 딥러닝손실함수
- mocov3
- mme paper
- dann paper
- CycleGAN
- remixmatch paper
- SSL
- cifar100-c
- 백준 알고리즘
- 최린컴퓨터구조
- Pix2Pix
- UnderstandingDeepLearning
- GAN
- simclrv2
- BYOL
- shrinkmatch
- dcgan
- ConMatch
- Meta Pseudo Labels
Archives
- Today
- Total
Hello Computer Vision
pretrained 모델에 fc레이어 추가해보기 본문
한번도 모델에 추가해본적 없는데 한번 직접 해보려고 한다.
import torchvision.models as models
import torch.nn as nn
import torch
model = models.resnet18(pretrained = True)
num_features = model.fc.out_features
new_fc = nn.Linear(num_features, 10)
model.fc = nn.Sequential(
model.fc, new_fc
)
print(model.fc)
Sequential(
(0): Linear(in_features=512, out_features=1000, bias=True)
(1): Linear(in_features=1000, out_features=10, bias=True)
)
이렇게 마지막 단에 10개의 클래스 변경한 것을 알 수 있다.
이번에는 여러개의 레이어도 추가해본다.
num_features = model.fc.out_features
fc_layers = [
nn.Linear(num_features, 512),
nn.ReLU(),
nn.Linear(512, 10)
]
model.fc = nn.Sequential(model.fc,
*fc_layers)
print(model.fc)
Sequential(
(0): Linear(in_features=512, out_features=1000, bias=True)
(1): Linear(in_features=1000, out_features=512, bias=True)
(2): ReLU()
(3): Linear(in_features=512, out_features=10, bias=True)
이렇게 따로 추가한 것을 알 수 있다.
만약 layer를 추가하고 추가된 레이어만 gradient가 흐르게 하고 싶으면 다음과 같다.
model = models.resnet18(pretrained = True)
num_features = model.fc.out_features
fc_layers = [
nn.Linear(num_features, 512),
nn.ReLU(),
nn.Linear(512, 10)
]
model.fc = nn.Sequential(model.fc,
*fc_layers)
for param in model.fc[:-1].parameters():
param.requires_grad = False
'딥러닝 > 파이토치' 카테고리의 다른 글
torchvision.dataset 조금 뜯어보기 (1) | 2023.12.21 |
---|---|
torch.cuda.amp GradScaler 공부해보기 (0) | 2023.07.11 |
F.cross entropy 구현 및 설명해보기 (0) | 2023.05.12 |
lr_scheduler CosineAnnealingLR 알아보기 (0) | 2023.05.04 |
argmax의 keepdim = True, False (0) | 2023.05.01 |