001/*
002 * Copyright (c) 2009 The openGion Project.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
013 * either express or implied. See the License for the specific language
014 * governing permissions and limitations under the License.
015 */
016package org.opengion.penguin.math.statistics;
017
018import java.util.Arrays;
019import java.util.Collections;
020import java.util.List;
021
022/**
023 * 多項ロジスティック回帰の実装です。
024 * 確率的勾配降下法(SGD)を利用します。
025 *
026 * ロジスティック回帰はn次元の情報からどのグループに所属するかの予測値を得るための手法の一つです。
027 *
028 * 実装は
029 * http://nbviewer.jupyter.org/gist/mitmul/9283713
030 * https://yusugomori.com/projects/deep-learning/
031 * を参考にしています。
032 */
033public class HybsLogisticRegression {
034//      private final int n_N;                  // データ個数
035        /** データ個数 */
036        private final int n_cnt;                // 8.5.4.2 (2024/01/12) PMD 7.0.0 FieldNamingConventions 対応 n_N ⇒ n_cnt
037
038        /** データ次元 */
039        private final int n_in;
040        /** ラベル種別数 */
041        private final int n_out;
042
043        // 8.5.4.2 (2024/01/12) PMD 7.0.0 ImmutableField 対応
044        /** 写像変数ベクトル f(x) = Wx + b */
045        private final double[][] vW;
046        private final double[] vb;
047
048        /**
049         * コンストラクタ。
050         *
051         * 学習もしてしまう。
052         *
053         * xはデータセット各行がn次元の説明変数となっている。
054         * trainはそれに対する{0,1,0},{1,0,0}のようなラベルを示すベクトルとなる。
055         * 学習率は通常、0.1程度を設定する。
056         * このロジックではループ毎に0.95をかけて徐々に学習率が下がるようにしている。
057         * 全データを利用すると時間がかかる場合があるので、確率的勾配降下法を利用しているが、
058         * 選択個数はデータに対する割合を与える。
059         * データ個数が少ない場合は1をセットすればよい。
060         *
061         * @og.rev 8.5.4.2 (2024/01/12) PMD 7.0.0 FieldNamingConventions 対応 n_N ⇒ n_cnt
062         *
063         * @param data データセット配列
064         * @param label データに対応したラベルを示す配列
065         * @param learning_rate 学習係数(0から1の間の数値)
066         * @param loop 学習のループ回数(ミニバッチを作る回数)
067         * @param minibatch_rate 全体に対するミニバッチの割合(0から1の間の数値)
068         *
069         */
070        public HybsLogisticRegression(final double data[][], final int label[][], final double learning_rate ,final int loop, final double minibatch_rate ) {
071        //      List<Integer> indexList; //シャッフル用
072
073//              this.n_N = data.length;
074                this.n_cnt = data.length;                                       // 8.5.4.2 (2024/01/12) PMD 7.0.0
075                this.n_in = data[0].length;
076                this.n_out = label[0].length;                           // ラベル種別
077
078                vW = new double[n_out][n_in];
079                vb = new double[n_out];
080
081                // 確率勾配に利用するための配列インデックス配列
082//              final Integer[] random_index = new Integer[n_N]; //プリミティブ型だとasListできないため
083                // 8.5.5.1 (2024/02/29) PMD 7.0.0 LocalVariableNamingConventions
084//              final Integer[] random_index = new Integer[n_cnt]; //プリミティブ型だとasListできないため
085                final Integer[] randomIdx = new Integer[n_cnt]; //プリミティブ型だとasListできないため
086//              for( int i=0; i<n_N; i++) {
087                for( int i=0; i<n_cnt; i++) {                           // 8.5.4.2 (2024/01/12) PMD 7.0.0
088//                      random_index[i] = i;
089                        randomIdx[i] = i;
090                }
091//              final List<Integer> indexList = Arrays.asList( random_index );
092                final List<Integer> indexList = Arrays.asList( randomIdx );
093
094                double localRate = learning_rate;
095                for(int epoch=0; epoch<loop; epoch++) {
096                        Collections.shuffle( indexList );
097        //              random_index = indexList.toArray(new Integer[indexList.size()]);
098
099                        //random_indexの先頭からn_N*minibatch_rate個のものを対象に学習をかける(ミニバッチ)
100//                      for(int i=0; i< n_N * minibatch_rate; i++) {
101                        for(int i=0; i< n_cnt * minibatch_rate; i++) {          // 8.5.4.2 (2024/01/12) PMD 7.0.0
102        //                      final int idx = random_index[i];
103                                final int idx = indexList.get(i);
104                                train(data[idx], label[idx], localRate);
105                        }
106                    localRate *= 0.95; //徐々に学習率を下げて振動を抑える。
107                }
108        }
109
110        /**
111         * データを与えて学習をさせます。
112         * パラメータの1行を与えています。
113         *
114         * 0/1のロジスティック回帰の場合は
115         * ラベルc(0or1)が各xに対して与えられている時
116         * s(x)=σ(Wx+b)=1/(1+ exp(-Wx-b))として、
117         * 確率の対数和L(W,b)の符号反転させたものの偏導関数
118         * ∂L/∂w=-∑x(c-s(x))
119         * ∂L/∂b=-∑=(c-s(x))
120         * が最小になるようなW,bの値をパラメータを変えながら求める。
121         * というのが実装になる。(=0を求められないため)
122         * 多次元の場合はシグモイド関数σ(x)の代わりにソフトマックス関数π(x)を利用して
123         * 拡張したものとなる。(以下はソフトマックス関数利用)
124         *
125         * @og.rev 8.5.4.2 (2024/01/12) PMD 7.0.0 FieldNamingConventions 対応 n_N ⇒ n_cnt
126         *
127         * @param in_x 1行分のデータ
128         * @param in_y xに対するラベル
129         * @param lr 学習率
130         * @return 差分配列
131         */
132        private double[] train( final double[] in_x, final int[] in_y, final double lr ) {
133                // 8.5.5.1 (2024/02/29) PMD 7.0.0 LocalVariableNamingConventions
134//              final double[] p_y_given_x = new double[n_out];
135                final double[] givenX   = new double[n_out];
136                final double[] dy               = new double[n_out];
137
138                for(int i=0; i<n_out; i++) {
139//                      p_y_given_x[i] = 0;
140                        givenX[i] = 0;
141                        for(int j=0; j<n_in; j++) {
142//                              p_y_given_x[i] += vW[i][j] * in_x[j];
143                                givenX[i] += vW[i][j] * in_x[j];
144                        }
145//                      p_y_given_x[i] += vb[i];
146                        givenX[i] += vb[i];
147                }
148//              softmax( p_y_given_x );
149                softmax( givenX );
150
151                // 勾配の平均で更新?
152                for(int i=0; i<n_out; i++) {
153//                      dy[i] = in_y[i] - p_y_given_x[i];
154                        dy[i] = in_y[i] - givenX[i];
155
156                        for(int j=0; j<n_in; j++) {
157//                              vW[i][j] += lr * dy[i] * in_x[j] / n_N;
158                                vW[i][j] += lr * dy[i] * in_x[j] / n_cnt;               // 8.5.4.2 (2024/01/12) PMD 7.0.0
159                        }
160
161//                      vb[i] += lr * dy[i] / n_N;
162                        vb[i] += lr * dy[i] / n_cnt;                                            // 8.5.4.2 (2024/01/12) PMD 7.0.0
163                }
164
165                return dy;
166        }
167
168        /**
169         * ソフトマックス関数。
170         * π(xi) = exp(xi)/Σexp(x)
171         * @param in_x 変数X
172         */
173        private void softmax( final double[] in_x ) {
174                // double max = 0.0;
175                double sum = 0.0;
176
177                // for(int i=0; i<n_out; i++) {
178                //      if(max < x[i]) {
179                //              max = x[i];
180                //      }
181                // }
182
183                for(int i=0; i<n_out; i++) {
184                        //x[i] = Math.exp(x[i] - max); // maxとの差分を取ると利点があるのか分からなかった
185                        in_x[i] = Math.exp(in_x[i]);
186                        sum += in_x[i];
187                }
188
189                for(int i=0; i<n_out; i++) {
190                        in_x[i] /= sum;
191                }
192        }
193
194        /**
195         * 写像式 Wx+b のW、係数ベクトル。
196         * @return 係数ベクトル
197         */
198        public double[][] getW() {
199                return vW;
200        }
201
202        /**
203         * 写像式 Wx + bのb、バイアス。
204         * @return バイアスベクトル
205         */
206        public double[] getB() {
207                return vb;
208        }
209
210        /**
211         * 出来た予測式に対して、データを入力してyを出力する。
212         * (yは各ラベルに対する確率分布となる)
213         * @param in_x 予測したいデータ
214         * @return 予測結果
215         */
216        public double[] predict(final double[] in_x) {
217                // 8.5.5.1 (2024/02/29) PMD 7.0.0 LocalVariableNamingConventions
218//              final double[] out_y = new double[n_out];
219                final double[] outY = new double[n_out];
220
221                for(int i=0; i<n_out; i++) {
222//                      out_y[i] = 0.;
223                        outY[i] = 0.;
224                        for(int j=0; j<n_in; j++) {
225//                              out_y[i] += vW[i][j] * in_x[j];
226                                outY[i] += vW[i][j] * in_x[j];
227                        }
228//                      out_y[i] += vb[i];
229                        outY[i] += vb[i];
230                }
231
232//              softmax(out_y);
233                softmax(outY);
234
235//              return out_y;
236                return outY;
237        }
238
239        // ================ ここまでが本体 ================
240
241        /**
242         * ここからテスト用mainメソッド 。
243         *
244         * @og.rev 8.5.4.2 (2024/01/12) PMD 7.0.0 LocalVariableNamingConventions 対応
245         *
246         * @param args 引数
247         */
248        public static void main( final String[] args ) {
249                // 3つの分類で分ける
250//              final double[][] train_X = {
251                final double[][] trainX = {
252                                {-2.0, 2.0}
253                                ,{-2.1, 1.9}
254                                ,{-1.8, 2.1}
255                                ,{0.0, 0.0}
256                                ,{0.2, -0.2}
257                                ,{-0.1, 0.1}
258                                ,{2.0, -2.0}
259                                ,{2.2, -2.1}
260                                ,{1.9, -2.0}
261                };
262
263//              final int[][] train_Y = {
264                final int[][] trainY = {
265                                {1, 0, 0}
266                                ,{1, 0, 0}
267                                ,{1, 0, 0}
268                                ,{0, 1, 0}
269                                ,{0, 1, 0}
270                                ,{0, 1, 0}
271                                ,{0, 0, 1}
272                                ,{0, 0, 1}
273                                ,{0, 0, 1}
274                };
275
276                 // test data
277//              final double[][] test_X = {
278                final double[][] testX = {
279                                {-2.5, 2.0}
280                                ,{0.1, -0.1}
281                                ,{1.5,-2.5}
282                };
283
284//              final double[][] test_Y = new double[test_X.length][train_Y[0].length];
285                final double[][] testY = new double[testX.length][trainY[0].length];
286
287//              final HybsLogisticRegression hlr = new HybsLogisticRegression( train_X, train_Y, 0.1, 500, 1 );
288                final HybsLogisticRegression hlr = new HybsLogisticRegression( trainX, trainY, 0.1, 500, 1 );
289
290                // テスト
291                // このデータでは2番目の条件には入りにくい?
292//              for(int i=0; i<test_X.length; i++) {
293//                       test_Y[i] = hlr.predict(test_X[i]);
294//                       System.out.print( Arrays.toString(test_Y[i]) );
295//              }
296                for(int i=0; i<testX.length; i++) {
297                         testY[i] = hlr.predict(testX[i]);
298                         System.out.print( Arrays.toString(testY[i]) );
299                }
300        }
301}
302