히비스서커스의 블로그

[error] RuntimeError: Expected object of scalar type Long but got scalar type Float when using CrossEntropyLoss (feat. RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size) 본문

Programming/Python

[error] RuntimeError: Expected object of scalar type Long but got scalar type Float when using CrossEntropyLoss (feat. RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size)

HibisCircus 2022. 6. 14. 16:39
728x90

 

torch로 multi class semantic segmentation model을 학습 중이다. 원래는 DiceLoss를 사용하여 Train을 하다가 CrossEntropyLoss에 Class Weight을 주어 다시 Train하려는데 다음과 같은 에러가 발생하였다.

 

 

RuntimeError: Expected object of scalar type Long but got scalar type Float when using CrossEntropyLoss

 

이를 해결방법을 찾던 중 loss를 구하는 부분에서 y부분 (ground truth 부분)에 

 

y.to(dtype=torch.long)

 

을 해주어 해결이 가능하다고 하여 적용하였으나 바로 다음과 같은 에러를 마주쳤다.

RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size)

 

이는 예전에 내가 마주쳤던 에러 중 하나인데 model의 클래스 설정과 class weight의 클래스 설정이 다를 경우 발생하는 에러다. 그런데 클래스를 같게 해주었음에도 에러가 계속 발생하였다.

 

https://biology-statistics-programming.tistory.com/173

 

[Pytorch] RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size

상황 torch로 segmentation model을 돌리는 상황에서 발생하였다. 대략적인 코드 # import unet from seg_model.py from seg_model import unet import torch # output 3 classes model = unet(class=3) # loss 4..

biology-statistics-programming.tistory.com

 

다른 서버에서는 맨 처음 에러 발생도 없이 잘 Train되는 것을 보고서 이제서야 torch의 버전이 다름을 깨달았다.

 

shell에서 torch 버전확인 코드

$ python
>>> import torch
>>> torch.__version__

 

 

잘 돌아가는 서버의 torch 버전은 1.10.0+cu113 이었고, 잘 돌아가지 않는 서버의 torch 버전은

1.8.2+cu111 버전이었다.

 

간단하게 torch, torchvision, torchaudio를 밀고 다시 재설치 하였다.

$ pip uninstall torch torchvision torchaudio

torch의 cuda와 맞게 설치하려면 이쪽 페이지를 참고하길 바란다.

https://pytorch.org/get-started/locally/

 

PyTorch

An open source machine learning framework that accelerates the path from research prototyping to production deployment.

pytorch.org

 

$ pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

 

설치 후 맨 처음 코드로 실행을 하니 잘 돌아간다!

 

 

 

- 히비스서커스 -

728x90