본문 바로가기
Python/ML

[머신러닝] k-최근접 이웃(생선 분류 문제)

by JooRi 2023. 1. 13.
728x90
SMALL
* 교재: 혼자 공부하는 머신러닝+딥러닝 (hanbit.co.kr)

 

* 문제: k-최근접 이웃을 사용하여 2개의 생선을 분류한다.

각 생선의 특징(무게, 길이)을 통해 도미와 빙어를 구분하는 머신러닝 프로그램을 만든다.

 

* 문제 해결 과정

1. 도미 데이터 준비하기

2. 빙어 데이터 준비하기

3. 도미와 빙어 데이터를 하나로 합치기

4. 2차원 리스트 만들기

5. 정답 데이터 준비하기

6. k-최근접 이웃 알고리즘을 사용해 머신러닝 프로그램 만들기

  1) KNeighborsClassifier 클래스 임포트하기

  2) KNeighborsClassifier 클래스의 객체 만들기

  3) 객체 훈련하기

  4) 훈련된 객체(모델)의 성능 평가하기

 

 

1. 도미 데이터 준비하기

35마리의 도미의 길이와 무게 리스트를 준비해 보자.

 

#bream_length는 도미의 길이, #bream_weight은 도미의 무게

bream_length = [25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7, 31.0, 31.0, 
                31.5, 32.0, 32.0, 32.0, 33.0, 33.0, 33.5, 33.5, 34.0, 34.0, 34.5, 35.0, 
                35.0, 35.0, 35.0, 36.0, 36.0, 37.0, 38.5, 38.5, 39.5, 41.0, 41.0]
bream_weight = [242.0, 290.0, 340.0, 363.0, 430.0, 450.0, 500.0, 390.0, 450.0, 500.0, 475.0, 500.0, 
                500.0, 340.0, 600.0, 600.0, 700.0, 700.0, 610.0, 650.0, 575.0, 685.0, 620.0, 680.0, 
                700.0, 725.0, 720.0, 714.0, 850.0, 1000.0, 920.0, 955.0, 925.0, 975.0, 950.0]

 

 

위의 리스트를 그래프로 표현해 보자.

 

import matplotlib.pyplot as plt  # matplotlib의 pyplot 함수를 plt로 줄여서 사용

plt.scatter(bream_length, bream_weight)
plt.xlable('length')  # x축은 길이
plt.ylable('weight')  # y축은 무게

 

파이썬에서 과학계산용 그래프를 그리는 대표적인 패키지는 맷플롯립이다. 맷플롯립을 임포트 하고 산점도를 그리는 scatter() 함수를 사용하였다. 여기서 임포트란 따로 만들어둔 파이썬 패키지(함수 묶음)를 사용하기 위해 불러오는 명령이다.

 

 

그래프를 출력해 보자.

 

plt.show()  # 그래프 출력

 

도미 그래프 출력

 

x축을 길이로 하고  y축을 무게로 정한 다음 도미를 그래프에 점으로 표시했다. 이런 그래프를 산점도라고 부른다. 또한 2개의 특성을 사용해 그린 그래프를 2차원 그래프라고 한다.

도미의 길이가 길수록 무게가 많이 나갔기 때문에 산점도 그래프가 일직선에 가까운 형태로 나타났다. 이런 경우를 선형이라고 한다.

 

 

2. 빙어 데이터 준비하기

14마리의 빙어의 길이와 무게 리스트를 준비하고 그래프로 표현해 보자.

 

smelt_length = [9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2, 12.4, 13.0, 14.3, 15.0]
smelt_weight = [6.7, 7.5, 7.0, 9.7, 9.8, 8.7, 10.0, 9.9, 9.8, 12.2, 13.4, 12.2, 19.7, 19.9]

plt.scatter(bream_length, bream_weight)
plt.scatter(smelt_length, smelt_weight)
plt.xlable('length')
plt.ylable('weight')
plt.show()

 

