ARCS6 AR6-REV.24062600
読み取り中…
検索中…
一致する文字列を見つけられません
RecurrentNeuralNet3.hh
[詳解]
1
8//
9// Copyright (C) 2011-2021 Yokokura, Yuki
10// This program is free software;
11// you can redistribute it and/or modify it under the terms of the FreeBSD License.
12// For details, see the License.txt file.
13
14#ifndef RECURRENTNEURALNET3
15#define RECURRENTNEURALNET3
16
17#include <cassert>
19#include "TimeSeriesDatasets.hh"
20
21// ARCS組込み用マクロ
22#ifdef ARCS_IN
23 // ARCSに組み込まれる場合
24 #include "ARCSassert.hh"
25 #include "ARCSeventlog.hh"
26#else
27 // ARCSに組み込まれない場合
28 #define arcs_assert(a) (assert(a))
29 #define PassedLog()
30 #define EventLog(a)
31 #define EventLogVar(a)
32#endif
33
34namespace ARCS { // ARCS名前空間
37template <
38 size_t Nin, // ユニット数(入力層)
39 size_t NL1, // ユニット数(内部層1)
40 size_t Nout, // ユニット数(出力層)
41 size_t Tlen, // 時刻データの長さ
42 size_t Wind, // 入力ウィンドウ幅
43 size_t Mbat, // ミニバッチサイズ
44 size_t Epch, // エポック数
45 ActvFunc FuncIn = ActvFunc::TANH, // 入力層の活性化関数
46 ActvFunc FuncL1 = ActvFunc::TANH, // 内部層1の活性化関数
47 ActvFunc FuncOut = ActvFunc::TANH, // 出力層の活性化関数
48 NnInitTypes InitType = NnInitTypes::HE, // 重み初期化のタイプ
49 NnDescentTypes GradDesType = NnDescentTypes::SGD, // 勾配降下法のタイプ
50 NnDropout DropOutEnable = NnDropout::DISABLE // ドロップアウトイネーブル
51>
53 public:
56 : Dataset(TSD), RNNin(), RNNL1(), RNNout()
57 {
58 PassedLog();
59 }
60
64 : Dataset(r.Dataset)
65 {
66
67 }
68
73
75 void Train(void){
76 // 重み行列の初期化
77 RNNin.InitWeight(Nin); // 入力層の重み行列の乱数による初期化
78 //RNNL1.InitWeight(Nin); // 内部層1の重み行列の乱数による初期化
79 RNNout.InitWeight(30); // 出力層の重み行列の乱数による初期化
80
81 //RNNin.DispWeightAndBias();
82 //RNNL1.DispWeightAndBias();
83 //RNNout.DispWeightAndBias();
84
85 // エポック数分のループ
86 for(size_t i = 1; i <= Epch; ++i){
87 //RNNin.GenerateDropMask();
88 //RNNL1.GenerateDropMask();
89 //RNNout.GenerateDropMask();
90
91 // 順伝播計算
92 for(size_t t = 1; t <= Wind; ++t){
93 RNNin.PropagateForward(Dataset.InputData, t); // 入力層の順伝播計算
94 }
95 RNNout.PropagateForwardForOutput(RNNin.z); // 出力層の順伝播計算
96
97 // 逆伝播計算
98 RNNout.PropagateBackwardForOutput(Dataset.TrainData); // 出力層の逆伝播計算
99 for(size_t t = Wind; 1 <= t; --t){
100 RNNin.PropagateBackward(RNNout.dLde, t); // 出力層の逆伝播計算
101 }
102
103 // 重み・バイアス更新計算
104 RNNin.UpdateWeightAndBias(Dataset.InputData);
105 RNNout.UpdateWeightAndBiasForOutput(RNNin.z);
106
107 if(i % 10 == 1) RNNout.DispError();
108
109 RNNin.ClearStateVars();
110 RNNout.ClearStateVars();
111
112 }
113
114 //RNNin.DispWeightAndBias();
115 //RNNL1.DispWeightAndBias();
116 //RNNout.DispWeightAndBias();
117 }
118
119 private:
121 const RecurrentNeuralNet3& operator=(const RecurrentNeuralNet3&) = delete;
122
123 static constexpr size_t EpchDisp = 10;
125
129};
130}
131
132#endif
133
ARCS イベントログクラス
#define PassedLog()
イベントログ用マクロ(ファイルと行番号のみ記録版)
Definition ARCSeventlog.hh:26
ARCS用ASSERTクラス
機械学習用 時系列データセットクラス
ActvFunc
活性化関数のタイプの定義
Definition ActivationFunctions.hh:35
再帰ニューラルレイヤクラス
NnInitTypes
重み初期化のタイプの定義
Definition NeuralNetParamDef.hh:19
NnDropout
ドロップアウトの定義
Definition NeuralNetParamDef.hh:35
NnDescentTypes
勾配降下法のタイプの定義
Definition NeuralNetParamDef.hh:25
再帰ニューラルレイヤクラス
Definition RecurrentNeuralLayer.hh:62
クラステンプレート
Definition RecurrentNeuralNet3.hh:52
RecurrentNeuralNet3(RecurrentNeuralNet3 &&r)
ムーブコンストラクタ
Definition RecurrentNeuralNet3.hh:63
~RecurrentNeuralNet3()
デストラクタ
Definition RecurrentNeuralNet3.hh:70
void Train(void)
誤差逆伝播法による訓練をする関数
Definition RecurrentNeuralNet3.hh:75
RecurrentNeuralNet3(TimeSeriesDatasets< Nin, Nout, Tlen, Wind, Mbat > &TSD)
コンストラクタ
Definition RecurrentNeuralNet3.hh:55
機械学習用 時系列データセットクラス
Definition TimeSeriesDatasets.hh:50
Matrix< M, K > TrainData
ベクトル配列版の標準化済み訓練データ(範囲 t = 1 … W, t = 0 と W + 1 の分も確保)
Definition TimeSeriesDatasets.hh:54
std::array< Matrix< M, N >, W+2 > InputData
ベクトル配列版の標準化済み入力データ(範囲 t = 1 … W, t = 0 と W + 1 の分も確保)
Definition TimeSeriesDatasets.hh:53