본문 바로가기

AI

4-2. 머신러닝(SVM, Cross Validation, GridSearch)

SVM은 머신러닝 방법 중 하나이다. 아래 사이트의 학습 알고리즘들을 제공한다.

scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html

 

Classifier comparison — scikit-learn 0.23.2 documentation

Note Click here to download the full example code or to run this example in your browser via Binder Classifier comparison A comparison of a several classifiers in scikit-learn on synthetic datasets. The point of this example is to illustrate the nature of

scikit-learn.org

1. 랜덤 포레스트

랜덤포레스트란 학습데이터로 여러 개의 의사결정트리를 만든 후, 테스트 데이터가 오면 다수결로 의사결정을 하는 알고리즘이다.

 

다음은 버섯이 독이 있는지 확인하는 코드이다.

https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data

import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn import metrics
from sklearn.model_selection import train_test_split

mr = pd.read_csv("mushroom.csv", name=["eat", "b", "c" .... "q"])

data =mr.values
label = mr[0] #먹을 수 있는가 항목

#데이터 분리
data_train, data_test, label_train, label_test = train_test_split(data, label)

#데이터 학습
clf = RandomForestClassifier()#이부분을 바꾸어서 학습방법 교체 가능
clf.fit(data_train, label_train)

#예측
predict = clf.predict(data_test)

#테스트
ac_score = metrics.accuracy_score(label_test, predict)
cl_report = metrics.classification_report(label_test, predict)
print("정답률:", ac_score)
print("리포트:\n", cl_report )

위 링크에서 다운받은 mushroom.csv는 (독의유무, 머리모양, 색상....) 등의 데이터가 알파벳으로 담겨있다. 독버섯이면 'p', 식용이면 'e' 이런 식으로 표시한다.

이러한 기호를 for문에서 숫자로 변형하여 데이터셋을 만든 뒤 RandomForestClassifier()를 사용하여 학습한 것이다.

다른 알고리즘을 쓰고 싶다면 이 부분을 변경하면 된다.

 

데이터를 숫자로 변형할 때 갈색은 0, 흰색은 1, 회색은 2... 이런식으로 변형할 수도 있지만, 이 색깔들 간에는 서로 연속관계가 없으므로 조금 길더라도 000000000001, 000010000000 이런 식으로 변환하는 것이 더 좋다. LabelEncoder는 이러한 변환을 지원하는 클래스이다.

from sklearn.preprocessing import LabelEncoder

e = LabelEncoder()
e.fit(Label)
Y = e.transform(Label)
Y_encoded = tf.keras.utils.to_categorical(Y) #원핫인코딩 Y 값을 0, 1로만 이루어진 형태로 바꿔준다.

 

2. CrossValidation-K분할 교차검증

학습데이터의 결과를 검증하는 방법이다. 학습을 한 뒤 1-10까지의 데이터 중 1-5까지만 정확도가 100%고 나머지는 50%가 되는 상황이 발생할 수 있기때문에 데이터를 K개로 분리하여 학습을 한 뒤 정확도의 평균을 내는 것이다.

 

예를들어 100개의 데이터를 25개씩 5개로 묶어 각각을 A/B/C/D 그룹이라 하자. 

 1) A/B+C+D로 분리하여 A를 테스트, 나머지를 훈련 데이터로 쓰고 정확도 S1을 구한다.

 2) B/A+C+D로 분리하여 A를 테스트, 나머지를 훈련 데이터로 쓰고 정확도 S2을 구한다.

 3) C/A+B+D로 분리하여 A를 테스트, 나머지를 훈련 데이터로 쓰고 정확도 S3을 구한다.

 4) D/A+B+C로 분리하여 A를 테스트, 나머지를 훈련 데이터로 쓰고 정확도 S4을 구한다.

결국 평균정확도 = (S1+S2+S3+S4)/4 가 되는 것이다.

import pandas as pd
from sklearn import svm, metrics, model_selection
import random, re

csv = pd.read_csv("iris.csv")

data = csv[["SepalLength","SepalLength","PetalLength","PetalWidth"]]
label = csv["Name"]

clf = svm.SVC()
scores = model_selection.cross_val_score(clf, data, label, cv=5)
print("각 정답률: ", scores)
print("평균 정답률: ", scores.mean())

학습 데이터 그룹에 따라 정답률의 차이가 난다

위처럼 cross_val_score() 함수를 통해 구현할 수 있다.

 

 

 

3. 그리드서치

학습알고리즘에는 학습데이터, 테스트데이터 뿐만 아니라 다양한 매개변수가 들어간다. 이 매개변수의 값에 따라 정확도가 달라질 수 있다. 따라서 이 매개변수를 잘 넣는 것이 중요하다.

scikit-learn에서는 어떤 매개변수가 적절한지 자동으로 조사하는 메소드를 제공한다.

import pandas as pd
from sklearn import svm, metrics, model_selection
from sklearn.model_selection import GridSearchCV, train_test_split#nomodule -> 위치 변경됨

csv = pd.read_csv("iris.csv")

csv_data = csv[["SepalLength","SepalLength","PetalLength","PetalWidth"]]
csv_label = csv["Name"]

train_data, test_data, train_label, test_label = train_test_split(csv_data, csv_label)

params = [
    {"C": [1,10,100,1000], "kernel":["linear"]},
    {"C": [1,10,100,1000], "kernel":["rbf"], "gamma":[0.001, 0.0001]}
]

#crossvalidation
#n_jobs: 병렬계산할 프로세스 수, -1은 자동지정
clf = GridSearchCV(svm.SVC(), params, n_jobs=-1)
clf.fit(train_data, train_label)
print("학습기=", clf.best_estimator_)

pre = clf.predict(test_data)
ac_score = metrics.accuracy_score(pre, test_label)
print("정답률: ", ac_score)

iris의 정답률은 평균 0.96정도였으나 그리드서치를 통해 1.0에 도달하였다.

단점은 시간이 오래 걸린다는 점이다.

 

 

 

+. 그래프 그리기

import matplotlib.pyplot as plt
import pandas as pd

#pandas로 csv 읽기
tbl = pd.read_csv("color.csv")

# 데이터 개괄 보기
print(df.info())

# 데이터의 일부분 미리 보기
print(df.head())

#그래프로 확인
sns.pairplot(df, hue='colors');
plt.show()

데이터 개괄
그래프

'AI' 카테고리의 다른 글

6-2. 자연어분석(베이즈정리)  (0) 2020.09.19
6-1. 자연어분석(NoMLPy, Gensim)  (0) 2020.09.17
5. 딥러닝(Keras)  (0) 2020.09.12
4-1. 머신러닝  (0) 2020.09.10