Skip to content

Commit f662cca

Browse files
author
vinhha
committed
Implement optional func to init Store
1 parent 1a480f2 commit f662cca

File tree

6 files changed

+285
-178
lines changed

6 files changed

+285
-178
lines changed

go.mod

+5
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,27 @@ module github.com/go-oauth2/mysql/v4
33
go 1.17
44

55
require (
6+
github.com/DATA-DOG/go-sqlmock v1.5.0
67
github.com/go-oauth2/oauth2/v4 v4.1.1
78
github.com/go-sql-driver/mysql v1.5.0
89
github.com/json-iterator/go v1.1.10
910
github.com/smartystreets/goconvey v1.6.4
11+
github.com/stretchr/testify v1.6.1
1012
gopkg.in/gorp.v2 v2.2.0
1113
)
1214

1315
require (
16+
github.com/davecgh/go-spew v1.1.1 // indirect
1417
github.com/go-gorp/gorp v2.2.0+incompatible // indirect
1518
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 // indirect
1619
github.com/jtolds/gls v4.20.0+incompatible // indirect
1720
github.com/lib/pq v1.10.4 // indirect
1821
github.com/mattn/go-sqlite3 v1.14.12 // indirect
1922
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
2023
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 // indirect
24+
github.com/pmezard/go-difflib v1.0.0 // indirect
2125
github.com/poy/onpar v1.1.2 // indirect
2226
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d // indirect
2327
github.com/ziutek/mymysql v1.5.4 // indirect
28+
gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 // indirect
2429
)

go.sum

+5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
22
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
33
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
4+
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
5+
github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
46
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
57
github.com/a8m/expect v1.0.0/go.mod h1:4IwSCMumY49ScypDnjNbYEjgVeqy1/U2cEs3Lat96eA=
68
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
@@ -86,8 +88,10 @@ github.com/klauspost/compress v1.10.4/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYs
8688
github.com/klauspost/compress v1.10.10/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
8789
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
8890
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
91+
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
8992
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
9093
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
94+
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
9195
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
9296
github.com/lib/pq v1.10.4 h1:SO9z7FRPzA03QhHKJrH5BXA6HU1rS4V2nIVrrNC1iYk=
9397
github.com/lib/pq v1.10.4/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
@@ -262,6 +266,7 @@ google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2
262266
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
263267
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
264268
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
269+
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
265270
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
266271
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
267272
gopkg.in/gorp.v2 v2.2.0 h1:rTlFZHz1gP1GZplUFomSJgnkDRU4rMxOoHTiV6aDSBk=

mysql.go

+23-177
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
11
package mysql
22

33
import (
4-
"context"
54
"database/sql"
6-
"fmt"
7-
"io"
85
"os"
96
"time"
107

11-
"github.com/go-oauth2/oauth2/v4"
12-
"github.com/go-oauth2/oauth2/v4/models"
13-
"github.com/json-iterator/go"
148
"gopkg.in/gorp.v2"
159
)
1610

