クラス HybsLogisticRegression

java.lang.Object
org.opengion.penguin.math.statistics.HybsLogisticRegression

public class HybsLogisticRegression extends Object
多項ロジスティック回帰の実装です。 確率的勾配降下法(SGD)を利用します。 ロジスティック回帰はn次元の情報からどのグループに所属するかの予測値を得るための手法の一つです。 実装は http://nbviewer.jupyter.org/gist/mitmul/9283713 https://yusugomori.com/projects/deep-learning/ を参考にしています。
  • コンストラクタの概要

    コンストラクタ
    コンストラクタ
    説明
    HybsLogisticRegression(double[][] data, int[][] label, double learning_rate, int loop, double minibatch_rate)
    コンストラクタ。
  • メソッドの概要

    修飾子とタイプ
    メソッド
    説明
    double[]
    写像式 Wx + bのb、バイアス。
    double[][]
    写像式 Wx+b のW、係数ベクトル。
    static void
    main(String[] args)
    ここからテスト用mainメソッド 。
    double[]
    predict(double[] in_x)
    出来た予測式に対して、データを入力してyを出力する。

    クラスから継承されたメソッド java.lang.Object

    clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
  • コンストラクタの詳細

    • HybsLogisticRegression

      public HybsLogisticRegression(double[][] data, int[][] label, double learning_rate, int loop, double minibatch_rate)
      コンストラクタ。 学習もしてしまう。 xはデータセット各行がn次元の説明変数となっている。 trainはそれに対する{0,1,0},{1,0,0}のようなラベルを示すベクトルとなる。 学習率は通常、0.1程度を設定する。 このロジックではループ毎に0.95をかけて徐々に学習率が下がるようにしている。 全データを利用すると時間がかかる場合があるので、確率的勾配降下法を利用しているが、 選択個数はデータに対する割合を与える。 データ個数が少ない場合は1をセットすればよい。
      パラメータ:
      data - データセット配列
      label - データに対応したラベルを示す配列
      learning_rate - 学習係数(0から1の間の数値)
      loop - 学習のループ回数(ミニバッチを作る回数)
      minibatch_rate - 全体に対するミニバッチの割合(0から1の間の数値)
      変更履歴:
      8.5.4.2 (2024/01/12) PMD 7.0.0 FieldNamingConventions 対応 n_N ⇒ n_cnt
  • メソッドの詳細

    • getW

      public double[][] getW()
      写像式 Wx+b のW、係数ベクトル。
      戻り値:
      係数ベクトル
    • getB

      public double[] getB()
      写像式 Wx + bのb、バイアス。
      戻り値:
      バイアスベクトル
    • predict

      public double[] predict(double[] in_x)
      出来た予測式に対して、データを入力してyを出力する。 (yは各ラベルに対する確率分布となる)
      パラメータ:
      in_x - 予測したいデータ
      戻り値:
      予測結果
    • main

      public static void main(String[] args)
      ここからテスト用mainメソッド 。
      パラメータ:
      args - 引数
      変更履歴:
      8.5.4.2 (2024/01/12) PMD 7.0.0 LocalVariableNamingConventions 対応