히비스서커스의 블로그

[Python] TTA (Test Time Augmentation) 적용하기 본문

Programming/Python

[Python] TTA (Test Time Augmentation) 적용하기

HibisCircus 2022. 8. 11. 10:47
728x90

TTA (Test Time Augmentation) in semantic segmentation

TTA(Test Time Augmentation)란? 

TTA이란 Train 과정이 아닌 Test (Inference) 과정에서 Augmentation을 적용하여 나온 결과들에 대해 대표값 (대체로 평균)을 도출하여 결과값이 더욱 Robust하게 만드는 기법을 말한다. segmentation, classification, super-resolution 등과 같은 Computer Vision 문제를 해결하는데 사용된다. 대체로 TTA를 적용하였을 경우 더 결과가 좋아지는 경우가 많아 Kaggle과 같은 경진대회에서도 많이 사용된다.

 

TTA 적용하기

Pytorch를 Framework로 사용하여 Segmentation Task를 진행한다고 하였을 때 TTA를 적용하는 코드를 간략하게 정리해보았다. 이때, 대표적으로 사용할 수 있는 라이브러리로 qubvel 님의 ttachandrewekhalel 님의 edafa가 있다.

 

segmentation model은 input으로 (H,W,3)의 이미지와 (H,W)의 마스크로 학습하여 (H,W,C)의 output을 가지는 모델을 사용하였다. H는 Height, W는 Width, C는 Class를 뜻한다.

 

ttach

 

먼저 ttach를 이용한 TTA는 매우 간단하다. 

import ttach as tta
import torch

# device setting
GPU = True
device = 'cuda' if GPU and torch.cuda.is_availabel() else 'cpu'

# TTA setting
transforms = tta.Compose(
	[
    	tta.HorizontalFlip(),
        tta.VerticalFlip(),
    ],
)

# model setting
model = torch.load('./segmentation_model.pth')
model = tta.SegmentationTTAWrapper(model, transforms)

다음과 같이 코드를 구성해주면 Test (Inference)할 시 결과가 TTA가 적용된 결과이다.

 

edafa

 

모델의 output이 조금 다르거나 커스텀을 해주고 싶은 경우 이 라이브러리를 사용하면 좋은 듯 하다. 

from edafa         import SegPredictor
import numpy       as np
import torch
import cv2

# model & img setting
model = torch.load('./segmentation_model.pth')
img = cv2.imread('./test_img.jpg')

# configure setting
conf = '{\
	"augs" : ["NO", "ROT180", "FLIP_UD", "FLIP_LR"],\
    "mean" : "ARITH",\
    "bits" : 8\
}'

# Predictor setting
class myPredictor(SegPredictor):
	def __init__(self, model, *args, **kwargs):
    	super().__init__(*args, **kwargs)
        self.model = model
        
    def predict_patches(self, patches):
    	preds = np.zeros((A, H, W, C))  # A: Number of Augmentation (conf: augs)
        for i in range(patches.shape[0]):
        	image = torch.from_numpy(patches[i,...]).float().to('cuda').unsqueeze(0).permute(0,3,1,2)
            seg_map = self.model(image)
            pred = seg_map.squeeze().permute(1,2,0).cpu().detach().numpy().round()
            preds[i,...] = pred
        return preds

# output to results
def output2results(model, H, C, conf, img):
	p = myPredictor(model, H, C, conf)
    output = p.predict_images([img])
    results = np.argmax(output.squeeze(), axis=2)
    return results

# results
A, H, W, C = 4, 512, 512, 6
results = output2results(model, H, C, conf, img)

마찬가지로 results 부분이 TTA가 적용된 결과이다.

 

 

- 히비스서커스 -

728x90