1
1
package mysql
2
2
3
3
import (
4
- "context"
5
4
"database/sql"
6
- "fmt"
7
- "io"
8
5
"os"
9
6
"time"
10
7
11
- "github.com/go-oauth2/oauth2/v4"
12
- "github.com/go-oauth2/oauth2/v4/models"
13
- "github.com/json-iterator/go"
14
8
"gopkg.in/gorp.v2"
15
9
)
16
10
@@ -69,20 +63,34 @@ func NewStore(config *Config, tableName string, gcInterval int) *Store {
69
63
// tableName table name (default oauth2_token),
70
64
// GC time interval (in seconds, default 600)
71
65
func NewStoreWithDB (db * sql.DB , tableName string , gcInterval int ) * Store {
66
+ // Init store with options
67
+ store := NewStoreWithOpts (db ,
68
+ WithSQLDialect (gorp.MySQLDialect {Encoding : "UTF8" , Engine : "MyISAM" }),
69
+ WithTableName (tableName ),
70
+ WithGCTimeInterval (gcInterval ),
71
+ )
72
+
73
+ go store .gc ()
74
+ return store
75
+ }
76
+
77
+ // NewStoreWithOpts create mysql store instance with apply custom input,
78
+ // db sql.DB,
79
+ // tableName table name (default oauth2_token),
80
+ // GC time interval (in seconds, default 600)
81
+ func NewStoreWithOpts (db * sql.DB , opts ... Option ) * Store {
82
+ // Init store with default value
72
83
store := & Store {
73
84
db : & gorp.DbMap {Db : db , Dialect : gorp.MySQLDialect {Encoding : "UTF8" , Engine : "MyISAM" }},
74
85
tableName : "oauth2_token" ,
75
86
stdout : os .Stderr ,
76
- }
77
- if tableName != "" {
78
- store .tableName = tableName
87
+ ticker : time .NewTicker (time .Second * time .Duration (600 )),
79
88
}
80
89
81
- interval := 600
82
- if gcInterval > 0 {
83
- interval = gcInterval
90
+ // Apply with optional function
91
+ for _ , opt := range opts {
92
+ opt . apply ( store )
84
93
}
85
- store .ticker = time .NewTicker (time .Second * time .Duration (interval ))
86
94
87
95
table := store .db .AddTableWithName (StoreItem {}, store .tableName )
88
96
table .AddIndex ("idx_code" , "Btree" , []string {"code" })
@@ -94,171 +102,9 @@ func NewStoreWithDB(db *sql.DB, tableName string, gcInterval int) *Store {
94
102
if err != nil {
95
103
panic (err )
96
104
}
97
- store .db .CreateIndex ()
105
+
106
+ _ = store .db .CreateIndex ()
98
107
99
108
go store .gc ()
100
109
return store
101
110
}
102
-
103
- // Store mysql token store
104
- type Store struct {
105
- tableName string
106
- db * gorp.DbMap
107
- stdout io.Writer
108
- ticker * time.Ticker
109
- }
110
-
111
- // SetStdout set error output
112
- func (s * Store ) SetStdout (stdout io.Writer ) * Store {
113
- s .stdout = stdout
114
- return s
115
- }
116
-
117
- // Close close the store
118
- func (s * Store ) Close () {
119
- s .ticker .Stop ()
120
- s .db .Db .Close ()
121
- }
122
-
123
- func (s * Store ) gc () {
124
- for range s .ticker .C {
125
- s .clean ()
126
- }
127
- }
128
-
129
- func (s * Store ) clean () {
130
- now := time .Now ().Unix ()
131
- query := fmt .Sprintf ("SELECT COUNT(*) FROM %s WHERE expired_at<=? OR (code='' AND access='' AND refresh='')" , s .tableName )
132
- n , err := s .db .SelectInt (query , now )
133
- if err != nil || n == 0 {
134
- if err != nil {
135
- s .errorf (err .Error ())
136
- }
137
- return
138
- }
139
-
140
- _ , err = s .db .Exec (fmt .Sprintf ("DELETE FROM %s WHERE expired_at<=? OR (code='' AND access='' AND refresh='')" , s .tableName ), now )
141
- if err != nil {
142
- s .errorf (err .Error ())
143
- }
144
- }
145
-
146
- func (s * Store ) errorf (format string , args ... interface {}) {
147
- if s .stdout != nil {
148
- buf := fmt .Sprintf ("[OAUTH2-MYSQL-ERROR]: " + format , args ... )
149
- s .stdout .Write ([]byte (buf ))
150
- }
151
- }
152
-
153
- // Create create and store the new token information
154
- func (s * Store ) Create (ctx context.Context , info oauth2.TokenInfo ) error {
155
- buf , _ := jsoniter .Marshal (info )
156
- item := & StoreItem {
157
- Data : string (buf ),
158
- }
159
-
160
- if code := info .GetCode (); code != "" {
161
- item .Code = code
162
- item .ExpiredAt = info .GetCodeCreateAt ().Add (info .GetCodeExpiresIn ()).Unix ()
163
- } else {
164
- item .Access = info .GetAccess ()
165
- item .ExpiredAt = info .GetAccessCreateAt ().Add (info .GetAccessExpiresIn ()).Unix ()
166
-
167
- if refresh := info .GetRefresh (); refresh != "" {
168
- item .Refresh = info .GetRefresh ()
169
- item .ExpiredAt = info .GetRefreshCreateAt ().Add (info .GetRefreshExpiresIn ()).Unix ()
170
- }
171
- }
172
-
173
- return s .db .Insert (item )
174
- }
175
-
176
- // RemoveByCode delete the authorization code
177
- func (s * Store ) RemoveByCode (ctx context.Context , code string ) error {
178
- query := fmt .Sprintf ("UPDATE %s SET code='' WHERE code=? LIMIT 1" , s .tableName )
179
- _ , err := s .db .Exec (query , code )
180
- if err != nil && err == sql .ErrNoRows {
181
- return nil
182
- }
183
- return err
184
- }
185
-
186
- // RemoveByAccess use the access token to delete the token information
187
- func (s * Store ) RemoveByAccess (ctx context.Context , access string ) error {
188
- query := fmt .Sprintf ("UPDATE %s SET access='' WHERE access=? LIMIT 1" , s .tableName )
189
- _ , err := s .db .Exec (query , access )
190
- if err != nil && err == sql .ErrNoRows {
191
- return nil
192
- }
193
- return err
194
- }
195
-
196
- // RemoveByRefresh use the refresh token to delete the token information
197
- func (s * Store ) RemoveByRefresh (ctx context.Context , refresh string ) error {
198
- query := fmt .Sprintf ("UPDATE %s SET refresh='' WHERE refresh=? LIMIT 1" , s .tableName )
199
- _ , err := s .db .Exec (query , refresh )
200
- if err != nil && err == sql .ErrNoRows {
201
- return nil
202
- }
203
- return err
204
- }
205
-
206
- func (s * Store ) toTokenInfo (data string ) oauth2.TokenInfo {
207
- var tm models.Token
208
- jsoniter .Unmarshal ([]byte (data ), & tm )
209
- return & tm
210
- }
211
-
212
- // GetByCode use the authorization code for token information data
213
- func (s * Store ) GetByCode (ctx context.Context , code string ) (oauth2.TokenInfo , error ) {
214
- if code == "" {
215
- return nil , nil
216
- }
217
-
218
- query := fmt .Sprintf ("SELECT * FROM %s WHERE code=? LIMIT 1" , s .tableName )
219
- var item StoreItem
220
- err := s .db .SelectOne (& item , query , code )
221
- if err != nil {
222
- if err == sql .ErrNoRows {
223
- return nil , nil
224
- }
225
- return nil , err
226
- }
227
- return s .toTokenInfo (item .Data ), nil
228
- }
229
-
230
- // GetByAccess use the access token for token information data
231
- func (s * Store ) GetByAccess (ctx context.Context , access string ) (oauth2.TokenInfo , error ) {
232
- if access == "" {
233
- return nil , nil
234
- }
235
-
236
- query := fmt .Sprintf ("SELECT * FROM %s WHERE access=? LIMIT 1" , s .tableName )
237
- var item StoreItem
238
- err := s .db .SelectOne (& item , query , access )
239
- if err != nil {
240
- if err == sql .ErrNoRows {
241
- return nil , nil
242
- }
243
- return nil , err
244
- }
245
- return s .toTokenInfo (item .Data ), nil
246
- }
247
-
248
- // GetByRefresh use the refresh token for token information data
249
- func (s * Store ) GetByRefresh (ctx context.Context , refresh string ) (oauth2.TokenInfo , error ) {
250
- if refresh == "" {
251
- return nil , nil
252
- }
253
-
254
- query := fmt .Sprintf ("SELECT * FROM %s WHERE refresh=? LIMIT 1" , s .tableName )
255
- var item StoreItem
256
- err := s .db .SelectOne (& item , query , refresh )
257
- if err != nil {
258
- if err == sql .ErrNoRows {
259
- return nil , nil
260
- }
261
- return nil , err
262
- }
263
- return s .toTokenInfo (item .Data ), nil
264
- }
0 commit comments