* 교재: 혼자 공부하는 머신러닝+딥러닝 (hanbit.co.kr)
* 문제: 선형 회귀를 사용하여 훈련 세트 범위 밖의 샘플을 예측한다.
k-최근접 이웃 회귀와 선형 회귀의 문제점을 경험해 보고 이를 해결하기 위해 다항 회귀를 사용하여 농어의 무게를 예측하는 모델을 만든다.
* 문제 해결 과정
1. 농어 데이터 준비하기
2. 훈련 세트와 테스트 세트 만들기
3. KNeighborsRegressor 클래스를 임포트 한 후 객체 만들기
4. 모델 훈련하기
5. 선형 회귀
6. 다항 회귀
1. 농어 데이터 준비하기
농어 56마리의 길이와 무게 리스트를 준비해 보자.
import numpy as np
perch_length = np.array([8.4, 13.7, 15.0, 16.2, 17.4, 18.0, 18.7, 19.0, 19.6, 20.0, 21.0,
21.0, 21.0, 21.3, 22.0, 22.0, 22.0, 22.0, 22.0, 22.5, 22.5, 22.7,
23.0, 23.5, 24.0, 24.0, 24.6, 25.0, 25.6, 26.5, 27.3, 27.5, 27.5,
27.5, 28.0, 28.7, 30.0, 32.8, 34.5, 35.0, 36.5, 36.0, 37.0, 37.0,
39.0, 39.0, 39.0, 40.0, 40.0, 40.0, 40.0, 42.0, 43.0, 43.0, 43.5,
44.0])
perch_weight = np.array([5.9, 32.0, 40.0, 51.5, 70.0, 100.0, 78.0, 80.0, 85.0, 85.0, 110.0,
115.0, 125.0, 130.0, 120.0, 120.0, 130.0, 135.0, 110.0, 130.0,
150.0, 145.0, 150.0, 170.0, 225.0, 145.0, 188.0, 180.0, 197.0,
218.0, 300.0, 260.0, 265.0, 250.0, 250.0, 300.0, 320.0, 514.0,
556.0, 840.0, 685.0, 700.0, 700.0, 690.0, 900.0, 650.0, 820.0,
850.0, 900.0, 1015.0, 820.0, 1100.0, 1000.0, 1100.0, 1000.0,
1000.0])
2. 훈련 세트와 테스트 세트 만들기
# 훈련 세트와 테스트 세트 나누기
from sklearn.model_selection import train_test_split
train_input, test_input, train_target, test_target = train_test_split(perch_length, perch_weight, random_state=42)
# 훈련 세트와 테스트 세트를 2차원 배열로 만들기
train_input = train_input.reshape(-1, 1)
test_input = test_input.reshape(-1, 1)
3. KNeighborsRegressor 클래스를 임포트 한 후 객체 만들기
k-최근접 이웃 회귀 모델을 만드는 사이킷런 클래스인 KNeighborsRegressor를 임포트 한 후 KNeighborsRegressor 클래스의 객체를 만들어보자.
from sklearn.neighbors import KNeighborsRegressor # KNeighborsRegressor 임포트
knr = KNeighborsRegressor(n_neighbors=3) # knr 객체 생성
4. 모델 훈련하기
최근접 이웃 개수를 3으로 하는 모델을 훈련해 보자.
knr.fit(train_input, train_target)
print(knr.predict([[50]]))
# 출력: [1033.33333333]
50cm의 농어의 무게를 1033g 정도로 예측했다.
하지만 실제 이 농어의 무게는 훨씬 더 많이 나간다고 한다...
문제가 무엇인지 알아보기 위해 산점도를 그려 확인해 보자.
import matplotlib.pyplot as plt
distances, indexes = knr.kneighbors([[50]]) # 샘플 [50]에서 가까운 이웃 찾기
plt.scatter(train_input, train_target)
plt.scatter(train_input[indexes], train_target[indexes], marker='D') # 아웃은 마름모
plt.scatter(50, 1033, marker='^') # 50cm 농어는 삼각형
plt.xlable('length')
plt.ylable('weight')
plt.show()
kneighbors() 메서드는 k-최근접 이웃 객체의 메서드로, 입력한 데이터에 가장 가까운 이웃을 찾아 거리와 이웃 샘플의 인덱스를 반환한다.
길이 50cm에 무게 1033g인 농어를 구분하기 쉽게 marker 매개변수를 '^'으로 지정하여 삼각형으로 표현하였고, 이 농어의 최근접 이웃을 마름모(marker='D')로 표현하였다.
50cm 농어에서 가장 가까운 것은 45cm 근방의 샘플들이기 때문에 k-최근접 이웃 회귀는 이 샘플들의 타깃을 평균한다. 즉 100cm의 농어도 1033g으로 예측할 것이다.
k-최근접 이웃을 사용해 이 문제를 해결하려면 가장 큰 농어가 포함되도록 훈련 세트를 다시 만들어야 한다.
k-최근접 이웃이 아닌 선형 회귀로 문제를 해결해 보자.
5. 선형 회귀
선형 회귀는 대표적인 회귀 알고리즘으로, 특성이 하나인 경우 어떤 직선을 학습하는 알고리즘이다.
사이킷런은 sklearn.linear_model 패키지 아래에 LinearRegression 클래스로 선형 회귀 알고리즘을 구현해 놓았다.
LinearRegression 클래스를 임포트 한 후 객체를 만들고 모델을 훈련해 보자.
from sklearn.linear_model import LinearRegression # LinearRegression 임포트
lr = LinearRegression() # lr 객체 생성
lr.fit(train_input, train_target) # 선형 회귀 모델 훈련
print(lr.predict([[50]])) # 50cm 농어 예측
# 출력: [1241.83860323]
[1241.83860323]이 출력되었다.
k-최근접 이웃 회귀를 사용했을 때와 달리 선형 회귀는 50cm 농어의 무게를 높게 예측하였다!
선형 회귀가 예측한 값 [1241.83860323]이 어떻게 나왔는지 확인해 보자.
하나의 직선을 그리려면 기울기와 절편이 있어야 한다. 즉 y = a * x + b로 나타낼 수 있다.
여기서 x를 농어의 길이, y를 농어의 무게로 바꾸면 위의 그림과 같다.
LinearRegression 클래스가 데이터에 가장 잘 맞는 a와 b를 찾았는지 확인해 보자.
print(lr.coef_, lr.intercept_)
# 출력: [39.01714496] -709.0186449535477
LinearRegression 클래스가 찾은 a와 b는 lr 객체의 coef_와 intercept_ 속성에 저장되어 있다.
coef_와 intercept_는 머신러닝 알고리즘이 찾은 값이라는 의미로 모델 피라미터라고 한다.
모델 피라미터는 선형 회귀가 찾은 가중치처럼 머신러닝 모델이 특성에서 학습한 피라미터를 말한다.
훈련 세트의 산점도와 선형 회귀가 학습한 직선을 그려보자.
# 훈련 세트의 산점도
plt.scatter(train_input, train_target)
# 15에서 50까지의 1차 방정식 그래프
plt.plot([15, 50], [15*lr.coef_+lr.intercept_, 50*lr.coef_+lr.intercept_])
# 50cm 농어 데이터
plt.scatter(50, 1241.8, marker='^')
plt.xlable('length')
plt.ylable('weight')
plt.show()
농어의 길이 15에서 50까지 직선을 그리기 위해 기울기와 절편을 사용하여 (15, 15*39-709)와 (50, 50*39-709) 두 점을 이었더니 선형 회귀 알고리즘이 찾은 최적의 직선이 나왔다. 이제 훈련 세트를 벗어난 농어의 무게도 예측할 수 있게 되었다.
이제 모델을 평가해 보자.
print(lr.score(train_input, train_target)) # 훈련 세트 점수 출력: 0.939846333997604
print(lr.score(test_input, test_target)) # 테스트 세트 점수 출력: 0.8247503123313558
테스트 세트가 훈련 세트보다 점수가 낮게 나왔으므로 과소적합되었다.
더 좋은 모델을 만들기 위해 선형 회귀가 만든 직선의 문제점을 찾아보자.
직선이 왼쪽 아래로 뻗어 있기 때문에 농어의 무게가 0g 이하로 내려갈 수밖에 없다.
또한 산점도를 보면 일직선이 아닌 곡선에 가깝다.
다항 회귀를 사용하여 문제를 해결해 보자.
6. 다항 회귀
다항회귀는 다항식을 사용한 선형 회귀이다.
다항회귀를 사용하여 최적의 직선이 아닌 최적의 곡선을 그려보자.
이런 2차 방정식의 그래프를 그리려면 길이를 제곱한 항이 훈련 세트에 추가되어야 한다.
농어의 길이를 제곱하여 원래 데이터 앞에 붙여 보자.
train_poly = np.column_stack((train_input ** 2, train_input))
test_poly = np.column_stack((test_input ** 2, test_input))
np.column_stack() 함수는 리스트를 일렬로 세운 다음 차례대로 나란히 연결한다.
np.column_stack() 함수로 train_input을 제곱한 것과 train_input 배열을 나란히 연결하고, test_inpput을 제곱한 것과 test_input 배열을 나란히 연결하였다.
새롭게 만든 데이터셋의 크기를 확인해 보자.
print(train_poly.shape, test_poly.shape)
# 출력: (42, 2) (14, 2)
길이를 제곱하여 왼쪽 열에 추가했기 때문에 훈련 세트와 테스트 세트 모두 열이 2개로 늘어났다.
이제 train_poly로 선형 회귀 모델을 다시 훈련해 보자.
lr = LinearRegression()
lr.fit(train_poly, train_target) # 훈련
print(lr.predict([[50**2, 50]])) # 50cm 농어 예측
# 출력: [1573.98423528]
위의 선형 회귀가 예측한 [1241.83860323]보다 더 높은 값인 [1573.98423528]을 예측하였다.
훈련 세트의 산점도와 다항 회귀가 학습한 곡선을 그려보자.
# 구간별 직선을 그리기 위한 15에서 49까지의 정수 배열 만들기
point = np.arange(15, 50)
# 훈련 세트의 산점도
plt.scatter(train_input, train_target)
# 15에서 49까지의 2차 방정식 그래프
plt.plot(point, 1.01*point**2 - 21.6*point + 116.05)
# 50cm 농어 데이터
plt.scatter(50, 1574, marker='^')
plt.xlable('length')
plt.ylable('weight')
plt.show()
농어의 무게가 0g 이하로 내려갈 수 없는 곡선 그래프가 나왔다!
마지막으로 모델을 평가해 보자.
print(lr.score(train_poly, train_target)) # 훈련 세트 점수 출력: 0.9706807451768623
print(lr.score(test_poly, test_target)) # 테스트 세트 점수 출력: 0.9775935108325122
훈련 세트와 테스트 세트에 대한 점수가 크게 높아졌다!
하지만 여전히 테스트 세트의 점수가 조금 더 높으므로 과소적합이 남아있다. 조금 더 복잡한 모델이 필요할 것 같다.
* 용어 정리
1. 선형 회귀(linear regression): 특성이 하나인 경우 어떤 직선을 학습하는 대표적인 회귀 알고리즘으로, 특성과 타깃 사이의 관계를 가장 잘 나타내는 선형 방정식을 찾음
2. LinearRegression: 사이킷런의 선형 회귀 클래스
3. 모델 피라미터: 선형 회귀가 찾은 가중치처럼 머신러닝 모델이 특성에서 학습한 피라미터를 말함
4. 다항 회귀(polynomial): 다항식을 사용한 선형 회귀로, 다항식을 사용하여 특성과 타깃 사이의 관계를 나타냄
'Python > ML' 카테고리의 다른 글
[머신러닝] 로지스틱 회귀(생선 확률 예측) (0) | 2023.02.10 |
---|---|
[머신러닝] 릿지, 라쏘(농어 무게 예측3) (0) | 2023.01.31 |
[머신러닝] k-최근접 이웃 회귀(농어 무게 예측) (0) | 2023.01.28 |
[머신러닝] 데이터 전처리(생선 분류 문제3) (2) | 2023.01.20 |
[머신러닝] 샘플링 편향, 넘파이(생선 분류 문제2) (0) | 2023.01.15 |
댓글