diff --git a/statement_parser.go b/statement_parser.go index 53117941..c58b25db 100644 --- a/statement_parser.go +++ b/statement_parser.go @@ -61,6 +61,112 @@ func parseParameters(sql string) (string, []string, error) { return findParams('?', sql) } +// Skips all whitespaces from the given position and returns the +// position of the next non-whitespace character or len(sql) if +// the string does not contain any whitespaces after pos. +func skipWhitespaces(sql []rune, pos int) int { + for ; pos < len(sql); pos++ { + c := sql[pos] + if c == '-' && len(sql) > pos+1 && sql[pos+1] == '-' { + // This is a single line comment starting with '--'. + skipSingleLineComment(sql, pos+2) + } else if c == '#' { + // This is a single line comment starting with '#'. + skipSingleLineComment(sql, pos+1) + } else if c == '/' && len(sql) > pos+1 && sql[pos+1] == '*' { + // This is a multi line comment starting with '/*'. + skipMultiLineComment(sql, pos) + } else if !unicode.IsSpace(c) { + break + } + } + return pos +} + +// Skips the next character, quoted literal, quoted identifier or comment in +// the given sql string from the given position and returns the position of the +// next character. +func skip(sql []rune, pos int) (int, error) { + if pos >= len(sql) { + return pos, nil + } + c := sql[pos] + + if c == '\'' || c == '"' || c == '`' { + // This is a quoted string or quoted identifier. + return skipQuoted(sql, pos, c) + } else if c == '-' && len(sql) > pos+1 && sql[pos+1] == '-' { + // This is a single line comment starting with '--'. + return skipSingleLineComment(sql, pos+2), nil + } else if c == '#' { + // This is a single line comment starting with '#'. + return skipSingleLineComment(sql, pos+1), nil + } else if c == '/' && len(sql) > pos+1 && sql[pos+1] == '*' { + // This is a multi line comment starting with '/*'. + return skipMultiLineComment(sql, pos), nil + } + return pos + 1, nil +} + +func skipSingleLineComment(sql []rune, pos int) int { + for ; pos < len(sql) && sql[pos] != '\n'; pos++ { + } + return min(pos+1, len(sql)) +} + +func skipMultiLineComment(sql []rune, pos int) int { + // Skip '/*'. + pos = pos + 2 + // Search for the first '*/' sequence. Note that GoogleSQL does not support + // nested comments, so any occurrence of '/*' inside the comment does not + // have any special meaning. + for ; pos < len(sql); pos++ { + if sql[pos] == '*' && len(sql) > pos+1 && sql[pos+1] == '/' { + return pos + 2 + } + } + return pos +} + +func skipQuoted(sql []rune, pos int, quote rune) (int, error) { + isTripleQuoted := len(sql) > pos+2 && sql[pos+1] == quote && sql[pos+2] == quote + if isTripleQuoted { + pos += 3 + } else { + pos += 1 + } + for ; pos < len(sql); pos++ { + c := sql[pos] + if c == quote { + if isTripleQuoted { + // Check if this is the end of the triple-quoted string. + if len(sql) > pos+2 && sql[pos+1] == quote && sql[pos+2] == quote { + return pos + 3, nil + } + } else { + // This was the end quote. + return pos + 1, nil + } + } else if len(sql) > pos+1 && c == '\\' && sql[pos+1] == quote { + // This is an escaped quote (e.g. 'foo\'bar'). + // Note that in raw strings, the \ officially does not start an + // escape sequence, but the result is still the same, as in a raw + // string 'both characters are preserved'. + pos += 1 + } else if !isTripleQuoted && c == '\n' { + break + } + } + return 0, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "SQL statement contains an unclosed literal: %s", string(sql))) +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + // RemoveCommentsAndTrim removes any comments in the query string and trims any // spaces at the beginning and end of the query. This makes checking what type // of query a string is a lot easier, as only the first word(s) need to be diff --git a/statement_parser_test.go b/statement_parser_test.go index 5262bf9b..9453d682 100644 --- a/statement_parser_test.go +++ b/statement_parser_test.go @@ -989,6 +989,170 @@ func TestFindParams_Errors(t *testing.T) { } } +func TestSkip(t *testing.T) { + tests := []struct { + input string + pos int + skipped string + invalid bool + }{ + { + input: "", + skipped: "", + }, + { + input: "1 ", + skipped: "1", + }, + { + input: "12 ", + skipped: "1", + }, + { + input: "12 ", + pos: 1, + skipped: "2", + }, + { + input: "12", + pos: 2, + skipped: "", + }, + { + input: "'foo' ", + skipped: "'foo'", + }, + { + input: "'foo''bar' ", + skipped: "'foo'", + }, + { + input: "'foo' 'bar' ", + skipped: "'foo'", + }, + { + input: "'foo''bar' ", + pos: 5, + skipped: "'bar'", + }, + { + input: `'foo"bar"' `, + skipped: `'foo"bar"'`, + }, + { + input: `"foo'bar'" `, + skipped: `"foo'bar'"`, + }, + { + input: "`foo'bar'` ", + skipped: "`foo'bar'`", + }, + { + input: "'''foo'bar''' ", + skipped: "'''foo'bar'''", + }, + { + input: "'''foo\\'bar''' ", + skipped: "'''foo\\'bar'''", + }, + { + input: "'''foo\\'\\'bar''' ", + skipped: "'''foo\\'\\'bar'''", + }, + { + input: "'''foo\\'\\'\\'bar''' ", + skipped: "'''foo\\'\\'\\'bar'''", + }, + { + input: "```foo`bar``` ", + skipped: "```foo`bar```", + }, + { + input: `"""foo"bar""" `, + skipped: `"""foo"bar"""`, + }, + { + input: "-- comment", + skipped: "-- comment", + }, + { + input: "-- comment\nselect * from foo", + skipped: "-- comment\n", + }, + { + input: "# comment", + skipped: "# comment", + }, + { + input: "# comment\nselect * from foo", + skipped: "# comment\n", + }, + { + input: "/* comment */", + skipped: "/* comment */", + }, + { + input: "/* comment */ select * from foo", + skipped: "/* comment */", + }, + { + input: "/* comment /* GoogleSQL does not support nested comments */ select * from foo", + skipped: "/* comment /* GoogleSQL does not support nested comments */", + }, + { + // GoogleSQL does not support dollar-quoted strings. + input: "$tag$not a string$tag$ select * from foo", + skipped: "$", + }, + { + input: "/* 'test' */ foo", + skipped: "/* 'test' */", + }, + { + input: "-- 'test' \n foo", + skipped: "-- 'test' \n", + }, + { + input: "'/* test */' foo", + skipped: "'/* test */'", + }, + { + input: "'foo\\'' ", + skipped: "'foo\\''", + }, + { + input: "r'foo\\'' ", + pos: 1, + skipped: "'foo\\''", + }, + { + input: "'''foo\\'\\'\\'bar''' ", + skipped: "'''foo\\'\\'\\'bar'''", + }, + { + input: "'foo\n' ", + invalid: true, + }, + { + input: "'''foo\n''' ", + skipped: "'''foo\n'''", + }, + } + for _, test := range tests { + pos, err := skip([]rune(test.input), test.pos) + if test.invalid && err == nil { + t.Errorf("missing expected error for %s", test.input) + } else if !test.invalid && err != nil { + t.Errorf("got unexpected error for %s: %v", test.input, err) + } else { + skipped := test.input[test.pos:pos] + if skipped != test.skipped { + t.Errorf("skipped mismatch\nGot: %v\nWant: %v", skipped, test.skipped) + } + } + } +} + var fuzzQuerySamples = []string{"", "SELECT 1;", "RUN BATCH", "ABORT BATCH", "Show variable Retry_Aborts_Internally", "@{JOIN_METHOD=HASH_JOIN SELECT * FROM PersonsTable"} func init() {