001/* 002 * Copyright (c) 2009 The openGion Project. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 013 * either express or implied. See the License for the specific language 014 * governing permissions and limitations under the License. 015 */ 016package org.opengion.penguin.math.statistics; 017 018import java.util.Arrays; 019 020import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression; 021 022/** 023 * apache.commons.mathを利用したOLS重回帰計算のクラスです。 024 * y = c0 + x1c1 + x2c2 + x3c3 ...の係数を求めます。 025 * c0の切片を考慮するかどうかはnoInterceptで決めます。 026 * 027 */ 028// 8.5.5.1 (2024/02/29) spotbugs CT_CONSTRUCTOR_THROW(コンストラクタで、Excweptionを出さない) class を final にすれば、警告は消える。 029// public class HybsMultiRegression implements HybsRegression { 030public final class HybsMultiRegression implements HybsRegression { 031 private double cnst[]; // 各係数(xの種類+1になる?) 032 private double rsquare; // 決定係数 033 private boolean noIntercept; //切片を利用するかどうか 034 035 /** 036 * コンストラクタ。 037 * 与えた二次元データを元に重回帰を計算します。 038 * xデータとして二次元配列を与えます。 039 * noInterceptで切片有り無しを選択します。 040 * 041 * @param in_x 説明変数 042 * @param in_y 目的変数 043 * @param noIntercept 切片利用有無(trueで利用しない) 044 */ 045 public HybsMultiRegression( final double[][] in_x, final double[] in_y, final boolean noIntercept ) { 046 train( in_x, in_y, noIntercept ); 047 } 048 049 /** 050 * 与えた二次元データを元に重回帰を計算します。 051 * xデータとして二次元配列を与えます。 052 * noInterceptで切片有り無しを選択します。 053 * 054 * @param in_x 説明変数 055 * @param in_y 目的変数 056 * @param noIntercept 切片利用有無(trueで利用しない) 057 */ 058 private void train( final double[][] in_x, final double[] in_y, final boolean noIntercept ) { 059 this.noIntercept = noIntercept; 060 061 // ここで重回帰計算 062 final OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); 063 regression.setNoIntercept(noIntercept); 064 regression.newSampleData(in_y, in_x); 065 066 cnst = regression.estimateRegressionParameters(); 067 rsquare = regression.calculateRSquared(); 068 } 069 070 /** 071 * 係数をセットした配列を返します。 072 * 073 * @return 係数の配列 074 */ 075 @Override // HybsRegression 076 public double[] getCoefficient() { 077 return Arrays.copyOf( cnst,cnst.length ); 078 } 079 080 /** 081 * 決定係数の取得。 082 * @return 決定係数 083 */ 084 @Override // HybsRegression 085 public double getRSquare() { 086 return rsquare; 087 } 088 089 /** 090 * 計算( c0 + c1x1...)を行う。 091 * noInterceptによってc0の利用を決める。 092 * xの大きさが足りない場合は0を返す。 093 * 094 * @param in_x 必要な大きさの変数配列 095 * @return 計算結果 096 */ 097 @Override // HybsRegression 098 public double predict( final double... in_x ) { 099 double rtn = 0; 100 final int itr = noIntercept ? 0 : 1; 101 // 8.5.5.1 (2024/02/29) PMD 7.0.0 OnlyOneReturn メソッドには終了ポイントが 1 つだけ必要 102// if( in_x.length < cnst.length-itr ) { 103// return rtn; 104// } 105 if( in_x.length >= cnst.length-itr ) { 106 for( int i=0; i < in_x.length; i++ ) { 107 rtn = rtn + in_x[i] * cnst[i+itr]; 108 } 109 if( !noIntercept ) { rtn = rtn + cnst[0]; } 110 } 111 return rtn; 112 } 113 114 // ================ ここまでが本体 ================ 115 116 /** 117 * ここからテスト用mainメソッド 。 118 * 119 * @param args 引数 120 */ 121 public static void main( final String[] args ) { 122 // データはhttp://mjin.doshisha.ac.jp/R/14.htmlより 123 // 8.5.4.2 (2024/01/12) PMD 7.0.0 UseShortArrayInitializer 124// final double[] y = new double[] { 50, 60, 65, 65, 70, 75, 80, 85, 90, 95 }; 125 // 8.5.4.2 (2024/01/12) PMD 7.0.0 ShortVariable x ⇒ xx , y ⇒ yy 126 final double[] yy = { 50, 60, 65, 65, 70, 75, 80, 85, 90, 95 }; 127 final double[][] xx = new double[10][]; 128 xx[0] = new double[] { 165, 65 }; 129 xx[1] = new double[] { 170, 68 }; 130 xx[2] = new double[] { 172, 70 }; 131 xx[3] = new double[] { 175, 65 }; 132 xx[4] = new double[] { 170, 80 }; 133 xx[5] = new double[] { 172, 85 }; 134 xx[6] = new double[] { 183, 78 }; 135 xx[7] = new double[] { 187, 79 }; 136 xx[8] = new double[] { 180, 95 }; 137 xx[9] = new double[] { 185, 97 }; 138 139 final HybsMultiRegression mr = new HybsMultiRegression(xx,yy,true); 140 141 System.out.println( mr.getRSquare() ); 142 System.out.println( Arrays.toString( mr.getCoefficient()) ); 143 144 System.out.println( mr.predict( new double[] { 169,85 } )); 145 } 146} 147