scikit-learnのtrain_test_split関数を使用してデータを分割する
scikit-learnに含まれるtrain_test_split関数を使用するとデータセットを訓練用データと試験用データに簡単に分割することができます。
train_test_split関数を用いることで、訓練用データは80%、試験用データは20%というように分割可能です。
訓練データとテスト用データに分割する理由
なぜ訓練データと試験用データに分割しないといけないのでしょうか?
それは試験用データを使用して試験を行うためでもありますが、過学習を防止して適切な学習モデルを作成するというためです。
過学習とは、文字通り学習しすぎて訓練データに適合しすぎた学習モデルのことです。例えばデータセットを100%訓練データとして使用させたとしましょう。その場合、訓練データは正答率が90%以上の好成績を出しますが、別の試験データを用いた場合は50%とか、それ以下の正答率しか出せなくなってしまう場合があります。
このように学習しすぎて訓練データに適合しすぎた学習モデルのことを過学習と呼びます。
それを防止するために、予め訓練データと試験用データに分割して学習を行う必要があります。
train_test_split関数を使用した場合のメリット
train_test_split関数を使用した場合のメリットは以下の通りです。
- 容易に訓練データとテストデータに分割することが可能
- 分割の割合を指定することが可能
- データセットの分割はランダムに行われる(固定も可)
- 順番を保ったまま分割することも可能
train_test_split関数の引数
train_test_split関数の引数は以下の通りです。
arrays | 分割対象の同じ長さを持った複数のリスト、Numpy の array, matrix, Pandasのデータフレームを指定。 |
test_size | 小数もしくは整数を指定。小数で指定した場合、テストデータの割合を 0.0 〜 1.0 の間で指定します。整数を指定した場合は、テストデータに必ず含めるレコード件数を整数で指定します。指定しなかった場合や None を設定した場合は、 のサイズを補うように設定します。train_size を設定していない場合、デフォルト値として 0.25 を用います。 |
train_size | 小数もしくは整数を指定。小数で指定した場合、トレーニングデータの割合を 0.0 〜 1.0 の間で指定します。整数を指定した場合は、トレーニングデータに必ず含めるレコード件数を整数で指定します。指定しなかった場合や None を設定した場合は、データセット全体から test_size を引いた分のサイズとします。 |
random_state | 乱数生成のシードとなる整数または、RandomState インスタンスを設定します。指定しなかった場合は、Numpy のnp.random を用いて乱数をセットします。 |
shuffle | データを分割する前にランダムに並び替えを行なうかどうか。True または False で指定します。False に設定した場合、stratify を None に設定しなければいけません。(デフォルト値: True) |
stratify | Stratified Sampling (層化サンプリング) を行なう場合に、クラスを示す行列を設定します。 (デフォルト値: None) |
公式サイト
http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html
train_test_split関数の戻り値
train_test_splitでの戻り値は以下の通りです。
- X_train: 訓練データ
- X_test: テストデータ
- Y_train: 訓練データの正解ラベル
- Y_test: テストデータの正解ラベル
ソースコード(例)
使用するデータ
1 2 3 4 5 6 7 8 9 |
from sklearn.model_selection import train_test_split import numpy as np #18個の訓練用データと試験用データをnumpy形式で作成する X = np.asarray([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18],dtype=int) Y = np.asarray([1,1,1,2,2,2,3,3,3,4,4,4,5,5,5,6,6,6],dtype=int) print("訓練用データ") print(X) print("試験用データ") print(Y) |
1 2 3 4 |
訓練用データ [ 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18] 試験用データ [1 1 1 2 2 2 3 3 3 4 4 4 5 5 5 6 6 6] |
訓練用サイズを指定する場合
1 2 3 4 5 6 7 8 9 10 11 12 |
#訓練用サイズを指定する場合(80%) X_train, X_test, Y_train, Y_test = train_test_split(X, Y,train_size=0.8) print("訓練用データの個数") print(X_train.shape) print("試験用データの個数") print(X_test.shape) print("訓練用データの中身") print(X_train) print(Y_train) print("試験用データの中身") print(X_test) print(Y_test) |
1 2 3 4 5 6 7 8 9 10 |
訓練用データの個数 (14,) 試験用データの個数 (4,) 訓練用データの中身 [13 12 8 6 18 16 1 17 7 4 10 11 15 14] [5 4 3 2 6 6 1 6 3 2 4 4 5 5] 試験用データの中身 [3 9 5 2] [1 3 2 1] |
試験用サイズを指定する場合
1 2 3 4 5 6 7 8 9 10 11 12 |
#試験用サイズを指定する場合(80%) X_train, X_test, Y_train, Y_test = train_test_split(X, Y,test_size=0.8) print("訓練用データの個数") print(X_train.shape) print("試験用データの個数") print(X_test.shape) print("訓練用データの中身") print(X_train) print(Y_train) print("試験用データの中身") print(X_test) print(Y_test) |
1 2 3 4 5 6 7 8 9 10 |
訓練用データの個数 (3,) 試験用データの個数 (15,) 訓練用データの中身 [16 1 10] [6 1 4] 試験用データの中身 [ 2 9 15 18 14 4 5 3 7 17 6 12 8 13 11] [1 3 5 6 5 2 2 1 3 6 2 4 3 5 4] |
乱数のシード値を指定する場合
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
#乱数のシード値を指定する場合 X_train, X_test, Y_train, Y_test = train_test_split(X, Y,random_state=1) print("訓練用データの中身(1回目)") print(X_train) print(Y_train) print("試験用データの中身(1回目)") print(X_test) print(Y_test) X_train, X_test, Y_train, Y_test = train_test_split(X, Y,random_state=1) print("訓練用データの中身(2回目)") print(X_train) print(Y_train) print("試験用データの中身(2回目)") print(X_test) print(Y_test) |
1 2 3 4 5 6 7 8 9 10 11 12 |
訓練用データの中身(1回目) [ 8 16 5 2 11 1 18 17 10 9 13 12 6] [3 6 2 1 4 1 6 6 4 3 5 4 2] 試験用データの中身(1回目) [ 7 4 14 3 15] [3 2 5 1 5] 訓練用データの中身(2回目) [ 8 16 5 2 11 1 18 17 10 9 13 12 6] [3 6 2 1 4 1 6 6 4 3 5 4 2] 試験用データの中身(2回目) [ 7 4 14 3 15] [3 2 5 1 5] |
乱数のシード値を指定しない場合
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
#乱数のシード値を指定しない場合 X_train, X_test, Y_train, Y_test = train_test_split(X, Y,random_state=None) print("訓練用データの中身(1回目)") print(X_train) print(Y_train) print("試験用データの中身(1回目)") print(X_test) print(Y_test) X_train, X_test, Y_train, Y_test = train_test_split(X, Y,random_state=None) print("訓練用データの中身(2回目)") print(X_train) print(Y_train) print("試験用データの中身(2回目)") print(X_test) print(Y_test) |
1 2 3 4 5 6 7 8 9 10 11 12 |
訓練用データの中身(1回目) [11 13 7 4 10 16 2 5 15 6 14 8 1] [4 5 3 2 4 6 1 2 5 2 5 3 1] 試験用データの中身(1回目) [ 9 3 12 18 17] [3 1 4 6 6] 訓練用データの中身(2回目) [ 1 5 8 4 15 16 2 11 9 17 3 14 7] [1 2 3 2 5 6 1 4 3 6 1 5 3] 試験用データの中身(2回目) [12 18 13 6 10] [4 6 5 2 4] |
シャッフルを行わない場合
1 2 3 4 5 6 7 8 |
#シャッフルを行わない場合 X_train, X_test, Y_train, Y_test = train_test_split(X, Y,shuffle=False) print("訓練用データの中身") print(X_train) print(Y_train) print("試験用データの中身") print(X_test) print(Y_test) |
1 2 3 4 5 6 |
訓練用データの中身 [ 1 2 3 4 5 6 7 8 9 10 11 12 13] [1 1 1 2 2 2 3 3 3 4 4 4 5] 試験用データの中身 [14 15 16 17 18] [5 5 6 6 6] |
シャッフルを行う場合
1 2 3 4 5 6 7 8 |
#シャッフルを行う場合 X_train, X_test, Y_train, Y_test = train_test_split(X, Y,shuffle=True) print("訓練用データの中身") print(X_train) print(Y_train) print("試験用データの中身") print(X_test) print(Y_test) |
1 2 3 4 5 6 |
訓練用データの中身 [ 5 7 16 18 14 11 15 13 9 17 12 6 1] [2 3 6 6 5 4 5 5 3 6 4 2 1] 試験用データの中身 [ 8 3 10 2 4] [3 1 4 1 2] |
scikit-learnのtrain_test_split関数を用いることでこのように非常に簡単に分割することが出来ます。但し数千万のようなとても多い件数の場合は対応できない場合があるので注意しましょう。