• Main Page
  • Related Pages
  • Modules
  • Namespaces
  • Classes
  • Files
  • File List
  • File Members

spl/math/NeuralNetwork.h

00001 /*
00002  *   This file is part of the Standard Portable Library (SPL).
00003  *
00004  *   SPL is free software: you can redistribute it and/or modify
00005  *   it under the terms of the GNU General Public License as published by
00006  *   the Free Software Foundation, either version 3 of the License, or
00007  *   (at your option) any later version.
00008  *
00009  *   SPL is distributed in the hope that it will be useful,
00010  *   but WITHOUT ANY WARRANTY; without even the implied warranty of
00011  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00012  *   GNU General Public License for more details.
00013  *
00014  *   You should have received a copy of the GNU General Public License
00015  *   along with SPL.  If not, see <http://www.gnu.org/licenses/>.
00016  */
00017 #ifndef _neuralnetwork_h
00018 #define _neuralnetwork_h
00019 
00020 #include <spl/collection/Array.h>
00021 #include <spl/data/Connection.h>
00022 #include <spl/Exception.h>
00023 #include <spl/Memory.h>
00024 
00030 class NeuralNetworkException : Exception
00031 {
00032 public:
00033         NeuralNetworkException(const char *msg) : Exception(msg)
00034         {
00035         }
00036 };
00037 
00038 enum NetworkLayer
00039 {
00040         NET_INPUT = 0,
00041         NET_L1 = 1,
00042         NET_L2 = 2,
00043         NET_OUTPUT = 3
00044 };
00045 
00049 class Network : public IMemoryValidate
00050 {
00051 private:
00052         Network();
00053         void BuildNetwork( int inlen, int len1, int len2, int outlen );
00054         void DeleteArrays();
00055         void RandomizeWeights();
00056 
00057         int wt1ToInLen() { return m_inputLen * m_h1len; }
00058         int wt2To1Len() { return m_h1len * m_h2len; }
00059         int wtOutTo2Len() { return m_outputLen * m_h2len; }
00060 
00061         int m_inputLen;
00062         int m_h1len;
00063         int m_h2len;
00064         int m_outputLen;
00065 
00066         Array<double> m_layer1;
00067         Array<double> m_layer2;
00068         
00069         Array<double> m_errors1; // error for each node
00070         Array<double> m_errors2;
00071         Array<double> m_errorsOut;
00072         
00073         Array<double> m_wt1ToIn; // weights from layer 1 to inputs
00074         Array<double> m_delta1ToIn; // deltas from layer 1 to inputs
00075         
00076         Array<double> m_wt2To1; // weights from layer 2 to layer 1
00077         Array<double> m_delta2To1; // deltas from layer 2 to layer 1
00078         
00079         Array<double> m_wtOutTo2; // weights from output to layer 3
00080         Array<double> m_deltaOutTo2; // deltas from output to layer 3
00081 
00082 protected:
00083         void Init();
00084         void ActivateLayer( Array<double>& layer1, Array<double>& layer2, Array<double>& wt );
00085 
00086         inline void Activate()
00087         {
00088                 ActivateLayer(m_input,  m_layer1, m_wt1ToIn);
00089                 ActivateLayer(m_layer1, m_layer2, m_wt2To1);
00090                 ActivateLayer(m_layer2, m_output, m_wtOutTo2);
00091         }
00092 
00093         int m_networkId;
00094         double m_learnRate;
00095         double m_momentum;
00096         double m_wtRange;
00097         bool m_useAdaptiveLR;
00098         bool m_useAdaptiveMom;
00099         double m_previousError;
00100 
00101         Array<double> m_input; // nodes
00102         Array<double> m_output;
00103 
00104 public:
00105         Network( int inlen, int len1, int len2, int outlen );
00106         Network( int id, Connection& conn );
00107         Network( const Network& net );
00108         virtual ~Network();
00109 
00110         void Write( Connection& conn );
00111         double CalcError( const Array<double>& input, const Array<double>& output );
00112         double Train(const Array<double>& input, const Array<double>& output, double dErrTarget, int maxCycles);
00113 
00114         inline void Activate( const Array<double>& input, Array<double>& output )
00115         {
00116                 input.CopyToBinary(m_input);
00117                 Activate();
00118                 m_output.CopyToBinary(output);
00119         }
00120 
00121         inline int NetworkId() const
00122         {
00123                 return m_networkId;
00124         }
00125 
00126         inline int InputCount() const
00127         {
00128                 return m_inputLen;
00129         }
00130 
00131         inline int Layer1Count() const
00132         {
00133                 return m_h1len;
00134         }
00135 
00136         inline int Layer2Count() const
00137         {
00138                 return m_h2len;
00139         }
00140 
00141         inline int OutputCount() const
00142         {
00143                 return m_outputLen;
00144         }
00145 
00146         inline void ResetID()
00147         {
00148                 m_networkId = -1;
00149         }
00150 
00151         Network& operator =(const Network& net);
00152 
00153 #ifdef DEBUG
00154         void ValidateMem() const;
00155         void CheckMem() const;
00156 #endif
00157 };
00158 
00161 #endif