Skip to content

Commit 449e9ec

Browse files
committed
Add RL model tests and improve the models
1 parent 4093dbc commit 449e9ec

22 files changed

+295
-59
lines changed

js/platform/rl.js

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,11 @@ export default class RLPlatform extends BasePlatform {
193193
this._plotter.printStep()
194194
this._plotter.plotRewards()
195195
}
196-
return [stepInfo.state, stepInfo.reward, stepInfo.done]
196+
return stepInfo
197197
}
198198

199199
test(state, action, agent) {
200-
const stepInfo = this._env.test(state, action, agent);
201-
return [stepInfo.state, stepInfo.reward, stepInfo.done]
200+
return this._env.test(state, action, agent)
202201
}
203202

204203
sample_action(agent) {

js/view/a2c.js

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import A2CAgent from '../../lib/model/a2c.js'
22

33
class A2CCBAgent {
44
constructor(env, resolution, layers, optimizer, use_worker, cb) {
5-
this._agent = new A2CAgent(env, resolution, 50, layers, optimizer)
5+
this._agent = new A2CAgent(env.env, resolution, 50, layers, optimizer)
66
cb && cb()
77
}
88

@@ -61,10 +61,10 @@ var dispA2C = function (elm, env) {
6161
const learning_rate = +elm.select('[name=learning_rate]').property('value')
6262
const batch = +elm.select('[name=batch]').property('value')
6363
agent.get_action(cur_state, action => {
64-
const [next_state, reward, done] = env.step(action, agent)
64+
const { state, done } = env.step(action, agent)
6565
agent.update(done, learning_rate, batch, () => {
6666
const end_proc = () => {
67-
cur_state = next_state
67+
cur_state = state
6868
cb && cb(done)
6969
}
7070
if (render) {

js/view/dqn.js

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ var dispDQN = function (elm, env) {
7070
const learning_rate = +elm.select('[name=learning_rate]').property('value')
7171
const batch = +elm.select('[name=batch]').property('value')
7272
agent.get_action(cur_state, Math.max(min_greedy_rate, greedy_rate * greedy_rate_update), action => {
73-
let [next_state, reward, done] = env.step(action, agent)
74-
agent.update(action, cur_state, next_state, reward, done, learning_rate, batch, () => {
73+
const { state, reward, done } = env.step(action, agent)
74+
agent.update(action, cur_state, state, reward, done, learning_rate, batch, () => {
7575
const end_proc = () => {
76-
cur_state = next_state
76+
cur_state = state
7777
if (done || env.epoch % 1000 === 999) {
7878
elm.select('[name=greedy_rate]').property('value', greedy_rate * greedy_rate_update)
7979
}

js/view/dynamic_programming.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ var dispDP = function (elm, env) {
5454
;(function loop() {
5555
if (isMoving) {
5656
const action = agent.get_action(cur_state)
57-
const [next_state, reward, done] = env.step(action, agent)
57+
const { state } = env.step(action, agent)
5858
env.render(() => agent.get_score())
59-
cur_state = next_state
59+
cur_state = state
6060
setTimeout(loop, 10)
6161
}
6262
})()

js/view/genetic_algorithm.js

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@ var dispGeneticAlgorithm = function (elm, env) {
6262
testButton.attr('value', isTesting ? 'Stop' : 'Test')
6363
if (isTesting) {
6464
const topAgent = agent.top_agent()
65-
let state = env.reset(topAgent)
65+
let cur_state = env.reset(topAgent)
6666
void (function loop() {
67-
const action = topAgent.get_action(state)
68-
const [next_state, reward, done] = env.step(action, topAgent)
69-
state = next_state
67+
const action = topAgent.get_action(cur_state)
68+
const { state, done } = env.step(action, topAgent)
69+
cur_state = state
7070
env.render()
7171
if (isTesting && !done) {
7272
setTimeout(() => loop(), 0)

js/view/monte_carlo.js

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,21 @@ var dispMC = function (elm, env) {
77
let cur_state = env.reset(agent)
88
env.render(() => agent.get_score())
99

10-
let action_history = []
11-
1210
const step = (render = true) => {
1311
const greedy_rate = +elm.select('[name=greedy_rate]').property('value')
1412
const action = agent.get_action(cur_state, greedy_rate)
15-
const [next_state, reward, done] = env.step(action, agent)
16-
action_history.push([action, cur_state, reward])
13+
const { state, reward, done } = env.step(action, agent)
14+
agent.update(action, cur_state, reward, done)
1715
if (render) {
1816
env.render()
1917
}
20-
cur_state = next_state
21-
if (done) {
22-
agent.update(action_history)
23-
action_history = []
24-
}
18+
cur_state = state
2519
return done
2620
}
2721

2822
const reset = () => {
2923
cur_state = env.reset(agent)
30-
action_history = []
24+
agent.reset()
3125
env.render(() => agent.get_score())
3226
}
3327

js/view/policy_gradient.js

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,21 @@ var dispPolicyGradient = function (elm, env) {
77
let cur_state = env.reset(agent)
88
env.render(() => agent.get_score())
99

10-
let action_history = []
11-
1210
const step = (render = true) => {
1311
const learning_rate = +elm.select('[name=learning_rate]').property('value')
1412
const action = agent.get_action(cur_state)
15-
const [next_state, reward, done] = env.step(action, agent)
16-
action_history.push([action, cur_state, reward])
13+
const { state, reward, done } = env.step(action, agent)
14+
agent.update(action, cur_state, reward, done, learning_rate)
1715
if (render) {
1816
env.render()
1917
}
20-
cur_state = next_state
21-
if (done) {
22-
agent.update(action_history, learning_rate)
23-
action_history = []
24-
}
18+
cur_state = state
2519
return done
2620
}
2721

2822
const reset = () => {
2923
cur_state = env.reset(agent)
30-
action_history = []
24+
agent.reset()
3125
env.render(() => agent.get_score())
3226
}
3327

js/view/q_learning.js

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,16 @@ var dispQLearning = function (elm, env) {
1010
const step = (render = true) => {
1111
const greedy_rate = +elm.select('[name=greedy_rate]').property('value')
1212
const action = agent.get_action(cur_state, greedy_rate)
13-
const [next_state, reward, done] = env.step(action, agent)
14-
agent.update(action, cur_state, next_state, reward)
13+
const { state, reward, done } = env.step(action, agent)
14+
agent.update(action, cur_state, state, reward)
1515
if (render) {
1616
if (env.epoch % 10 === 0) {
1717
env.render(() => agent.get_score())
1818
} else {
1919
env.render()
2020
}
2121
}
22-
cur_state = next_state
22+
cur_state = state
2323
return done
2424
}
2525

js/view/sarsa.js

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,16 @@ var dispSARSA = function (elm, env) {
1010
const step = (render = true) => {
1111
const greedy_rate = +elm.select('[name=greedy_rate]').property('value')
1212
const action = agent.get_action(cur_state, greedy_rate)
13-
const [next_state, reward, done] = env.step(action, agent)
14-
agent.update(action, cur_state, next_state, reward)
13+
const { state, reward, done } = env.step(action, agent)
14+
agent.update(action, cur_state, state, reward)
1515
if (render) {
1616
if (env.epoch % 10 === 0) {
1717
env.render(() => agent.get_score())
1818
} else {
1919
env.render()
2020
}
2121
}
22-
cur_state = next_state
22+
cur_state = state
2323
if (done) {
2424
agent.reset()
2525
}

lib/model/a2c.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ export default class A2CAgent {
162162
constructor(env, resolution, procs, layers, optimizer) {
163163
this._net = new ActorCriticNet(env, resolution, layers, optimizer)
164164
this._procs = procs
165-
this._env = env.env
165+
this._env = env
166166
this._advanced_step = 5
167167
this._gamma = 0.99
168168

lib/model/dynamic_programming.js

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ class DPTable extends QTableBase {
4848
let vs = []
4949
a.fill(0)
5050
do {
51-
let [y, reward, done] = this._env.test(this._state_value(x), this._action_value(a))
52-
y = this._state_index(y)
51+
const { state, reward, done } = this._env.test(this._state_value(x), this._action_value(a))
52+
const y = this._state_index(state)
5353
const [s, e] = this._to_position(this._state_sizes, y)
5454
const v = reward + this._gamma * lastV[s]
5555
const [_, ps] = this._q(x, a)
@@ -75,8 +75,8 @@ class DPTable extends QTableBase {
7575
a.fill(0)
7676
const x_state = this._state_value(x)
7777
do {
78-
let [y, reward, done] = this._env.test(x_state, this._action_value(a))
79-
y = this._state_index(y)
78+
const { state, reward, done } = this._env.test(x_state, this._action_value(a))
79+
const y = this._state_index(state)
8080
const [s, e] = this._to_position(this._state_sizes, y)
8181
const v = reward + this._gamma * lastV[s]
8282
const [_, ps] = this._q(x, a)

lib/model/genetic_algorithm.js

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,12 @@ class GeneticAlgorithmAgent {
113113
}
114114

115115
run(env) {
116-
let state = env.reset(this)
116+
let cur_state = env.reset(this)
117117
let c = 0
118118
while (c++ < this._max_epoch) {
119-
const action = this.get_action(state)
120-
const [next_state, reward, done] = env.step(action, this)
121-
state = next_state
119+
const action = this.get_action(cur_state)
120+
const { state, reward, done } = env.step(action, this)
121+
cur_state = state
122122
this._total_reward += reward
123123
if (done) break
124124
}

lib/model/monte_carlo.js

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ export default class MCAgent {
3232
constructor(env, resolution = 20) {
3333
this._env = env
3434
this._table = new MCTable(env, resolution)
35+
36+
this._history = []
37+
}
38+
39+
/**
40+
* Reset agent.
41+
*/
42+
reset() {
43+
this._history = []
3544
}
3645

3746
/**
@@ -58,9 +67,16 @@ export default class MCAgent {
5867

5968
/**
6069
* Update model.
61-
* @param {*[]} actions
70+
* @param {*[]} action
71+
* @param {*[]} state
72+
* @param {number} reward
73+
* @param {boolean} done
6274
*/
63-
update(actions) {
64-
this._table.update(actions)
75+
update(action, state, reward, done) {
76+
this._history.push([action, state, reward])
77+
if (done) {
78+
this._table.update(this._history)
79+
this._history = []
80+
}
6581
}
6682
}

lib/model/policy_gradient.js

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,15 @@ export default class PGAgent {
9393
*/
9494
constructor(env, resolution = 20) {
9595
this._table = new SoftmaxPolicyGradient(env, resolution)
96+
97+
this._history = []
98+
}
99+
100+
/**
101+
* Reset agent.
102+
*/
103+
reset() {
104+
this._history = []
96105
}
97106

98107
/**
@@ -114,10 +123,17 @@ export default class PGAgent {
114123

115124
/**
116125
* Update model.
117-
* @param {*[]} actions
126+
* @param {*[]} action
127+
* @param {*[]} state
128+
* @param {number} reward
129+
* @param {boolean} done
118130
* @param {number} learning_rate
119131
*/
120-
update(actions, learning_rate) {
121-
this._table.update(actions, learning_rate)
132+
update(action, state, reward, done, learning_rate) {
133+
this._history.push([action, state, reward])
134+
if (done) {
135+
this._table.update(this._history, learning_rate)
136+
this._history = []
137+
}
122138
}
123139
}

tests/lib/model/a2c.test.js

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import { jest } from '@jest/globals'
2+
jest.retryTimes(3)
3+
4+
import A2CAgent from '../../../lib/model/a2c.js'
5+
import CartPoleRLEnvironment from '../../../lib/rl/cartpole.js'
6+
7+
test('default', () => {
8+
const env = new CartPoleRLEnvironment()
9+
const agent = new A2CAgent(env, 20, 10, [{ type: 'full', out_size: 5, activation: 'tanh' }], 'adam')
10+
for (let i = 0; i < 10000; i++) {
11+
agent.update(true, 0.01, 10)
12+
}
13+
14+
let totalReward = 0
15+
let curState = env.reset()
16+
while (true) {
17+
const action = agent.get_action(curState)
18+
const { state, reward, done } = env.step(action)
19+
totalReward += reward
20+
curState = state
21+
if (done) {
22+
break
23+
}
24+
}
25+
expect(totalReward).toBeGreaterThan(150)
26+
const score = agent.get_score()
27+
expect(score).toHaveLength(20)
28+
expect(score[0]).toHaveLength(20)
29+
expect(score[0][0]).toHaveLength(20)
30+
expect(score[0][0][0]).toHaveLength(20)
31+
expect(score[0][0][0][0]).toHaveLength(2)
32+
})

tests/lib/model/dqn.test.js

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import { jest } from '@jest/globals'
2+
jest.retryTimes(3)
3+
4+
import DQNAgent from '../../../lib/model/dqn.js'
5+
import CartPoleRLEnvironment from '../../../lib/rl/cartpole.js'
6+
7+
test('default', () => {
8+
const env = new CartPoleRLEnvironment()
9+
const agent = new DQNAgent(env, 20, [{ type: 'full', out_size: 10, activation: 'tanh' }], 'adam')
10+
11+
const totalRewards = []
12+
const n = 200
13+
for (let i = 0; i < n; i++) {
14+
let curState = env.reset()
15+
totalRewards[i] = 0
16+
while (true) {
17+
const action = agent.get_action(curState, 1 - (i / n) ** 2)
18+
const { state, reward, done } = env.step(action)
19+
agent.update(action, curState, state, reward, done, 0.001, 10)
20+
totalRewards[i] += reward
21+
curState = state
22+
if (done) {
23+
break
24+
}
25+
}
26+
}
27+
expect(totalRewards.slice(-5).reduce((s, v) => s + v, 0) / 5).toBeGreaterThan(150)
28+
const score = agent.get_score()
29+
expect(score).toHaveLength(20)
30+
expect(score[0]).toHaveLength(20)
31+
expect(score[0][0]).toHaveLength(20)
32+
expect(score[0][0][0]).toHaveLength(20)
33+
expect(score[0][0][0][0]).toHaveLength(2)
34+
})

0 commit comments

Comments
 (0)