Skip to content

Commit 47c96de

Browse files
example of ml function (#67)
1 parent 473da51 commit 47c96de

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed

examples/std/ml_func/main.go

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"errors"
7+
"fmt"
8+
"github.com/timeplus-io/proton-go-driver/v2"
9+
"log"
10+
"math/rand"
11+
"time"
12+
)
13+
14+
func trainModel(conn *sql.DB, ctx context.Context) (err error) {
15+
_, err = conn.ExecContext(ctx, "DROP stream IF EXISTS train_data;")
16+
if err != nil {
17+
return
18+
}
19+
_, err = conn.ExecContext(ctx, "CREATE STREAM train_data(param1 int, param2 int, target int) engine=Memory;")
20+
if err != nil {
21+
return
22+
}
23+
scope, err := conn.Begin()
24+
if err != nil {
25+
return
26+
}
27+
batch, err := scope.PrepareContext(ctx, "INSERT INTO train_data(param1, param2, target) VALUES")
28+
if err != nil {
29+
return
30+
}
31+
for i := int32(0); i < 100000; i++ {
32+
x := rand.Int31n(10000)
33+
y := rand.Int31n(10000)
34+
z := x + y
35+
_, err = batch.ExecContext(ctx, x, y, z)
36+
if err != nil {
37+
return
38+
}
39+
}
40+
err = scope.Commit()
41+
if err != nil {
42+
return
43+
}
44+
_, err = conn.ExecContext(ctx, "DROP STREAM IF EXISTS your_model;")
45+
if err != nil {
46+
return
47+
}
48+
_, err = conn.ExecContext(ctx, "CREATE STREAM your_model ENGINE = Memory AS SELECT stochastic_linear_regression_state(0.01, 1.0, 10, 'Adam', target, param1, param2) AS state FROM train_data;")
49+
if err != nil {
50+
return
51+
}
52+
return nil
53+
}
54+
55+
func insertTestData(conn *sql.DB, ctx context.Context) {
56+
time.Sleep(time.Second)
57+
scope, err := conn.Begin()
58+
if err != nil {
59+
log.Fatal(err)
60+
}
61+
batch, err := scope.PrepareContext(ctx, "INSERT INTO test_data(param1, param2) VALUES")
62+
if err != nil {
63+
log.Fatal(err)
64+
}
65+
for i := 0; i < 100; i++ {
66+
x := rand.Int31n(10000)
67+
y := rand.Int31n(10000)
68+
_, err = batch.Exec(x, y)
69+
if err != nil {
70+
log.Fatal(err)
71+
}
72+
}
73+
err = scope.Commit()
74+
if err != nil {
75+
log.Fatal(err)
76+
}
77+
}
78+
79+
func testModel(conn *sql.DB, ctx context.Context) (err error) {
80+
_, err = conn.ExecContext(ctx, "DROP STREAM IF EXISTS test_data;")
81+
if err != nil {
82+
return
83+
}
84+
_, err = conn.ExecContext(ctx, "CREATE stream test_data(param1 int, param2 int);")
85+
if err != nil {
86+
return
87+
}
88+
go insertTestData(conn, ctx)
89+
ctxCancel, cancel := context.WithTimeout(ctx, time.Second*10)
90+
defer cancel()
91+
rows, err := conn.QueryContext(ctxCancel, "WITH (SELECT state FROM your_model) AS model SELECT param1, param2, eval_ml_method(model, param1, param2) FROM test_data;")
92+
for rows.Next() {
93+
var param1, param2 int32
94+
var target float64
95+
err = rows.Scan(&param1, &param2, &target)
96+
if err != nil {
97+
return
98+
}
99+
fmt.Printf("(%d, %d) -> %f\n", param1, param2, target)
100+
}
101+
rows.Close()
102+
return rows.Err()
103+
}
104+
105+
func example() error {
106+
conn, err := sql.Open("proton", "proton://127.0.0.1:8463?dial_timeout=1s&compress=true")
107+
if err != nil {
108+
return err
109+
}
110+
conn.SetMaxIdleConns(5)
111+
conn.SetMaxOpenConns(10)
112+
conn.SetConnMaxLifetime(time.Hour)
113+
ctx := proton.Context(context.Background(), proton.WithSettings(proton.Settings{
114+
"max_block_size": 10,
115+
}))
116+
err = trainModel(conn, ctx)
117+
if err != nil {
118+
return err
119+
}
120+
err = testModel(conn, ctx)
121+
if err != nil && !errors.Is(err, context.DeadlineExceeded) {
122+
return err
123+
}
124+
return nil
125+
}
126+
127+
func main() {
128+
err := example()
129+
if err != nil {
130+
log.Fatal(err)
131+
}
132+
}

0 commit comments

Comments
 (0)