히비스서커스의 블로그

[Python] AttributeError: 'model' object has no attribute 'predict' (feat. pytorch에서 직접 설계한 모델) 본문

Programming/Python

[Python] AttributeError: 'model' object has no attribute 'predict' (feat. pytorch에서 직접 설계한 모델)

HibisCircus 2021. 11. 10. 14:28
728x90

pytorch로 모델을 직접 설계하여 모델을 학습시킨 후 model.predict()을 하였으나 model에는 predict 특성이 없다는 에러가 발생하였다.

 

상황

설계한 모델형식 (간단하게 줄여서 나타냄)

import torch
import torch.nn as nn

class simple_model(nn.Module):
     def __int__(self, inchannels, out_channels, kernel_size, stride, padding):
    	super().__init__()
        
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
    	x = self.conv(x)
        x = self.relu(x)
        
        return x

다음 모델 학습을 시키고 가중치를 저장 후 불러와 탑재하였다. 자세한 과정은 생략

 

simple_model = torch.load('./model_path.pth')
simple_model.predict()

결과는 

AttributeError: 'model' object has no attribute 'predict' 

이와 같은 에러가 발생하였다.

 

 

해결방법

class를 만들 때 predict 함수를 따로 지정해주지 않았기 때문이다. keras와는 달리 pytorch에서는 predict함수를 따로 지정해주지 않는다면 사용할 수 없다. predict 함수를 만드는 방법은 간단하게 아래를 class에 포함시켜주면 될 것이다.

 

def predict(self, x):
    x = self.forward(x)
    
    return x

 

이 방식보다는 아래와 같이 model에 입력으로 tensor를 주는 것이 좀 더 좋을 것이다.

simple_model(tensor)

 

여기서 tensor는 data generator 단에서 생성된 형식과 동일한 형식이다!

728x90