일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
Tags
- 백신후원
- vscode
- 사회조사분석사2급
- linear regression
- aiffel exploration
- Decision Boundary
- docker exec
- cocre
- logistic regression
- CellPin
- WSSS
- AIFFEL
- Multi-Resolution Networks for Semantic Segmentation in Whole Slide Images
- IVI
- Pull Request
- 도커
- GIT
- Jupyter notebook
- 프로그래머스
- cs231n
- 기초확률론
- airflow
- numpy
- docker attach
- 머신러닝
- HookNet
- ssh
- 히비스서커스
- docker
- 코크리
Archives
- Today
- Total
히비스서커스의 블로그
[Pytorch] RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size 본문
Programming/Python
[Pytorch] RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size
HibisCircus 2021. 12. 31. 17:49728x90
상황
torch로 segmentation model을 돌리는 상황에서 발생하였다.
대략적인 코드
# import unet from seg_model.py
from seg_model import unet
import torch
# output 3 classes
model = unet(class=3)
# loss 4 classes
weights = torch.tensor([1.2, 2.6, 7.5, 17.0], dtype=torch.float32)
loss = torch.nn.CrossEntropyLoss(weight=weights)
에러메시지
RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size
원인
segmentation model은 3개의 class를 가지도록 학습하였으나 loss function에서 class weight을 줄 때 4개의 class로 하였기 때문
해결방법
segmentation model의 클래스와 loss function에서 class weight의 리스트 수 (클래스 수)를 동일하게 해준다.
수정된 코드
# import unet from seg_model.py
from seg_model import unet
import torch
# output 4 classes
model = unet(class=4)
# loss 4 classes
weights = torch.tensor([1.2, 2.6, 7.5, 17.0], dtype=torch.float32)
loss = torch.nn.CrossEntropyLoss(weight=weights)
잘 해결되었다.
- 히비스서커스 -
728x90