Skip to content

Commit 8a62b9b

Browse files
committed
tensor:基础算子完成
1 parent 6c7d951 commit 8a62b9b

10 files changed

+421
-418
lines changed

dl/array_math.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,10 @@ func Div[T Number](a, b []T) (c []T) {
4343
}
4444
return
4545
}
46+
47+
func Max[T Number](a, b T) T {
48+
if a > b {
49+
return a
50+
}
51+
return b
52+
}

dl/tensor_op_l2.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package dl
2+
3+
func (t *Tensor[T]) BroadcastShape(shape []int) []int {
4+
maxShape := Max(len(t.Shape), len(shape))
5+
result := make([]int, maxShape)
6+
for i := 1; i <= maxShape; i++ {
7+
var dim1 int
8+
if i <= len(t.Shape) {
9+
dim1 = t.Shape[len(t.Shape)-i]
10+
} else {
11+
dim1 = 1
12+
}
13+
var dim2 int
14+
if i <= len(shape) {
15+
dim2 = shape[len(shape)-i]
16+
} else {
17+
dim2 = 1
18+
}
19+
20+
if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
21+
return nil
22+
}
23+
result[maxShape-i] = Max(dim1, dim2)
24+
}
25+
return result
26+
}
27+
28+
type BroadcastCase int
29+
30+
const (
31+
XToX BroadcastCase = iota
32+
NullTo1
33+
XTo1
34+
)
35+
36+
func (t *Tensor[T]) BroadcastMap(broadcastShape []int) []BroadcastCase {
37+
broadcastMap := make([]BroadcastCase, len(broadcastShape))
38+
s := len(broadcastShape) - len(t.Shape)
39+
for i := 0; i < s; i++ {
40+
broadcastMap[i] = NullTo1
41+
}
42+
for i := s; i < len(broadcastShape); i++ {
43+
if t.Shape[i-s] == broadcastShape[i] {
44+
broadcastMap[i] = XToX
45+
} else if t.Shape[i-s] == 1 {
46+
broadcastMap[i] = XTo1
47+
} else {
48+
panic("Shapes are not broadcastable for operation")
49+
}
50+
}
51+
return broadcastMap
52+
}
53+
func FromBroadcastIndices(broadcastMap []BroadcastCase, broadcastIndices []int) (indices []int) {
54+
indices = make([]int, 0)
55+
for i, j := 0, 0; i < len(broadcastIndices); i++ {
56+
switch broadcastMap[i] {
57+
case XToX:
58+
indices = append(indices, broadcastIndices[i])
59+
j++
60+
case NullTo1:
61+
continue
62+
case XTo1:
63+
indices = append(indices, 0)
64+
j++
65+
}
66+
}
67+
return
68+
}
69+
func (t *Tensor[T]) Broadcast(broadcastShape []int) *Tensor[T] {
70+
broadcastMap := t.BroadcastMap(broadcastShape)
71+
result := NewTensor[T](broadcastShape)
72+
73+
result.Range(len(broadcastShape), func(indices []int) {
74+
oldIndices := FromBroadcastIndices(broadcastMap, indices)
75+
result.Set(indices, t.Get(oldIndices...))
76+
})
77+
return result
78+
}

dl/tensor_op_l2_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package dl
2+
3+
import "testing"
4+
5+
func TestBroadcast(t *testing.T) {
6+
a := NewTensor[float32]([]int{2, 3})
7+
a.Linear(1, float64(a.Len()))
8+
bShape := a.BroadcastShape([]int{3, 2, 3})
9+
b := a.Broadcast(bShape)
10+
b.Print("%0.f")
11+
}

