読者です 読者をやめる 読者になる 読者になる

データサイエンスしてみる

新米エンジニアがデータサイエンスを勉強する。機械学習とかRとかPythonとか

scikit-learnのSVMで数字を線形分離してみる

Python 機械学習

先日はanacondaを導入して簡単なプロットをしてみました。
が、anacondaはscikit-learnのような機械学習ライブラリがまとめて入っています。
せっかくなので、簡単な機械学習について手を出してみようと思います。

参考にしたのは下記ブログです。
参考というかそのまんま写経させていただきました。
先人に感謝!

sucrose.hatenablog.com

内容としてはグレースケールの数字画像をSVMで分類する、というものです。
SVMについての説明はまた今度ということで、とりあえずやってみます。

from sklearn.datasets import load_digits
from sklearn.cross_validation import train_test_split
from sklearn.svm import LinearSVC

digits = load_digits(2)

data_train, data_test, label_train, label_test = train_test_split(digits.data, digits.target)

estimator = LinearSVC(C=1.0)
estimator.fit(data_train,label_train)
label_predict = estimator.predict(data_test)

from sklearn.metrics import confusion_matrix
confusion_matrix(label_test, label_predict)

これで結果として下記が返ってきました。

array([[49,  0],
       [ 0, 41]])

実際のlabelと予想したlabelがすべてあっているので100%分離ですね。
と、こんな感じにSVMを簡単に実装することができるscikit-learnさんマジでスゴいという話でした。