001/* 002 * Copyright (c) 2009 The openGion Project. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 013 * either express or implied. See the License for the specific language 014 * governing permissions and limitations under the License. 015 */ 016package org.opengion.penguin.math.statistics; 017 018import java.util.Arrays; 019 020/** 021 * 独自実装の二次回帰計算クラスです。 022 * f(x) = c1x^2 + c2x + c3 023 * の曲線を求めます。 024 */ 025public class HybsSquadraticRegression implements HybsRegression { 026 private final double[] cnst = new double[3] ; // 係数(0次、1次、2次) 027 private double rsquare; // 決定係数 今のところ求めていない 028 029 /** 030 * コンストラクタ。 031 * 与えた二次元データを元に二次回帰を計算します。 032 * 033 * @param data xとyの組み合わせの配列 034 */ 035 public HybsSquadraticRegression( final double[][] data ) { 036 //二次回帰曲線を求めるが、これはapacheにはなさそうなので自前で計算する。 037 train( data ); 038 } 039 040 /** 041 * 係数計算 042 * 043 * c3Σ+c2Σx+c1Σx^2=Σy 044 * c3Σx+c2Σ(x^2)+c1Σx^3=Σ(xy) 045 * c3Σ(x^2)+c2Σ(x^3)+c1Σ(x^4)=Σ(x^2*y) 046 * この三元連立方程式を解くことになる。 047 * 048 * @param data x,yの配列 049 */ 050 private void train( final double[][] data ) { 051 // xの二乗等の総和用 052 // 8.5.5.1 (2024/02/29) PMD 7.0.0 LocalVariableNamingConventions 053// final int data_n=data.length; 054 final int count=data.length; 055 double sumx2 = 0; 056 double sumx = 0; 057 double sumxy = 0; 058 double sumy = 0; 059 double sumx3 = 0; 060 double sumx2y = 0; 061 double sumx4 = 0; 062 063 // まずは計算に使うための和を計算 064// for( int i=0; i < data_n; i++ ) { 065 for( int i=0; i < count; i++ ) { 066 // 8.5.5.1 (2024/02/29) PMD 7.0.0 LocalVariableNamingConventions 067// final double data_x = data[i][0]; 068// final double data_y = data[i][1]; 069 final double dataX = data[i][0]; 070 final double dataY = data[i][1]; 071// final double x2 = data_x*data_x; 072 final double x2 = dataX*dataX; 073 074// sumx += data_x; 075 sumx += dataX; 076 sumx2 += x2; 077// sumxy += data_x * data_y; 078 sumxy += dataX * dataY; 079 sumy += dataY; 080// sumx3 += x2 * data_x; 081// sumx2y += x2 * data_y; 082 sumx3 += x2 * dataX; 083 sumx2y += x2 * dataY; 084 sumx4 += x2 * x2; 085 } 086 087 // ガウス・ジョルダン法で係数計算 088// final double diffx2 = sumx2 - sumx * sumx / data_n; 089// final double diffxy = sumxy - sumx * sumy / data_n; 090// final double diffx3 = sumx3 - sumx2 * sumx /data_n; 091// final double diffx2y = sumx2y - sumx2 * sumy /data_n; 092// final double diffx4 = sumx4 - sumx2 * sumx2 /data_n; 093 final double diffx2 = sumx2 - sumx * sumx / count; 094 final double diffxy = sumxy - sumx * sumy / count; 095 final double diffx3 = sumx3 - sumx2 * sumx /count; 096 final double diffx2y = sumx2y - sumx2 * sumy /count; 097 final double diffx4 = sumx4 - sumx2 * sumx2 /count; 098 final double diffd = diffx2 * diffx4 - diffx3 * diffx3; 099 100 cnst[2] = ( diffx2y * diffx2 - diffxy * diffx3 ) / diffd; 101 cnst[1] = ( diffxy * diffx4 - diffx2y * diffx3 ) / diffd; 102// cnst[0] = sumy/data_n - cnst[1]*sumx/ data_n - cnst[2]*sumx2/data_n; 103 cnst[0] = sumy/count - cnst[1]*sumx/count - cnst[2]*sumx2/count; 104 105 rsquare = 0; // 決定係数 今のところ求めていない 106 } 107 108 /** 109 * 係数(0次、1次、2次)の順にセットした配列を返します。 110 * 111 * @return 係数の配列 112 */ 113 @Override // HybsRegression 114 public double[] getCoefficient() { 115 return Arrays.copyOf( cnst,cnst.length ); 116 } 117 118 /** 119 * 決定係数の取得。 120 * @return 決定係数 121 */ 122 @Override // HybsRegression 123 public double getRSquare() { 124 return rsquare; 125 } 126 127 /** 128 * c2*x^2 + c1*x + c0を計算。 129 * 130 * @param in_x 必要な大きさの変数配列 131 * @return 計算結果 132 */ 133 @Override // HybsRegression 134 public double predict( final double... in_x ) { 135 return cnst[2] * in_x[0] * in_x[0] + cnst[1] * in_x[0] + cnst[0]; 136 } 137 138 // ================ ここまでが本体 ================ 139 /** 140 * ここからテスト用mainメソッド 。 141 * 142 * @param args 引数 143 */ 144 public static void main( final String[] args ) { 145 final double[][] data = {{1, 2.3}, {2, 5.1}, {3, 9.1}, {4, 16.2}}; 146 147 final HybsSquadraticRegression sr = new HybsSquadraticRegression(data); 148 149 final double[] cnst = sr.getCoefficient(); 150 151 System.out.println(cnst[2]); 152 System.out.println(cnst[1]); 153 System.out.println(cnst[0]); 154 155 System.out.println(sr.predict( 5 )); 156 } 157} 158