Skip to content

Commit 296cc15

Browse files
authored
Improve RNN layers (#722)
* Improve RNN layers * Fix onnx models path
1 parent 7f062fe commit 296cc15

File tree

8 files changed

+871
-418
lines changed

8 files changed

+871
-418
lines changed

lib/model/nns/layer/gru.js

Lines changed: 206 additions & 93 deletions
Large diffs are not rendered by default.

lib/model/nns/layer/index.js

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ import Tensor from '../../../util/tensor.js'
164164
* { type: 'global_max_pool', channel_dim?: number } |
165165
* { type: 'greater' } |
166166
* { type: 'greater_or_equal' } |
167-
* { type: 'gru', size: number, return_sequences?: boolean, w_z?: number[][] | Matrix, w_r?: number[][] | Matrix, w_h?: number[][] | Matrix, u_z?: number[][] | Matrix, u_r?: number[][] | Matrix, u_h?: number[][] | Matrix, b_z?: number[][] | Matrix, b_r?: number[][] | Matrix, b_h?: number[][] | Matrix } |
167+
* { type: 'gru', size: number, return_sequences?: boolean, w_z?: number[][] | Matrix | string, w_r?: number[][] | Matrix | string, w_h?: number[][] | Matrix | string, u_z?: number[][] | Matrix | string, u_r?: number[][] | Matrix | string, u_h?: number[][] | Matrix | string, b_z?: number[][] | Matrix | string, b_r?: number[][] | Matrix | string, b_h?: number[][] | Matrix | string, sequence_dim?: number } |
168168
* { type: 'hard_elish' } |
169169
* { type: 'hard_shrink', l?: number } |
170170
* { type: 'hard_sigmoid', alpha?: number, beta?: number } |
@@ -190,7 +190,7 @@ import Tensor from '../../../util/tensor.js'
190190
* { type: 'logsigmoid' } |
191191
* { type: 'lp_pool', p?: number, kernel: number | number[], stride?: number | number[], padding?: number | number[], channel_dim?: number } |
192192
* { type: 'lrn', alpha?: number, beta?: number, k?: number, n: number, channel_dim?: number } |
193-
* { type: 'lstm', size: number, return_sequences?: boolean, w_z?: number[][] | Matrix, w_in?: number[][] | Matrix, w_for?: number[][] | Matrix, w_out?: number[][] | Matrix, r_z?: number[][] | Matrix, r_in?: number[][] | Matrix, r_for?: number[][] | Matrix, r_out?: number[][] | Matrix, p_in?: number[][] | Matrix, p_for?: number[][] | Matrix, p_out?: number[][] | Matrix, b_z?: number[][] | Matrix, b_in?: number[][] | Matrix, b_for?: number[][] | Matrix, b_out?: number[][] | Matrix } |
193+
* { type: 'lstm', size: number, return_sequences?: boolean, w_z?: number[][] | Matrix | string, w_in?: number[][] | Matrix | string, w_for?: number[][] | Matrix | string, w_out?: number[][] | Matrix | string, r_z?: number[][] | Matrix | string, r_in?: number[][] | Matrix | string, r_for?: number[][] | Matrix | string, r_out?: number[][] | Matrix | string, p_in?: number[][] | Matrix | string, p_for?: number[][] | Matrix | string, p_out?: number[][] | Matrix | string, b_z?: number[][] | Matrix | string, b_in?: number[][] | Matrix | string, b_for?: number[][] | Matrix | string, b_out?: number[][] | Matrix | string, sequence_dim?: number } |
194194
* { type: 'matmul' } |
195195
* { type: 'max' } |
196196
* { type: 'max_pool', kernel: number | number[], stride?: number | number[], padding?: number | number[], channel_dim?: number } |
@@ -230,7 +230,7 @@ import Tensor from '../../../util/tensor.js'
230230
* { type: 'reu' } |
231231
* { type: 'reverse', axis?: number } |
232232
* { type: 'right_bitshift' } |
233-
* { type: 'rnn', size: number, out_size?: number, activation?: string | object, recurrent_activation?: string | object, return_sequences?: boolean, w_xh?: number[][] | Matrix, w_hh?: number[][] | Matrix, w_hy?: number[][] | Matrix, b_xh?: number[][] | Matrix, b_hh?: number[][] | Matrix, b_hy?: number[][] | Matrix } |
233+
* { type: 'rnn', size: number, activation?: string | object, return_sequences?: boolean, w_x?: number[][] | Matrix | string, w_h?: number[][] | Matrix | string, b_x?: number[][] | Matrix | string, b_h?: number[][] | Matrix | string, sequence_dim?: number } |
234234
* { type: 'rootsig' } |
235235
* { type: 'round' } |
236236
* { type: 'rrelu', l?: number, u?: number } |

0 commit comments

Comments
 (0)