Skip to content

Commit a1618ed

Browse files
authored
Add balancing method in matrix class (#724)
1 parent 296cc15 commit a1618ed

File tree

2 files changed

+131
-1
lines changed

2 files changed

+131
-1
lines changed

lib/util/matrix.js

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3420,6 +3420,60 @@ export default class Matrix {
34203420
return h
34213421
}
34223422

3423+
/**
3424+
* Returns a doubly stochastic matrix.
3425+
*
3426+
* @returns {[number[], Matrix, number[]]} Doubly stochastic matrix
3427+
*/
3428+
balancing() {
3429+
if (!this.isSquare()) {
3430+
throw new MatrixException('Doubly stochastic matrix only defined for square matrix.', this)
3431+
}
3432+
if (this._value.some(v => v <= 0)) {
3433+
throw new MatrixException('Doubly stochastic matrix only calculate for non negative matrix.', this)
3434+
}
3435+
if (this.rows === 1) {
3436+
return [[this._value[0]], new Matrix(1, 1, 1), [1]]
3437+
}
3438+
return this.balancingSinkhornKnopp()
3439+
}
3440+
3441+
/**
3442+
* Returns a doubly stochastic matrix by Sinkhorn-Knopp algorithm.
3443+
*
3444+
* @returns {[number[], Matrix, number[]]} Doubly stochastic matrix
3445+
*/
3446+
balancingSinkhornKnopp() {
3447+
if (!this.isSquare()) {
3448+
throw new MatrixException('Doubly stochastic matrix only defined for square matrix.', this)
3449+
}
3450+
if (this._value.some(v => v <= 0)) {
3451+
throw new MatrixException('Doubly stochastic matrix only calculate for non negative matrix.', this)
3452+
}
3453+
const n = this.rows
3454+
const d1 = Array(n).fill(1)
3455+
const d2 = Array(n).fill(1)
3456+
const a = this.copy()
3457+
let maxCount = 1.0e4
3458+
while (maxCount-- > 0) {
3459+
const s1 = a.sum(1)
3460+
a.div(s1)
3461+
for (let i = 0; i < n; i++) {
3462+
d1[i] *= s1.at(i, 0)
3463+
}
3464+
const s2 = a.sum(0)
3465+
a.div(s2)
3466+
for (let i = 0; i < n; i++) {
3467+
d2[i] *= s2.at(0, i)
3468+
}
3469+
const e = s1.reduce((s, v) => s + Math.abs(v - 1), 0) + s2.reduce((s, v) => s + Math.abs(v - 1), 0)
3470+
if (e < 1.0e-8) {
3471+
break
3472+
}
3473+
}
3474+
return [d1, a, d2]
3475+
}
3476+
34233477
/**
34243478
* Returns a LU decomposition.
34253479
*

tests/lib/util/matrix.test.js

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5405,6 +5405,82 @@ describe('Matrix', () => {
54055405
})
54065406
})
54075407

5408+
describe('balancing', () => {
5409+
test.each([1, 2, 3, 4])('square %d', n => {
5410+
const mat = Matrix.random(n, n, 0.1, 10)
5411+
const [d1, a, d2] = mat.balancing()
5412+
5413+
expect(a.sizes).toEqual([n, n])
5414+
expect(d1).toHaveLength(n)
5415+
expect(d2).toHaveLength(n)
5416+
const s0 = a.sum(0).value
5417+
const s1 = a.sum(1).value
5418+
for (let i = 0; i < n; i++) {
5419+
expect(s0[i]).toBeCloseTo(1)
5420+
expect(s1[i]).toBeCloseTo(1)
5421+
}
5422+
5423+
const dad = Matrix.diag(d1).dot(a).dot(Matrix.diag(d2))
5424+
for (let i = 0; i < n; i++) {
5425+
for (let j = 0; j < n; j++) {
5426+
expect(dad.at(i, j)).toBeCloseTo(mat.at(i, j))
5427+
}
5428+
}
5429+
})
5430+
5431+
test('fail neg value', () => {
5432+
const mat = Matrix.randn(3, 3)
5433+
mat.set(0, 0, -0.1)
5434+
expect(() => mat.balancing()).toThrow('Doubly stochastic matrix only calculate for non negative matrix.')
5435+
})
5436+
5437+
test.each([
5438+
[2, 3],
5439+
[3, 2],
5440+
])('fail(%i, %i)', (r, c) => {
5441+
const mat = Matrix.randn(r, c)
5442+
expect(() => mat.balancing()).toThrow('Doubly stochastic matrix only defined for square matrix.')
5443+
})
5444+
})
5445+
5446+
describe('balancingSinkhornKnopp', () => {
5447+
test.each([1, 2, 3, 4])('square %d', n => {
5448+
const mat = Matrix.random(n, n, 0.1, 10)
5449+
const [d1, a, d2] = mat.balancingSinkhornKnopp()
5450+
5451+
expect(a.sizes).toEqual([n, n])
5452+
expect(d1).toHaveLength(n)
5453+
expect(d2).toHaveLength(n)
5454+
const s0 = a.sum(0).value
5455+
const s1 = a.sum(1).value
5456+
for (let i = 0; i < n; i++) {
5457+
expect(s0[i]).toBeCloseTo(1)
5458+
expect(s1[i]).toBeCloseTo(1)
5459+
}
5460+
5461+
const dad = Matrix.diag(d1).dot(a).dot(Matrix.diag(d2))
5462+
for (let i = 0; i < n; i++) {
5463+
for (let j = 0; j < n; j++) {
5464+
expect(dad.at(i, j)).toBeCloseTo(mat.at(i, j))
5465+
}
5466+
}
5467+
})
5468+
5469+
test('fail neg value', () => {
5470+
const mat = Matrix.randn(3, 3)
5471+
mat.set(0, 0, -0.1)
5472+
expect(() => mat.balancingSinkhornKnopp()).toThrow('Doubly stochastic matrix only calculate for non negative matrix.')
5473+
})
5474+
5475+
test.each([
5476+
[2, 3],
5477+
[3, 2],
5478+
])('fail(%i, %i)', (r, c) => {
5479+
const mat = Matrix.randn(r, c)
5480+
expect(() => mat.balancingSinkhornKnopp()).toThrow('Doubly stochastic matrix only defined for square matrix.')
5481+
})
5482+
})
5483+
54085484
describe('lu', () => {
54095485
test.each([0, 1, 2, 3, 5])('success %i', n => {
54105486
const mat = Matrix.randn(n, n)
@@ -6577,7 +6653,7 @@ describe('Matrix', () => {
65776653
expect(() => mat.eigenValuesQR()).toThrow('Eigen values only define square matrix.')
65786654
})
65796655

6580-
test('iteration not converged', () => {
6656+
test.only('iteration not converged', () => {
65816657
const mat = new Matrix(3, 3, [
65826658
[-0.3, -0.4, 1.7],
65836659
[-0.2, -1.8, -0.8],

0 commit comments

Comments
 (0)