Skip to content

Commit 2de7e87

Browse files
committed
Improve Matrix class and add tests
1 parent 9f7ec61 commit 2de7e87

File tree

3 files changed

+285
-19
lines changed

3 files changed

+285
-19
lines changed

lib/util/math.js

Lines changed: 110 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,8 @@ export class Matrix {
784784
myu = new Matrix(1, cols, myu)
785785
} else if (myu.rows === cols || myu.cols === 1) {
786786
myu = myu.t
787-
} else if (myu.cols !== cols || myu.rows !== 1) {
787+
}
788+
if (myu.cols !== cols || myu.rows !== 1) {
788789
throw new MatrixException("'myu' cols must be same as 'cols' and rows must be 1.")
789790
}
790791
if (!(sigma instanceof Matrix)) {
@@ -3002,20 +3003,28 @@ export class Matrix {
30023003
* @returns {Matrix}
30033004
*/
30043005
tridiag() {
3006+
return this.tridiagHouseholder()
3007+
}
3008+
3009+
/**
3010+
* Returns a tridiagonal matrix.
3011+
* @returns {Matrix}
3012+
*/
3013+
tridiagHouseholder() {
30053014
if (!this.isSymmetric()) {
30063015
throw new MatrixException('Tridiagonal only define symmetric matrix.', this)
30073016
}
3008-
let a = this.copy()
3009-
let n = this.cols
3017+
const a = this.copy()
3018+
const n = this.cols
30103019
for (let i = 0; i < n - 2; i++) {
3011-
let v = a.slice(i + 1, i, n, i + 1)
3012-
let alpha = v.norm() * (v._value[0] < 0 ? 1 : -1)
3020+
const v = a.slice(i + 1, i, n, i + 1)
3021+
const alpha = v.norm() * (v._value[0] < 0 ? 1 : -1)
30133022
v._value[0] -= alpha
30143023
v.div(v.norm())
30153024

3016-
let new_a = a.slice(i + 1, i + 1)
3017-
let d = new_a.dot(v)
3018-
let g = v.copyMult(v.tDot(d))
3025+
const new_a = a.slice(i + 1, i + 1)
3026+
const d = new_a.dot(v)
3027+
const g = v.copyMult(v.tDot(d))
30193028
g.isub(d)
30203029
g.mult(2)
30213030

@@ -3031,6 +3040,85 @@ export class Matrix {
30313040
return a
30323041
}
30333042

3043+
/**
3044+
* Returns a tridiagonal matrix.
3045+
* @param {number} k
3046+
* @returns {Matrix}
3047+
*/
3048+
tridiagLanczos(k = 0) {
3049+
if (!this.isSymmetric()) {
3050+
throw new MatrixException('Tridiagonal only define symmetric matrix.', this)
3051+
}
3052+
const n = this.cols
3053+
if (k <= 0) {
3054+
k = n
3055+
}
3056+
let s = 0
3057+
let q0 = Matrix.zeros(n, 1)
3058+
let q1 = Matrix.randn(n, 1)
3059+
q1.div(q1.norm())
3060+
3061+
const a = Matrix.zeros(k, k)
3062+
for (let i = 0; i < k; i++) {
3063+
const v = this.dot(q1)
3064+
const t = q1.tDot(v).value[0]
3065+
v.sub(q0.copyMult(s))
3066+
v.sub(q1.copyMult(t))
3067+
s = v.norm()
3068+
q0 = q1
3069+
v.div(s)
3070+
q1 = v
3071+
3072+
a.set(i, i, t)
3073+
if (i < k - 1) {
3074+
a.set(i, i + 1, s)
3075+
a.set(i + 1, i, s)
3076+
}
3077+
}
3078+
return a
3079+
}
3080+
3081+
/**
3082+
* Returns a hessenberg matrix.
3083+
* @returns {Matrix}
3084+
*/
3085+
hessenberg() {
3086+
return this.hessenbergArnoldi()
3087+
}
3088+
3089+
/**
3090+
* Returns a hessenberg matrix.
3091+
* @param {number} k
3092+
* @returns {Matrix}
3093+
*/
3094+
hessenbergArnoldi(k = 0) {
3095+
if (!this.isSquare()) {
3096+
throw new MatrixException('Hessenberg only define square matrix.', this)
3097+
}
3098+
const n = this.cols
3099+
if (k <= 0) {
3100+
k = n
3101+
}
3102+
const h = Matrix.zeros(k, k)
3103+
let q = [Matrix.randn(n, 1)]
3104+
q[0].div(q[0].norm())
3105+
for (let j = 0; j < k; j++) {
3106+
const v = this.dot(q[j])
3107+
for (let i = 0; i <= j; i++) {
3108+
const hij = q[i].tDot(v).value[0]
3109+
v.sub(q[i].copyMult(hij))
3110+
h.set(i, j, hij)
3111+
}
3112+
const hi1j = v.norm()
3113+
v.div(hi1j)
3114+
q[j + 1] = v
3115+
if (j < k - 1) {
3116+
h.set(j + 1, j, hi1j)
3117+
}
3118+
}
3119+
return h
3120+
}
3121+
30343122
/**
30353123
* Returns a LU decomposition.
30363124
* @returns {[Matrix, Matrix]}
@@ -3282,6 +3370,14 @@ export class Matrix {
32823370
return [l, d]
32833371
}
32843372

3373+
/**
3374+
* Returns schur decomposition.
3375+
* @returns {Matrix}
3376+
*/
3377+
schur() {
3378+
throw new MatrixException('Not implemented.')
3379+
}
3380+
32853381
/**
32863382
* Returns eigenvalues and eigenvectors. Eigenvectors correspond to each column of the matrix.
32873383
* @returns {[number[], Matrix]}
@@ -3434,8 +3530,12 @@ export class Matrix {
34343530

34353531
let a = this.copy()
34363532
const ev = []
3437-
if (this.rows > 10 && this.isSymmetric()) {
3438-
a = a.tridiag()
3533+
if (this.rows > 2) {
3534+
if (this.isSymmetric()) {
3535+
a = a.tridiag()
3536+
} else {
3537+
a = a.hessenberg()
3538+
}
34393539
}
34403540
const tol = 1.0e-8
34413541
for (let n = a.rows; n > 2; n--) {

tests/lib/model/q_learning.test.js

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,22 @@ test('default', () => {
55
const env = new GridRLEnvironment()
66
const agent = new QAgent(env, env.size[0])
77

8-
let totalReward = 0
9-
const n = 1000
8+
let totalReward = 0
9+
const n = 1000
1010
for (let i = 0; i < n; i++) {
1111
let curState = env.reset()
1212
while (true) {
1313
const action = agent.get_action(env, curState, 0.01)
14-
const {state, reward, done} = env.step(action)
14+
const { state, reward, done } = env.step(action)
1515
agent.update(action, curState, state, reward)
16-
totalReward += reward
16+
totalReward += reward
1717
curState = state
1818
if (done) {
1919
break
2020
}
2121
}
2222
}
23-
expect(totalReward / n).toBeGreaterThan(-60)
23+
expect(totalReward / n).toBeGreaterThan(-60)
2424
const score = agent.get_score(env)
2525
expect(score).toHaveLength(20)
2626
expect(score[0]).toHaveLength(10)

tests/lib/util/math.test.js

Lines changed: 170 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,21 @@ describe('Matrix', () => {
741741
}
742742
})
743743

744+
test('array mean', () => {
745+
const mat = Matrix.randn(100000, 2, Matrix.fromArray([3, 5]), 1.5)
746+
const [mean, vari] = calcMV(mat)
747+
for (let j = 0; j < 2; j++) {
748+
expect(mean[j]).toBeCloseTo(j * 2 + 3, 1)
749+
for (let k = 0; k < 2; k++) {
750+
if (j === k) {
751+
expect(vari[j][k]).toBeCloseTo(1.5, 1)
752+
} else {
753+
expect(vari[j][k]).toBeCloseTo(0, 1)
754+
}
755+
}
756+
}
757+
})
758+
744759
test('array vari', () => {
745760
const cov = [
746761
[0.3, 0.1],
@@ -755,6 +770,23 @@ describe('Matrix', () => {
755770
}
756771
}
757772
})
773+
774+
test.each([[3, 5, 7], Matrix.randn(2, 2)])('fail invalid mean %p', m => {
775+
expect(() => Matrix.randn(100000, 2, m, 1)).toThrowError(
776+
"'myu' cols must be same as 'cols' and rows must be 1."
777+
)
778+
})
779+
780+
test.each([
781+
[
782+
[1, 2],
783+
[3, 4],
784+
[5, 6],
785+
],
786+
Matrix.randn(2, 3),
787+
])('fail invalid mean %p', s => {
788+
expect(() => Matrix.randn(100000, 2, 0, s)).toThrowError("'sigma' cols and rows must be same as 'cols'.")
789+
})
758790
})
759791

760792
test('diag', () => {
@@ -3614,10 +3646,10 @@ describe('Matrix', () => {
36143646
}
36153647
}
36163648

3617-
const orgeig = mat.eigenJacobi()[0]
3649+
const orgeig = mat.eigenValues()
36183650
for (let i = 0; i < n; i++) {
3619-
const ev = tridiag.eigenInverseIteration(orgeig[i])[0]
3620-
expect(ev).toBeCloseTo(orgeig[i])
3651+
const s = tridiag.copySub(Matrix.eye(n, n, orgeig[i]))
3652+
expect(s.det()).toBeCloseTo(0)
36213653
}
36223654
})
36233655

@@ -3631,6 +3663,105 @@ describe('Matrix', () => {
36313663
})
36323664
})
36333665

3666+
describe('tridiagHouseholder', () => {
3667+
test('symmetric', () => {
3668+
const n = 10
3669+
const mat = Matrix.randn(n, n, 0, 0.1).gram()
3670+
const tridiag = mat.tridiagHouseholder()
3671+
for (let i = 0; i < n; i++) {
3672+
for (let j = 0; j < n; j++) {
3673+
if (Math.abs(i - j) > 1) {
3674+
expect(tridiag.at(i, j)).toBeCloseTo(0)
3675+
} else if (Math.abs(i - j) === 1) {
3676+
expect(tridiag.at(i, j)).toBeCloseTo(tridiag.at(j, i))
3677+
}
3678+
}
3679+
}
3680+
3681+
const orgeig = mat.eigenValues()
3682+
for (let i = 0; i < n; i++) {
3683+
const s = tridiag.copySub(Matrix.eye(n, n, orgeig[i]))
3684+
expect(s.det()).toBeCloseTo(0)
3685+
}
3686+
})
3687+
3688+
test.each([
3689+
[3, 3],
3690+
[2, 3],
3691+
[3, 2],
3692+
])('fail(%i, %i)', (r, c) => {
3693+
const mat = Matrix.randn(r, c)
3694+
expect(() => mat.tridiagHouseholder()).toThrowError('Tridiagonal only define symmetric matrix.')
3695+
})
3696+
})
3697+
3698+
describe('tridiagLanczos', () => {
3699+
test('symmetric', () => {
3700+
const n = 10
3701+
const mat = Matrix.randn(n, n, 0, 0.1).gram()
3702+
const tridiag = mat.tridiagLanczos()
3703+
for (let i = 0; i < n; i++) {
3704+
for (let j = 0; j < n; j++) {
3705+
if (Math.abs(i - j) > 1) {
3706+
expect(tridiag.at(i, j)).toBeCloseTo(0)
3707+
} else if (Math.abs(i - j) === 1) {
3708+
expect(tridiag.at(i, j)).toBeCloseTo(tridiag.at(j, i))
3709+
}
3710+
}
3711+
}
3712+
3713+
const orgeig = mat.eigenValues()
3714+
for (let i = 0; i < n; i++) {
3715+
const s = tridiag.copySub(Matrix.eye(n, n, orgeig[i]))
3716+
expect(s.det()).toBeCloseTo(0)
3717+
}
3718+
})
3719+
3720+
test.todo('k')
3721+
3722+
test.each([
3723+
[3, 3],
3724+
[2, 3],
3725+
[3, 2],
3726+
])('fail(%i, %i)', (r, c) => {
3727+
const mat = Matrix.randn(r, c)
3728+
expect(() => mat.tridiagLanczos()).toThrowError('Tridiagonal only define symmetric matrix.')
3729+
})
3730+
})
3731+
3732+
describe('hessenberg', () => {
3733+
test('symmetric', () => {
3734+
const n = 10
3735+
const mat = Matrix.randn(n, n).gram()
3736+
const hessenberg = mat.hessenberg()
3737+
for (let i = 0; i < n; i++) {
3738+
for (let j = 0; j < n; j++) {
3739+
if (i - j > 1) {
3740+
expect(hessenberg.at(i, j)).toBeCloseTo(0)
3741+
}
3742+
}
3743+
}
3744+
3745+
const orgeig = mat.eigenValues()
3746+
for (let i = 0; i < n; i++) {
3747+
const s = hessenberg.copySub(Matrix.eye(n, n, orgeig[i]))
3748+
expect(s.det()).toBeCloseTo(0)
3749+
}
3750+
})
3751+
3752+
test.todo('not symmetric')
3753+
3754+
test.each([
3755+
[2, 3],
3756+
[3, 2],
3757+
])('fail(%i, %i)', (r, c) => {
3758+
const mat = Matrix.randn(r, c)
3759+
expect(() => mat.hessenberg()).toThrowError('Hessenberg only define square matrix.')
3760+
})
3761+
})
3762+
3763+
test.todo('hessenbergArnoldi')
3764+
36343765
describe('lu', () => {
36353766
test.each([0, 1, 2, 3, 5])('success %i', n => {
36363767
const mat = Matrix.randn(n, n)
@@ -4168,5 +4299,40 @@ describe('Matrix', () => {
41684299
})
41694300
})
41704301

4171-
test.todo('eigenInverseIteration')
4302+
describe('eigenInverseIteration', () => {
4303+
test.each([1, 2, 5])('symmetric %i', n => {
4304+
const mat = Matrix.randn(n, n).gram()
4305+
const ev = mat.eigenValues()
4306+
for (let i = 0; i < n; i++) {
4307+
const e = ev[i] - (ev[i] - (ev[i + 1] || ev[i] - 1)) / 4
4308+
const [eigvalue, eigvector] = mat.eigenInverseIteration(e)
4309+
4310+
expect(eigvalue).toBeCloseTo(ev[i])
4311+
4312+
const cmat = mat.copy()
4313+
for (let k = 0; k < n; k++) {
4314+
cmat.subAt(k, k, eigvalue)
4315+
}
4316+
expect(cmat.det()).toBeCloseTo(0)
4317+
4318+
const x = mat.dot(eigvector)
4319+
const y = eigvector.copyMult(eigvalue)
4320+
for (let k = 0; k < n; k++) {
4321+
expect(x.at(k, 0)).toBeCloseTo(y.at(k, 0))
4322+
}
4323+
const eye = eigvector.tDot(eigvector)
4324+
expect(eye.at(0, 0)).toBeCloseTo(1)
4325+
}
4326+
})
4327+
4328+
test.todo('non symmetric')
4329+
4330+
test.each([
4331+
[2, 3],
4332+
[3, 2],
4333+
])('fail(%i, %i)', (r, c) => {
4334+
const mat = Matrix.randn(r, c)
4335+
expect(() => mat.eigenInverseIteration()).toThrowError('Eigen vectors only define square matrix.')
4336+
})
4337+
})
41724338
})

0 commit comments

Comments
 (0)