Hello Computer Vision

비전공생의 Adversarial Learning for semi supervised semantic segmentation 코드 살펴보기 본문

Self,Semi-supervised learning

비전공생의 Adversarial Learning for semi supervised semantic segmentation 코드 살펴보기

지웅쓰 2023. 6. 25. 15:43

해당 논문에 대한 흥미를 느껴 코드를 한번 살펴보려고 한다. 

논문에 대한 리뷰는 다음과 같다.

https://keepgoingrunner.tistory.com/146

 

비전공생의 Adversarial Learning for Semi supervised Semantic Segmentation(2018)논문 리뷰

https://arxiv.org/pdf/1802.07934v2.pdf SSL을 segmentation에 적용하고 싶어서 읽을 첫 논문이다. 시험기간인데 공부안하고 논문 읽어 버리기~ Introduction FCN 방법과 추가적인 모듈은 semantic segmentation 에서 SOTA를

keepgoingrunner.tistory.com

 

 

https://github.com/hfslyc/AdvSemiSeg/blob/master/model/discriminator.py

 

GitHub - hfslyc/AdvSemiSeg: Adversarial Learning for Semi-supervised Semantic Segmentation, BMVC 2018

Adversarial Learning for Semi-supervised Semantic Segmentation, BMVC 2018 - GitHub - hfslyc/AdvSemiSeg: Adversarial Learning for Semi-supervised Semantic Segmentation, BMVC 2018

github.com

코드 출처는 다음과 같다.

 

backbone segmenatation 모델은 Resnet을 기반으로 한 Deeplab을 사용하였고 Discriminator는 기본 CNN을 사용하였다. generator에서 만든 class개수의 채널을 1개의 채널로 변환하는(확률로 변환) 모델이다.

 

 내가 분해해볼 코드는 Train 관련하여 loss를 어떻게 적용하냐이다. 코드가 길어서 필요한 부분만 발췌해서 가져왔다.

optimizer = optim.SGD(model.optim_parameters(args),
                lr=args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay)
optimizer.zero_grad()

    # optimizer for discriminator network
optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99))
optimizer_D.zero_grad()

generator와 Discriminator 최적화 함수이다. SGD, Adam 을 사용하였다.

 

bce_loss = BCEWithLogitsLoss2d()
interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

#example :  pred = interp(model(images))

여기서 정의된 bce loss는 torch함수가 아닌 따로 정의한 함수이다. 그리고 interp 는 연산과정에서 크기가 달라지는 것에 대해 보간법을 해준다.

 

for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False
                
                
            # do semi first

            images, _, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            
            pred = interp(model(images))
            pred_remain = pred.detach()
            
            D_out = interp(model_D(F.softmax(pred)))
            D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1)

처음에는 Generator를 훈련하기 위해서 discriminator의 gradient가 학습되지 않도록 off시킨다. 그리고 gt가 없는 unlabeled 데이터에 대해서 먼저 loss를 계산한다.

input으로 받은 image를 generator에 넣어 segmentation map을 생성한다. 그리고 이 map이 기존 이미지의 크기와 같도록 맞추어주고 pred라는 변수에 저장한다.

그리고 해당 map을 discriminator에 넣어준다.

 

Semi Adversarial Loss

ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool)

loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, make_D_label(gt_label, ignore_mask_remain))
loss_semi_adv = loss_semi_adv/args.iter_size

 

def make_D_label(label, ignore_mask):
    ignore_mask = np.expand_dims(ignore_mask, axis=1)
    D_label = np.ones(ignore_mask.shape)*label
    D_label[ignore_mask] = 255
    D_label = Variable(torch.FloatTensor(D_label)).cuda(args.gpu)

    return D_label

여기서 make D label 함수를 통해 gt label(1 값) 크기를 만들어주고, discriminator의 map과 해당 크기의 1로 이루어진 행렬과 BCE loss를 활용해 Adversarial loss를 발생시켜준다. 이 부분이 논문읽을 때도 잘 이해안됐는데 코드를 보아도 정확히 이해가 잘 되지는 않는다. 한번 정리해보면 이미지를 받은 generator는 segmentation map을 생성할 것이고 discriminator는 이 map을 받고 map에서 sigmoid를 통해(D_out) 각 픽셀들이 generator가 생성한 것인지, gt값인지 구분을 할 것이다(0이면은 generator 가 만든 것이라고 판단, 1이면은 gt라고 판단). 그리고 이 해당 부분은 generator 만 훈련하는 부분이고 해당 discriminator 를 속일 수 있도록, 그럴싸 하도록 만들어야 하기 때문에 bce_loss에 들어간 인자 make_D_label 함수를 통해 1로 가득찬 행렬을 만들고 이를 모두다 gt가 만들었다고 속이는 것이다. 

