Hello Computer Vision

비전공생의 Pix2Pix(Image to Image translation with Conditional Adversarial Networks, 2016) 코드 구현 본문

Generative

비전공생의 Pix2Pix(Image to Image translation with Conditional Adversarial Networks, 2016) 코드 구현

지웅쓰 2022. 11. 28. 00:56

지난번 논문 리뷰에 이은 코드리뷰이다.

 

생성자와 판별자, 손실함수 위주로 진행해보겠습니다.

생성자 정의

class UNetDown(nn.Module): #UNet class 정의하기
  def __init__(self, in_channels, out_channels, normalize = True, dropout = 0.0):
    super(UNetDown, self).__init__()
    layers = [nn.Conv2d(in_channels, out_channels,4, stride = 2, padding = 1, bias = False)]

    if normalize:
      layers.append(nn.InstanceNorm2d(out_channels))
    layers.append(nn.LeakyReLU(0.2))

    if dropout:
      layers.append(nn.Dropout(dropout))
    self.down = nn.Sequential(*layers)

  def forward(self, x):
    x = self.down(x)
    return x

Pix2Pix의 생성자 구조는 UNet의 구조를 띄고 있습니다.

U자 모양의 네트워크를 가지고 있으며 skip connection을 사용합니다.

그 중 Down Block은 이미지가 들어왔을 때 인코딩하는 block을 의미합니다

기존 UNet에서는 패딩을 없애고 crop함수를 사용하였지만 편리함을 위해 패딩을 추가하였습니다.

(기존 UNet에서는 kernel size = 3, padding = 0)

추가로 논문에 나왔던 대로 이미지 생성에 효과적인 Instance Norm을 추가하고

활성화 함수는 LeakyReLU를 사용하였습니다.

 

class UNetUp(nn.Module):
  def __init__(self, in_channels, out_channels, dropout = 0.0):
    super(UNetUp, self).__init__()
    layers = [
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
        ,nn.InstanceNorm2d(out_channels)
        ,nn.LeakyReLU()
    ]
    if dropout:
      layers.append(nn.Dropout(dropout))
    self.up = nn.Sequential(*layers)
  def forward(self, x, skip):
    x = self.up(x)
    x = torch.cat((x, skip), 1)
    return x

Up block은 Down block과 반대편에서 사용이 되며 인코딩된 feature들을 이용해

알맞은 이미지를 생성하는데 사용됩니다. 그리고 UNet에서는 skip connection을 활용해

더 많은 low level feature들을 전달할 수 있다고 합니다. Pix2Pix논문 리뷰

forward파트에서 feature들을 합쳐줍니다. 이 부분은 논문리뷰에서 이미지를 보시면 더 확실히 이해할 수 있을 것이라 생각합니다.

class GeneratorUNet(nn.Module):
  def __init__(self, in_channels = 3, out_channels = 3):
    super(GeneratorUNet, self).__init__()
    self.down1 = UNetDown(in_channels, 64, normalize = False)
    self.down2 = UNetDown(64, 128)
    self.down3 = UNetDown(128, 256)
    self.down4 = UNetDown(256, 512, dropout = 0.5)
    self.down5 = UNetDown(512, 512, dropout = 0.5)
    self.down6 = UNetDown(512, 512, dropout = 0.5)
    self.down7 = UNetDown(512, 512, dropout = 0.5)
    self.down8 = UNetDown(512, 512, normalize = False, dropout = 0.5)

    self.up1 = UNetUp(512, 512, dropout = 0.5)
    self.up2 = UNetUp(1024, 512, dropout = 0.5)
    self.up3 = UNetUp(1024, 512, dropout = 0.5)
    self.up4 = UNetUp(1024, 512, dropout = 0.5)
    self.up5 = UNetUp(1024, 256)
    self.up6 = UNetUp(512, 128)
    self.up7 = UNetUp(256, 64)
    self.up8 = nn.Sequential(
        nn.ConvTranspose2d(128, 3, 4, stride = 2, padding = 1),
        nn.Tanh()
    )

  def forward(self, x):
    d1 = self.down1(x)
    d2 = self.down2(d1)
    d3 = self.down3(d2)
    d4 = self.down4(d3)
    d5 = self.down5(d4)
    d6 = self.down6(d5)
    d7 = self.down7(d6)
    d8 = self.down8(d7)

    u1 = self.up1(d8, d7)
    u2 = self.up2(u1, d6)
    u3 = self.up3(u2, d5)
    u4 = self.up4(u3, d4)
    u5 = self.up5(u4, d3)
    u6 = self.up6(u5, d2)
    u7 = self.up7(u6, d1)
    u8 = self.up8(u7)

    return u8

그렇게 만든 up, down block들을 합쳐줍니다. self.up1 이 512를 내뱉는데 self.up2가 1024를 받는 것은

up 하는 과정에서 skip connection을 활용해 feature들을 합쳐주었기 때문입니다.

이러한 것이 계속 반복되며 마지막에는 이미지를 생성합니다.

 

판별자 정의

