일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
Tags
- HookNet
- WSSS
- 도커
- Decision Boundary
- cocre
- 기초확률론
- 사회조사분석사2급
- AIFFEL
- logistic regression
- 프로그래머스
- 티스토리챌린지
- Pull Request
- aiffel exploration
- Jupyter notebook
- GIT
- 오블완
- ssh
- CellPin
- 백신후원
- 히비스서커스
- docker exec
- 코크리
- docker
- Multi-Resolution Networks for Semantic Segmentation in Whole Slide Images
- airflow
- docker attach
- vscode
- cs231n
- IVI
- numpy
Archives
- Today
- Total
히비스서커스의 블로그
[Python] TTA (Test Time Augmentation) 적용하기 본문
728x90
TTA(Test Time Augmentation)란?
TTA이란 Train 과정이 아닌 Test (Inference) 과정에서 Augmentation을 적용하여 나온 결과들에 대해 대표값 (대체로 평균)을 도출하여 결과값이 더욱 Robust하게 만드는 기법을 말한다. segmentation, classification, super-resolution 등과 같은 Computer Vision 문제를 해결하는데 사용된다. 대체로 TTA를 적용하였을 경우 더 결과가 좋아지는 경우가 많아 Kaggle과 같은 경진대회에서도 많이 사용된다.
TTA 적용하기
Pytorch를 Framework로 사용하여 Segmentation Task를 진행한다고 하였을 때 TTA를 적용하는 코드를 간략하게 정리해보았다. 이때, 대표적으로 사용할 수 있는 라이브러리로 qubvel 님의 ttach와 andrewekhalel 님의 edafa가 있다.
segmentation model은 input으로 (H,W,3)의 이미지와 (H,W)의 마스크로 학습하여 (H,W,C)의 output을 가지는 모델을 사용하였다. H는 Height, W는 Width, C는 Class를 뜻한다.
ttach
먼저 ttach를 이용한 TTA는 매우 간단하다.
import ttach as tta
import torch
# device setting
GPU = True
device = 'cuda' if GPU and torch.cuda.is_availabel() else 'cpu'
# TTA setting
transforms = tta.Compose(
[
tta.HorizontalFlip(),
tta.VerticalFlip(),
],
)
# model setting
model = torch.load('./segmentation_model.pth')
model = tta.SegmentationTTAWrapper(model, transforms)
다음과 같이 코드를 구성해주면 Test (Inference)할 시 결과가 TTA가 적용된 결과이다.
edafa
모델의 output이 조금 다르거나 커스텀을 해주고 싶은 경우 이 라이브러리를 사용하면 좋은 듯 하다.
from edafa import SegPredictor
import numpy as np
import torch
import cv2
# model & img setting
model = torch.load('./segmentation_model.pth')
img = cv2.imread('./test_img.jpg')
# configure setting
conf = '{\
"augs" : ["NO", "ROT180", "FLIP_UD", "FLIP_LR"],\
"mean" : "ARITH",\
"bits" : 8\
}'
# Predictor setting
class myPredictor(SegPredictor):
def __init__(self, model, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = model
def predict_patches(self, patches):
preds = np.zeros((A, H, W, C)) # A: Number of Augmentation (conf: augs)
for i in range(patches.shape[0]):
image = torch.from_numpy(patches[i,...]).float().to('cuda').unsqueeze(0).permute(0,3,1,2)
seg_map = self.model(image)
pred = seg_map.squeeze().permute(1,2,0).cpu().detach().numpy().round()
preds[i,...] = pred
return preds
# output to results
def output2results(model, H, C, conf, img):
p = myPredictor(model, H, C, conf)
output = p.predict_images([img])
results = np.argmax(output.squeeze(), axis=2)
return results
# results
A, H, W, C = 4, 512, 512, 6
results = output2results(model, H, C, conf, img)
마찬가지로 results 부분이 TTA가 적용된 결과이다.
- 히비스서커스 -
728x90