여기서 코드를 보면서 헷갈렸던 부분이 있다. 그렇다면 모든 값을 1로 만들도록 하면은 segmentation map에서 모든 map을 마킹하라는 것인가? 아니다. segmentation map도 0~1로 할당이 되고 discriminator도 0~1로 할당이 되서 헷갈리긴 한다. 그러나 generator가 왼쪽 제일 모서리를 일반 배경이라 예측하면 0값이지만 이를 discriminator가 gt 값인줄 알고 속으면 값이 discriminator가 뱉는 값은 1인 것이다. 그렇기 때문에 generator는 이러한 과정을 통해 discrimnator도 훈련을 하므로 더 gt값과 비슷하게 맞춰지는 것이다(그리고 make D label함수에서 255라는 값에 혼동이 될 수 있는데 Semi Adversarial loss를 계산할 때는 이 255값이 전혀 활용되지 않는다. 그이유는 semi adversarial 에서 ignore mask는 다 False처리 되어 있기 때문이다).

 

Semi Segsmentation Loss

semi_ignore_mask = (D_out_sigmoid < args.mask_T)

semi_gt = pred.data.cpu().numpy().argmax(axis=1)
semi_gt[semi_ignore_mask] = 255

loss_semi = args.lambda_semi * loss_calc(pred, semi_gt, args.gpu)

Discriminator에서 나온 D_out 과 threshold값을 통해 False, True 행렬을 만들어주는데, 여기서 True는 sigmoid 값이 낮은 값들이라고 볼 수 있다(여기서 sigmoid 값이 낮은 것은 discriminator가 gt라고 생각하지 않는다는 것이다. 조금 이상하긴 하지만 여기 loss term 에서는 discriminator를 믿고 gt라고 생각되는 것들에 대해 훈련하는 거 같다). semi gt는 generator는 H x W x categories 만큼의 map을 뱉는데 여기서 가장 높은 값들만 남기는 것을 알 수 있고 이것을 정답이라고 본다고 할 수 있다(pseudo labeling). 마지막에는 Cross entropy를 통해 semi gt와 generator가 내뱉은 값에 대해 loss값을 계산한다.

여기서 lamda 값은 하이퍼파리미터이다.

 

Supervised  Loss

images, labels, _, _ = batch
images = Variable(images).cuda(args.gpu)
ignore_mask = (labels.numpy() == 255)
pred = interp(model(images))

loss_seg = loss_calc(pred, labels, args.gpu)

D_out = interp(model_D(F.softmax(pred)))

loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask))

loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

이제 label map이 있는 데이터와 훈련한다. loss calc를 통해 label 값과 pred 에 대한 loss를 내뱉는다.

 

그리고 adversarial loss가 semi, supervised 둘 다 있는 것을 확인할 수 있는데,  여기서는  make D label을 통해 masking된 부분을 제외하고 Discriminator가 내뱉은 값과 함께 loss를 생성한다. 여기서 무시되는 mask를 찾아본다면 배경을 뜻하는 255 부분이 ignore 되는 것을 확인할 수 있다.

 

 

def loss_calc(pred, label, gpu):
    """
    This function returns cross entropy loss for semantic segmentation
    """
    # out shape batch_size x channels x h x w -> batch_size x channels x h x w
    # label shape h x w x 1 x batch_size  -> batch_size x 1 x h x w
    label = Variable(label.long()).cuda(gpu)
    criterion = CrossEntropy2d().cuda(gpu)

    return criterion(pred, label)

 

for param in model_D.parameters():
                param.requires_grad = True

이제는 discriminator를 훈련한다.

 

if args.D_remain:
        pred = torch.cat((pred, pred_remain), 0)
        ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0)
D_out = interp(model_D(F.softmax(pred)))
loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
loss_D = loss_D/args.iter_size/2
loss_D.backward()
loss_D_value += loss_D.data.cpu().numpy()[0]

여기서 pred값과 pred remian값을 axis = 0 으로 붙인다.

여기서 pred label값은 0인데, 이를 실제 discriminator가 예측한 map과의 BCE loss를 계산한다. 여기서의 map은 generator가 생성한 map이므로 다 가짜라고 생각한다. 그리고 ignore mask는 배경인데 이를 mask 처리한 것을 알 수 있다.

 


지금까지 3개의 loss term 을 알아봤다. 그래도 직접 글로 쓰면서 이해하니 개념이 어느 정도 이해되는 거 같다.