class Dis_block(nn.Module):
  def __init__(self, in_channels, out_channels, normalize = True):
    super(Dis_block, self).__init__()
    layers = [nn.Conv2d(in_channels, out_channels, 3, stride = 2, padding = 1)]
    if normalize:
      layers.append(nn.InstanceNorm2d(out_channels))
    layers.append(nn.LeakyReLU(0.2))

    self.block = nn.Sequential(*layers)

  def forward(self, x):
    x = self.block(x)
    return x

판별할 때는 PatchGAN을 사용합니다.

단순히 전체 이미지를 보고 판별자가 판단하는 것이 아니라 부분부분 이미지를 보고

이미지가 진짜 이미지인지 가짜이미지인지 파악합니다.

논문에서는 70x70 의 PatchGAN이 가장 효과가 좋았다고 합니다. 저는 16x16 을 사용하였습니다.  

class Discriminator(nn.Module):
  def __init__(self, in_channels = 3):
    super(Discriminator, self).__init__()
    self.stage_1 = Dis_block(in_channels * 2, 64, normalize = False)
    self.stage_2 = Dis_block(64, 128)
    self.stage_3 = Dis_block(128, 256)
    self.stage_4 = Dis_block(256, 512)

    self.path = nn.Conv2d(512, 1, 3, padding = 1) # 16 * 16 의 패치 생성

  def forward(self, a, b):
    x = torch.cat((a, b), 1)
    x = self.stage_1(x)
    x = self.stage_2(x)
    x = self.stage_3(x)
    x = self.stage_4(x)
    x = self.path(x)
    x = torch.sigmoid(x)
    return x

마지막에는 sigmoid를 넣어줍니다.

 

모델 훈련 및 결과

loss_func_gan = nn.BCELoss()
loss_func_pix = nn.L1Loss()

lambda_pixel = 100 #loss_func_pix 가중치

patch = (1, 256 // 2 ** 4, 256 // 2 ** 4) #(1, 16, 16)

optim_dis = optim.Adam(model_dis.parameters(), lr = 2e-4, betas = (0.5, 0.999))
optim_gen = optim.Adam(model_gen.parameters(), lr = 2e-4, betas = (0.5, 0.999))

기존의 GAN과는 다르게 L1 Loss를 추가하여 줍니다.

논문에서 증명했듯이 GAN Loss와 L1 Loss를 같이 사용하는 것이 가장 효율적이었다고 말합니다.

 

model_gen.train()
model_dis.train()

batch_count = 0
num_epochs = 50

loss_hist = {'gen' : [], 'dis' : []}

for epoch in range(num_epochs):
  for a, b in train_loader:

    batch_size = a.size(0)

    real_a = a.to(device)
    real_b = b.to(device)

    real_label = torch.ones(batch_size, *patch, requires_grad = False).to(device)
    fake_label = torch.zeros(batch_size, *patch, requires_grad = False).to(device)

    model_gen.zero_grad() #생성자 훈련 위한 가중치 초기화

    fake_b = model_gen(real_a) #선 이미지로 가짜 이미지 생성
    out_dis = model_dis(fake_b, real_b) #가짜 이미지인지 진짜 이미지인지 판별

    gen_loss = loss_func_gan(out_dis, real_label) #최대한 1에 가깝게 만들 수 있도록 generator 훈련
    pixel_loss = loss_func_pix(fake_b, real_b) #진짜 이미지와 가짜이미지간의 L1 loss계산

    g_loss = gen_loss + lambda_pixel * pixel_loss
    g_loss.backward()
    optim_gen.step()

    model_dis.zero_grad() #구별자 훈련 위한 가중치 초기화
    out_dis = model_dis(real_b, real_a) #진짜 이미지 인식
    real_loss = loss_func_gan(out_dis, real_label) #최대한 1에 가깝게 훈련
    
    out_dis = model_dis(fake_b.detach(), real_a) #가짜 이미지 인식, 구별자 훈련 중이므로 생성자 훈련 x
    fake_loss = loss_func_gan(out_dis, fake_label) #0에 가깝게 훈련

    d_loss = (real_loss + fake_loss) / 2
    d_loss.backward()
    optim_dis.step()

    loss_hist['gen'].append(g_loss.item())
    loss_hist['dis'].append(d_loss.item())

    batch_count += 1
    if batch_count % 100 == 0:
      print(f'Epoch : {epoch}, G_loss : {g_loss.item():.3f}, D_loss : {d_loss.item():.3f}')

Epoch은 50을 주고 훈련시킵니다.

 

최정적으로 생성한 이미지는 다음과같습니다.

이미지 중간중간 검은 부분 같은건 원본 이미지에도 있었기에 생기는 부분 같습니다.

중간중간 blurry한 이미지들도 있지만 꽤 잘 생성하는 것을 확인할 수 있습니다.

 

전체코드는 github에서 확인하실 수 있습니다.

https://github.com/JiWoongCho1/Gernerative-Model_Paper_Review/blob/main/Pix2Pix%20(1).ipynb 

 

GitHub - JiWoongCho1/Gernerative-Model_Paper_Review: 생성모델 논문 review 및 코드 구현입니다.

생성모델 논문 review 및 코드 구현입니다. Contribute to JiWoongCho1/Gernerative-Model_Paper_Review development by creating an account on GitHub.

github.com