@@ -69,20 +63,34 @@ func NewStore(config *Config, tableName string, gcInterval int) *Store {
6963
// tableName table name (default oauth2_token),
7064
// GC time interval (in seconds, default 600)
7165
func NewStoreWithDB(db *sql.DB, tableName string, gcInterval int) *Store {
66+
// Init store with options
67+
store := NewStoreWithOpts(db,
68+
WithSQLDialect(gorp.MySQLDialect{Encoding: "UTF8", Engine: "MyISAM"}),
69+
WithTableName(tableName),
70+
WithGCTimeInterval(gcInterval),
71+
)
72+
73+
go store.gc()
74+
return store
75+
}
76+
77+
// NewStoreWithOpts create mysql store instance with apply custom input,
78+
// db sql.DB,
79+
// tableName table name (default oauth2_token),
80+
// GC time interval (in seconds, default 600)
81+
func NewStoreWithOpts(db *sql.DB, opts ...Option) *Store {
82+
// Init store with default value
7283
store := &Store{
7384
db: &gorp.DbMap{Db: db, Dialect: gorp.MySQLDialect{Encoding: "UTF8", Engine: "MyISAM"}},
7485
tableName: "oauth2_token",
7586
stdout: os.Stderr,
76-
}
77-
if tableName != "" {
78-
store.tableName = tableName
87+
ticker: time.NewTicker(time.Second * time.Duration(600)),
7988
}
8089

81-
interval := 600
82-
if gcInterval > 0 {
83-
interval = gcInterval
90+
// Apply with optional function
91+
for _, opt := range opts {
92+
opt.apply(store)
8493
}
85-
store.ticker = time.NewTicker(time.Second * time.Duration(interval))
8694

8795
table := store.db.AddTableWithName(StoreItem{}, store.tableName)
8896
table.AddIndex("idx_code", "Btree", []string{"code"})
@@ -94,171 +102,9 @@ func NewStoreWithDB(db *sql.DB, tableName string, gcInterval int) *Store {
94102
if err != nil {
95103
panic(err)
96104
}
97-
store.db.CreateIndex()
105+
106+
_ = store.db.CreateIndex()
98107

99108
go store.gc()
100109
return store
101110
}
102-
103-
// Store mysql token store
104-
type Store struct {
105-
tableName string
106-
db *gorp.DbMap
107-
stdout io.Writer
108-
ticker *time.Ticker
109-
}
110-
111-
// SetStdout set error output
112-
func (s *Store) SetStdout(stdout io.Writer) *Store {
113-
s.stdout = stdout
114-
return s
115-
}
116-
117-
// Close close the store
118-
func (s *Store) Close() {
119-
s.ticker.Stop()
120-
s.db.Db.Close()
121-
}
122-
123-
func (s *Store) gc() {
124-
for range s.ticker.C {
125-
s.clean()
126-
}
127-
}
128-
129-
func (s *Store) clean() {
130-
now := time.Now().Unix()
131-
query := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE expired_at<=? OR (code='' AND access='' AND refresh='')", s.tableName)
132-
n, err := s.db.SelectInt(query, now)
133-
if err != nil || n == 0 {
134-
if err != nil {
135-
s.errorf(err.Error())
136-
}
137-
return
138-
}
139-
140-
_, err = s.db.Exec(fmt.Sprintf("DELETE FROM %s WHERE expired_at<=? OR (code='' AND access='' AND refresh='')", s.tableName), now)
141-
if err != nil {
142-
s.errorf(err.Error())
143-
}
144-
}
145-
146-
func (s *Store) errorf(format string, args ...interface{}) {
147-
if s.stdout != nil {
148-
buf := fmt.Sprintf("[OAUTH2-MYSQL-ERROR]: "+format, args...)
149-
s.stdout.Write([]byte(buf))
150-
}
151-
}
152-
153-
// Create create and store the new token information
154-
func (s *Store) Create(ctx context.Context, info oauth2.TokenInfo) error {
155-
buf, _ := jsoniter.Marshal(info)
156-
item := &StoreItem{
157-
Data: string(buf),
158-
}
159-
160-
if code := info.GetCode(); code != "" {
161-
item.Code = code
162-
item.ExpiredAt = info.GetCodeCreateAt().Add(info.GetCodeExpiresIn()).Unix()
163-
} else {
164-
item.Access = info.GetAccess()
165-
item.ExpiredAt = info.GetAccessCreateAt().Add(info.GetAccessExpiresIn()).Unix()
166-
167-
if refresh := info.GetRefresh(); refresh != "" {
168-
item.Refresh = info.GetRefresh()
169-
item.ExpiredAt = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Unix()
170-
}
171-
}
172-
173-
return s.db.Insert(item)
174-
}
175-
176-
// RemoveByCode delete the authorization code
177-
func (s *Store) RemoveByCode(ctx context.Context, code string) error {
178-
query := fmt.Sprintf("UPDATE %s SET code='' WHERE code=? LIMIT 1", s.tableName)
179-
_, err := s.db.Exec(query, code)
180-
if err != nil && err == sql.ErrNoRows {
181-
return nil
182-
}
183-
return err
184-
}
185-
186-
// RemoveByAccess use the access token to delete the token information
187-
func (s *Store) RemoveByAccess(ctx context.Context, access string) error {
188-
query := fmt.Sprintf("UPDATE %s SET access='' WHERE access=? LIMIT 1", s.tableName)
189-
_, err := s.db.Exec(query, access)
190-
if err != nil && err == sql.ErrNoRows {
191-
return nil
192-
}
193-
return err
194-
}
195-
196-
// RemoveByRefresh use the refresh token to delete the token information
197-
func (s *Store) RemoveByRefresh(ctx context.Context, refresh string) error {
198-
query := fmt.Sprintf("UPDATE %s SET refresh='' WHERE refresh=? LIMIT 1", s.tableName)
199-
_, err := s.db.Exec(query, refresh)
200-
if err != nil && err == sql.ErrNoRows {
201-
return nil
202-
}
203-
return err
204-
}
205-
206-
func (s *Store) toTokenInfo(data string) oauth2.TokenInfo {
207-
var tm models.Token
208-
jsoniter.Unmarshal([]byte(data), &tm)
209-
return &tm
210-
}
211-
212-
// GetByCode use the authorization code for token information data
213-
func (s *Store) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
214-
if code == "" {
215-
return nil, nil
216-
}
217-
218-
query := fmt.Sprintf("SELECT * FROM %s WHERE code=? LIMIT 1", s.tableName)
219-
var item StoreItem
220-
err := s.db.SelectOne(&item, query, code)
221-
if err != nil {
222-
if err == sql.ErrNoRows {
223-
return nil, nil
224-
}
225-
return nil, err
226-
}
227-
return s.toTokenInfo(item.Data), nil
228-
}
229-
230-
// GetByAccess use the access token for token information data
231-
func (s *Store) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
232-
if access == "" {
233-
return nil, nil
234-
}
235-
236-
query := fmt.Sprintf("SELECT * FROM %s WHERE access=? LIMIT 1", s.tableName)
237-
var item StoreItem
238-
err := s.db.SelectOne(&item, query, access)
239-
if err != nil {
240-
if err == sql.ErrNoRows {
241-
return nil, nil
242-
}
243-
return nil, err
244-
}
245-
return s.toTokenInfo(item.Data), nil
246-
}
247-
248-
// GetByRefresh use the refresh token for token information data
249-
func (s *Store) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
250-
if refresh == "" {
251-
return nil, nil
252-
}
253-
254-
query := fmt.Sprintf("SELECT * FROM %s WHERE refresh=? LIMIT 1", s.tableName)
255-
var item StoreItem
256-
err := s.db.SelectOne(&item, query, refresh)
257-
if err != nil {
258-
if err == sql.ErrNoRows {
259-
return nil, nil
260-
}
261-
return nil, err
262-
}
263-
return s.toTokenInfo(item.Data), nil
264-
}

