2013年2月16日土曜日

k-近傍法のRコード

今回は、k-近傍法のRコードについて書きます。

どんな処理をしているのかとかいう話については、
前回の記事をご参照ください。

この記事を書くにあたり、 ここのブログ記事を参考にさせていただきました。



(1)例に使うソースコードの処理概要

k近傍法でクラス分類するにあたり、ラベル付きの2次元データをたくさん作成し、
そのうちの1部をテストデータ、残りのデータを学習用データとして扱います。

ここでは、2次元データを次のように作成しています。

 クラス1のデータ → 平均(1,0), 分散(0.5,0.5)の正規乱数、
 クラス2のデータ → 平均(0,1), 分散(0.5,0.5)の正規乱数

クラス1、クラス2はそれぞれ200個データを生成しています。
そして、それぞれのクラスにおいて、100個分を学習用データ、
そして残りの100個をテストデータとして使います。

そして、学習用データ,テストデータをKNN関数に食わせて
どちらに分類されるのかを見ることによって、ざっと性能を見ます。

そんな内容になっています。


(2) パッケージのインストール

ここでは、classパッケージとMASSパッケージの2つをインストールします。

 ■classパッケージ

  クラス分類器のいくつかが入っているものです。

  マニュアル を読むと、自己組織化マップやLVQも扱えるようですね。

 ■MASSパッケージ

  こちらを見ると、

    Modern Applied Statistics with S』に出てくるデータセット
   と関数パッケージ


  と書いてありますが、これが何者なのかは分かりません。

この2つのパッケージのインストールは、Rのコマンドプロンプト上で
以下のように行いました。

 $ options(CRAN="http://cran.md.tsukuba.ac.jp/")
 $ install.package("MASS");
 $ install.package("class");


(3) KNN関数の使い方


マニュアルによると、使い方はこんな感じです。

 $ res <- knn(train,test,cl,k = 1,l = 0,prob=TRUE,use.all=TRUE)

変数名 内容
train 学習用データが格納された配列
test テストデータが格納された配列
cl 学習用データに対応するラベルデータが格納された配列
k 上位何位まで取るか。k=3で上位3位。
l クラス判定する際、最低限必要とする抽出データ数。
データ数に満たない場合は、doubtを返す。
prob TRUEにすると、推定されたクラスに学習用データが
何割属していたのかを返す。
use.all TRUEの場合は、全部のデータを使用してk近傍法を行う。
FALSEの場合は、ランダムに取ってきた一部のデータで行う。
最後のuse.allですが、FALSEの場合はk近傍法そのものの意味が
破綻してしまう気がするんですが、気にしすぎでしょうかね。。。


(4) ソースコードの例

nato6933さんのブログに書いてあったソースがとてもシンプルで分かりやすいので、
大部分を参考にさせて頂きました。

参考: http://room6933.com/blog/2011/11/08/r_knn_2/



### ここから、学習用データとテストデータを生成する。

# 平均値に該当するベクトルデータを生成しています。
# この場合は、2つの列データ(1,0)をmu1に代入しています。

mu1 <- c(1,0) 

# 共分散行列を生成します。
Sigma1 <- matrix(c(0.5,0,0,0.5),2,2)

# 正規乱数を200個生成します。
dat1 <- mvrnorm(200,mu1,Sigma1)
 
mu2 <- c(0,1) #平均
Sigma2 <- matrix(c(0.5,0,0,0.5),2,2) #分散
dat2 <- mvrnorm(200,mu2,Sigma2)


# 学習用データを生成します。 
train <- rbind(dat1[1:100,],dat2[1:100,])

# テストデータを生成します。
test <- rbind(dat1[101:200,],dat2[101:200,])

# 学習用データの正解ラベルを生成します。 
cl <- factor(c(rep("green",100),rep("red",100)))

# 最上位のデータだけを使って分類する場合 
k <- 1
res1 <- knn(train,train,cl,k,prob=TRUE)
table(cl,res1)
res2 <- knn(train,test,cl,k,prob=TRUE)
table(cl,res2)

# 上位3位のデータを使って分類する場合 
k <- 3 #3近傍
res3 <- knn(train,train,cl,k,prob=TRUE)
table(cl,res3)
 
res4 <- knn(train,test,cl,k,prob=TRUE)
table(cl,res4)


テーブルには、分類結果が表示されるだけなので、
各データどんな風に分類されたのかが知りたい場合は、
Rのコマンドプロンプト上で、res1, res2, res3を
それぞれ参照してみてください。

0 件のコメント:

コメントを投稿