[go: up one dir, main page]

Menu

[r74]: / trunk / rlgo / RlTrainer.h  Maximize  Restore  History

Download this file

138 lines (100 with data), 4.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
//----------------------------------------------------------------------------
/** @file RlTrainer.h
Classes to train from historic experience
*/
//----------------------------------------------------------------------------
#ifndef RLTRAINER_H
#define RLTRAINER_H
#include "RlUtils.h"
class RlEvaluator;
class RlHistory;
class RlLearningRule;
class RlState;
//----------------------------------------------------------------------------
/** Classes to train from historic experience (experience replay) */
class RlTrainer : public RlAutoObject
{
public:
RlTrainer(GoBoard& board, RlLearningRule* rule = 0,
RlHistory* history = 0, RlEvaluator* evaluator = 0);
virtual void LoadSettings(std::istream& settings);
virtual void Train() = 0;
const RlHistory* GetHistory() const { return m_history; }
RlLearningRule* GetLearningRule() const { return m_learningRule; }
protected:
void RefreshValue(RlState& state);
int SelectEpisode(int replay);
enum
{
EP_CURRENT, // train on current episode only
EP_LAST, // train on most recent episodes in history
EP_RANDOM // train on randomly selected episodes in history
};
/** Learning rule to apply to the experience */
RlLearningRule* m_learningRule;
/** History containing training experience */
RlHistory* m_history;
/** Evaluator used to refresh value of experience */
RlEvaluator* m_evaluator;
/** Which episodes to train on */
int m_episodes;
/** How many experience replays to perform */
int m_numReplays;
/** Whether to train from root position */
bool m_updateRoot;
/** How many time-steps between pairs of states that are updated:
1=Update from black to white
2=Update from black to black
n=Update from n steps into the future */
int m_temporalDifference;
/** Whether to refresh values online, or use historic values */
bool m_refreshValues;
/** Whether to interleave updates, or to do a single trajectory
(Only relevant for episodic updates with temporal difference > 1) */
bool m_interleave;
/** Whether to update weights during learning (e.g. training stage)
or just to measure error (e.g. testing stage) */
bool m_updateWeights;
};
//----------------------------------------------------------------------------
/** Train in sweeps through episodes of experience */
class RlEpisodicTrainer : public RlTrainer
{
public:
RlEpisodicTrainer(GoBoard& board, RlLearningRule* rule = 0,
RlHistory* history = 0, RlEvaluator* evaluator = 0);
virtual void Train();
virtual void Sweep(int episode, int start, int offset, int gap) = 0;
};
//----------------------------------------------------------------------------
/** Train in forward sweeps through episodes of experience */
class RlForwardTrainer : public RlEpisodicTrainer
{
public:
DECLARE_OBJECT(RlForwardTrainer);
RlForwardTrainer(GoBoard& board, RlLearningRule* rule = 0,
RlHistory* history = 0, RlEvaluator* evaluator = 0);
virtual void Sweep(int episode, int start, int offset, int gap);
};
//----------------------------------------------------------------------------
/** Train in backward sweeps through episodes of experience */
class RlBackwardTrainer : public RlEpisodicTrainer
{
public:
DECLARE_OBJECT(RlBackwardTrainer);
RlBackwardTrainer(GoBoard& board, RlLearningRule* rule = 0,
RlHistory* history = 0, RlEvaluator* evaluator = 0);
virtual void Sweep(int episode, int start, int offset, int gap);
};
//----------------------------------------------------------------------------
/** Train on randomly selected transitions from the history */
class RlRandomTrainer : public RlTrainer
{
public:
DECLARE_OBJECT(RlRandomTrainer);
RlRandomTrainer(GoBoard& board, RlLearningRule* rule = 0,
RlHistory* history = 0, RlEvaluator* evaluator = 0);
virtual void Train();
};
//----------------------------------------------------------------------------
#endif // RLTRAINER_H