일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
Tags
- CoMatch
- tent paper
- UnderstandingDeepLearning
- ConMatch
- mocov3
- CGAN
- conjugate pseudo label paper
- dann paper
- BYOL
- dcgan
- SSL
- 최린컴퓨터구조
- adamatch paper
- mme paper
- Pseudo Label
- remixmatch paper
- 컴퓨터구조
- shrinkmatch paper
- Pix2Pix
- WGAN
- shrinkmatch
- 백준 알고리즘
- CycleGAN
- 딥러닝손실함수
- GAN
- Meta Pseudo Labels
- cifar100-c
- semi supervised learnin 가정
- Entropy Minimization
- simclrv2
Archives
- Today
- Total
Hello Computer Vision
비전공생의 DCGAN(Deep Convolutional Generative Adversarial Network, 2016)코드 구현 본문
Generative
비전공생의 DCGAN(Deep Convolutional Generative Adversarial Network, 2016)코드 구현
지웅쓰 2022. 11. 4. 21:28저번 DCGAN 논문 리뷰에 이은 코드 실습입니다.
원래 코드는 1주일 전부터 만들었는데 결과가 너무 안좋아서 여러 시도 끝에 그나마 괜찮은 결과
얻어서 지금에야 올립니다..
https://keepgoingrunner.tistory.com/10
전체 코드는 아래 들어가시면 볼 수 있습니다.
https://github.com/JiWoongCho1/Gernerative-Model_Paper_Review/blob/main/DCGAN.ipynb
1. Generator 정의
class Generator(nn.Module):
def __init__(self, params):
super().__init__()
nz = params['nz'] # noise 수, 100
ngf = params['ngf'] # conv filter 수
img_channel = params['img_channel'] # 이미지 채널 수
self.dconv1 = nn.ConvTranspose2d(nz,ngf*8,4, stride=1, padding=0, bias=False)
self.bn1 = nn.BatchNorm2d(ngf*8)
self.dconv2 = nn.ConvTranspose2d(ngf*8,ngf*4, 4, stride=2, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(ngf*4)
self.dconv3 = nn.ConvTranspose2d(ngf*4,ngf*2,4,stride=2,padding=1,bias=False)
self.bn3 = nn.BatchNorm2d(ngf*2)
self.dconv4 = nn.ConvTranspose2d(ngf*2,ngf * 1,4,stride=2,padding=1,bias=False)
self.bn4 = nn.BatchNorm2d(ngf*1)
self.dconv5 = nn.ConvTranspose2d(ngf,img_channel,4,stride=2,padding=1,bias=False)
def forward(self,x):
x = F.relu(self.bn1(self.dconv1(x)))
x = F.relu(self.bn2(self.dconv2(x)))
x = F.relu(self.bn3(self.dconv3(x)))
x = F.relu(self.bn4(self.dconv4(x)))
x = torch.tanh(self.dconv5(x)) #마지막 tanh
return x
기존 GAN과 구조는 크게 다른 바 없습니다.
단지 64x64 의 크기를 위해 ConvTranspose2d를 5번 반복하였습니다.
2. Discriminator 정의
class Discriminator(nn.Module):
def __init__(self,params):
super().__init__()
img_channel = params['img_channel'] # 3
ndf = params['ndf'] # 64
self.conv1 = nn.Conv2d(img_channel,ndf,4,stride=2,padding=1,bias=False)
self.conv2 = nn.Conv2d(ndf,ndf*2,4,stride=2,padding=1,bias=False)
self.bn2 = nn.BatchNorm2d(ndf*2)
self.conv3 = nn.Conv2d(ndf*2,ndf*4,4,stride=2,padding=1,bias=False)
self.bn3 = nn.BatchNorm2d(ndf*4)
self.conv4 = nn.Conv2d(ndf*4,ndf*8,4,stride=2,padding=1,bias=False)
self.bn4 = nn.BatchNorm2d(ndf*8)
self.conv5 = nn.Conv2d(ndf*8,1,4,stride=1,padding=0,bias=False)
def forward(self,x):
x = F.leaky_relu(self.conv1(x),0.2)
x = F.leaky_relu(self.bn2(self.conv2(x)),0.2)
x = F.leaky_relu(self.bn3(self.conv3(x)),0.2)
x = F.leaky_relu(self.bn4(self.conv4(x)),0.2)
x = torch.sigmoid(self.conv5(x))
return x.view(-1,1)
활성화함수는 LeakyReLU를 사용하였습니다.
배치 정규화를 사용한 것도 기존 GAN과는 다른 점입니다.
3. 가중치 초기화
def initialize_weights(model): #가중치 초기화 함수
classname = model.__class__.__name__
if classname.find('Conv') != -1: #ConvTrans까지 가중치초기화
nn.init.normal_(model.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(model.weight.data, 1.0, 0.02)
nn.init.constant_(model.bias.data, 0)
generator.apply(initialize_weights);
discriminator.apply(initialize_weights);
기존 CNN 에서는 레이어를 만들 때부터 어느 정도 가중치가 정규분포를 따르게 초기화가 됩니다.
아마 지금 따로 해주는 이 파라미터들로 하면 학습이 더 잘됐다는 것이겠죠?
4. 모델 학습
generator.train()
discriminator.train()
num_epochs=100
start_time = time.time()
nz = params['nz'] # 노이즈 크기 100
loss_hist = {'dis':[],
'gen':[]}
for epoch in range(num_epochs):
for xb, yb in train_loader:
ba_si = xb.shape[0]
xb = xb.to(device)
yb_real = torch.Tensor(ba_si,1).fill_(1.0).to(device)
yb_fake = torch.Tensor(ba_si,1).fill_(0.0).to(device)
# generator
generator.zero_grad()
z = torch.randn(ba_si,nz,1,1).to(device) # 노이즈생성
out_gen = generator(z) # 가짜 이미지 생성
out_dis = discriminator(out_gen) # 가짜 이미지 식별
g_loss = loss_func(out_dis,yb_real)
g_loss.backward()
optim_gen.step()
# discriminator
discriminator.zero_grad()
out_dis = discriminator(xb) # 진짜 이미지 식별
loss_real = loss_func(out_dis,yb_real)
out_dis = discriminator(out_gen.detach()) #discriminator 훈련 중이므로 generator영향없게 detach
loss_fake = loss_func(out_dis,yb_fake)
d_loss = (loss_real + loss_fake) / 2
d_loss.backward()
optim_dis.step()
loss_hist['gen'].append(g_loss.item())
loss_hist['dis'].append(d_loss.item())
print('Epoch: %.0f, G_Loss: %.6f, D_Loss: %.6f, time: %.2f min' %(epoch, g_loss.item(), d_loss.item(), (time.time()-start_time)/60))
원래 모든 코드를 혼자 짰는데 훈련 시간에서 너무 오래 걸리더라고요..
그래서 다른 분들의 코드를 참고해서 싹다 갈아엎었습니다.
그리고 보통 Discriminator가 학습이 더 잘되는 편이라 항상 생각했는데
이번에는 Generator가 학습이 더 잘되었습니다.(결과는 그렇게 좋지는 않은게 함정입니다)
학습 결과
막 엄청 훌륭한 모습은 아닙니다.
저는 epoch 100회를 3번 돌린 결과물인데 더 많은 epoch으로 학습하면 더 좋은 결과가 있을 거 같습니다.
궁금하신 점이나 틀린 점 지적해주시면 감사하겠습니다.