Skip to content

Commit 3964003

Browse files
authored
Improve test (#961)
* Improve test * Use Python version 3.12 * improve some test
1 parent 9304f39 commit 3964003

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+1655
-53
lines changed

lib/model/a2c.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,15 +231,15 @@ export default class A2CAgent {
231231
for (let i = 0; i < this._advanced_step; i++) {
232232
for (let k = 0; k < this._procs; k++) {
233233
const action = this._net.get_action(this._states[k])
234-
const info = this._envs[i].step(action)
234+
const info = this._envs[k].step(action)
235235
;(actions[k] ||= []).push(action)
236236
;(states[k] ||= []).push(this._states[k])
237237
;(next_states[k] ||= []).push(info.state)
238238
;(rewards[k] ||= []).push(info.reward)
239239
;(dones[k] ||= []).push(info.done)
240240

241241
if (info.done) {
242-
this._states[k] = this._envs[i].reset()
242+
this._states[k] = this._envs[k].reset()
243243
} else {
244244
this._states[k] = info.state
245245
}

lib/model/dqn.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ class DQN {
198198
q[i][a_idx] = data[i][3] + this._gamma * next_t_q[i][argmax(next_q[i])]
199199
}
200200
const loss = this._net.fit(x, q, 1, learning_rate, batch)
201-
if (this._epoch % this._fix_param_update_step) {
201+
if (this._epoch % this._fix_param_update_step === 0) {
202202
this._target = this._net.copy()
203203
}
204204
return loss[0]

lib/model/nns/graph.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ export default class ComputationalGraph {
293293
}
294294
this._order = []
295295
while (s.length > 0) {
296+
s.sort((a, b) => b - a)
296297
const n = s.pop()
297298
this._order.push(n)
298299
for (const i of outputList[n]) {

onnx.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ function makeOnnxFiles () {
8787
fi
8888
export PYENV_ROOT="${WORK_DIR}/.pyenv"
8989
PYENV="${PYENV_ROOT}/bin/pyenv"
90-
PYENV_PYTHON_VERSION=miniconda3-latest
90+
PYENV_PYTHON_VERSION=3.12.9
9191
if [ ! -d "${PYENV_ROOT}" ]; then
9292
git clone https://github.com/pyenv/pyenv.git "${PYENV_ROOT}"
9393
pushd "${PYENV_ROOT}"

tests/lib/model/a2c.test.js

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ jest.retryTimes(20)
44
import A2CAgent from '../../../lib/model/a2c.js'
55
import CartPoleRLEnvironment from '../../../lib/rl/cartpole.js'
66
import InHypercubeRLEnvironment from '../../../lib/rl/inhypercube.js'
7+
import PendulumRLEnvironment from '../../../lib/rl/pendulum.js'
8+
import ReversiRLEnvironment from '../../../lib/rl/reversi.js'
79

810
test('update', () => {
911
const env = new InHypercubeRLEnvironment(2)
@@ -31,6 +33,27 @@ test('update', () => {
3133
}
3234
}
3335
expect(totalReward.slice(Math.max(0, totalReward.length - 10)).every(v => v > 0)).toBeTruthy()
36+
agent.terminate()
37+
})
38+
39+
test('realrange action', () => {
40+
const env = new PendulumRLEnvironment()
41+
const agent = new A2CAgent(env, 10, 10, [{ type: 'full', out_size: 10, activation: 'tanh' }], 'adam')
42+
agent.update(true, 0.01, 10)
43+
44+
let curState = env.reset()
45+
const best_action = agent.get_action(curState, 0)
46+
expect(best_action).toHaveLength(1)
47+
})
48+
49+
test('array state action', () => {
50+
const env = new ReversiRLEnvironment()
51+
const agent = new A2CAgent(env, 10, 7, [{ type: 'full', out_size: 10, activation: 'tanh' }], 'adam')
52+
agent.update(true, 0.01, 10)
53+
54+
let curState = env.reset()
55+
const best_action = agent.get_action(curState, 0)
56+
expect(best_action).toHaveLength(1)
3457
})
3558

3659
test('get_score', () => {
@@ -43,4 +66,6 @@ test('get_score', () => {
4366
expect(score[0][0]).toHaveLength(20)
4467
expect(score[0][0][0]).toHaveLength(20)
4568
expect(score[0][0][0][0]).toHaveLength(2)
69+
70+
agent.get_score()
4671
})

tests/lib/model/c2p.test.js

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,45 @@ import C2P from '../../../lib/model/c2p.js'
33

44
import { randIndex } from '../../../lib/evaluate/clustering.js'
55

6-
test('clustering', () => {
7-
const model = new C2P(10, 50)
8-
const n = 20
9-
const x = Matrix.concat(
10-
Matrix.concat(Matrix.randn(n, 2, 0, 0.1), Matrix.randn(n, 2, 5, 0.1)),
11-
Matrix.randn(n, 2, [0, 5], 0.1)
12-
).toArray()
6+
describe('clustering', () => {
7+
test('default', () => {
8+
const model = new C2P(10, 50)
9+
const n = 20
10+
const x = Matrix.concat(
11+
Matrix.concat(Matrix.randn(n, 2, 0, 0.1), Matrix.randn(n, 2, 5, 0.1)),
12+
Matrix.randn(n, 2, [0, 5], 0.1)
13+
).toArray()
1314

14-
model.fit(x)
15-
const y = model.predict(3)
16-
expect(y).toHaveLength(x.length)
15+
model.fit(x)
16+
const y = model.predict(3)
17+
expect(y).toHaveLength(x.length)
1718

18-
const t = []
19-
for (let i = 0; i < x.length; i++) {
20-
t[i] = Math.floor(i / n)
21-
}
22-
const ri = randIndex(y, t)
23-
expect(ri).toBeGreaterThan(0.9)
19+
const t = []
20+
for (let i = 0; i < x.length; i++) {
21+
t[i] = Math.floor(i / n)
22+
}
23+
const ri = randIndex(y, t)
24+
expect(ri).toBeGreaterThan(0.9)
25+
})
26+
27+
test('no cutoff', () => {
28+
const model = new C2P(10, 50)
29+
model._cutoff_scale = 0
30+
const n = 20
31+
const x = Matrix.concat(
32+
Matrix.concat(Matrix.randn(n, 2, 0, 0.1), Matrix.randn(n, 2, 5, 0.1)),
33+
Matrix.randn(n, 2, [0, 5], 0.1)
34+
).toArray()
35+
36+
model.fit(x)
37+
const y = model.predict(3)
38+
expect(y).toHaveLength(x.length)
39+
40+
const t = []
41+
for (let i = 0; i < x.length; i++) {
42+
t[i] = Math.floor(i / n)
43+
}
44+
const ri = randIndex(y, t)
45+
expect(ri).toBeGreaterThan(0.9)
46+
})
2447
})

tests/lib/model/dqn.test.js

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
import { jest } from '@jest/globals'
1+
import { expect, jest } from '@jest/globals'
22
jest.retryTimes(5)
33

44
import DQNAgent from '../../../lib/model/dqn.js'
5+
import ReversiRLEnvironment from '../../../lib/rl/reversi.js'
56
import CartPoleRLEnvironment from '../../../lib/rl/cartpole.js'
67
import InHypercubeRLEnvironment from '../../../lib/rl/inhypercube.js'
8+
import PendulumRLEnvironment from '../../../lib/rl/pendulum.js'
79

810
test('update dqn', () => {
911
const env = new InHypercubeRLEnvironment(2)
@@ -29,6 +31,7 @@ test('update dqn', () => {
2931
}
3032
}
3133
expect(totalReward.slice(Math.max(0, totalReward.length - 10)).every(v => v > 0)).toBeTruthy()
34+
agent.terminate()
3235
})
3336

3437
test('update ddqn', () => {
@@ -56,16 +59,94 @@ test('update ddqn', () => {
5659
}
5760
}
5861
expect(totalReward.slice(Math.max(0, totalReward.length - 10)).every(v => v > 0)).toBeTruthy()
62+
agent.terminate()
63+
})
64+
65+
test('realrange action', () => {
66+
const env = new PendulumRLEnvironment()
67+
const agent = new DQNAgent(env, 10, [{ type: 'full', out_size: 3, activation: 'tanh' }], 'adam')
68+
agent._net._batch_size = 1
69+
agent._net._fix_param_update_step = 1
70+
agent._net._do_update_step = 1
71+
72+
let curState = env.reset()
73+
const action = agent.get_action(curState, 0.9)
74+
const { state, reward, done } = env.step(action)
75+
agent.update(action, curState, state, reward, done, 0.001, 10)
76+
77+
const best_action = agent.get_action(state, 0)
78+
expect(best_action).toHaveLength(1)
79+
})
80+
81+
test('array state action', () => {
82+
const env = new ReversiRLEnvironment()
83+
const agent = new DQNAgent(env, 20, [{ type: 'full', out_size: 10, activation: 'tanh' }], 'adam')
84+
85+
agent._net._batch_size = 1
86+
agent._net._fix_param_update_step = 1
87+
agent._net._do_update_step = 1
88+
89+
let curState = env.reset()
90+
const action = agent.get_action(curState, 0.9)
91+
const { state, reward, done } = env.step(action)
92+
agent.update(action, curState, state, reward, done, 0.001, 10)
93+
94+
const best_action = agent.get_action(state, 0)
95+
expect(best_action).toHaveLength(1)
96+
})
97+
98+
test('max memory size', () => {
99+
const env = new InHypercubeRLEnvironment(2)
100+
const agent = new DQNAgent(env, 10, [{ type: 'full', out_size: 3, activation: 'tanh' }], 'adam')
101+
agent.method = 'DDQN'
102+
agent._net._batch_size = 1
103+
agent._net._max_memory_size = 10
104+
105+
let curState = env.reset()
106+
const action = agent.get_action(curState, 0.9)
107+
const { state, reward, done } = env.step(action)
108+
for (let i = 0; i < 20; i++) {
109+
agent.update(action, curState, state, reward, done, 0.001, 10)
110+
expect(agent._net._memory.length).toBeLessThanOrEqual(10)
111+
}
112+
})
113+
114+
test('reset to dqn', () => {
115+
const env = new InHypercubeRLEnvironment(2)
116+
const agent = new DQNAgent(env, 10, [{ type: 'full', out_size: 3, activation: 'tanh' }], 'adam')
117+
agent.method = 'DDQN'
118+
agent._net._batch_size = 1
119+
agent._net._fix_param_update_step = 1
120+
agent._net._do_update_step = 1
121+
122+
let curState = env.reset()
123+
const action = agent.get_action(curState, 0.9)
124+
const { state, reward, done } = env.step(action)
125+
agent.update(action, curState, state, reward, done, 0.001, 10)
126+
127+
expect(agent._net._target).toBeDefined()
128+
agent.method = 'DQN'
129+
expect(agent._net._target).toBeNull()
59130
})
60131

61132
test('get_score', () => {
62133
const env = new CartPoleRLEnvironment()
63-
const agent = new DQNAgent(env, 20, [{ type: 'full', out_size: 10, activation: 'tanh' }], 'adam')
134+
const agent = new DQNAgent(env, 12, [{ type: 'full', out_size: 10, activation: 'tanh' }], 'adam')
64135

65136
const score = agent.get_score()
66-
expect(score).toHaveLength(20)
67-
expect(score[0]).toHaveLength(20)
68-
expect(score[0][0]).toHaveLength(20)
69-
expect(score[0][0][0]).toHaveLength(20)
137+
expect(score).toHaveLength(12)
138+
expect(score[0]).toHaveLength(12)
139+
expect(score[0][0]).toHaveLength(12)
140+
expect(score[0][0][0]).toHaveLength(12)
70141
expect(score[0][0][0][0]).toHaveLength(2)
142+
143+
agent.get_score()
144+
})
145+
146+
test('get_action default', () => {
147+
const env = new InHypercubeRLEnvironment(2)
148+
const agent = new DQNAgent(env, 10, [{ type: 'full', out_size: 3, activation: 'tanh' }], 'adam')
149+
150+
const action = agent.get_action(env.state())
151+
expect(action).toHaveLength(1)
71152
})

tests/lib/model/genetic_algorithm.test.js

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,28 @@ jest.retryTimes(3)
44
import GeneticAlgorithmGeneration from '../../../lib/model/genetic_algorithm.js'
55
import CartPoleRLEnvironment from '../../../lib/rl/cartpole.js'
66

7+
describe('constructor', () => {
8+
test('default', () => {
9+
const env = new CartPoleRLEnvironment()
10+
const agent = new GeneticAlgorithmGeneration(env)
11+
12+
expect(agent._resolution).toBe(20)
13+
expect(agent._model._models).toHaveLength(100)
14+
})
15+
16+
test('resolution', () => {
17+
const env = new CartPoleRLEnvironment()
18+
const agent = new GeneticAlgorithmGeneration(env, 6, 8)
19+
20+
expect(agent._resolution).toBe(8)
21+
expect(agent._model._models).toHaveLength(6)
22+
})
23+
})
24+
725
test('next', () => {
826
const env = new CartPoleRLEnvironment()
927
const agent = new GeneticAlgorithmGeneration(env, 100, 10)
28+
agent.next()
1029
for (let i = 0; i < 100; i++) {
1130
agent.run()
1231
agent.next(0.1)

tests/lib/model/gtm.test.js

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,52 @@ describe('clustering', () => {
6868
expect(y[0]).toBeGreaterThan(y[2])
6969
expect(y[1]).toBeGreaterThan(y[2])
7070
})
71+
72+
test('init random', () => {
73+
const model = new GTM(2, 1)
74+
model._init_method = 'random'
75+
const n = 50
76+
const x = Matrix.concat(
77+
Matrix.concat(Matrix.randn(n, 2, 0, 0.1), Matrix.randn(n, 2, 5, 0.1)),
78+
Matrix.randn(n, 2, [0, 5], 0.1)
79+
).toArray()
80+
81+
for (let i = 0; i < 100; i++) {
82+
model.fit(x)
83+
}
84+
const y = model.predictIndex(x)
85+
expect(y).toHaveLength(x.length)
86+
87+
const t = []
88+
for (let i = 0; i < x.length; i++) {
89+
t[i] = Math.floor(i / n)
90+
}
91+
const ri = randIndex(y, t)
92+
expect(ri).toBeGreaterThan(0.8)
93+
})
94+
95+
test('mode fit', () => {
96+
const model = new GTM(2, 1)
97+
model._fit_method = 'mode'
98+
const n = 50
99+
const x = Matrix.concat(
100+
Matrix.concat(Matrix.randn(n, 2, 0, 0.1), Matrix.randn(n, 2, 5, 0.1)),
101+
Matrix.randn(n, 2, [0, 5], 0.1)
102+
).toArray()
103+
104+
for (let i = 0; i < 100; i++) {
105+
model.fit(x)
106+
}
107+
const y = model.predict(x)
108+
expect(y).toHaveLength(x.length)
109+
110+
const t = []
111+
for (let i = 0; i < x.length; i++) {
112+
t[i] = Math.floor(i / n)
113+
}
114+
const ri = randIndex(y, t)
115+
expect(ri).toBeGreaterThan(0.8)
116+
})
71117
})
72118

73119
test('dimensionality reduction', () => {

tests/lib/model/monte_carlo.test.js

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,23 @@ jest.retryTimes(3)
33

44
import MCAgent from '../../../lib/model/monte_carlo.js'
55
import GridRLEnvironment from '../../../lib/rl/grid.js'
6+
import InHypercubeRLEnvironment from '../../../lib/rl/inhypercube.js'
7+
8+
describe('constructor', () => {
9+
test('default', () => {
10+
const env = new InHypercubeRLEnvironment()
11+
const agent = new MCAgent(env)
12+
13+
expect(agent._table.resolution).toBe(20)
14+
})
15+
16+
test('resolution', () => {
17+
const env = new InHypercubeRLEnvironment()
18+
const agent = new MCAgent(env, 6)
19+
20+
expect(agent._table.resolution).toBe(6)
21+
})
22+
})
623

724
test('update', () => {
825
const env = new GridRLEnvironment()
@@ -42,3 +59,11 @@ test('get_score', () => {
4259
expect(score[0]).toHaveLength(10)
4360
expect(score[0][0]).toHaveLength(4)
4461
})
62+
63+
test('get_action default', () => {
64+
const env = new GridRLEnvironment()
65+
const agent = new MCAgent(env, env.size[0])
66+
67+
const action = agent.get_action(env.state())
68+
expect(action).toHaveLength(1)
69+
})

0 commit comments

Comments
 (0)