Skip to content

Commit f283db5

Browse files
Add support for getting db session form context
1 parent cfe6c55 commit f283db5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+587
-12
lines changed

docs/reference/config.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ The `gen` mapping supports the following keys:
167167
that returns all valid enum values.
168168
- `emit_sql_as_comment`:
169169
- If true, emits the SQL statement as a code-block comment above the generated function, appending to any existing comments. Defaults to `false`.
170+
- `get_db_from_context`:
171+
- If true, emits `New` method for `Querier` with a function argument which accepts a ctx argument and returns DBTX. Defaults to `false`.
170172
- `build_tags`:
171173
- If set, add a `//go:build <build_tags>` directive at the beginning of each generated Go file.
172174
- `initialisms`:
@@ -414,6 +416,7 @@ packages:
414416
emit_pointers_for_null_types: false
415417
emit_enum_valid_method: false
416418
emit_all_enum_values: false
419+
get_db_from_context: false
417420
build_tags: "some_tag"
418421
json_tags_case_style: "camel"
419422
omit_unused_structs: false
@@ -469,6 +472,8 @@ Each mapping in the `packages` collection has the following keys:
469472
- `emit_all_enum_values`:
470473
- If true, emit a function per enum type
471474
that returns all valid enum values.
475+
- `get_db_from_context`:
476+
- If true, emits `New` method for `Querier` with a function argument which accepts a ctx argument and returns DBTX. Defaults to `false`.
472477
- `build_tags`:
473478
- If set, add a `//go:build <build_tags>` directive at the beginning of each generated Go file.
474479
- `json_tags_case_style`:

