001/*
002 * Copyright (c) 2009 The openGion Project.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
013 * either express or implied. See the License for the specific language
014 * governing permissions and limitations under the License.
015 */
016package org.opengion.penguin.math.statistics;
017
018import java.util.Arrays;
019
020/**
021 * 独自実装の二次回帰計算クラスです。
022 * f(x) = c1x^2 + c2x + c3
023 * の曲線を求めます。
024 */
025public class HybsSquadraticRegression implements HybsRegression {
026        private final double[] cnst = new double[3] ;           // 係数(0次、1次、2次)
027        private double rsquare;         // 決定係数 今のところ求めていない
028
029        /**
030         * コンストラクタ。
031         * 与えた二次元データを元に二次回帰を計算します。
032         *
033         * @param data xとyの組み合わせの配列
034         */
035        public HybsSquadraticRegression( final double[][] data ) {
036                //二次回帰曲線を求めるが、これはapacheにはなさそうなので自前で計算する。
037                train( data );
038        }
039
040        /**
041         * 係数計算
042         *
043         *      c3Σ+c2Σx+c1Σx^2=Σy
044         *      c3Σx+c2Σ(x^2)+c1Σx^3=Σ(xy)
045         *      c3Σ(x^2)+c2Σ(x^3)+c1Σ(x^4)=Σ(x^2*y)
046         *      この三元連立方程式を解くことになる。
047         *
048         * @param data x,yの配列
049         */
050        private void train( final double[][] data ) {
051                // xの二乗等の総和用
052                // 8.5.5.1 (2024/02/29) PMD 7.0.0 LocalVariableNamingConventions
053//              final int data_n=data.length;
054                final int count=data.length;
055                double sumx2    = 0;
056                double sumx             = 0;
057                double sumxy    = 0;
058                double sumy             = 0;
059                double sumx3    = 0;
060                double sumx2y   = 0;
061                double sumx4    = 0;
062
063                // まずは計算に使うための和を計算
064//              for( int i=0; i < data_n; i++ ) {
065                for( int i=0; i < count; i++ ) {
066                        // 8.5.5.1 (2024/02/29) PMD 7.0.0 LocalVariableNamingConventions
067//                      final double data_x     = data[i][0];
068//                      final double data_y     = data[i][1];
069                        final double dataX      = data[i][0];
070                        final double dataY      = data[i][1];
071//                      final double x2         = data_x*data_x;
072                        final double x2         = dataX*dataX;
073
074//                      sumx    += data_x;
075                        sumx    += dataX;
076                        sumx2   += x2;
077//                      sumxy   += data_x * data_y;
078                        sumxy   += dataX * dataY;
079                        sumy    += dataY;
080//                      sumx3   += x2 * data_x;
081//                      sumx2y  += x2 * data_y;
082                        sumx3   += x2 * dataX;
083                        sumx2y  += x2 * dataY;
084                        sumx4   += x2 * x2;
085                }
086
087                // ガウス・ジョルダン法で係数計算
088//              final double diffx2 = sumx2 - sumx * sumx / data_n;
089//              final double diffxy = sumxy - sumx * sumy / data_n;
090//              final double diffx3 = sumx3 - sumx2 * sumx /data_n;
091//              final double diffx2y = sumx2y - sumx2 * sumy /data_n;
092//              final double diffx4 = sumx4 - sumx2 * sumx2 /data_n;
093                final double diffx2 = sumx2 - sumx * sumx / count;
094                final double diffxy = sumxy - sumx * sumy / count;
095                final double diffx3 = sumx3 - sumx2 * sumx /count;
096                final double diffx2y = sumx2y - sumx2 * sumy /count;
097                final double diffx4 = sumx4 - sumx2 * sumx2 /count;
098                final double diffd = diffx2 * diffx4 - diffx3 * diffx3;
099
100                cnst[2] = ( diffx2y * diffx2 - diffxy * diffx3 ) / diffd;
101                cnst[1] = ( diffxy * diffx4 - diffx2y * diffx3 ) / diffd;
102//              cnst[0] = sumy/data_n - cnst[1]*sumx/ data_n - cnst[2]*sumx2/data_n;
103                cnst[0] = sumy/count - cnst[1]*sumx/count - cnst[2]*sumx2/count;
104
105                rsquare = 0;            // 決定係数 今のところ求めていない
106        }
107
108        /**
109         * 係数(0次、1次、2次)の順にセットした配列を返します。
110         *
111         * @return 係数の配列
112         */
113        @Override       // HybsRegression
114        public double[] getCoefficient() {
115                return Arrays.copyOf( cnst,cnst.length );
116        }
117
118        /**
119         * 決定係数の取得。
120         * @return 決定係数
121         */
122        @Override       // HybsRegression
123        public double getRSquare() {
124                return rsquare;
125        }
126
127        /**
128         * c2*x^2 + c1*x + c0を計算。
129         *
130         * @param in_x 必要な大きさの変数配列
131         * @return 計算結果
132         */
133        @Override       // HybsRegression
134        public double predict( final double... in_x ) {
135                return cnst[2] * in_x[0] * in_x[0] + cnst[1] * in_x[0] + cnst[0];
136        }
137
138        // ================ ここまでが本体 ================
139        /**
140         * ここからテスト用mainメソッド 。
141         *
142         * @param args 引数
143         */
144        public static void main( final String[] args ) {
145                final double[][] data = {{1, 2.3}, {2, 5.1}, {3, 9.1}, {4, 16.2}};
146
147                final HybsSquadraticRegression sr = new HybsSquadraticRegression(data);
148
149                final double[] cnst = sr.getCoefficient();
150
151                System.out.println(cnst[2]);
152                System.out.println(cnst[1]);
153                System.out.println(cnst[0]);
154
155                System.out.println(sr.predict( 5 ));
156        }
157}
158