Skip to content

Commit c7fb7b0

Browse files
authored
Improve some functions in util (#647)
1 parent a825c5a commit c7fb7b0

File tree

6 files changed

+221
-17
lines changed

6 files changed

+221
-17
lines changed

lib/util/complex.js

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,23 @@ export default class Complex {
134134
new Complex(Math.cos(th + (4 * Math.PI) / 3) * s, Math.sin(th + (4 * Math.PI) / 3) * s),
135135
]
136136
}
137+
138+
/**
139+
* Returns value of complex exponential function.
140+
*
141+
* @returns {Complex} Exponential value
142+
*/
143+
exp() {
144+
const a = Math.exp(this._real)
145+
return new Complex(a * Math.cos(this._imag), a * Math.sin(this._imag))
146+
}
147+
148+
/**
149+
* Returns value of complex log function.
150+
*
151+
* @returns {Complex} Principal log value
152+
*/
153+
log() {
154+
return new Complex(Math.log(this.abs()), Math.atan2(this._imag, this._real))
155+
}
137156
}

lib/util/graph.js

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,6 +1825,26 @@ export default class Graph {
18251825
return false
18261826
}
18271827

1828+
/**
1829+
* Returns if this is oriented graph or not.
1830+
*
1831+
* @returns {boolean} `true` if this is oriented graph
1832+
*/
1833+
isOriented() {
1834+
const n = this._nodes.length
1835+
const amat = Array.from({ length: n }, () => Array(n).fill(false))
1836+
for (const e of this._edges) {
1837+
if (!e.direct) {
1838+
return false
1839+
}
1840+
if (amat[e[0]][e[1]]) {
1841+
return false
1842+
}
1843+
amat[e[0]][e[1]] = amat[e[1]][e[0]] = true
1844+
}
1845+
return true
1846+
}
1847+
18281848
/**
18291849
* Returns if this is weighted graph or not.
18301850
*

lib/util/tensor.js

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -339,20 +339,30 @@ export default class Tensor {
339339
* @returns {Tensor} Selected tensor
340340
*/
341341
select(idx, axis = 0) {
342-
if (axis !== 0) {
343-
throw new MatrixException('Invalid axis. Only 0 is accepted.')
342+
if (axis < 0 || this.dimension <= axis) {
343+
throw new MatrixException('Invalid axis.')
344344
}
345345
if (!Array.isArray(idx)) {
346346
idx = [idx]
347347
}
348-
let s = 1
349-
for (let d = 1; d < this.dimension; d++) {
350-
s *= this._size[d]
348+
let step = 1
349+
let sublen = 1
350+
for (let d = 0; d < axis; d++) {
351+
step *= this._size[d]
352+
}
353+
for (let d = axis + 1; d < this.dimension; d++) {
354+
sublen *= this._size[d]
351355
}
352-
const t = new Tensor([idx.length, ...this._size.slice(1)])
356+
const newSizes = this._size.concat()
357+
newSizes[axis] = idx.length
358+
const t = new Tensor(newSizes)
353359
for (let i = 0; i < idx.length; i++) {
354-
for (let j = 0; j < s; j++) {
355-
t._value[i * s + j] = this._value[idx[i] * s + j]
360+
for (let k = 0; k < step; k++) {
361+
const toff1 = k * idx.length * sublen + i * sublen
362+
const toff2 = k * this._size[axis] * sublen + idx[i] * sublen
363+
for (let l = 0; l < sublen; l++) {
364+
t._value[toff1 + l] = this._value[toff2 + l]
365+
}
356366
}
357367
}
358368
return t
@@ -375,17 +385,14 @@ export default class Tensor {
375385
} else if (to < from) {
376386
throw new MatrixException('Invalid index.')
377387
}
378-
let s = 1
379-
for (let d = 0; d < this.dimension; d++) {
380-
if (d === axis) {
381-
continue
382-
}
383-
s *= this._size[d]
384-
}
385388
const newSizes = this._size.concat()
386389
newSizes[axis] = to - from
387390
const t = new Tensor(newSizes)
388391
if (axis === 0) {
392+
let s = 1
393+
for (let d = 1; d < this.dimension; d++) {
394+
s *= this._size[d]
395+
}
389396
t._value = this._value.slice(from * s, to * s)
390397
} else {
391398
for (let i = 0; i < t.length; i++) {

tests/lib/util/complex.test.js

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,4 +128,28 @@ describe('Complex', () => {
128128
expect(r.imaginary).toBeCloseTo(complex.imaginary)
129129
}
130130
})
131+
132+
test('exp', () => {
133+
const complex = new Complex(Math.random(), Math.random())
134+
const exp = complex.exp()
135+
expect(exp.real).toBeCloseTo(Math.exp(complex.real) * Math.cos(complex.imaginary))
136+
expect(exp.imaginary).toBeCloseTo(Math.exp(complex.real) * Math.sin(complex.imaginary))
137+
138+
const log = exp.log()
139+
expect(log.real).toBeCloseTo(complex.real)
140+
expect(log.imaginary).toBeCloseTo(complex.imaginary)
141+
})
142+
143+
test('log', () => {
144+
const complex = new Complex(Math.random(), Math.random())
145+
const log = complex.log()
146+
expect(log.real).toBeCloseTo(Math.log(Math.sqrt(complex.real ** 2 + complex.imaginary ** 2)))
147+
expect(log.imaginary).toBeLessThanOrEqual(Math.PI)
148+
expect(log.imaginary).toBeGreaterThan(-Math.PI)
149+
expect(log.imaginary).toBeCloseTo(Math.atan2(complex.imaginary, complex.real))
150+
151+
const exp = log.exp()
152+
expect(exp.real).toBeCloseTo(complex.real)
153+
expect(exp.imaginary).toBeCloseTo(complex.imaginary)
154+
})
131155
})

tests/lib/util/graph.test.js

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,6 +1672,34 @@ describe('graph', () => {
16721672
})
16731673
})
16741674

1675+
describe('isOriented', () => {
1676+
test('oriented', () => {
1677+
const graph = new Graph(3, [
1678+
{ 0: 0, 1: 1, direct: true },
1679+
{ 0: 1, 1: 2, direct: true },
1680+
])
1681+
expect(graph.isOriented()).toBeTruthy()
1682+
})
1683+
1684+
test('empty', () => {
1685+
const graph = new Graph(3)
1686+
expect(graph.isOriented()).toBeTruthy()
1687+
})
1688+
1689+
test('undirected', () => {
1690+
const graph = new Graph(3, [[0, 1]])
1691+
expect(graph.isOriented()).toBeFalsy()
1692+
})
1693+
1694+
test('directed', () => {
1695+
const graph = new Graph(3, [
1696+
{ 0: 0, 1: 1, direct: true },
1697+
{ 0: 1, 1: 0, direct: true },
1698+
])
1699+
expect(graph.isOriented()).toBeFalsy()
1700+
})
1701+
})
1702+
16751703
describe('isWeighted', () => {
16761704
test('weighted', () => {
16771705
const graph = new Graph(3, [{ 0: 0, 1: 1, value: 2 }])

tests/lib/util/tensor.test.js

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,59 @@ describe('Tensor', () => {
433433
})
434434
})
435435

436-
test.each([-1, 1])('axis %i', i => {
436+
describe('axis 1', () => {
437+
test.each([0, 1, 2])('scalar %p', k => {
438+
const ten = Tensor.randn([3, 4, 5])
439+
const slice = ten.select(k, 1)
440+
expect(slice.sizes).toEqual([3, 1, 5])
441+
for (let i = 0; i < 3; i++) {
442+
for (let j = 0; j < 5; j++) {
443+
expect(slice.at(i, 0, j)).toBe(ten.at(i, k, j))
444+
}
445+
}
446+
})
447+
448+
test.each([[[0]], [[1]], [[2]], [[0, 0]], [[1, 2]], [[2, 0]]])('array %p', k => {
449+
const ten = Tensor.randn([3, 4, 5])
450+
const slice = ten.select(k, 1)
451+
expect(slice.sizes).toEqual([3, k.length, 5])
452+
for (let t = 0; t < k.length; t++) {
453+
for (let i = 0; i < 3; i++) {
454+
for (let j = 0; j < 5; j++) {
455+
expect(slice.at(i, t, j)).toBe(ten.at(i, k[t], j))
456+
}
457+
}
458+
}
459+
})
460+
})
461+
462+
describe('axis 2', () => {
463+
test.each([0, 1, 2])('scalar %p', k => {
464+
const ten = Tensor.randn([3, 4, 5])
465+
const slice = ten.select(k, 2)
466+
expect(slice.sizes).toEqual([3, 4, 1])
467+
for (let i = 0; i < 3; i++) {
468+
for (let j = 0; j < 4; j++) {
469+
expect(slice.at(i, j, 0)).toBe(ten.at(i, j, k))
470+
}
471+
}
472+
})
473+
474+
test.each([[[0]], [[1]], [[2]], [[0, 0]], [[1, 2]], [[2, 0]]])('array %p', k => {
475+
const ten = Tensor.randn([3, 4, 5])
476+
const slice = ten.select(k, 2)
477+
expect(slice.sizes).toEqual([3, 4, k.length])
478+
for (let t = 0; t < k.length; t++) {
479+
for (let i = 0; i < 3; i++) {
480+
for (let j = 0; j < 4; j++) {
481+
expect(slice.at(i, j, t)).toBe(ten.at(i, j, k[t]))
482+
}
483+
}
484+
}
485+
})
486+
})
487+
488+
test.each([-1, 3])('axis %i', i => {
437489
const ten = new Tensor([2, 3, 4])
438490
expect(() => ten.select(0, i)).toThrow('Invalid axis.')
439491
})
@@ -618,7 +670,61 @@ describe('Tensor', () => {
618670
}
619671
})
620672

621-
test.each([-1, 1])('fail invalid axis %p', axis => {
673+
test('axis 1', () => {
674+
const org = Tensor.randn([3, 4, 5])
675+
const ten = org.copy()
676+
ten.shuffle(1)
677+
678+
const expidx = []
679+
for (let t = 0; t < org.sizes[1]; t++) {
680+
for (let i = 0; i < org.sizes[1]; i++) {
681+
let flg = true
682+
for (let j = 0; j < org.sizes[0]; j++) {
683+
for (let k = 0; k < org.sizes[2]; k++) {
684+
flg &= ten.at(j, t, k) === org.at(j, i, k)
685+
}
686+
}
687+
if (flg) {
688+
expidx.push(i)
689+
break
690+
}
691+
}
692+
}
693+
expidx.sort((a, b) => a - b)
694+
expect(expidx).toHaveLength(org.sizes[1])
695+
for (let i = 0; i < org.sizes[1]; i++) {
696+
expect(expidx[i]).toBe(i)
697+
}
698+
})
699+
700+
test('axis 2', () => {
701+
const org = Tensor.randn([3, 4, 5])
702+
const ten = org.copy()
703+
ten.shuffle(2)
704+
705+
const expidx = []
706+
for (let t = 0; t < org.sizes[2]; t++) {
707+
for (let i = 0; i < org.sizes[2]; i++) {
708+
let flg = true
709+
for (let j = 0; j < org.sizes[0]; j++) {
710+
for (let k = 0; k < org.sizes[1]; k++) {
711+
flg &= ten.at(j, k, t) === org.at(j, k, i)
712+
}
713+
}
714+
if (flg) {
715+
expidx.push(i)
716+
break
717+
}
718+
}
719+
}
720+
expidx.sort((a, b) => a - b)
721+
expect(expidx).toHaveLength(org.sizes[2])
722+
for (let i = 0; i < org.sizes[2]; i++) {
723+
expect(expidx[i]).toBe(i)
724+
}
725+
})
726+
727+
test.each([-1, 3])('fail invalid axis %p', axis => {
622728
const mat = Tensor.randn([2, 3, 4])
623729
expect(() => mat.shuffle(axis)).toThrow('Invalid axis.')
624730
})

0 commit comments

Comments
 (0)