001/* 002 * Copyright (c) 2009 The openGion Project. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 013 * either express or implied. See the License for the specific language 014 * governing permissions and limitations under the License. 015 */ 016package org.opengion.penguin.math.statistics; 017 018import org.apache.commons.math3.stat.StatUtils; 019import org.apache.commons.math3.linear.RealMatrix; 020import org.apache.commons.math3.linear.Array2DRowRealMatrix; 021import org.apache.commons.math3.linear.LUDecomposition; 022import org.apache.commons.math3.stat.correlation.Covariance; 023 024/** 025 * apache.commons.mathを利用した、マハラノビス距離関係の処理クラスです。 026 * 027 * 相関を考慮した距離が求まります。 028 * 教師無し学習的に、異常値検知に利用可能です。 029 * 閾値は95%区間の2.448がデフォルトです。(3なら99%) 030 * 031 * 「Juan Francisco Quesada-Brizuela」氏の距離計算PGを参照しています。 032 * 学術的には様々な改良が提案されていますが、このクラスでは単純なマハラノビス距離を扱います。 033 */ 034// 8.5.5.1 (2024/02/29) spotbugs CT_CONSTRUCTOR_THROW(コンストラクタで、Excweptionを出さない) class を final にすれば、警告は消える。 035// public class HybsMahalanobis { 036public final class HybsMahalanobis { 037 038 private double[] dataDistance; // 元データの各マハラノビス距離 039 private double[] average; // 平均 040 private RealMatrix covariance; // 共分散 041 private double limen=2.448; // 異常値検知をする際の閾値(初期値は95%信頼楕円) 042 043 /** 044 * コンストラクタ。 045 * 与えたデータマトリクスを元にマハラノビス距離を求めるための準備をします。 046 * (平均と共分散を求めます) 047 * 引数calcにtrueをセットすると各点のマハラノビス距離を計算します。 048 * 049 * データ = { { 90 ,60 }, { 70, 80 } } 050 * のような形としてデータを与えます。 051 * 052 * @param matrix 値のデータ 053 * @param calc 距離計算を行うかどうか 054 */ 055 public HybsMahalanobis( final double[][] matrix, final boolean calc ) { 056 // 一応元データをセットしておく 057 final RealMatrix dataMatrix = new Array2DRowRealMatrix( matrix ); 058 059 // 共分散行列を作成 060 covariance = new Covariance(matrix).getCovarianceMatrix(); 061 //平均の配列を作成 062 average = new double[matrix[0].length]; 063 for( int i=0; i<matrix[0].length; i++) { 064 average[i] = StatUtils.mean(dataMatrix.getColumn(i)); 065 } 066 067 if(calc) { 068 dataDistance = new double[matrix.length]; 069 for( int i=0; i< matrix.length; i++ ) { 070 // dataDistance[i] = distance( matrix[i] ); 071 dataDistance[i] = distance( covariance,matrix[i],average ); // PMD:Overridable method 'distance' called during object construction 072 } 073 // 標準偏差、平均を取る場合 074 //double maxDst = StatUtils.max( dataDistance ); // 最大 075 //double vrDst = StatUtils.variance( dataDistance ); // 分散 076 //double shigma = Math.sqrt(vrDst); // シグマ 077 //double meanDst = StatUtils.mean( dataDistance ); // 平均 078 } 079 } 080 081 /** 082 * 距離計算がtrueの形の簡易版コンストラクタです。 083 * 084 * @param matrix 値データ 085 */ 086 public HybsMahalanobis(final double[][] matrix) { 087 this(matrix,true); 088 } 089 090 /** 091 * コンストラクタ。 092 * 計算済みの共分散と平均、閾値を与えるパターン。 093 * 094 * @param covarianceData 共分散 095 * @param averageData 平均配列 096 */ 097 public HybsMahalanobis(final double[][] covarianceData, final double[] averageData) { 098 this.covariance = new Array2DRowRealMatrix(covarianceData); 099 this.average = averageData; 100 } 101 102 /** 103 * 平均配列を返します。 104 * 105 * @return 平均 106 */ 107 public double[] getAverage() { 108 return average; 109 } 110 111 /** 112 * 共分散配列を返します。 113 * 114 * @return 共分散 115 */ 116 public double[][] getCovariance() { 117 return covariance.getData(); 118 } 119 120 /** 121 * 閾値を返します。 122 * 123 * @return 閾値 124 */ 125 public double getLimen() { 126 return limen; 127 } 128 129 /** 130 * 平均配列をセットします。 131 * 132 * @param ave 平均 133 */ 134 public void setAverage( final double[] ave ) { 135 this.average = ave; 136 } 137 138 /** 139 * 共分散配列をセットします。 140 * 141 * @param cvr 共分散 142 */ 143 public void setCovariance( final double[][] cvr ) { 144 this.covariance = new Array2DRowRealMatrix(cvr); 145 } 146 147 /** 148 * 閾値をセットします。 149 * 距離の二乗がカイ2乗分布となるため、 150 * 初期値は2.448で、95%区間を意味します。 151 * 2が86%、3が99%です。 152 * 153 * @param lim 閾値 154 */ 155 public void setLimen( final double lim ) { 156 this.limen = lim; 157 } 158 159 /** 160 * コンストラクタで元データを与え、計算させた場合のマハラノビス距離の配列を返します。 161 * 162 * @return 各点のマハラノビス距離の配列 163 */ 164 public double[] getDataDistance() { 165 return dataDistance; 166 } 167 168 /** 169 * マハラノビス距離を計算します。 170 * 171 * @param vec 判定する点(ベクトル) 172 * @return マハラノビス距離 173 */ 174 public double distance( final double[] vec) { 175 return distance( covariance, vec, average ); 176 } 177 178 /** 179 * 与えたベクトルが閾値を超えたマハラノビス距離かどうかを判定します。 180 * 閾値以下ならtrue、超えている場合はfalseを返します。 181 * (異常値判定) 182 * 183 * @param vec 判定する点(ベクトル) 184 * @return 閾値以下かどうか 185 */ 186 public boolean check( final double[] vec) { 187// final double dist = distance( covariance, vec, average ); 188// return ( dist <= limen ); 189 return distance( covariance, vec, average ) <= limen ; // 6.9.7.0 (2018/05/14) PMD Useless parentheses. 190 } 191 192 /** 193 * 平均、共分散を利用して対象ベクトルとの距離を測ります。 194 * 195 * @og.rev 6.9.8.0 (2018/05/28) det を削除します。 196 * @og.rev 6.9.9.0 (2018/08/20) ロジック修正ミス 197 * 198 * @param mtx1 共分散行列 199 * @param vec1 距離を測りたいベクトル 200 * @param vec2 平均ベクトル 201 * @return マハラノビス距離 202 */ 203 private double distance(final RealMatrix mtx1, final double vec1[], final double vec2[]) { 204 // マハラノビス距離の公式 205 // マハラノビス距離 = (v1-v2)*inv(m1)*t(v1-v2) 206 // inv():逆行列 207 // t():転置行列 208 209 // ※getDeterminantは行列式(正方行列に対して定義される量)を取得 210 // javaの処理上、v1.lengthが2以上の場合、1/(v1.length)が0になる。 211 // その結果、行列式を0乗になるので、detに1が設定される。 212 // この式はマハラノビス距離を求める公式にない為、不要な処理? 213// final double det = Math.pow((new LUDecomposition(mtx1).getDeterminant()), 1/(vec1.length)); 214 // 6.9.8.0 (2018/05/28) det を削除します。 215 // PMD で、1/(vec1.length) が指摘され、FindBugs で、整数同士の割り算を、double にキャストしている警告が出ます。 216 // vec1 の配列が1の場合のみ有効にするなら、他の方法があるはずで、不要な処理? というコメントとあわせて、 217 // とりあえずコメントアウトしておきます。 218 // final double det = Math.pow( new LUDecomposition(mtx1).getDeterminant() , 1/ vec1.length ); // 6.9.7.0 (2018/05/14) PMD Useless parentheses. 219 220 final double[] temp = new double[vec1.length]; // 8.5.4.2 (2024/01/12) PMD 7.0.0 LocalVariableCouldBeFinal 221 // (x - y)を計算 222 for(int i=0; i < vec1.length; i++) { 223 temp[i] = vec1[i]-vec2[i]; // 6.9.7.0 (2018/05/14) PMD Useless parentheses. 224 } 225 226 // double[] tempSub = new double[vec1.length]; 227 228 // // (x - y)を計算 229 // for(int i=0; i < vec1.length; i++) { 230 // tempSub[i] = vec1[i]-vec2[i]; // 6.9.7.0 (2018/05/14) PMD Useless parentheses. 231 // } 232 233 // double[] temp = new double[vec1.length]; 234 235 // // (x - y) * det 不要な処理? 236 // for(int i=0; i < temp.length; i++) { 237 // temp[i] = tempSub[i]*det; 238 // } 239 240 // m2: (x - y)を行列に変換 241 final RealMatrix m2 = new Array2DRowRealMatrix( new double[][] { temp } ); 242 243 // m3: m2 * 共分散行列の逆行列 244 final RealMatrix m3 = m2.multiply( new LUDecomposition(mtx1).getSolver().getInverse() ); 245 246 // m4: m3 * (x-y)の転置行列 247// final RealMatrix m4 = m3.multiply((new Array2DRowRealMatrix(new double[][] { temp })).transpose()); 248// final RealMatrix m4 = m3.multiply( new Array2DRowRealMatrix( new double[][] { temp } ) ).transpose() ; // 6.9.7.0 (2018/05/14) PMD Useless parentheses. 249 final RealMatrix m4 = m3.multiply( new Array2DRowRealMatrix( new double[][] { temp } ).transpose() ) ; // 6.9.9.0 (2018/08/20) ロジック修正ミス 250 251 // m4の平方根を返す 252 return Math.sqrt(m4.getEntry(0, 0)); 253 } 254 255 // *** ここまでが本体 *** 256 257 /** 258 * ここからテスト用mainメソッド。 259 * 260 * @param args **************************************** 261 */ 262 public static void main( final String [] args ) { 263 // 幾何的には、これらの重心を中心とした楕円の中に入っているかどうかを判定 264 final double[][] data = { 265 {2, 10}, 266 {4, 21}, 267 {6, 27}, 268 {8, 41}, 269 {10, 50} 270 }; 271 272 final double[] test = {12, 50}; 273 final double[] test2 = {12, 59}; 274 275 final HybsMahalanobis rtn = new HybsMahalanobis(data); 276 277 System.out.println( java.util.Arrays.toString(rtn.getDataDistance()) ); 278 279 System.out.println(rtn.check( test )); 280 System.out.println(rtn.check( test2 )); 281 } 282} 283