빙어 그래프 출력

 

맷플롯립은 2개의 산점도를 색깔로 구분해서 나타낸다.

주황색 점이 빙어의 산점도이다. 빙어는 도미에 비해 길이와 무게가 매우 작고, 일직선에 가깝기 때문에  선형적이다.

 

 

3. 도미와 빙어 데이터를 하나로 합치기

 

length = bream_length + smelt_length  # 도미와 빙어의 길이 리스트 합치기
weight = bream_weight + smelt_weight  # 도미와 빙어의 무게 리스트 합치기

 

 

4.  2차원 리스트 만들기 

2차원 리스트란 각 특성의 리스트를 세로 방향으로 늘어뜨린 것이다.

대표적인 머신러닝 라이브러리인 사이킷런을 사용하려면 2차원 리스트를 만들어야 한다.

length와 weight 리스트를 2차원 리스트로 만들어보자.

 

fish_data = [[l, w] for l, w in zip(length, weight)]
print(fish_data)

# 출력: [[25.4, 242.0], [26.3, 290.0], [26.5, 340.0], [29.0, 363.0], [29.0, 430.0], [29.7, 450.0], [29.7, 500.0], [30.0, 390.0], [30.0, 450.0], [30.7, 500.0], [31.0, 475.0], [31.0, 500.0], [31.5, 500.0], [32.0, 340.0], [32.0, 600.0], [32.0, 600.0], [33.0, 700.0], [33.0, 700.0], [33.5, 610.0], [33.5, 650.0], [34.0, 575.0], [34.0, 685.0], [34.5, 620.0], [35.0, 680.0], [35.0, 700.0], [35.0, 725.0], [35.0, 720.0], [36.0, 714.0], [36.0, 850.0], [37.0, 1000.0], [38.5, 920.0], [38.5, 955.0], [39.5, 925.0], [41.0, 975.0], [41.0, 950.0], [9.8, 6.7], [10.5, 7.5], [10.6, 7.0], [11.0, 9.7], [11.2, 9.8], [11.3, 8.7], [11.8, 10.0], [11.8, 9.9], [12.0, 9.8], [12.2, 12.2], [12.4, 13.4], [13.0, 12.2], [14.3, 19.7], [15.0, 19.9]]

 

zip() 함수는 나열된 리스트에서 원소를 하나씩 꺼내주는 일을 한다.

for문은 zip() 함수로 length와 weight 리스트에서 원소를 하나씩 꺼내어 l과 w에 할당한다. 그러면 [l, w]가 하나의 원소로 구성된 리스트가 만들어진다.

첫 번째 생선의 길이 25.4cm와 무게 242.0g이 하나의 리스트를 구성하고 이런 리스트가 모여 전체 리스트가 만들어졌다.

 

 

5. 정답 데이터 준비하기

이제 정답 데이터를 준비해야 한다. 왜냐하면 머신러닝 알고리즘이 스스로 생선의 길이와 무게를 보고 도미와 빙어를 구분하는 규칙을 찾을 수 있어야 한다.

쉽게 말해 머신러닝 알고리즘이 스스로 어떤 생선이 도미인지 빙어인지 정답을 맞힐 수 있게 해야 한다.

 

컴퓨터가 문자를 이해할 수 있게 도미와 빙어를 숫자 1과 0으로 표현해 보자.

 

fish_target = [1]*35 + [0]*14
print(fish_target)

# 출력: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

 

도미를 1로 놓고 빙어를 0으로 놓았다. 

첫 번째 생선은 도미이므로 1이고 마지막 생선은 빙어이므로 0이 된다. 즉 정답 리스트는 1이 35번 등장하고 0이 14번 등장한다.

 

 

6. k-최근접 이웃 알고리즘을 사용해 머신러닝 프로그램 만들기

도미와 빙어의 데이터 준비가 끝났다. 이 두 데이터를 스스로 구분하기 위한 머신러닝 프로그램을 만들어보자.

 

 

