Skip to content

Commit a904f81

Browse files
committed
tensor:matmul 按range实现
1 parent 6fc5ef8 commit a904f81

File tree

4 files changed

+233
-300
lines changed

4 files changed

+233
-300
lines changed

dl/tensor.go

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ func (t *Tensor[T]) SumShape(dims []int) []int {
148148

149149
// Print 打印Tensor的值
150150
func (t *Tensor[T]) Print(format_ ...string) {
151+
format := "%.2f"
152+
if len(format_) > 0 {
153+
format = format_[0]
154+
}
151155
fmt.Print("shape:[")
152156
for i := 0; i < len(t.Shape); i++ {
153157
fmt.Print(t.Shape[i])
@@ -162,35 +166,35 @@ func (t *Tensor[T]) Print(format_ ...string) {
162166
if i > 0 {
163167
fmt.Print(" ")
164168
}
165-
fmt.Printf(format_[0], t.Data[i])
169+
fmt.Printf(format, t.Data[i])
166170
}
167171
fmt.Println("]")
168172
} else if len(t.Shape) == 2 {
169-
fmt.Print("[")
173+
fmt.Println("[")
170174
for i := 0; i < t.Shape[0]; i++ {
171175
fmt.Print(" [")
172176
for j := 0; j < t.Shape[1]; j++ {
173177
if j > 0 {
174178
fmt.Print(" ")
175179
}
176-
fmt.Printf(format_[0], t.Data[i*t.Shape[1]+j])
180+
fmt.Printf(format, t.Data[i*t.Shape[1]+j])
177181
}
178182

179183
fmt.Print("]")
180184
if i < t.Shape[0]-1 {
181185
fmt.Print(",")
182186
}
187+
fmt.Println()
183188
}
184189
fmt.Println("]")
185190
} else {
186191
t.Range(len(t.Shape)-2, func(indices []int) {
187192
start := t.LinearAt(indices)
188-
189193
fmt.Print("[", fmt.Sprint(indices), "]=")
190194
m := NewTensor[T](t.Shape[len(t.Shape)-2:])
191195
end := start + m.Len()
192196
m.Data = t.Data[start:end]
193-
m.Print(format_[0])
197+
m.Print(format)
194198
})
195199
}
196200
}
@@ -255,6 +259,7 @@ func (t *Tensor[T]) Sum(dims []int) *Tensor[T] {
255259
for i, j := 0, 0; i < len(t.Shape); i++ {
256260
if sumMap[i] == 0 {
257261
outputIndices[j] = indices[i]
262+
j++
258263
}
259264
}
260265

@@ -265,3 +270,36 @@ func (t *Tensor[T]) Sum(dims []int) *Tensor[T] {
265270
})
266271
return result
267272
}
273+
func (a *Tensor[T]) MatMulShape(b *Tensor[T]) (c []int) {
274+
if len(a.Shape) < 2 || len(b.Shape) < 2 {
275+
panic("TensorCPU dimensions do not match for multiplication")
276+
}
277+
if a.Shape[len(a.Shape)-1] != b.Shape[len(b.Shape)-2] {
278+
panic("TensorCPU dimensions do not match for multiplication")
279+
}
280+
resultShape := make([]int, len(a.Shape))
281+
copy(resultShape, a.Shape)
282+
resultShape[len(resultShape)-1] = b.Shape[len(b.Shape)-1]
283+
return resultShape
284+
}
285+
286+
// MatMul 实现高维矩阵 Tensor 的矩阵乘法
287+
// 矩阵的最后两维满足:A矩阵的列数B矩阵的行数相等
288+
func (a *Tensor[T]) MatMul(b *Tensor[T]) (c *Tensor[T]) {
289+
c = NewTensor[T](a.MatMulShape(b))
290+
c.Range(len(c.Shape)-2, func(indices []int) {
291+
aIdx := a.LinearAt(indices)
292+
bIdx := b.LinearAt(indices)
293+
cIdx := c.LinearAt(indices)
294+
295+
m, k, n := c.Shape[len(c.Shape)-2], a.Shape[len(a.Shape)-1], c.Shape[len(c.Shape)-1]
296+
for i := 0; i < m; i++ {
297+
for j := 0; j < n; j++ {
298+
for x := 0; x < k; x++ {
299+
c.Data[cIdx+i*n+j] += a.Data[aIdx+i*k+x] * b.Data[bIdx+x*n+j]
300+
}
301+
}
302+
}
303+
})
304+
return c
305+
}

dl/tensor_matmul.go

Lines changed: 0 additions & 108 deletions
This file was deleted.

dl/tensor_matmul_test.go.a

Lines changed: 0 additions & 184 deletions
This file was deleted.

0 commit comments

Comments
 (0)