히비스서커스의 블로그

[Pytorch] RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size 본문

Programming/Python

[Pytorch] RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size

HibisCircus 2021. 12. 31. 17:49
728x90

 

상황

torch로 segmentation model을 돌리는 상황에서 발생하였다.

 

대략적인 코드

# import unet from seg_model.py
from seg_model import unet
import torch

# output 3 classes
model = unet(class=3)

# loss 4 classes
weights = torch.tensor([1.2, 2.6, 7.5, 17.0], dtype=torch.float32)
loss = torch.nn.CrossEntropyLoss(weight=weights)

 

에러메시지

RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size

 

원인

segmentation model은 3개의 class를 가지도록 학습하였으나 loss function에서 class weight을 줄 때 4개의 class로 하였기 때문

 

해결방법

segmentation model의 클래스와 loss function에서 class weight의 리스트 수 (클래스 수)를 동일하게 해준다.

 

수정된 코드 

# import unet from seg_model.py
from seg_model import unet
import torch

# output 4 classes
model = unet(class=4)

# loss 4 classes
weights = torch.tensor([1.2, 2.6, 7.5, 17.0], dtype=torch.float32)
loss = torch.nn.CrossEntropyLoss(weight=weights)

 

 

잘 해결되었다.

 

- 히비스서커스 - 

728x90