internal/codegen/golang/gen.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type tmplCtx struct {
3737
EmitMethodsWithDBArgument bool
3838
EmitEnumValidMethod bool
3939
EmitAllEnumValues bool
40+
GetDBFromContext bool
4041
UsesCopyFrom bool
4142
UsesBatch bool
4243
OmitSqlcVersion bool
@@ -65,6 +66,9 @@ func (t *tmplCtx) codegenQueryMethod(q Query) string {
6566
if t.EmitMethodsWithDBArgument {
6667
db = "db"
6768
}
69+
if t.GetDBFromContext {
70+
db = "q.getDBFromContext(ctx)"
71+
}
6872

6973
switch q.Cmd {
7074
case ":one":
@@ -177,6 +181,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
177181
EmitMethodsWithDBArgument: options.EmitMethodsWithDbArgument,
178182
EmitEnumValidMethod: options.EmitEnumValidMethod,
179183
EmitAllEnumValues: options.EmitAllEnumValues,
184+
GetDBFromContext: options.GetDBFromContext,
180185
UsesCopyFrom: usesCopyFrom(queries),
181186
UsesBatch: usesBatch(queries),
182187
SQLDriver: parseDriver(options.SqlPackage),

internal/codegen/golang/opts/options.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type Options struct {
2525
EmitEnumValidMethod bool `json:"emit_enum_valid_method,omitempty" yaml:"emit_enum_valid_method"`
2626
EmitAllEnumValues bool `json:"emit_all_enum_values,omitempty" yaml:"emit_all_enum_values"`
2727
EmitSqlAsComment bool `json:"emit_sql_as_comment,omitempty" yaml:"emit_sql_as_comment"`
28+
GetDBFromContext bool `json:"get_db_from_context,omitempty" yaml:"get_db_from_context"`
2829
JsonTagsCaseStyle string `json:"json_tags_case_style,omitempty" yaml:"json_tags_case_style"`
2930
Package string `json:"package" yaml:"package"`
3031
Out string `json:"out" yaml:"out"`
@@ -147,6 +148,12 @@ func ValidateOpts(opts *Options) error {
147148
if opts.EmitMethodsWithDbArgument && opts.EmitPreparedQueries {
148149
return fmt.Errorf("invalid options: emit_methods_with_db_argument and emit_prepared_queries options are mutually exclusive")
149150
}
151+
if opts.GetDBFromContext && opts.EmitPreparedQueries {
152+
return fmt.Errorf("invalid options: get_db_from_context and emit_prepared_queries options are mutually exclusive")
153+
}
154+
if opts.GetDBFromContext && opts.EmitMethodsWithDbArgument {
155+
return fmt.Errorf("invalid options: get_db_from_context and emit_methods_with_db_argument options are mutually exclusive")
156+
}
150157
if *opts.QueryParameterLimit < 0 {
151158
return fmt.Errorf("invalid options: query parameter limit must not be negative")
152159
}

internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@ func (q *Queries) {{.MethodName}}(ctx context.Context{{if $.EmitMethodsWithDBArg
4040
go convertRowsFor{{.MethodName}}(pw, {{.Arg.Name}})
4141
// The string interpolation is necessary because LOAD DATA INFILE requires
4242
// the file name to be given as a literal string.
43-
result, err := {{if (not $.EmitMethodsWithDBArgument)}}q.{{end}}db.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' INTO TABLE {{.TableIdentifierForMySQL}} %s ({{range $index, $name := .Arg.ColumnNames}}{{if gt $index 0}}, {{end}}{{$name}}{{end}})", "Reader::" + rh, mysqltsv.Escaping))
43+
{{- $db := "q.db"}}
44+
{{- if $.EmitMethodsWithDBArgument}}
45+
{{- $db = "db"}}
46+
{{- else if $.GetDBFromContext}}
47+
{{- $db = "q.getDBFromContext(ctx)"}}
48+
{{- end}}
49+
result, err := {{$db}}.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' INTO TABLE {{.TableIdentifierForMySQL}} %s ({{range $index, $name := .Arg.ColumnNames}}{{if gt $index 0}}, {{end}}{{$name}}{{end}})", "Reader::" + rh, mysqltsv.Escaping))
4450
if err != nil {
4551
return 0, err
4652
}

internal/codegen/golang/templates/pgx/batchCode.tmpl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ if $.EmitMethodsWithDB
4646
}
4747
batch.Queue({{.ConstantName}}, vals...)
4848
}
49-
br := {{if not $.EmitMethodsWithDBArgument}}q.{{end}}db.SendBatch(ctx, batch)
49+
{{- $db := "q.db"}}
50+
{{- if $.EmitMethodsWithDBArgument}}
51+
{{- $db = "db"}}
52+
{{- else if $.GetDBFromContext}}
53+
{{- $db = "q.getDBFromContext(ctx)"}}
54+
{{- end}}
55+
br := {{$db}}.SendBatch(ctx, batch)
5056
return &{{.MethodName}}BatchResults{br,len({{.Arg.Name}}),false}
5157
}
5258

internal/codegen/golang/templates/pgx/copyfromCopy.tmpl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,12 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair
4242
return db.CopyFrom(ctx, {{.TableIdentifierAsGoSlice}}, {{.Arg.ColumnNamesAsGoSlice}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}})
4343
{{- else -}}
4444
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) (int64, error) {
45-
return q.db.CopyFrom(ctx, {{.TableIdentifierAsGoSlice}}, {{.Arg.ColumnNamesAsGoSlice}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}})
45+
46+
{{- $db := "db"}}
47+
{{- if $.GetDBFromContext}}
48+
{{- $db = "getDBFromContext(ctx)"}}
49+
{{- end}}
50+
return q.{{$db}}.CopyFrom(ctx, {{.TableIdentifierAsGoSlice}}, {{.Arg.ColumnNamesAsGoSlice}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}})
4651
{{- end}}
4752
}
4853

internal/codegen/golang/templates/pgx/dbCode.tmpl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,24 @@ type DBTX interface {
1515
{{ if .EmitMethodsWithDBArgument}}
1616
func New() *Queries {
1717
return &Queries{}
18+
{{- else if .GetDBFromContext}}
19+
func New(getDBFromContext func(context.Context) DBTX) *Queries {
20+
return &Queries{getDBFromContext: getDBFromContext}
1821
{{- else -}}
1922
func New(db DBTX) *Queries {
2023
return &Queries{db: db}
2124
{{- end}}
2225
}
2326

2427
type Queries struct {
25-
{{if not .EmitMethodsWithDBArgument}}
28+
{{- if .GetDBFromContext}}
29+
getDBFromContext func(context.Context) DBTX
30+
{{- else if not .EmitMethodsWithDBArgument}}
2631
db DBTX
2732
{{end}}
2833
}
2934

30-
{{if not .EmitMethodsWithDBArgument}}
35+
{{if and (not .EmitMethodsWithDBArgument) (not .GetDBFromContext)}}
3136
func (q *Queries) WithTx(tx pgx.Tx) *Queries {
3237
return &Queries{
3338
db: tx,

internal/codegen/golang/templates/pgx/queryCode.tmpl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}}
77
{{$.Q}}
88
{{end}}
99

10+
{{- $db := "db" }}
11+
{{- if $.GetDBFromContext}}
12+
{{- $db = "getDBFromContext(ctx)"}}
13+
{{- end}}
14+
1015
{{if ne (hasPrefix .Cmd ":batch") true}}
1116
{{if .Arg.EmitStruct}}
1217
type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}}
@@ -31,7 +36,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (
3136
row := db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}})
3237
{{- else -}}
3338
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) {
34-
row := q.db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}})
39+
row := q.{{$db}}.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}})
3540
{{- end}}
3641
{{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }}
3742
var {{.Ret.Name}} {{.Ret.Type}}
@@ -49,7 +54,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (
4954
rows, err := db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}})
5055
{{- else -}}
5156
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) {
52-
rows, err := q.db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}})
57+
rows, err := q.{{$db}}.Query(ctx, {{.ConstantName}}, {{.Arg.Params}})
5358
{{- end}}
5459
if err != nil {
5560
return nil, err
@@ -82,7 +87,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) e
8287
_, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
8388
{{- else -}}
8489
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error {
85-
_, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
90+
_, err := q.{{$db}}.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
8691
{{- end}}
8792
return err
8893
}
@@ -96,7 +101,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (
96101
result, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
97102
{{- else -}}
98103
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) {
99-
result, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
104+
result, err := q.{{$db}}.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
100105
{{- end}}
101106
if err != nil {
102107
return 0, err
@@ -113,7 +118,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (
113118
return db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
114119
{{- else -}}
115120
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error) {
116-
return q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
121+
return q.{{$db}}.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
117122
{{- end}}
118123
}
119124
{{end}}

internal/codegen/golang/templates/stdlib/dbCode.tmpl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ type DBTX interface {
99
{{ if .EmitMethodsWithDBArgument}}
1010
func New() *Queries {
1111
return &Queries{}
12+
{{- else if .GetDBFromContext}}
13+
func New(getDBFromContext func(context.Context) DBTX) *Queries {
14+
return &Queries{getDBFromContext: getDBFromContext}
1215
{{- else -}}
1316
func New(db DBTX) *Queries {
1417
return &Queries{db: db}
@@ -77,7 +80,9 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar
7780
{{end}}
7881

7982
type Queries struct {
80-
{{- if not .EmitMethodsWithDBArgument}}
83+
{{- if .GetDBFromContext}}
84+
getDBFromContext func(context.Context) DBTX
85+
{{- else if not .EmitMethodsWithDBArgument}}
8186
db DBTX
8287
{{- end}}
8388

@@ -89,7 +94,7 @@ type Queries struct {
8994
{{- end}}
9095
}
9196

92-
{{if not .EmitMethodsWithDBArgument}}
97+
{{if and (not .EmitMethodsWithDBArgument) (not .GetDBFromContext)}}
9398
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
9499
return &Queries{
95100
db: tx,

internal/config/v_one.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ type v1PackageSettings struct {
4242
EmitEnumValidMethod bool `json:"emit_enum_valid_method,omitempty" yaml:"emit_enum_valid_method"`
4343
EmitAllEnumValues bool `json:"emit_all_enum_values,omitempty" yaml:"emit_all_enum_values"`
4444
EmitSqlAsComment bool `json:"emit_sql_as_comment,omitempty" yaml:"emit_sql_as_comment"`
45+
GetDBFromContext bool `json:"get_db_from_context,omitempty" yaml:"get_db_from_context"`
4546
JSONTagsCaseStyle string `json:"json_tags_case_style,omitempty" yaml:"json_tags_case_style"`
4647
SQLPackage string `json:"sql_package" yaml:"sql_package"`
4748
SQLDriver string `json:"sql_driver" yaml:"sql_driver"`
@@ -152,6 +153,7 @@ func (c *V1GenerateSettings) Translate() Config {
152153
EmitEnumValidMethod: pkg.EmitEnumValidMethod,
153154
EmitAllEnumValues: pkg.EmitAllEnumValues,
154155
EmitSqlAsComment: pkg.EmitSqlAsComment,
156+
GetDBFromContext: pkg.GetDBFromContext,
155157
Package: pkg.Name,
156158
Out: pkg.Path,
157159
SqlPackage: pkg.SQLPackage,

internal/config/v_one.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@
134134
"emit_sql_as_comment": {
135135
"type": "boolean"
136136
},
137+
"get_db_from_context": {
138+
"type": "boolean"
139+
},
137140
"build_tags": {
138141
"type": "string"
139142
},

internal/config/v_two.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@
143143
"emit_sql_as_comment": {
144144
"type": "boolean"
145145
},
146+
"get_db_from_context": {
147+
"type": "boolean"
148+
},
146149
"build_tags": {
147150
"type": "string"
148151
},

internal/endtoend/testdata/get_db_from_context/mysql/go/db.go

Lines changed: 25 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/get_db_from_context/mysql/go/models.go

Lines changed: 16 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/get_db_from_context/mysql/go/query.sql.go

Lines changed: 42 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
/* name: GetAll :many */
2+
SELECT * FROM users;
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
CREATE TABLE users (
2+
id integer NOT NULL AUTO_INCREMENT PRIMARY KEY,
3+
first_name varchar(255) NOT NULL,
4+
last_name varchar(255),
5+
age integer NOT NULL
6+
);
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"version": "1",
3+
"packages": [
4+
{
5+
"name": "querytest",
6+
"path": "go",
7+
"schema": "schema.sql",
8+
"queries": "query.sql",
9+
"engine": "mysql",
10+
"get_db_from_context": true
11+
}
12+
]
13+
}

0 commit comments

Comments
 (0)