Generative

비전공생의 Pix2Pix(Image to Image translation with Conditional Adversarial Networks, 2016) 코드 구현

지웅쓰 2022. 11. 28. 00:56

지난번 논문 리뷰에 이은 코드리뷰이다.

 

생성자와 판별자, 손실함수 위주로 진행해보겠습니다.

생성자 정의

class UNetDown(nn.Module): #UNet class 정의하기
  def __init__(self, in_channels, out_channels, normalize = True, dropout = 0.0):
    super(UNetDown, self).__init__()
    layers = [nn.Conv2d(in_channels, out_channels,4, stride = 2, padding = 1, bias = False)]

    if normalize:
      layers.append(nn.InstanceNorm2d(out_channels))
    layers.append(nn.LeakyReLU(0.2))

    if dropout:
      layers.append(nn.Dropout(dropout))
    self.down = nn.Sequential(*layers)

  def forward(self, x):
    x = self.down(x)
    return x

Pix2Pix의 생성자 구조는 UNet의 구조를 띄고 있습니다.

U자 모양의 네트워크를 가지고 있으며 skip connection을 사용합니다.

그 중 Down Block은 이미지가 들어왔을 때 인코딩하는 block을 의미합니다

기존 UNet에서는 패딩을 없애고 crop함수를 사용하였지만 편리함을 위해 패딩을 추가하였습니다.

(기존 UNet에서는 kernel size = 3, padding = 0)

추가로 논문에 나왔던 대로 이미지 생성에 효과적인 Instance Norm을 추가하고

활성화 함수는 LeakyReLU를 사용하였습니다.

 

class UNetUp(nn.Module):
  def __init__(self, in_channels, out_channels, dropout = 0.0):
    super(UNetUp, self).__init__()
    layers = [
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
        ,nn.InstanceNorm2d(out_channels)
        ,nn.LeakyReLU()
    ]
    if dropout:
      layers.append(nn.Dropout(dropout))
    self.up = nn.Sequential(*layers)
  def forward(self, x, skip):
    x = self.up(x)
    x = torch.cat((x, skip), 1)
    return x

Up block은 Down block과 반대편에서 사용이 되며 인코딩된 feature들을 이용해

알맞은 이미지를 생성하는데 사용됩니다. 그리고 UNet에서는 skip connection을 활용해

더 많은 low level feature들을 전달할 수 있다고 합니다. Pix2Pix논문 리뷰

forward파트에서 feature들을 합쳐줍니다. 이 부분은 논문리뷰에서 이미지를 보시면 더 확실히 이해할 수 있을 것이라 생각합니다.

class GeneratorUNet(nn.Module):
  def __init__(self, in_channels = 3, out_channels = 3):
    super(GeneratorUNet, self).__init__()
    self.down1 = UNetDown(in_channels, 64, normalize = False)
    self.down2 = UNetDown(64, 128)
    self.down3 = UNetDown(128, 256)
    self.down4 = UNetDown(256, 512, dropout = 0.5)
    self.down5 = UNetDown(512, 512, dropout = 0.5)
    self.down6 = UNetDown(512, 512, dropout = 0.5)
    self.down7 = UNetDown(512, 512, dropout = 0.5)
    self.down8 = UNetDown(512, 512, normalize = False, dropout = 0.5)

    self.up1 = UNetUp(512, 512, dropout = 0.5)
    self.up2 = UNetUp(1024, 512, dropout = 0.5)
    self.up3 = UNetUp(1024, 512, dropout = 0.5)
    self.up4 = UNetUp(1024, 512, dropout = 0.5)
    self.up5 = UNetUp(1024, 256)
    self.up6 = UNetUp(512, 128)
    self.up7 = UNetUp(256, 64)
    self.up8 = nn.Sequential(
        nn.ConvTranspose2d(128, 3, 4, stride = 2, padding = 1),
        nn.Tanh()
    )

  def forward(self, x):
    d1 = self.down1(x)
    d2 = self.down2(d1)
    d3 = self.down3(d2)
    d4 = self.down4(d3)
    d5 = self.down5(d4)
    d6 = self.down6(d5)
    d7 = self.down7(d6)
    d8 = self.down8(d7)

    u1 = self.up1(d8, d7)
    u2 = self.up2(u1, d6)
    u3 = self.up3(u2, d5)
    u4 = self.up4(u3, d4)
    u5 = self.up5(u4, d3)
    u6 = self.up6(u5, d2)
    u7 = self.up7(u6, d1)
    u8 = self.up8(u7)

    return u8

그렇게 만든 up, down block들을 합쳐줍니다. self.up1 이 512를 내뱉는데 self.up2가 1024를 받는 것은

up 하는 과정에서 skip connection을 활용해 feature들을 합쳐주었기 때문입니다.

이러한 것이 계속 반복되며 마지막에는 이미지를 생성합니다.

 

판별자 정의

