Skip to content

Improve test #961

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/model/a2c.js
Original file line number Diff line number Diff line change
Expand Up @@ -231,15 +231,15 @@ export default class A2CAgent {
for (let i = 0; i < this._advanced_step; i++) {
for (let k = 0; k < this._procs; k++) {
const action = this._net.get_action(this._states[k])
const info = this._envs[i].step(action)
const info = this._envs[k].step(action)
;(actions[k] ||= []).push(action)
;(states[k] ||= []).push(this._states[k])
;(next_states[k] ||= []).push(info.state)
;(rewards[k] ||= []).push(info.reward)
;(dones[k] ||= []).push(info.done)

if (info.done) {
this._states[k] = this._envs[i].reset()
this._states[k] = this._envs[k].reset()
} else {
this._states[k] = info.state
}
Expand Down
2 changes: 1 addition & 1 deletion lib/model/dqn.js
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class DQN {
q[i][a_idx] = data[i][3] + this._gamma * next_t_q[i][argmax(next_q[i])]
}
const loss = this._net.fit(x, q, 1, learning_rate, batch)
if (this._epoch % this._fix_param_update_step) {
if (this._epoch % this._fix_param_update_step === 0) {
this._target = this._net.copy()
}
return loss[0]
Expand Down
1 change: 1 addition & 0 deletions lib/model/nns/graph.js
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ export default class ComputationalGraph {
}
this._order = []
while (s.length > 0) {
s.sort((a, b) => b - a)
const n = s.pop()
this._order.push(n)
for (const i of outputList[n]) {
Expand Down
2 changes: 1 addition & 1 deletion onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ function makeOnnxFiles () {
fi
export PYENV_ROOT="${WORK_DIR}/.pyenv"
PYENV="${PYENV_ROOT}/bin/pyenv"
PYENV_PYTHON_VERSION=miniconda3-latest
PYENV_PYTHON_VERSION=3.12.9
if [ ! -d "${PYENV_ROOT}" ]; then
git clone https://github.com/pyenv/pyenv.git "${PYENV_ROOT}"
pushd "${PYENV_ROOT}"
Expand Down
25 changes: 25 additions & 0 deletions tests/lib/model/a2c.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ jest.retryTimes(20)
import A2CAgent from '../../../lib/model/a2c.js'
import CartPoleRLEnvironment from '../../../lib/rl/cartpole.js'
import InHypercubeRLEnvironment from '../../../lib/rl/inhypercube.js'
import PendulumRLEnvironment from '../../../lib/rl/pendulum.js'
import ReversiRLEnvironment from '../../../lib/rl/reversi.js'

test('update', () => {
const env = new InHypercubeRLEnvironment(2)
Expand Down Expand Up @@ -31,6 +33,27 @@ test('update', () => {
}
}
expect(totalReward.slice(Math.max(0, totalReward.length - 10)).every(v => v > 0)).toBeTruthy()
agent.terminate()
})

test('realrange action', () => {
const env = new PendulumRLEnvironment()
const agent = new A2CAgent(env, 10, 10, [{ type: 'full', out_size: 10, activation: 'tanh' }], 'adam')
agent.update(true, 0.01, 10)

let curState = env.reset()
const best_action = agent.get_action(curState, 0)
expect(best_action).toHaveLength(1)
})

test('array state action', () => {
const env = new ReversiRLEnvironment()
const agent = new A2CAgent(env, 10, 7, [{ type: 'full', out_size: 10, activation: 'tanh' }], 'adam')
agent.update(true, 0.01, 10)

let curState = env.reset()
const best_action = agent.get_action(curState, 0)
expect(best_action).toHaveLength(1)
})

test('get_score', () => {
Expand All @@ -43,4 +66,6 @@ test('get_score', () => {
expect(score[0][0]).toHaveLength(20)
expect(score[0][0][0]).toHaveLength(20)
expect(score[0][0][0][0]).toHaveLength(2)

agent.get_score()
})
55 changes: 39 additions & 16 deletions tests/lib/model/c2p.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,45 @@ import C2P from '../../../lib/model/c2p.js'

import { randIndex } from '../../../lib/evaluate/clustering.js'

test('clustering', () => {
const model = new C2P(10, 50)
const n = 20
const x = Matrix.concat(
Matrix.concat(Matrix.randn(n, 2, 0, 0.1), Matrix.randn(n, 2, 5, 0.1)),
Matrix.randn(n, 2, [0, 5], 0.1)
).toArray()
describe('clustering', () => {
test('default', () => {
const model = new C2P(10, 50)
const n = 20
const x = Matrix.concat(
Matrix.concat(Matrix.randn(n, 2, 0, 0.1), Matrix.randn(n, 2, 5, 0.1)),
Matrix.randn(n, 2, [0, 5], 0.1)
).toArray()

model.fit(x)
const y = model.predict(3)
expect(y).toHaveLength(x.length)
model.fit(x)
const y = model.predict(3)
expect(y).toHaveLength(x.length)

const t = []
for (let i = 0; i < x.length; i++) {
t[i] = Math.floor(i / n)
}
const ri = randIndex(y, t)
expect(ri).toBeGreaterThan(0.9)
const t = []
for (let i = 0; i < x.length; i++) {
t[i] = Math.floor(i / n)
}
const ri = randIndex(y, t)
expect(ri).toBeGreaterThan(0.9)
})

test('no cutoff', () => {
const model = new C2P(10, 50)
model._cutoff_scale = 0
const n = 20
const x = Matrix.concat(
Matrix.concat(Matrix.randn(n, 2, 0, 0.1), Matrix.randn(n, 2, 5, 0.1)),
Matrix.randn(n, 2, [0, 5], 0.1)
).toArray()

model.fit(x)
const y = model.predict(3)
expect(y).toHaveLength(x.length)

const t = []
for (let i = 0; i < x.length; i++) {
t[i] = Math.floor(i / n)
}
const ri = randIndex(y, t)
expect(ri).toBeGreaterThan(0.9)
})
})
93 changes: 87 additions & 6 deletions tests/lib/model/dqn.test.js
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import { jest } from '@jest/globals'
import { expect, jest } from '@jest/globals'
jest.retryTimes(5)

import DQNAgent from '../../../lib/model/dqn.js'
import ReversiRLEnvironment from '../../../lib/rl/reversi.js'
import CartPoleRLEnvironment from '../../../lib/rl/cartpole.js'
import InHypercubeRLEnvironment from '../../../lib/rl/inhypercube.js'
import PendulumRLEnvironment from '../../../lib/rl/pendulum.js'

test('update dqn', () => {
const env = new InHypercubeRLEnvironment(2)
Expand All @@ -29,6 +31,7 @@ test('update dqn', () => {
}
}
expect(totalReward.slice(Math.max(0, totalReward.length - 10)).every(v => v > 0)).toBeTruthy()
agent.terminate()
})

test('update ddqn', () => {
Expand Down Expand Up @@ -56,16 +59,94 @@ test('update ddqn', () => {
}
}
expect(totalReward.slice(Math.max(0, totalReward.length - 10)).every(v => v > 0)).toBeTruthy()
agent.terminate()
})

test('realrange action', () => {
const env = new PendulumRLEnvironment()
const agent = new DQNAgent(env, 10, [{ type: 'full', out_size: 3, activation: 'tanh' }], 'adam')
agent._net._batch_size = 1
agent._net._fix_param_update_step = 1
agent._net._do_update_step = 1

let curState = env.reset()
const action = agent.get_action(curState, 0.9)
const { state, reward, done } = env.step(action)
agent.update(action, curState, state, reward, done, 0.001, 10)

const best_action = agent.get_action(state, 0)
expect(best_action).toHaveLength(1)
})

test('array state action', () => {
const env = new ReversiRLEnvironment()
const agent = new DQNAgent(env, 20, [{ type: 'full', out_size: 10, activation: 'tanh' }], 'adam')

agent._net._batch_size = 1
agent._net._fix_param_update_step = 1
agent._net._do_update_step = 1

let curState = env.reset()
const action = agent.get_action(curState, 0.9)
const { state, reward, done } = env.step(action)
agent.update(action, curState, state, reward, done, 0.001, 10)

const best_action = agent.get_action(state, 0)
expect(best_action).toHaveLength(1)
})

test('max memory size', () => {
const env = new InHypercubeRLEnvironment(2)
const agent = new DQNAgent(env, 10, [{ type: 'full', out_size: 3, activation: 'tanh' }], 'adam')
agent.method = 'DDQN'
agent._net._batch_size = 1
agent._net._max_memory_size = 10

let curState = env.reset()
const action = agent.get_action(curState, 0.9)
const { state, reward, done } = env.step(action)
for (let i = 0; i < 20; i++) {
agent.update(action, curState, state, reward, done, 0.001, 10)
expect(agent._net._memory.length).toBeLessThanOrEqual(10)
}
})

test('reset to dqn', () => {
const env = new InHypercubeRLEnvironment(2)
const agent = new DQNAgent(env, 10, [{ type: 'full', out_size: 3, activation: 'tanh' }], 'adam')
agent.method = 'DDQN'
agent._net._batch_size = 1
agent._net._fix_param_update_step = 1
agent._net._do_update_step = 1

let curState = env.reset()
const action = agent.get_action(curState, 0.9)
const { state, reward, done } = env.step(action)
agent.update(action, curState, state, reward, done, 0.001, 10)

expect(agent._net._target).toBeDefined()
agent.method = 'DQN'
expect(agent._net._target).toBeNull()
})

test('get_score', () => {
const env = new CartPoleRLEnvironment()
const agent = new DQNAgent(env, 20, [{ type: 'full', out_size: 10, activation: 'tanh' }], 'adam')
const agent = new DQNAgent(env, 12, [{ type: 'full', out_size: 10, activation: 'tanh' }], 'adam')

const score = agent.get_score()
expect(score).toHaveLength(20)
expect(score[0]).toHaveLength(20)
expect(score[0][0]).toHaveLength(20)
expect(score[0][0][0]).toHaveLength(20)
expect(score).toHaveLength(12)
expect(score[0]).toHaveLength(12)
expect(score[0][0]).toHaveLength(12)
expect(score[0][0][0]).toHaveLength(12)
expect(score[0][0][0][0]).toHaveLength(2)

agent.get_score()
})

test('get_action default', () => {
const env = new InHypercubeRLEnvironment(2)
const agent = new DQNAgent(env, 10, [{ type: 'full', out_size: 3, activation: 'tanh' }], 'adam')

const action = agent.get_action(env.state())
expect(action).toHaveLength(1)
})
19 changes: 19 additions & 0 deletions tests/lib/model/genetic_algorithm.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,28 @@ jest.retryTimes(3)
import GeneticAlgorithmGeneration from '../../../lib/model/genetic_algorithm.js'
import CartPoleRLEnvironment from '../../../lib/rl/cartpole.js'

describe('constructor', () => {
test('default', () => {
const env = new CartPoleRLEnvironment()
const agent = new GeneticAlgorithmGeneration(env)

expect(agent._resolution).toBe(20)
expect(agent._model._models).toHaveLength(100)
})

test('resolution', () => {
const env = new CartPoleRLEnvironment()
const agent = new GeneticAlgorithmGeneration(env, 6, 8)

expect(agent._resolution).toBe(8)
expect(agent._model._models).toHaveLength(6)
})
})

test('next', () => {
const env = new CartPoleRLEnvironment()
const agent = new GeneticAlgorithmGeneration(env, 100, 10)
agent.next()
for (let i = 0; i < 100; i++) {
agent.run()
agent.next(0.1)
Expand Down
46 changes: 46 additions & 0 deletions tests/lib/model/gtm.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,52 @@ describe('clustering', () => {
expect(y[0]).toBeGreaterThan(y[2])
expect(y[1]).toBeGreaterThan(y[2])
})

test('init random', () => {
const model = new GTM(2, 1)
model._init_method = 'random'
const n = 50
const x = Matrix.concat(
Matrix.concat(Matrix.randn(n, 2, 0, 0.1), Matrix.randn(n, 2, 5, 0.1)),
Matrix.randn(n, 2, [0, 5], 0.1)
).toArray()

for (let i = 0; i < 100; i++) {
model.fit(x)
}
const y = model.predictIndex(x)
expect(y).toHaveLength(x.length)

const t = []
for (let i = 0; i < x.length; i++) {
t[i] = Math.floor(i / n)
}
const ri = randIndex(y, t)
expect(ri).toBeGreaterThan(0.8)
})

test('mode fit', () => {
const model = new GTM(2, 1)
model._fit_method = 'mode'
const n = 50
const x = Matrix.concat(
Matrix.concat(Matrix.randn(n, 2, 0, 0.1), Matrix.randn(n, 2, 5, 0.1)),
Matrix.randn(n, 2, [0, 5], 0.1)
).toArray()

for (let i = 0; i < 100; i++) {
model.fit(x)
}
const y = model.predict(x)
expect(y).toHaveLength(x.length)

const t = []
for (let i = 0; i < x.length; i++) {
t[i] = Math.floor(i / n)
}
const ri = randIndex(y, t)
expect(ri).toBeGreaterThan(0.8)
})
})

test('dimensionality reduction', () => {
Expand Down
25 changes: 25 additions & 0 deletions tests/lib/model/monte_carlo.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,23 @@ jest.retryTimes(3)

import MCAgent from '../../../lib/model/monte_carlo.js'
import GridRLEnvironment from '../../../lib/rl/grid.js'
import InHypercubeRLEnvironment from '../../../lib/rl/inhypercube.js'

describe('constructor', () => {
test('default', () => {
const env = new InHypercubeRLEnvironment()
const agent = new MCAgent(env)

expect(agent._table.resolution).toBe(20)
})

test('resolution', () => {
const env = new InHypercubeRLEnvironment()
const agent = new MCAgent(env, 6)

expect(agent._table.resolution).toBe(6)
})
})

test('update', () => {
const env = new GridRLEnvironment()
Expand Down Expand Up @@ -42,3 +59,11 @@ test('get_score', () => {
expect(score[0]).toHaveLength(10)
expect(score[0][0]).toHaveLength(4)
})

test('get_action default', () => {
const env = new GridRLEnvironment()
const agent = new MCAgent(env, env.size[0])

const action = agent.get_action(env.state())
expect(action).toHaveLength(1)
})
Loading