Hello Computer Vision

비전공생의 DCGAN(Deep Convolutional Generative Adversarial Network, 2016)코드 구현 본문

Generative

비전공생의 DCGAN(Deep Convolutional Generative Adversarial Network, 2016)코드 구현

지웅쓰 2022. 11. 4. 21:28

저번 DCGAN 논문 리뷰에 이은 코드 실습입니다. 

원래 코드는 1주일 전부터 만들었는데 결과가 너무 안좋아서 여러 시도 끝에 그나마 괜찮은 결과

얻어서 지금에야 올립니다..

https://keepgoingrunner.tistory.com/10

 

비전공생의 DCGAN(Deep Convolutional Generative Adversarial Networks, 2016) 논문 리뷰

지난번 GAN에 이은 두번째 논문 리뷰이다. DCGAN을 코드로 구현해본 적 있는데 GAN이 MLP로 이미지를 생성했다면 DCGAN은 CNN을 이용해 이미지를 생성한 것이 가장 큰 특징이다. GAN 논문을 읽을 때도 익

keepgoingrunner.tistory.com

 

전체 코드는 아래 들어가시면 볼 수 있습니다.

https://github.com/JiWoongCho1/Gernerative-Model_Paper_Review/blob/main/DCGAN.ipynb

 

GitHub - JiWoongCho1/Gernerative-Model_Paper_Review: 생성모델 논문 review 및 코드 구현입니다.

생성모델 논문 review 및 코드 구현입니다. Contribute to JiWoongCho1/Gernerative-Model_Paper_Review development by creating an account on GitHub.

github.com

 

1. Generator 정의

class Generator(nn.Module):
    def __init__(self, params):
        super().__init__()
        nz = params['nz'] # noise 수, 100
        ngf = params['ngf'] # conv filter 수
        img_channel = params['img_channel'] # 이미지 채널 수

        self.dconv1 = nn.ConvTranspose2d(nz,ngf*8,4, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(ngf*8)
        self.dconv2 = nn.ConvTranspose2d(ngf*8,ngf*4, 4, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(ngf*4)
        self.dconv3 = nn.ConvTranspose2d(ngf*4,ngf*2,4,stride=2,padding=1,bias=False)
        self.bn3 = nn.BatchNorm2d(ngf*2)
        self.dconv4 = nn.ConvTranspose2d(ngf*2,ngf * 1,4,stride=2,padding=1,bias=False)
        self.bn4 = nn.BatchNorm2d(ngf*1)
        self.dconv5 = nn.ConvTranspose2d(ngf,img_channel,4,stride=2,padding=1,bias=False)


    def forward(self,x):
        x = F.relu(self.bn1(self.dconv1(x)))
        x = F.relu(self.bn2(self.dconv2(x)))
        x = F.relu(self.bn3(self.dconv3(x)))
        x = F.relu(self.bn4(self.dconv4(x)))
        x = torch.tanh(self.dconv5(x)) #마지막 tanh
        return x

기존 GAN과 구조는 크게 다른 바 없습니다.

단지 64x64 의 크기를 위해 ConvTranspose2d를 5번 반복하였습니다.

2. Discriminator 정의

class Discriminator(nn.Module):
    def __init__(self,params):
        super().__init__()
        img_channel = params['img_channel'] # 3
        ndf = params['ndf'] # 64

        self.conv1 = nn.Conv2d(img_channel,ndf,4,stride=2,padding=1,bias=False)
        self.conv2 = nn.Conv2d(ndf,ndf*2,4,stride=2,padding=1,bias=False)
        self.bn2 = nn.BatchNorm2d(ndf*2)
        self.conv3 = nn.Conv2d(ndf*2,ndf*4,4,stride=2,padding=1,bias=False)
        self.bn3 = nn.BatchNorm2d(ndf*4)
        self.conv4 = nn.Conv2d(ndf*4,ndf*8,4,stride=2,padding=1,bias=False)
        self.bn4 = nn.BatchNorm2d(ndf*8)
        self.conv5 = nn.Conv2d(ndf*8,1,4,stride=1,padding=0,bias=False)


    def forward(self,x):
        x = F.leaky_relu(self.conv1(x),0.2)
        x = F.leaky_relu(self.bn2(self.conv2(x)),0.2)
        x = F.leaky_relu(self.bn3(self.conv3(x)),0.2)
        x = F.leaky_relu(self.bn4(self.conv4(x)),0.2)
        x = torch.sigmoid(self.conv5(x))
        return x.view(-1,1)

활성화함수는 LeakyReLU를 사용하였습니다.

배치 정규화를 사용한 것도 기존 GAN과는 다른 점입니다.

 

3. 가중치 초기화

def initialize_weights(model): #가중치 초기화 함수
    classname = model.__class__.__name__
    if classname.find('Conv') != -1: #ConvTrans까지 가중치초기화
        nn.init.normal_(model.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(model.weight.data, 1.0, 0.02)
        nn.init.constant_(model.bias.data, 0)


generator.apply(initialize_weights);
discriminator.apply(initialize_weights);

기존 CNN 에서는 레이어를 만들 때부터 어느 정도 가중치가 정규분포를 따르게 초기화가 됩니다.

아마 지금 따로 해주는 이 파라미터들로 하면 학습이 더 잘됐다는 것이겠죠?

4. 모델 학습

generator.train()
discriminator.train()


num_epochs=100
start_time = time.time()
nz = params['nz'] # 노이즈 크기 100
loss_hist = {'dis':[],
             'gen':[]}

for epoch in range(num_epochs):

    for xb, yb in train_loader:

        ba_si = xb.shape[0]

        xb = xb.to(device)
        yb_real = torch.Tensor(ba_si,1).fill_(1.0).to(device)
        yb_fake = torch.Tensor(ba_si,1).fill_(0.0).to(device)

        # generator
        generator.zero_grad()

        z = torch.randn(ba_si,nz,1,1).to(device) # 노이즈생성
        out_gen = generator(z) # 가짜 이미지 생성
        out_dis = discriminator(out_gen) # 가짜 이미지 식별

        g_loss = loss_func(out_dis,yb_real)
        g_loss.backward()
        optim_gen.step()

        # discriminator
        discriminator.zero_grad()
        
        out_dis = discriminator(xb) # 진짜 이미지 식별
        loss_real = loss_func(out_dis,yb_real)

        out_dis = discriminator(out_gen.detach()) #discriminator 훈련 중이므로 generator영향없게 detach
        loss_fake = loss_func(out_dis,yb_fake)

        d_loss = (loss_real + loss_fake) / 2
        d_loss.backward()
        optim_dis.step()

        loss_hist['gen'].append(g_loss.item())
        loss_hist['dis'].append(d_loss.item())

    print('Epoch: %.0f, G_Loss: %.6f, D_Loss: %.6f, time: %.2f min' %(epoch, g_loss.item(), d_loss.item(), (time.time()-start_time)/60))

 

원래 모든 코드를 혼자 짰는데 훈련 시간에서 너무 오래 걸리더라고요..

그래서 다른 분들의 코드를 참고해서 싹다 갈아엎었습니다.

그리고 보통 Discriminator가 학습이 더 잘되는 편이라 항상 생각했는데

이번에는 Generator가 학습이 더 잘되었습니다.(결과는 그렇게 좋지는 않은게 함정입니다)

 

학습 결과

막 엄청 훌륭한 모습은 아닙니다.

저는 epoch 100회를 3번 돌린 결과물인데 더 많은 epoch으로 학습하면 더 좋은 결과가 있을 거 같습니다.

 

궁금하신 점이나 틀린 점 지적해주시면 감사하겠습니다.