001/*
002 * Copyright (c) 2009 The openGion Project.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
013 * either express or implied. See the License for the specific language
014 * governing permissions and limitations under the License.
015 */
016package org.opengion.penguin.math.statistics;
017
018import java.util.Arrays;
019
020import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
021
022/**
023 * apache.commons.mathを利用したOLS重回帰計算のクラスです。
024 * y = c0 + x1c1 + x2c2 + x3c3 ...の係数を求めます。
025 * c0の切片を考慮するかどうかはnoInterceptで決めます。
026 *
027 */
028// 8.5.5.1 (2024/02/29) spotbugs CT_CONSTRUCTOR_THROW(コンストラクタで、Excweptionを出さない) class を final にすれば、警告は消える。
029// public class HybsMultiRegression implements HybsRegression {
030public final class HybsMultiRegression implements HybsRegression {
031        private double cnst[];                  // 各係数(xの種類+1になる?)
032        private double rsquare;                 // 決定係数
033        private boolean noIntercept;    //切片を利用するかどうか
034
035        /**
036         * コンストラクタ。
037         * 与えた二次元データを元に重回帰を計算します。
038         * xデータとして二次元配列を与えます。
039         * noInterceptで切片有り無しを選択します。
040         *
041         * @param in_x 説明変数
042         * @param in_y 目的変数
043         * @param noIntercept 切片利用有無(trueで利用しない)
044         */
045        public HybsMultiRegression( final double[][] in_x, final double[] in_y, final boolean noIntercept ) {
046                train( in_x, in_y, noIntercept );
047        }
048
049        /**
050         * 与えた二次元データを元に重回帰を計算します。
051         * xデータとして二次元配列を与えます。
052         * noInterceptで切片有り無しを選択します。
053         *
054         * @param in_x 説明変数
055         * @param in_y 目的変数
056         * @param noIntercept 切片利用有無(trueで利用しない)
057         */
058        private void train( final double[][] in_x, final double[] in_y, final boolean noIntercept ) {
059                this.noIntercept = noIntercept;
060
061                // ここで重回帰計算
062                final OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
063                regression.setNoIntercept(noIntercept);
064                regression.newSampleData(in_y, in_x);
065
066                cnst    = regression.estimateRegressionParameters();
067                rsquare = regression.calculateRSquared();
068        }
069
070        /**
071         * 係数をセットした配列を返します。
072         *
073         * @return 係数の配列
074         */
075        @Override       // HybsRegression
076        public double[] getCoefficient() {
077                return Arrays.copyOf( cnst,cnst.length );
078        }
079
080        /**
081         * 決定係数の取得。
082         * @return 決定係数
083         */
084        @Override       // HybsRegression
085        public double getRSquare() {
086                return rsquare;
087        }
088
089        /**
090         * 計算( c0 + c1x1...)を行う。
091         * noInterceptによってc0の利用を決める。
092         * xの大きさが足りない場合は0を返す。
093         *
094         * @param in_x 必要な大きさの変数配列
095         * @return 計算結果
096         */
097        @Override       // HybsRegression
098        public double predict( final double... in_x ) {
099                double rtn = 0;
100                final int itr = noIntercept ? 0 : 1;
101                // 8.5.5.1 (2024/02/29) PMD 7.0.0 OnlyOneReturn メソッドには終了ポイントが 1 つだけ必要
102//              if( in_x.length < cnst.length-itr ) {
103//                      return rtn;
104//              }
105                if( in_x.length >= cnst.length-itr ) {
106                        for( int i=0; i < in_x.length; i++ ) {
107                                rtn = rtn + in_x[i] * cnst[i+itr];
108                        }
109                        if( !noIntercept ) { rtn = rtn + cnst[0]; }
110                }
111                return rtn;
112        }
113
114        // ================ ここまでが本体 ================
115
116        /**
117         * ここからテスト用mainメソッド 。
118         *
119         * @param args 引数
120         */
121        public static void main( final String[] args ) {
122                // データはhttp://mjin.doshisha.ac.jp/R/14.htmlより
123                // 8.5.4.2 (2024/01/12) PMD 7.0.0 UseShortArrayInitializer
124//              final double[] y = new double[] { 50, 60, 65, 65, 70, 75, 80, 85, 90, 95 };
125                // 8.5.4.2 (2024/01/12) PMD 7.0.0 ShortVariable x ⇒ xx , y ⇒ yy
126                final double[] yy = { 50, 60, 65, 65, 70, 75, 80, 85, 90, 95 };
127                final double[][] xx = new double[10][];
128                xx[0] = new double[] { 165, 65 };
129                xx[1] = new double[] { 170, 68 };
130                xx[2] = new double[] { 172, 70 };
131                xx[3] = new double[] { 175, 65 };
132                xx[4] = new double[] { 170, 80 };
133                xx[5] = new double[] { 172, 85 };
134                xx[6] = new double[] { 183, 78 };
135                xx[7] = new double[] { 187, 79 };
136                xx[8] = new double[] { 180, 95 };
137                xx[9] = new double[] { 185, 97 };
138
139                final HybsMultiRegression mr = new HybsMultiRegression(xx,yy,true);
140
141                System.out.println( mr.getRSquare() );
142                System.out.println( Arrays.toString( mr.getCoefficient()) );
143
144                System.out.println( mr.predict( new double[] { 169,85 } ));
145        }
146}
147