1 분 소요

Class를 이용해서 선형 회귀 모델 구현하기

이번에는 Class를 이용해서 선형 회귀 모델을 구현 해보도록 하겠습니다.

앞에서 우리가 배웠던 nn.module을 이용한 것과 똑같습니다.

다른 점이 있다면 Class를 선언해서 사용한다는 것 빼고는 없기 때문에 Class 활용 부분만 따로 주의 깊게 보면 되겠습니다.

???:그냥 nn.Linear()써서 하면 되는거아님? 왜 굳이 Class를 씀?

클래스를 사용한 모델 구현 방식은 파이토치 구현체에서 사용하고 있는 방식이기 때문에 숙지해야할 필요가 있기 때문에 Class를 사용해서 하는 것 입니다.


  • 클래스를 이용해서 단순 선형 회귀 모델 만들기

Untitled

기본적인 라이브러리 셋팅과 시드 설정은 기존과 똑같이 해주시면 됩니다.

Untitled 1

클래스 부분이 중요하기 때문에 클래스 부분만 조금 자세하게 설명하면서 다루도록 하겠습니다.

먼저 nn.Module을 상속 받는 클래스로 선언을 해줍니다.

그리고 def init(self)를 통해 생성자를 정의합니다.

이 생성자의 역할은 super().init()을 이용하여 부모 클래스 즉 nn.Module의 생성자를 호출한 뒤 초기화를 진행 해줍니다.

self.linear에 nn.Linear(1,1)을 저장해줌으로써 self.linear에 입력 feature가 1개, 출력 feature가 1개인 선형 회귀 모델을 저장해주도록 합니다.

다음은 foward함수를 봐보도록 하겠습니다.

foward함수는 입력 데이터인 x를 받은 뒤 해당 데이터를 선형 회귀 모델에 입력 데이터로 넣어 나온 결과 값은 반환 해주는 함수입니다.

이 클래스를 요약하자면 nn.Module을 상속받는 클래스이며 해당 클래스를 사용할 때 입력 데이터 x를 넣어주면 x를 자동으로 선형 회귀 모델에 넣은 뒤 데이터 x에 대한 결과 값 y 자동으로 반환을 해주는 클래스라고 생각하시면 됩니다.

Untitled 2

이후로는 우리가 앞서 했던 것과 똑같습니다.

optimizer는 경사 하강법을 사용해주시면 됩니다.

이 부분은 자세히 설명하지 않겠습니다.


  • 클래스를 이용해 다중 선형 회귀 모델 만들기

이번에는 클래스를 활용해서 다중 선형 회귀 모델을 만들어 보겠습니다.

사실 다중 선형 회귀 모델이라고 해서 크게 달라지는 것은 없습니다.

Untitled 3

데이터는 우리가 앞에서 사용했던 데이터를 그대로 사용하도록 하겠습니다.

Untitled 4

그리고 다중 선형 회귀라고 해서 달라지는 것은 없습니다.

model=nn.Linear(1,1)
#->
model=nn.Linear(3,1)

위와 같이 model이라는 변수 부분만 바꾸어 주시면 됩니다.

왜냐하면 다중 선형 회귀 모델에서 사용하는 데이터의 feature가 3개 이기 때문에 단순 선형 회귀 모델에서의 입력 데이터의 feature의 개수만 1에서 3으로 바꾸어주시면 클래스를 이용한 다중 선형 모델을 만들 수 있습니다.

Untitled 5

그 이외에는 learning rate만 바뀌었을 뿐 그 외에는 단순 선형 회귀 모델과 같은 코드입니다.

댓글남기기