일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
Tags
- BYOL
- UnderstandingDeepLearning
- Pix2Pix
- remixmatch paper
- shrinkmatch paper
- mme paper
- mocov3
- WGAN
- ConMatch
- SSL
- cifar100-c
- 최린컴퓨터구조
- Pseudo Label
- CGAN
- semi supervised learnin 가정
- CycleGAN
- Entropy Minimization
- conjugate pseudo label paper
- tent paper
- shrinkmatch
- dann paper
- CoMatch
- simclrv2
- 컴퓨터구조
- dcgan
- Meta Pseudo Labels
- 딥러닝손실함수
- 백준 알고리즘
- GAN
- adamatch paper
Archives
- Today
- Total
Hello Computer Vision
비전공생의 WGAN(Wasserstein GAN, 2017) 코드 구현 본문
WGAN에 관해서 간략한 리뷰를 지난번에 했는데요, 이번에는 코드구현을 해보겠습니다.
훈련과정에서의 손실함수 부분을 주의깊게 봐주시면 될 거 같습니다.
WGAN 설명
https://keepgoingrunner.tistory.com/32
1. 구분자 정의
class Critic(nn.Module): #구분자 정의
def __init__(self, channels_img, features_d): #이미지 채널, 곱해나갈 채널을 parameter로 받는다.
super(Critic, self).__init__()
self.disc = nn.Sequential(
#input : N x channels_img x 64 x 64
nn.Conv2d(channels_img, features_d, kernel_size = 4, stride = 2, padding = 1), #4,2,1로 맞춰주면 이미지의 크기가 2씩 줄어든다.
nn.LeakyReLU(0.2),
self._block(features_d, features_d * 2, 4, 2, 1),
self._block(features_d * 2, features_d * 4, 4, 2, 1),
self._block(features_d * 4, features_d * 8, 4, 2, 1),
nn.Conv2d(features_d * 8, 1, kernel_size = 4, stride = 2, padding = 0)
) #판별자와 달리 sigmoid가 없다는 것이 큰 차이점이다.
def _block(self, in_channels, out_channels, kernel_size, stride, padding): #중첩할 블럭들을 미리 정의해준다.
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias = False),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2)
)
def forward(self, x):
return self.disc(x) #output 값은 (b, 1, 1, 1)
기존 GAN에서는 discriminator라고 정의해주던 것이 여기서는 Critic(구분자)로 정의되었습니다.
다른 점은 마지막 output에 대해 sigmoid가 없습니다.
이미지를 받고 CNN 블럭들을 통해 (b, 1, 1, 1)크기를 output으로 내뱉습니다.
2. 생성자 정의
class Generator(nn.Module):
def __init__(self, z_dim, channels_img, features_g):
super(Generator, self).__init__()
#input : N x z_dim x 1 x 1
self.gen = nn.Sequential(
self._block(z_dim, features_g * 16, 4, 1, 0),
self._block(features_g * 16, features_g * 8, 4, 2, 1),
self._block(features_g * 8, features_g * 4, 4, 2, 1),
self._block(features_g * 4, features_g * 2, 4, 2, 1),
nn.ConvTranspose2d(features_g * 2, channels_img, kernel_size = 4, stride = 2, padding = 1),
nn.Tanh()
)
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential( #이미지의 크기를 늘리기 위해 convtranspose2d 를 사용하였다.
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias = False),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x):
return self.gen(x)
1x64x64 크기의 이미지를 만듭니다.
3. 최적화함수 정의 및 훈련
opt_gen = optim.RMSprop(gen.parameters(), lr = LEARNING_RATE) #최적화는 Adam 대신 RMSprop을 사용해준다.
opt_critic = optim.RMSprop(critic.parameters(), lr = LEARNING_RATE)
for epoch in range(NUM_EPOCHS):
for batch_idx, (real, _) in enumerate(trainloader):
real = real.to(device)
for _ in range(CRITIC_ITERATIONS): #1번의 생성자 훈련동안 5번의 구분자 훈련을 해준다.
noise = torch.rand(BATCH_SIZE, Z_DIM, 1, 1).to(device) #노이즈 생성, 크기 100의 벡터이다.
fake = gen(noise) #생성자를 통해 이미지 생성
critic_real = critic(real).reshape(-1) #구분자에 진짜 이미지를 넣어준다. output : (b, 1, 1, 1)
critic_fake = critic(fake).reshape(-1) #구분자에 가짜 이미지를 넣어준다.
loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake)) #Wasserstein 손실함수를 이용해준다.(기존에는 BCE Loss를 사용하였다.)
critic.zero_grad()
loss_critic.backward(retain_graph = True)
opt_critic.step()
for p in critic.parameters(): #구분자의 가중치 파라미터를 [-0.01, 0.01]로 한정한다.
p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)
output = critic(fake).reshape(-1) #여기서 fake는 5번의 critic훈련 중 마지막 fake img
loss_gen = -torch.mean(output)
gen.zero_grad()
loss_gen.backward()
opt_gen.step()
print('Epoch : ', epoch, 'loss_Gen : ', loss_gen, 'loss_critic : ', loss_critic)
논문에서 나왔다시피 5번의 구분자 훈련 이후 1번의 생성자 훈련을 수행함을 알 수 있습니다.
기존 GAN에서는 판별자를 통해 0,1을 통해 훈련을 했다면 WGAN에서는 단순히
평균값을 빼줌을 알 수 있습니다. (이러한 값 최소만들기 위해 앞에 -붙였다)
그리고 생성자를 훈련할 때 역시 이러한 값을 최소로 해줍니다.
이러한 손실을 최소로 해주어야하는데 논문에서는 일정한 변형을 통해max 값을 내뱉도록 했으니
(gradient ascent)
-부호를 붙여줍니다.
noize = torch.rand(32, 100, 1, 1).to(device)
plt.figure(figsize = (10, 10))
imgs = gen(noize)
for i in range(32):
plt.subplot(4,8,i+1)
plt.imshow(make_grid(imgs[i], normalize = True).permute(1,2,0))
5번밖에 시행하지 않았는데 결과가 나쁘지않음을 알 수 있습니다.
사실 훈련시간이 금방 끝나지는 않았습니다.
전체코드는 제 github에서 확인할 수 있습니다.
https://github.com/JiWoongCho1/Gernerative-Model_Paper_Review/blob/main/wGAN.ipynb
틀린점 있다면 지적해주시면 감사하겠습니다.
'Generative' 카테고리의 다른 글
비전공생의 CycleGAN(Cycle consistent GAN, 2017)코드 구현 (0) | 2022.11.20 |
---|---|
비전공생의 CycleGAN(Cycle-Consistent Adersarial Networks, 2017)논문 리뷰 (1) | 2022.11.19 |
비전공생의 WGAN(Wasserstein GAN, 2017) 에 대한 간단한 이해 (0) | 2022.11.11 |
비전공생의 InfoGAN(Information Maximizing Generative Adversarial Nets, 2016) 논문 리뷰 (0) | 2022.11.07 |
비전공생의 DCGAN(Deep Convolutional Generative Adversarial Network, 2016)코드 구현 (0) | 2022.11.04 |