dl/tensor_op_l3.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,64 @@ func (t *Tensor[T]) Sum(dims []int) *Tensor[T] {
101101
})
102102
return result
103103
}
104+
105+
func (t *Tensor[T]) OpTensorInPlace(other *Tensor[T], op func(a, b T) T) {
106+
if Equal(t.Shape, other.Shape) {
107+
t.Range(len(t.Shape), func(indices []int) {
108+
t.Set(indices, op(t.Get(indices...), other.Get(indices...)))
109+
})
110+
return
111+
}
112+
broadcastShape := t.BroadcastShape(other.Shape)
113+
if broadcastShape == nil {
114+
panic("shapes are not broadcastable for inplace operation,my shape:" + fmt.Sprint(t.Shape) + " other shape:" + fmt.Sprint(other.Shape))
115+
}
116+
if !Equal(t.Shape, broadcastShape) {
117+
panic("shapes are not broadcastable for inplace operation,my shape:" + fmt.Sprint(t.Shape) + " broadcastedShape:" + fmt.Sprint(broadcastShape))
118+
}
119+
otherMap := other.BroadcastMap(broadcastShape)
120+
t.Range(len(t.Shape), func(indices []int) {
121+
otherIndices := FromBroadcastIndices(otherMap, indices)
122+
t.Set(indices, op(t.Get(indices...), other.Get(otherIndices...)))
123+
})
124+
}
125+
126+
func (t *Tensor[T]) OpNumberInPlace(other T, op func(a, b T) T) {
127+
for i := 0; i < t.Len(); i++ {
128+
t.Data[i] = op(t.Data[i], other)
129+
}
130+
}
131+
132+
func (t *Tensor[T]) AddInPlace(other *Tensor[T]) *Tensor[T] {
133+
t.OpTensorInPlace(other, func(a, b T) T { return a + b })
134+
return t
135+
}
136+
func (t *Tensor[T]) AddNumberInPlace(other T) *Tensor[T] {
137+
t.OpNumberInPlace(other, func(a, b T) T { return a + b })
138+
return t
139+
}
140+
141+
func (t *Tensor[T]) SubInPlace(other *Tensor[T]) *Tensor[T] {
142+
t.OpTensorInPlace(other, func(a, b T) T { return a - b })
143+
return t
144+
}
145+
func (t *Tensor[T]) SubNumberInPlace(other T) *Tensor[T] {
146+
t.OpNumberInPlace(other, func(a, b T) T { return a - b })
147+
return t
148+
}
149+
func (t *Tensor[T]) MulInPlace(other *Tensor[T]) *Tensor[T] {
150+
t.OpTensorInPlace(other, func(a, b T) T { return a * b })
151+
return t
152+
}
153+
func (t *Tensor[T]) MulNumberInPlace(other T) *Tensor[T] {
154+
t.OpNumberInPlace(other, func(a, b T) T { return a * b })
155+
return t
156+
}
157+
func (t *Tensor[T]) DivInPlace(other *Tensor[T]) *Tensor[T] {
158+
t.OpTensorInPlace(other, func(a, b T) T { return a / b })
159+
return t
160+
}
161+
func (t *Tensor[T]) DivNumberInPlace(other T) *Tensor[T] {
162+
t.OpNumberInPlace(other, func(a, b T) T { return a / b })
163+
return t
164+
}

dl/tensor_op_l3_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package dl
2+
3+
import (
4+
"testing"
5+
6+
"git.array2d.com/ai/deepgo/py"
7+
)
8+
9+
func TestTensor_Transpose(t *testing.T) {
10+
a := NewTensor([]int{2, 3}, 1, 2, 3, 4, 5, 6)
11+
at := a.Transpose([]int{1, 0})
12+
at.Print("%d")
13+
14+
b := NewTensor[float32]([]int{4, 3, 2})
15+
b.Linear(0, float64(b.Len()))
16+
b.Print()
17+
bt := b.Transpose([]int{0, 2, 1})
18+
bt.Print("%0.f")
19+
}
20+
21+
func TestSum(t *testing.T) {
22+
testCases := []struct {
23+
Data []float32
24+
Shape []int
25+
}{
26+
{
27+
Data: []float32{-1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0},
28+
Shape: []int{3, 3},
29+
},
30+
{
31+
Data: []float32{1.0, 2.0, -3.0, -4.0, -5.0, 6.0, -7.0, 8.0, 9.0, 10.0},
32+
Shape: []int{2, 5},
33+
},
34+
{
35+
Data: []float32{
36+
64, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49,
37+
48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33,
38+
32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17,
39+
16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1,
40+
},
41+
Shape: []int{4, 4, 4},
42+
},
43+
}
44+
45+
for index, tc := range testCases {
46+
inputTensor := NewTensor(tc.Shape, tc.Data...)
47+
48+
// 使用 Go 实现计算 Sum
49+
axes := []int{0, len(tc.Shape) - 1}
50+
goResult := inputTensor.Sum(axes)
51+
52+
// 使用 Python 计算 Sum
53+
pyResult, pyShape, err := py.CalculateA_breturnC("tensor_op_A_b_return_C.py", "sum", tc.Data, tc.Shape, axes)
54+
if err != nil {
55+
t.Fatalf("计算Python Sum时出错: %v", err)
56+
}
57+
pyTensor := NewTensor(pyShape, pyResult...)
58+
// 比较结果
59+
if !TensorAlmostEqual(goResult, pyTensor, 1e-19) {
60+
t.Errorf("Sum结果不匹配。\nGo结果: %v\nPy结果: %v", goResult.Data, pyTensor.Data)
61+
t.Errorf("shape不匹配。\nGo结果: %v\nPy结果: %v", goResult.Shape, pyTensor.Shape)
62+
} else {
63+
t.Log("Sum结果与Python一致", index)
64+
}
65+
}
66+
}
67+
68+
func TestTensor_AddInPlace(t *testing.T) {
69+
a := NewTensor[float32]([]int{2, 3})
70+
a.Linear(1, float64(a.Len()))
71+
b := NewTensor[float32]([]int{3, 2, 3})
72+
b.Linear(1, float64(b.Len()))
73+
c := a.AddInPlace(b)
74+
c.Print("%0.f")
75+
}

0 commit comments

Comments
 (0)