00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef _neuralnetwork_h
00018 #define _neuralnetwork_h
00019
00020 #include <spl/collection/Array.h>
00021 #include <spl/data/Connection.h>
00022 #include <spl/Exception.h>
00023 #include <spl/Memory.h>
00024
00030 class NeuralNetworkException : Exception
00031 {
00032 public:
00033 NeuralNetworkException(const char *msg) : Exception(msg)
00034 {
00035 }
00036 };
00037
00038 enum NetworkLayer
00039 {
00040 NET_INPUT = 0,
00041 NET_L1 = 1,
00042 NET_L2 = 2,
00043 NET_OUTPUT = 3
00044 };
00045
00049 class Network : public IMemoryValidate
00050 {
00051 private:
00052 Network();
00053 void BuildNetwork( int inlen, int len1, int len2, int outlen );
00054 void DeleteArrays();
00055 void RandomizeWeights();
00056
00057 int wt1ToInLen() { return m_inputLen * m_h1len; }
00058 int wt2To1Len() { return m_h1len * m_h2len; }
00059 int wtOutTo2Len() { return m_outputLen * m_h2len; }
00060
00061 int m_inputLen;
00062 int m_h1len;
00063 int m_h2len;
00064 int m_outputLen;
00065
00066 Array<double> m_layer1;
00067 Array<double> m_layer2;
00068
00069 Array<double> m_errors1;
00070 Array<double> m_errors2;
00071 Array<double> m_errorsOut;
00072
00073 Array<double> m_wt1ToIn;
00074 Array<double> m_delta1ToIn;
00075
00076 Array<double> m_wt2To1;
00077 Array<double> m_delta2To1;
00078
00079 Array<double> m_wtOutTo2;
00080 Array<double> m_deltaOutTo2;
00081
00082 protected:
00083 void Init();
00084 void ActivateLayer( Array<double>& layer1, Array<double>& layer2, Array<double>& wt );
00085
00086 inline void Activate()
00087 {
00088 ActivateLayer(m_input, m_layer1, m_wt1ToIn);
00089 ActivateLayer(m_layer1, m_layer2, m_wt2To1);
00090 ActivateLayer(m_layer2, m_output, m_wtOutTo2);
00091 }
00092
00093 int m_networkId;
00094 double m_learnRate;
00095 double m_momentum;
00096 double m_wtRange;
00097 bool m_useAdaptiveLR;
00098 bool m_useAdaptiveMom;
00099 double m_previousError;
00100
00101 Array<double> m_input;
00102 Array<double> m_output;
00103
00104 public:
00105 Network( int inlen, int len1, int len2, int outlen );
00106 Network( int id, Connection& conn );
00107 Network( const Network& net );
00108 virtual ~Network();
00109
00110 void Write( Connection& conn );
00111 double CalcError( const Array<double>& input, const Array<double>& output );
00112 double Train(const Array<double>& input, const Array<double>& output, double dErrTarget, int maxCycles);
00113
00114 inline void Activate( const Array<double>& input, Array<double>& output )
00115 {
00116 input.CopyToBinary(m_input);
00117 Activate();
00118 m_output.CopyToBinary(output);
00119 }
00120
00121 inline int NetworkId() const
00122 {
00123 return m_networkId;
00124 }
00125
00126 inline int InputCount() const
00127 {
00128 return m_inputLen;
00129 }
00130
00131 inline int Layer1Count() const
00132 {
00133 return m_h1len;
00134 }
00135
00136 inline int Layer2Count() const
00137 {
00138 return m_h2len;
00139 }
00140
00141 inline int OutputCount() const
00142 {
00143 return m_outputLen;
00144 }
00145
00146 inline void ResetID()
00147 {
00148 m_networkId = -1;
00149 }
00150
00151 Network& operator =(const Network& net);
00152
00153 #ifdef DEBUG
00154 void ValidateMem() const;
00155 void CheckMem() const;
00156 #endif
00157 };
00158
00161 #endif