Hello Computer Vision

비전공생의 WGAN(Wasserstein GAN, 2017) 코드 구현 본문

Generative

비전공생의 WGAN(Wasserstein GAN, 2017) 코드 구현

지웅쓰 2022. 11. 16. 12:45

WGAN에 관해서 간략한 리뷰를 지난번에 했는데요, 이번에는 코드구현을 해보겠습니다.

훈련과정에서의 손실함수 부분을 주의깊게 봐주시면 될 거 같습니다.

WGAN 설명

https://keepgoingrunner.tistory.com/32

 

WGAN(Wasserstein GAN, 2017) 에 대한 간단한 이해

이번에 읽으려고 했던 논문은 WGAN이었다. 그러나 너무나 많은 수식으로 인해 저 혼자서 이해하기에는 힘들다고 생각했고 많은 블로그들을 참조해 WGAN 에서 필요한 부분들에 대한 저만의 이해를

keepgoingrunner.tistory.com

 

1. 구분자 정의

class Critic(nn.Module): #구분자 정의
  def __init__(self, channels_img, features_d): #이미지 채널, 곱해나갈 채널을 parameter로 받는다.
    super(Critic, self).__init__()
    self.disc = nn.Sequential(
        #input : N x channels_img x 64 x 64
        nn.Conv2d(channels_img, features_d, kernel_size = 4, stride = 2, padding = 1), #4,2,1로 맞춰주면 이미지의 크기가 2씩 줄어든다.
        nn.LeakyReLU(0.2),
        self._block(features_d, features_d * 2, 4, 2, 1),
        self._block(features_d * 2, features_d * 4, 4, 2, 1),
        self._block(features_d * 4, features_d * 8, 4, 2, 1),
        nn.Conv2d(features_d * 8, 1, kernel_size = 4, stride = 2, padding = 0)
    ) #판별자와 달리 sigmoid가 없다는 것이 큰 차이점이다.
  def _block(self, in_channels, out_channels, kernel_size, stride, padding): #중첩할 블럭들을 미리 정의해준다.
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias = False),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2)
    )
  def forward(self, x):
    return self.disc(x) #output 값은 (b, 1, 1, 1)

기존 GAN에서는 discriminator라고 정의해주던 것이 여기서는 Critic(구분자)로 정의되었습니다.

다른 점은 마지막 output에 대해 sigmoid가 없습니다.

이미지를 받고 CNN 블럭들을 통해 (b, 1, 1, 1)크기를 output으로 내뱉습니다.

 

2. 생성자 정의

class Generator(nn.Module):
  def __init__(self, z_dim, channels_img, features_g):
    super(Generator, self).__init__()
    #input : N x z_dim x 1 x 1
    self.gen = nn.Sequential(
        self._block(z_dim, features_g * 16, 4, 1, 0),
        self._block(features_g * 16, features_g * 8, 4, 2, 1),
        self._block(features_g * 8, features_g * 4, 4, 2, 1),
        self._block(features_g * 4, features_g * 2, 4, 2, 1),
        nn.ConvTranspose2d(features_g * 2, channels_img, kernel_size = 4, stride = 2, padding = 1),
        nn.Tanh()
    )
  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(  #이미지의 크기를 늘리기 위해 convtranspose2d 를 사용하였다.
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias = False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU()
    )

  def forward(self, x):
    return self.gen(x)

1x64x64 크기의 이미지를 만듭니다.

3. 최적화함수 정의 및 훈련

opt_gen = optim.RMSprop(gen.parameters(), lr = LEARNING_RATE) #최적화는 Adam 대신 RMSprop을 사용해준다.
opt_critic = optim.RMSprop(critic.parameters(), lr = LEARNING_RATE)

for epoch in range(NUM_EPOCHS):
  for batch_idx, (real, _) in enumerate(trainloader):
    real = real.to(device)

    for _ in range(CRITIC_ITERATIONS): #1번의 생성자 훈련동안 5번의 구분자 훈련을 해준다.
      noise = torch.rand(BATCH_SIZE, Z_DIM, 1, 1).to(device)  #노이즈 생성, 크기 100의 벡터이다.
      fake = gen(noise)                     #생성자를 통해 이미지 생성
      critic_real = critic(real).reshape(-1) #구분자에 진짜 이미지를 넣어준다. output : (b, 1, 1, 1)
      critic_fake = critic(fake).reshape(-1) #구분자에 가짜 이미지를 넣어준다.
      loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake)) #Wasserstein 손실함수를 이용해준다.(기존에는 BCE Loss를 사용하였다.)
      critic.zero_grad() 
      loss_critic.backward(retain_graph = True)
      opt_critic.step()

      for p in critic.parameters(): #구분자의 가중치 파라미터를 [-0.01, 0.01]로 한정한다.
        p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)

    
    output = critic(fake).reshape(-1) #여기서 fake는 5번의 critic훈련 중 마지막 fake img
    loss_gen = -torch.mean(output) 
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()
  print('Epoch : ', epoch, 'loss_Gen : ', loss_gen, 'loss_critic : ', loss_critic)

논문에서 나왔다시피 5번의 구분자 훈련 이후 1번의 생성자 훈련을 수행함을 알 수 있습니다.

기존 GAN에서는 판별자를 통해 0,1을 통해 훈련을 했다면 WGAN에서는 단순히 

평균값을 빼줌을 알 수 있습니다. (이러한 값 최소만들기 위해 앞에 -붙였다)

그리고 생성자를 훈련할 때 역시 이러한 값을 최소로 해줍니다.

이러한 손실을 최소로 해주어야하는데 논문에서는 일정한 변형을 통해max 값을 내뱉도록 했으니
(gradient ascent)

-부호를 붙여줍니다.

noize = torch.rand(32, 100, 1, 1).to(device)

plt.figure(figsize = (10, 10))
imgs = gen(noize)
for i in range(32):
  plt.subplot(4,8,i+1)
  plt.imshow(make_grid(imgs[i], normalize = True).permute(1,2,0))

5번밖에 시행하지 않았는데 결과가 나쁘지않음을 알 수 있습니다.

사실 훈련시간이 금방 끝나지는 않았습니다.

 

전체코드는 제 github에서 확인할 수 있습니다.

https://github.com/JiWoongCho1/Gernerative-Model_Paper_Review/blob/main/wGAN.ipynb

 

GitHub - JiWoongCho1/Gernerative-Model_Paper_Review: 생성모델 논문 review 및 코드 구현입니다.

생성모델 논문 review 및 코드 구현입니다. Contribute to JiWoongCho1/Gernerative-Model_Paper_Review development by creating an account on GitHub.

github.com

틀린점 있다면 지적해주시면 감사하겠습니다.