@@ -148,6 +148,10 @@ func (t *Tensor[T]) SumShape(dims []int) []int {
148
148
149
149
// Print 打印Tensor的值
150
150
func (t * Tensor [T ]) Print (format_ ... string ) {
151
+ format := "%.2f"
152
+ if len (format_ ) > 0 {
153
+ format = format_ [0 ]
154
+ }
151
155
fmt .Print ("shape:[" )
152
156
for i := 0 ; i < len (t .Shape ); i ++ {
153
157
fmt .Print (t .Shape [i ])
@@ -162,35 +166,35 @@ func (t *Tensor[T]) Print(format_ ...string) {
162
166
if i > 0 {
163
167
fmt .Print (" " )
164
168
}
165
- fmt .Printf (format_ [ 0 ] , t .Data [i ])
169
+ fmt .Printf (format , t .Data [i ])
166
170
}
167
171
fmt .Println ("]" )
168
172
} else if len (t .Shape ) == 2 {
169
- fmt .Print ("[" )
173
+ fmt .Println ("[" )
170
174
for i := 0 ; i < t .Shape [0 ]; i ++ {
171
175
fmt .Print (" [" )
172
176
for j := 0 ; j < t .Shape [1 ]; j ++ {
173
177
if j > 0 {
174
178
fmt .Print (" " )
175
179
}
176
- fmt .Printf (format_ [ 0 ] , t .Data [i * t .Shape [1 ]+ j ])
180
+ fmt .Printf (format , t .Data [i * t .Shape [1 ]+ j ])
177
181
}
178
182
179
183
fmt .Print ("]" )
180
184
if i < t .Shape [0 ]- 1 {
181
185
fmt .Print ("," )
182
186
}
187
+ fmt .Println ()
183
188
}
184
189
fmt .Println ("]" )
185
190
} else {
186
191
t .Range (len (t .Shape )- 2 , func (indices []int ) {
187
192
start := t .LinearAt (indices )
188
-
189
193
fmt .Print ("[" , fmt .Sprint (indices ), "]=" )
190
194
m := NewTensor [T ](t .Shape [len (t .Shape )- 2 :])
191
195
end := start + m .Len ()
192
196
m .Data = t .Data [start :end ]
193
- m .Print (format_ [ 0 ] )
197
+ m .Print (format )
194
198
})
195
199
}
196
200
}
@@ -255,6 +259,7 @@ func (t *Tensor[T]) Sum(dims []int) *Tensor[T] {
255
259
for i , j := 0 , 0 ; i < len (t .Shape ); i ++ {
256
260
if sumMap [i ] == 0 {
257
261
outputIndices [j ] = indices [i ]
262
+ j ++
258
263
}
259
264
}
260
265
@@ -265,3 +270,36 @@ func (t *Tensor[T]) Sum(dims []int) *Tensor[T] {
265
270
})
266
271
return result
267
272
}
273
+ func (a * Tensor [T ]) MatMulShape (b * Tensor [T ]) (c []int ) {
274
+ if len (a .Shape ) < 2 || len (b .Shape ) < 2 {
275
+ panic ("TensorCPU dimensions do not match for multiplication" )
276
+ }
277
+ if a .Shape [len (a .Shape )- 1 ] != b .Shape [len (b .Shape )- 2 ] {
278
+ panic ("TensorCPU dimensions do not match for multiplication" )
279
+ }
280
+ resultShape := make ([]int , len (a .Shape ))
281
+ copy (resultShape , a .Shape )
282
+ resultShape [len (resultShape )- 1 ] = b .Shape [len (b .Shape )- 1 ]
283
+ return resultShape
284
+ }
285
+
286
+ // MatMul 实现高维矩阵 Tensor 的矩阵乘法
287
+ // 矩阵的最后两维满足:A矩阵的列数B矩阵的行数相等
288
+ func (a * Tensor [T ]) MatMul (b * Tensor [T ]) (c * Tensor [T ]) {
289
+ c = NewTensor [T ](a .MatMulShape (b ))
290
+ c .Range (len (c .Shape )- 2 , func (indices []int ) {
291
+ aIdx := a .LinearAt (indices )
292
+ bIdx := b .LinearAt (indices )
293
+ cIdx := c .LinearAt (indices )
294
+
295
+ m , k , n := c .Shape [len (c .Shape )- 2 ], a .Shape [len (a .Shape )- 1 ], c .Shape [len (c .Shape )- 1 ]
296
+ for i := 0 ; i < m ; i ++ {
297
+ for j := 0 ; j < n ; j ++ {
298
+ for x := 0 ; x < k ; x ++ {
299
+ c .Data [cIdx + i * n + j ] += a .Data [aIdx + i * k + x ] * b .Data [bIdx + x * n + j ]
300
+ }
301
+ }
302
+ }
303
+ })
304
+ return c
305
+ }
0 commit comments