001/* 002 * Copyright (c) 2009 The openGion Project. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 013 * either express or implied. See the License for the specific language 014 * governing permissions and limitations under the License. 015 */ 016package org.opengion.penguin.math.statistics; 017 018import java.util.Arrays; 019import java.util.Collections; 020import java.util.List; 021 022/** 023 * 多項ロジスティック回帰の実装です。 024 * 確率的勾配降下法(SGD)を利用します。 025 * 026 * ロジスティック回帰はn次元の情報からどのグループに所属するかの予測値を得るための手法の一つです。 027 * 028 * 実装は 029 * http://nbviewer.jupyter.org/gist/mitmul/9283713 030 * https://yusugomori.com/projects/deep-learning/ 031 * を参考にしています。 032 */ 033public class HybsLogisticRegression { 034// private final int n_N; // データ個数 035 /** データ個数 */ 036 private final int n_cnt; // 8.5.4.2 (2024/01/12) PMD 7.0.0 FieldNamingConventions 対応 n_N ⇒ n_cnt 037 038 /** データ次元 */ 039 private final int n_in; 040 /** ラベル種別数 */ 041 private final int n_out; 042 043 // 8.5.4.2 (2024/01/12) PMD 7.0.0 ImmutableField 対応 044 /** 写像変数ベクトル f(x) = Wx + b */ 045 private final double[][] vW; 046 private final double[] vb; 047 048 /** 049 * コンストラクタ。 050 * 051 * 学習もしてしまう。 052 * 053 * xはデータセット各行がn次元の説明変数となっている。 054 * trainはそれに対する{0,1,0},{1,0,0}のようなラベルを示すベクトルとなる。 055 * 学習率は通常、0.1程度を設定する。 056 * このロジックではループ毎に0.95をかけて徐々に学習率が下がるようにしている。 057 * 全データを利用すると時間がかかる場合があるので、確率的勾配降下法を利用しているが、 058 * 選択個数はデータに対する割合を与える。 059 * データ個数が少ない場合は1をセットすればよい。 060 * 061 * @og.rev 8.5.4.2 (2024/01/12) PMD 7.0.0 FieldNamingConventions 対応 n_N ⇒ n_cnt 062 * 063 * @param data データセット配列 064 * @param label データに対応したラベルを示す配列 065 * @param learning_rate 学習係数(0から1の間の数値) 066 * @param loop 学習のループ回数(ミニバッチを作る回数) 067 * @param minibatch_rate 全体に対するミニバッチの割合(0から1の間の数値) 068 * 069 */ 070 public HybsLogisticRegression(final double data[][], final int label[][], final double learning_rate ,final int loop, final double minibatch_rate ) { 071 // List<Integer> indexList; //シャッフル用 072 073// this.n_N = data.length; 074 this.n_cnt = data.length; // 8.5.4.2 (2024/01/12) PMD 7.0.0 075 this.n_in = data[0].length; 076 this.n_out = label[0].length; // ラベル種別 077 078 vW = new double[n_out][n_in]; 079 vb = new double[n_out]; 080 081 // 確率勾配に利用するための配列インデックス配列 082// final Integer[] random_index = new Integer[n_N]; //プリミティブ型だとasListできないため 083 // 8.5.5.1 (2024/02/29) PMD 7.0.0 LocalVariableNamingConventions 084// final Integer[] random_index = new Integer[n_cnt]; //プリミティブ型だとasListできないため 085 final Integer[] randomIdx = new Integer[n_cnt]; //プリミティブ型だとasListできないため 086// for( int i=0; i<n_N; i++) { 087 for( int i=0; i<n_cnt; i++) { // 8.5.4.2 (2024/01/12) PMD 7.0.0 088// random_index[i] = i; 089 randomIdx[i] = i; 090 } 091// final List<Integer> indexList = Arrays.asList( random_index ); 092 final List<Integer> indexList = Arrays.asList( randomIdx ); 093 094 double localRate = learning_rate; 095 for(int epoch=0; epoch<loop; epoch++) { 096 Collections.shuffle( indexList ); 097 // random_index = indexList.toArray(new Integer[indexList.size()]); 098 099 //random_indexの先頭からn_N*minibatch_rate個のものを対象に学習をかける(ミニバッチ) 100// for(int i=0; i< n_N * minibatch_rate; i++) { 101 for(int i=0; i< n_cnt * minibatch_rate; i++) { // 8.5.4.2 (2024/01/12) PMD 7.0.0 102 // final int idx = random_index[i]; 103 final int idx = indexList.get(i); 104 train(data[idx], label[idx], localRate); 105 } 106 localRate *= 0.95; //徐々に学習率を下げて振動を抑える。 107 } 108 } 109 110 /** 111 * データを与えて学習をさせます。 112 * パラメータの1行を与えています。 113 * 114 * 0/1のロジスティック回帰の場合は 115 * ラベルc(0or1)が各xに対して与えられている時 116 * s(x)=σ(Wx+b)=1/(1+ exp(-Wx-b))として、 117 * 確率の対数和L(W,b)の符号反転させたものの偏導関数 118 * ∂L/∂w=-∑x(c-s(x)) 119 * ∂L/∂b=-∑=(c-s(x)) 120 * が最小になるようなW,bの値をパラメータを変えながら求める。 121 * というのが実装になる。(=0を求められないため) 122 * 多次元の場合はシグモイド関数σ(x)の代わりにソフトマックス関数π(x)を利用して 123 * 拡張したものとなる。(以下はソフトマックス関数利用) 124 * 125 * @og.rev 8.5.4.2 (2024/01/12) PMD 7.0.0 FieldNamingConventions 対応 n_N ⇒ n_cnt 126 * 127 * @param in_x 1行分のデータ 128 * @param in_y xに対するラベル 129 * @param lr 学習率 130 * @return 差分配列 131 */ 132 private double[] train( final double[] in_x, final int[] in_y, final double lr ) { 133 // 8.5.5.1 (2024/02/29) PMD 7.0.0 LocalVariableNamingConventions 134// final double[] p_y_given_x = new double[n_out]; 135 final double[] givenX = new double[n_out]; 136 final double[] dy = new double[n_out]; 137 138 for(int i=0; i<n_out; i++) { 139// p_y_given_x[i] = 0; 140 givenX[i] = 0; 141 for(int j=0; j<n_in; j++) { 142// p_y_given_x[i] += vW[i][j] * in_x[j]; 143 givenX[i] += vW[i][j] * in_x[j]; 144 } 145// p_y_given_x[i] += vb[i]; 146 givenX[i] += vb[i]; 147 } 148// softmax( p_y_given_x ); 149 softmax( givenX ); 150 151 // 勾配の平均で更新? 152 for(int i=0; i<n_out; i++) { 153// dy[i] = in_y[i] - p_y_given_x[i]; 154 dy[i] = in_y[i] - givenX[i]; 155 156 for(int j=0; j<n_in; j++) { 157// vW[i][j] += lr * dy[i] * in_x[j] / n_N; 158 vW[i][j] += lr * dy[i] * in_x[j] / n_cnt; // 8.5.4.2 (2024/01/12) PMD 7.0.0 159 } 160 161// vb[i] += lr * dy[i] / n_N; 162 vb[i] += lr * dy[i] / n_cnt; // 8.5.4.2 (2024/01/12) PMD 7.0.0 163 } 164 165 return dy; 166 } 167 168 /** 169 * ソフトマックス関数。 170 * π(xi) = exp(xi)/Σexp(x) 171 * @param in_x 変数X 172 */ 173 private void softmax( final double[] in_x ) { 174 // double max = 0.0; 175 double sum = 0.0; 176 177 // for(int i=0; i<n_out; i++) { 178 // if(max < x[i]) { 179 // max = x[i]; 180 // } 181 // } 182 183 for(int i=0; i<n_out; i++) { 184 //x[i] = Math.exp(x[i] - max); // maxとの差分を取ると利点があるのか分からなかった 185 in_x[i] = Math.exp(in_x[i]); 186 sum += in_x[i]; 187 } 188 189 for(int i=0; i<n_out; i++) { 190 in_x[i] /= sum; 191 } 192 } 193 194 /** 195 * 写像式 Wx+b のW、係数ベクトル。 196 * @return 係数ベクトル 197 */ 198 public double[][] getW() { 199 return vW; 200 } 201 202 /** 203 * 写像式 Wx + bのb、バイアス。 204 * @return バイアスベクトル 205 */ 206 public double[] getB() { 207 return vb; 208 } 209 210 /** 211 * 出来た予測式に対して、データを入力してyを出力する。 212 * (yは各ラベルに対する確率分布となる) 213 * @param in_x 予測したいデータ 214 * @return 予測結果 215 */ 216 public double[] predict(final double[] in_x) { 217 // 8.5.5.1 (2024/02/29) PMD 7.0.0 LocalVariableNamingConventions 218// final double[] out_y = new double[n_out]; 219 final double[] outY = new double[n_out]; 220 221 for(int i=0; i<n_out; i++) { 222// out_y[i] = 0.; 223 outY[i] = 0.; 224 for(int j=0; j<n_in; j++) { 225// out_y[i] += vW[i][j] * in_x[j]; 226 outY[i] += vW[i][j] * in_x[j]; 227 } 228// out_y[i] += vb[i]; 229 outY[i] += vb[i]; 230 } 231 232// softmax(out_y); 233 softmax(outY); 234 235// return out_y; 236 return outY; 237 } 238 239 // ================ ここまでが本体 ================ 240 241 /** 242 * ここからテスト用mainメソッド 。 243 * 244 * @og.rev 8.5.4.2 (2024/01/12) PMD 7.0.0 LocalVariableNamingConventions 対応 245 * 246 * @param args 引数 247 */ 248 public static void main( final String[] args ) { 249 // 3つの分類で分ける 250// final double[][] train_X = { 251 final double[][] trainX = { 252 {-2.0, 2.0} 253 ,{-2.1, 1.9} 254 ,{-1.8, 2.1} 255 ,{0.0, 0.0} 256 ,{0.2, -0.2} 257 ,{-0.1, 0.1} 258 ,{2.0, -2.0} 259 ,{2.2, -2.1} 260 ,{1.9, -2.0} 261 }; 262 263// final int[][] train_Y = { 264 final int[][] trainY = { 265 {1, 0, 0} 266 ,{1, 0, 0} 267 ,{1, 0, 0} 268 ,{0, 1, 0} 269 ,{0, 1, 0} 270 ,{0, 1, 0} 271 ,{0, 0, 1} 272 ,{0, 0, 1} 273 ,{0, 0, 1} 274 }; 275 276 // test data 277// final double[][] test_X = { 278 final double[][] testX = { 279 {-2.5, 2.0} 280 ,{0.1, -0.1} 281 ,{1.5,-2.5} 282 }; 283 284// final double[][] test_Y = new double[test_X.length][train_Y[0].length]; 285 final double[][] testY = new double[testX.length][trainY[0].length]; 286 287// final HybsLogisticRegression hlr = new HybsLogisticRegression( train_X, train_Y, 0.1, 500, 1 ); 288 final HybsLogisticRegression hlr = new HybsLogisticRegression( trainX, trainY, 0.1, 500, 1 ); 289 290 // テスト 291 // このデータでは2番目の条件には入りにくい? 292// for(int i=0; i<test_X.length; i++) { 293// test_Y[i] = hlr.predict(test_X[i]); 294// System.out.print( Arrays.toString(test_Y[i]) ); 295// } 296 for(int i=0; i<testX.length; i++) { 297 testY[i] = hlr.predict(testX[i]); 298 System.out.print( Arrays.toString(testY[i]) ); 299 } 300 } 301} 302