히비스서커스의 블로그

[Pytorch] RuntimeError: mat1 and mat2 shapes cannot be multiplied (16x204800 and 2048x4) 본문

Programming/Python

[Pytorch] RuntimeError: mat1 and mat2 shapes cannot be multiplied (16x204800 and 2048x4)

HibisCircus 2023. 2. 1. 16:33
728x90

 

상황

 

classification model 학습을 위해 pretrained 모델을 사용하기 위해 output을 바꿔주려는 중 다음과 같은 에러를 마주쳤다.

 

코드

 

# model의 코드는 https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py 참조
from model      import se_resnext101_32x4d
import torch.nn as nn

CLASS = 4

model = se_resnext101_32x4d(pretrained='imagenet')
model.last_linear = nn.Linear(in_features=2048, out_features=CLASS)

 

에러메시지

 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (16x204800 and 2048x4)

 

해결방법

 

위의 에러 메시지의 16x204800에서 16은 batch size이고 204800은 tensor의 가로,세로,채널을 곱해준 값이며,

2048X4에서 2048은 변경 전의 tensor크기와 동일해야 하므로 204800이 되어야 하며 4는 바꾸려는 class의 크기이다.

 

따라서, in_features의 2048을 204800으로 변경해주면 된다.

 

 

 

 

728x90