1) KNeighborsClassifier 클래스 임포트하기

사이킷런 패키지에서 k-최근접 이웃 분류 모델을 만드는 클래스인 KNeighborsClassifier를 임포트 해보자.

 

from sklearn.neighbors import KNeighborsClassifier

 

파이썬에서 from ~ import 구문은 모듈 전체를 임포트 하지 않고 특정 클래스만 임포트 하기 위해 사용한다.

 

 

2) KNeighborsClassifier 클래스의 객체 만들기

 

kn = KNeighborsClassifier()

 

KNeighborsClassifier 클래스의 객체 이름은 kn으로 저장하였고 클래스의 이름 뒤에 괄호를 붙여 클래스의 객체를 만들었다.

 

 

3) 객체 훈련하기 

kn 객체(또는 모델)에 fish_data와 fish_target을 전달하여 도미를 찾기 위한 기준을 학습시켜 보자. 이런 과정을 머신러닝에서는 훈련이라고 한다.

 

kn.fit(fish_data, fish_target)

 

kn 객체(또는 모델)의 fit 메서드에 fish_data와 fish_target을 전달하였다.

fit()은 두 데이터를 가지고 모델을 훈련한다.

 

 

4) 훈련된 객체(모델)의 성능 평가하기

kn 객체(또는 모델)가 얼마나 잘 훈련되었는지 평가해 보자.

 

kn.score(fish_data, fish_target)

# 출력: 1.0

 

score()은 훈련된 모델의 성능을 평가하는 메서드이다.

score()은 0에서 1 사이의 값을 반환하며 1은 모든 데이터를 정확히 맞혔다는 것을 의미한다. 예를 들어 0.5가 출력되었다면 절반만 맞혔다는 의미이다.

1.0이 출력되었으므로 이 모델은 정확도가 100%이며 도미와 빙어를 완벽히 분류했다!

 

 

* 용어 정리

1. 분류(classification): 머신러닝에서 여러 개의 종류(또는 클래스) 중 하나를 구별해 내는 문제

2. 이진 분류(binary classification): 머신러닝에서 2개의 클래스 중 하나를 고르는 문제

3. 산점도(scatter plot): x, y축으로 이뤄진 그래프

4. 맷플롯립(matplotlib): 파이썬에서 과학계산용 그래프를 그리는 대표적인 패키지

5. 임포트(import): 파이썬 패키지(함수 묶음)를 사용하기 위해 불러오는 명령

6. scatter(): 맷플롯립을 임포트 하고 산점도를 그리는 맷플롯립 함수

7. 2차원 그래프: 2개의 특성을 사용해 그린 그래프

8. 선형(linear): 산점도 그래프가 일직선에 가까운 형태로 나타나는 경우

9. 2차원 리스트: 각 특성의 리스트를 세로 방향으로 늘어뜨린 리스트로 리스트의 리스트라고도 함

10. KNeighborsClassifier(): k-최근접 이웃 분류 모델을 만드는 사이킷런 클래스

11. 사이킷런(scikit-learn): 대표적인 머신러닝 라이브러리로 이 패키지를 사용하려면 2차원 리스트를 만들어야 함

12. zip(): 나열된 리스트에서 각각 하나씩 원소를 꺼내어 반환하는 함수

13. from ~ import: 파이썬에서 패키지나 모듈 전체를 임포트 하지 않고 특정 클래스만 임포트 하기 위해 사용하는 구문

14. fit(): 주어진 데이터로 사이킷런 모델을 훈련할 때 사용하는 메서드

15. score(): 훈련된 사이킷런 모델의 성능을 측정하는 메서드

16. k-최근접 이웃: 가장 간단한 머신러닝 알고리즘으로 주변에서 가장 가까운 5개의 데이터를 보고 다수를 차지하는 것에 따라 데이터를 예측함

 

728x90
LIST

댓글