mysql_test.go

+33-1
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@ package mysql
22

33
import (
44
"context"
5-
"github.com/go-oauth2/oauth2/v4/models"
5+
"regexp"
66
"testing"
77
"time"
88

9+
"github.com/DATA-DOG/go-sqlmock"
10+
"github.com/go-oauth2/oauth2/v4/models"
911
_ "github.com/go-sql-driver/mysql"
1012
. "github.com/smartystreets/goconvey/convey"
13+
"github.com/stretchr/testify/assert"
14+
"gopkg.in/gorp.v2"
1115
)
1216

1317
const (
@@ -111,3 +115,31 @@ func TestTokenStore(t *testing.T) {
111115
})
112116
})
113117
}
118+
119+
func TestNewStoreWithOpts_ShouldReturnStoreNotNil(t *testing.T) {
120+
// ARRANGE
121+
db, mockDB, _ := sqlmock.New()
122+
tableName := "custom_table_name"
123+
124+
// Mock sql exec create table
125+
mockDB.ExpectExec(regexp.QuoteMeta("create table if not exists `custom_table_name` (`id` bigint not null primary key auto_increment, `expired_at` bigint, `code` varchar(255), `access` varchar(255), `refresh` varchar(255), `data` text) engine=InnoDB charset=UTF8;")).
126+
WillReturnResult(sqlmock.NewResult(0, 0))
127+
128+
// Mock query:
129+
mockDB.ExpectQuery(regexp.QuoteMeta("SELECT COUNT(*) FROM custom_table_name WHERE expired_at<=? OR (code='' AND access='' AND refresh='')")).
130+
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow(0))
131+
132+
// ACTION
133+
store := NewStoreWithOpts(db,
134+
WithTableName(tableName),
135+
WithSQLDialect(gorp.MySQLDialect{Engine: "InnoDB", Encoding: "UTF8"}),
136+
WithGCTimeInterval(1000),
137+
)
138+
139+
defer store.clean()
140+
141+
// ASSERT
142+
assert.NotNil(t, store)
143+
assert.NotNil(t, store.ticker)
144+
assert.Equal(t, store.tableName, tableName)
145+
}

options.go

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package mysql
2+
3+
import (
4+
"time"
5+
6+
"gopkg.in/gorp.v2"
7+
)
8+
9+
type Option interface {
10+
apply(*Store)
11+
}
12+
13+
type optionFunc func(store *Store)
14+
15+
func (f optionFunc) apply(store *Store) {
16+
f(store)
17+
}
18+
19+
// WithTableName sets the table name for the store.
20+
func WithTableName(tableName string) Option {
21+
return optionFunc(func(store *Store) {
22+
if tableName != "" {
23+
store.tableName = tableName
24+
}
25+
})
26+
}
27+
28+
// WithSQLDialect sets the database for the store.
29+
func WithSQLDialect(dialect gorp.MySQLDialect) Option {
30+
return optionFunc(func(store *Store) {
31+
store.db.Dialect = dialect
32+
})
33+
}
34+
35+
// WithGCTimeInterval sets the time interval for garbage collection.
36+
func WithGCTimeInterval(interval int) Option {
37+
return optionFunc(func(store *Store) {
38+
if interval != 0 {
39+
store.ticker = time.NewTicker(time.Second * time.Duration(interval))
40+
}
41+
})
42+
}

0 commit comments

Comments
 (0)