class Dis_block(nn.Module):
  def __init__(self, in_channels, out_channels, normalize = True):
    super(Dis_block, self).__init__()
    layers = [nn.Conv2d(in_channels, out_channels, 3, stride = 2, padding = 1)]
    if normalize:
      layers.append(nn.InstanceNorm2d(out_channels))
    layers.append(nn.LeakyReLU(0.2))

    self.block = nn.Sequential(*layers)

  def forward(self, x):
    x = self.block(x)
    return x

판별할 때는 PatchGAN을 사용합니다.

단순히 전체 이미지를 보고 판별자가 판단하는 것이 아니라 부분부분 이미지를 보고

이미지가 진짜 이미지인지 가짜이미지인지 파악합니다.

논문에서는 70x70 의 PatchGAN이 가장 효과가 좋았다고 합니다. 저는 16x16 을 사용하였습니다.  

class Discriminator(nn.Module):
  def __init__(self, in_channels = 3):
    super(Discriminator, self).__init__()
    self.stage_1 = Dis_block(in_channels * 2, 64, normalize = False)
    self.stage_2 = Dis_block(64, 128)
    self.stage_3 = Dis_block(128, 256)
    self.stage_4 = Dis_block(256, 512)

    self.path = nn.Conv2d(512, 1, 3, padding = 1) # 16 * 16 의 패치 생성

  def forward(self, a, b):
    x = torch.cat((a, b), 1)
    x = self.stage_1(x)
    x = self.stage_2(x)
    x = self.stage_3(x)
    x = self.stage_4(x)
    x = self.path(x)
    x = torch.sigmoid(x)
    return x

마지막에는 sigmoid를 넣어줍니다.

 

모델 훈련 및 결과

loss_func_gan = nn.BCELoss()
loss_func_pix = nn.L1Loss()

lambda_pixel = 100 #loss_func_pix 가중치

patch = (1, 256 // 2 ** 4, 256 // 2 ** 4) #(1, 16, 16)

optim_dis = optim.Adam(model_dis.parameters(), lr = 2e-4, betas = (0.5, 0.999))
optim_gen = optim.Adam(model_gen.parameters(), lr = 2e-4, betas = (0.5, 0.999))

기존의 GAN과는 다르게 L1 Loss를 추가하여 줍니다.

논문에서 증명했듯이 GAN Loss와 L1 Loss를 같이 사용하는 것이 가장 효율적이었다고 말합니다.

 

model_gen.train()
model_dis.train()

batch_count = 0
num_epochs = 50

loss_hist = {'gen' : [], 'dis' : []}

for epoch in range(num_epochs):
  for a, b in train_loader:

    batch_size = a.size(0)

    real_a = a.to(device)
    real_b = b.to(device)

    real_label = torch.ones(batch_size, *patch, requires_grad = False).to(device)
    fake_label = torch.zeros(batch_size, *patch, requires_grad = False).to(device)

    model_gen.zero_grad() #생성자 훈련 위한 가중치 초기화

    fake_b = model_gen(real_a) #선 이미지로 가짜 이미지 생성
    out_dis = model_dis(fake_b, real_b) #가짜 이미지인지 진짜 이미지인지 판별

    gen_loss = loss_func_gan(out_dis, real_label) #최대한 1에 가깝게 만들 수 있도록 generator 훈련
    pixel_loss = loss_func_pix(fake_b, real_b) #진짜 이미지와 가짜이미지간의 L1 loss계산

    g_loss = gen_loss + lambda_pixel * pixel_loss
    g_loss.backward()
    optim_gen.step()

    model_dis.zero_grad() #구별자 훈련 위한 가중치 초기화
    out_dis = model_dis(real_b, real_a) #진짜 이미지 인식
    real_loss = loss_func_gan(out_dis, real_label) #최대한 1에 가깝게 훈련
    
    out_dis = model_dis(fake_b.detach(), real_a) #가짜 이미지 인식, 구별자 훈련 중이므로 생성자 훈련 x
    fake_loss = loss_func_gan(out_dis, fake_label) #0에 가깝게 훈련

    d_loss = (real_loss + fake_loss) / 2
    d_loss.backward()
    optim_dis.step()

    loss_hist['gen'].append(g_loss.item())
    loss_hist['dis'].append(d_loss.item())

    batch_count += 1
    if batch_count % 100 == 0:
      print(f'Epoch : {epoch}, G_loss : {g_loss.item():.3f}, D_loss : {d_loss.item():.3f}')

Epoch은 50을 주고 훈련시킵니다.

 

최정적으로 생성한 이미지는 다음과같습니다.

이미지 중간중간 검은 부분 같은건 원본 이미지에도 있었기에 생기는 부분 같습니다.

중간중간 blurry한 이미지들도 있지만 꽤 잘 생성하는 것을 확인할 수 있습니다.

 

전체코드는 github에서 확인하실 수 있습니다.

https://github.com/JiWoongCho1/Gernerative-Model_Paper_Review/blob/main/Pix2Pix%20(1).ipynb 

 

GitHub - JiWoongCho1/Gernerative-Model_Paper_Review: 생성모델 논문 review 및 코드 구현입니다.

생성모델 논문 review 및 코드 구현입니다. Contribute to JiWoongCho1/Gernerative-Model_Paper_Review development by creating an account on GitHub.

github.com