비전공생의 Pix2Pix(Image to Image translation with Conditional Adversarial Networks, 2016) 코드 구현
지난번 논문 리뷰에 이은 코드리뷰이다.
생성자와 판별자, 손실함수 위주로 진행해보겠습니다.
생성자 정의
class UNetDown(nn.Module): #UNet class 정의하기
def __init__(self, in_channels, out_channels, normalize = True, dropout = 0.0):
super(UNetDown, self).__init__()
layers = [nn.Conv2d(in_channels, out_channels,4, stride = 2, padding = 1, bias = False)]
if normalize:
layers.append(nn.InstanceNorm2d(out_channels))
layers.append(nn.LeakyReLU(0.2))
if dropout:
layers.append(nn.Dropout(dropout))
self.down = nn.Sequential(*layers)
def forward(self, x):
x = self.down(x)
return x
Pix2Pix의 생성자 구조는 UNet의 구조를 띄고 있습니다.
U자 모양의 네트워크를 가지고 있으며 skip connection을 사용합니다.
그 중 Down Block은 이미지가 들어왔을 때 인코딩하는 block을 의미합니다
기존 UNet에서는 패딩을 없애고 crop함수를 사용하였지만 편리함을 위해 패딩을 추가하였습니다.
(기존 UNet에서는 kernel size = 3, padding = 0)
추가로 논문에 나왔던 대로 이미지 생성에 효과적인 Instance Norm을 추가하고
활성화 함수는 LeakyReLU를 사용하였습니다.
class UNetUp(nn.Module):
def __init__(self, in_channels, out_channels, dropout = 0.0):
super(UNetUp, self).__init__()
layers = [
nn.ConvTranspose2d(in_channels, out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
,nn.InstanceNorm2d(out_channels)
,nn.LeakyReLU()
]
if dropout:
layers.append(nn.Dropout(dropout))
self.up = nn.Sequential(*layers)
def forward(self, x, skip):
x = self.up(x)
x = torch.cat((x, skip), 1)
return x
Up block은 Down block과 반대편에서 사용이 되며 인코딩된 feature들을 이용해
알맞은 이미지를 생성하는데 사용됩니다. 그리고 UNet에서는 skip connection을 활용해
더 많은 low level feature들을 전달할 수 있다고 합니다. Pix2Pix논문 리뷰
forward파트에서 feature들을 합쳐줍니다. 이 부분은 논문리뷰에서 이미지를 보시면 더 확실히 이해할 수 있을 것이라 생각합니다.
class GeneratorUNet(nn.Module):
def __init__(self, in_channels = 3, out_channels = 3):
super(GeneratorUNet, self).__init__()
self.down1 = UNetDown(in_channels, 64, normalize = False)
self.down2 = UNetDown(64, 128)
self.down3 = UNetDown(128, 256)
self.down4 = UNetDown(256, 512, dropout = 0.5)
self.down5 = UNetDown(512, 512, dropout = 0.5)
self.down6 = UNetDown(512, 512, dropout = 0.5)
self.down7 = UNetDown(512, 512, dropout = 0.5)
self.down8 = UNetDown(512, 512, normalize = False, dropout = 0.5)
self.up1 = UNetUp(512, 512, dropout = 0.5)
self.up2 = UNetUp(1024, 512, dropout = 0.5)
self.up3 = UNetUp(1024, 512, dropout = 0.5)
self.up4 = UNetUp(1024, 512, dropout = 0.5)
self.up5 = UNetUp(1024, 256)
self.up6 = UNetUp(512, 128)
self.up7 = UNetUp(256, 64)
self.up8 = nn.Sequential(
nn.ConvTranspose2d(128, 3, 4, stride = 2, padding = 1),
nn.Tanh()
)
def forward(self, x):
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
d5 = self.down5(d4)
d6 = self.down6(d5)
d7 = self.down7(d6)
d8 = self.down8(d7)
u1 = self.up1(d8, d7)
u2 = self.up2(u1, d6)
u3 = self.up3(u2, d5)
u4 = self.up4(u3, d4)
u5 = self.up5(u4, d3)
u6 = self.up6(u5, d2)
u7 = self.up7(u6, d1)
u8 = self.up8(u7)
return u8
그렇게 만든 up, down block들을 합쳐줍니다. self.up1 이 512를 내뱉는데 self.up2가 1024를 받는 것은
up 하는 과정에서 skip connection을 활용해 feature들을 합쳐주었기 때문입니다.
이러한 것이 계속 반복되며 마지막에는 이미지를 생성합니다.
판별자 정의
class Dis_block(nn.Module):
def __init__(self, in_channels, out_channels, normalize = True):
super(Dis_block, self).__init__()
layers = [nn.Conv2d(in_channels, out_channels, 3, stride = 2, padding = 1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_channels))
layers.append(nn.LeakyReLU(0.2))
self.block = nn.Sequential(*layers)
def forward(self, x):
x = self.block(x)
return x
판별할 때는 PatchGAN을 사용합니다.
단순히 전체 이미지를 보고 판별자가 판단하는 것이 아니라 부분부분 이미지를 보고
이미지가 진짜 이미지인지 가짜이미지인지 파악합니다.
논문에서는 70x70 의 PatchGAN이 가장 효과가 좋았다고 합니다. 저는 16x16 을 사용하였습니다.
class Discriminator(nn.Module):
def __init__(self, in_channels = 3):
super(Discriminator, self).__init__()
self.stage_1 = Dis_block(in_channels * 2, 64, normalize = False)
self.stage_2 = Dis_block(64, 128)
self.stage_3 = Dis_block(128, 256)
self.stage_4 = Dis_block(256, 512)
self.path = nn.Conv2d(512, 1, 3, padding = 1) # 16 * 16 의 패치 생성
def forward(self, a, b):
x = torch.cat((a, b), 1)
x = self.stage_1(x)
x = self.stage_2(x)
x = self.stage_3(x)
x = self.stage_4(x)
x = self.path(x)
x = torch.sigmoid(x)
return x
마지막에는 sigmoid를 넣어줍니다.
모델 훈련 및 결과
loss_func_gan = nn.BCELoss()
loss_func_pix = nn.L1Loss()
lambda_pixel = 100 #loss_func_pix 가중치
patch = (1, 256 // 2 ** 4, 256 // 2 ** 4) #(1, 16, 16)
optim_dis = optim.Adam(model_dis.parameters(), lr = 2e-4, betas = (0.5, 0.999))
optim_gen = optim.Adam(model_gen.parameters(), lr = 2e-4, betas = (0.5, 0.999))
기존의 GAN과는 다르게 L1 Loss를 추가하여 줍니다.
논문에서 증명했듯이 GAN Loss와 L1 Loss를 같이 사용하는 것이 가장 효율적이었다고 말합니다.
model_gen.train()
model_dis.train()
batch_count = 0
num_epochs = 50
loss_hist = {'gen' : [], 'dis' : []}
for epoch in range(num_epochs):
for a, b in train_loader:
batch_size = a.size(0)
real_a = a.to(device)
real_b = b.to(device)
real_label = torch.ones(batch_size, *patch, requires_grad = False).to(device)
fake_label = torch.zeros(batch_size, *patch, requires_grad = False).to(device)
model_gen.zero_grad() #생성자 훈련 위한 가중치 초기화
fake_b = model_gen(real_a) #선 이미지로 가짜 이미지 생성
out_dis = model_dis(fake_b, real_b) #가짜 이미지인지 진짜 이미지인지 판별
gen_loss = loss_func_gan(out_dis, real_label) #최대한 1에 가깝게 만들 수 있도록 generator 훈련
pixel_loss = loss_func_pix(fake_b, real_b) #진짜 이미지와 가짜이미지간의 L1 loss계산
g_loss = gen_loss + lambda_pixel * pixel_loss
g_loss.backward()
optim_gen.step()
model_dis.zero_grad() #구별자 훈련 위한 가중치 초기화
out_dis = model_dis(real_b, real_a) #진짜 이미지 인식
real_loss = loss_func_gan(out_dis, real_label) #최대한 1에 가깝게 훈련
out_dis = model_dis(fake_b.detach(), real_a) #가짜 이미지 인식, 구별자 훈련 중이므로 생성자 훈련 x
fake_loss = loss_func_gan(out_dis, fake_label) #0에 가깝게 훈련
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optim_dis.step()
loss_hist['gen'].append(g_loss.item())
loss_hist['dis'].append(d_loss.item())
batch_count += 1
if batch_count % 100 == 0:
print(f'Epoch : {epoch}, G_loss : {g_loss.item():.3f}, D_loss : {d_loss.item():.3f}')
Epoch은 50을 주고 훈련시킵니다.
최정적으로 생성한 이미지는 다음과같습니다.
이미지 중간중간 검은 부분 같은건 원본 이미지에도 있었기에 생기는 부분 같습니다.
중간중간 blurry한 이미지들도 있지만 꽤 잘 생성하는 것을 확인할 수 있습니다.
전체코드는 github에서 확인하실 수 있습니다.
https://github.com/JiWoongCho1/Gernerative-Model_Paper_Review/blob/main/Pix2Pix%20(1).ipynb
GitHub - JiWoongCho1/Gernerative-Model_Paper_Review: 생성모델 논문 review 및 코드 구현입니다.
생성모델 논문 review 및 코드 구현입니다. Contribute to JiWoongCho1/Gernerative-Model_Paper_Review development by creating an account on GitHub.
github.com