new tokenizer that handles quotes properly, use it everywhere
authorMarko Kreen <markokr@gmail.com>
Tue, 11 Mar 2008 16:20:26 +0000 (16:20 +0000)
committerMarko Kreen <markokr@gmail.com>
Tue, 11 Mar 2008 16:20:26 +0000 (16:20 +0000)
python/skytools/parsing.py

index 1f4dd78105dd80e9431d8774fc3aeec5b2eaf5cc..906c25a92285314f58241c34b96b72a081d43368 100644 (file)
@@ -3,7 +3,8 @@
 
 import re
 
-from skytools.quoting import unescape
+from skytools.quoting import unescape, unquote_sql_string, unquote_sql_ident
+from skytools.sqltools import dbdict
 
 __all__ = ["parse_pgarray", "parse_logtriga_sql", "parse_tabbed_table", "parse_statements"]
 
@@ -41,40 +42,9 @@ def parse_pgarray(array):
 #
 
 class _logtriga_parser:
-    token_re = r"""
-        [ \t\r\n]*
-        ( [a-z][a-z0-9_]*
-        | ["] ( [^"\\]+ | \\. )* ["]
-        | ['] ( [^'\\]+ | \\. | [']['] )* [']
-        | [^ \t\r\n]
-        )"""
-    token_rc = None
-
     def tokenizer(self, sql):
-        if not _logtriga_parser.token_rc:
-            _logtriga_parser.token_rc = re.compile(self.token_re, re.X | re.I)
-        rc = self.token_rc
-
-        pos = 0
-        while 1:
-            m = rc.match(sql, pos)
-            if not m:
-                break
-            pos = m.end()
-            yield m.group(1)
-
-    def unquote_data(self, fields, values):
-        # unquote data and column names
-        data = {}
-        for k, v in zip(fields, values):
-            if k[0] == '"':
-                k = unescape(k[1:-1])
-            if len(v) == 4 and v.lower() == "null":
-                v = None
-            elif v[0] == "'":
-                v = unescape(v[1:-1])
-            data[k] = v
-        return data
+        for typ, tok in sql_tokenizer(sql, ignore_whitespace = True):
+            yield tok
 
     def parse_insert(self, tk, fields, values):
         # (col1, col2) values ('data', null)
@@ -88,17 +58,19 @@ class _logtriga_parser:
             elif t != ",":
                 raise Exception("syntax error")
         if tk.next().lower() != "values":
-            raise Exception("syntax error")
+            raise Exception("syntax error, expected VALUES")
         if tk.next() != "(":
-            raise Exception("syntax error")
+            raise Exception("syntax error, expected (")
         while 1:
+            values.append(tk.next())
             t = tk.next()
             if t == ")":
                 break
             if t == ",":
                 continue
-            values.append(t)
-        tk.next()
+            raise Exception("expected , or ) got "+t)
+        t = tk.next()
+        raise Exception("expected EOF, got " + repr(t))
 
     def parse_update(self, tk, fields, values):
         # col1 = 'data1', col2 = null where pk1 = 'pk1' and pk2 = 'pk2'
@@ -114,27 +86,26 @@ class _logtriga_parser:
             elif t.lower() == "where":
                 break
             else:
-                raise Exception("syntax error")
+                raise Exception("syntax error, expected WHERE or , got "+repr(t))
         while 1:
-            t = tk.next()
-            fields.append(t)
+            fields.append(tk.next())
             if tk.next() != "=":
                 raise Exception("syntax error")
             values.append(tk.next())
             t = tk.next()
             if t.lower() != "and":
-                raise Exception("syntax error")
+                raise Exception("syntax error, expected AND got "+repr(t))
 
     def parse_delete(self, tk, fields, values):
         # pk1 = 'pk1' and pk2 = 'pk2'
         while 1:
-            t = tk.next()
-            if t == "and":
-                continue
-            fields.append(t)
+            fields.append(tk.next())
             if tk.next() != "=":
                 raise Exception("syntax error")
             values.append(tk.next())
+            t = tk.next()
+            if t.lower() != "and":
+                raise Exception("syntax error, expected AND, got "+repr(t))
 
     def parse_sql(self, op, sql):
         tk = self.tokenizer(sql)
@@ -151,9 +122,10 @@ class _logtriga_parser:
         except StopIteration:
             # last sanity check
             if len(fields) == 0 or len(fields) != len(values):
-                raise Exception("syntax error")
-
-        return self.unquote_data(fields, values)
+                raise Exception("syntax error, fields do not match values")
+        fields = [unquote_sql_ident(f) for f in fields]
+        values = [unquote_sql_string(f) for f in values]
+        return dbdict(zip(fields, values))
 
 def parse_logtriga_sql(op, sql):
     """Parse partial SQL used by logtriga() back to data values.
@@ -195,24 +167,37 @@ def parse_tabbed_table(txt):
     return data
 
 
-_sql_token_re = r"""
-    ( [a-z][a-z0-9_$]*
-    | ["] ( [^"\\]+ | \\. )* ["]
-    | ['] ( [^'\\]+ | \\. | [']['] )* [']
-    | [$] ([_a-z][_a-z0-9]*)? [$]
-    | (?P<ws> \s+ | [/][*] | [-][-][^\n]* )
-    | .
-    )"""
-_sql_token_rc = None
-_copy_from_stdin_re = "copy.*from\s+stdin"
-_copy_from_stdin_rc = None
+_extstr = r""" ['] (?: [^'\\]+ | \\. | [']['] )* ['] """
+_stdstr = r""" ['] (?: [^']+ | [']['] )* ['] """
+_base_sql = r"""
+      (?P<ident>  [a-z][a-z0-9_$]* | ["] (?: [^"]+ | ["]["] )* ["] )
+    | (?P<dolq>   (?P<dname> [$] (?: [_a-z][_a-z0-9]*)? [$] )
+                  .*?
+                  (?P=dname) )
+    | (?P<num>    [0-9][0-9.e]*
+    | (?P<numarg> [$] [0-9]+ )
+    | (?P<pyold>  [%][(] [a-z0-9_]+ [)][s] | [%][%])
+    | (?P<pynew>  [{] [^}]+ [}] | [{][{] | [}] [}] )
+    | (?P<ws>     (?: \s+ | [/][*] .*? [*][/] | [-][-][^\n]* )+ )
+    | (?P<sym>    . )"""
+_std_sql = r"""(?: (?P<str> [E] %s | %s ) | %s )""" % (_extstr, _stdstr, _base_sql)
+_ext_sql = r"""(?: (?P<str> [E]? %s ) | %s )""" % (_extstr, _base_sql)
+_std_sql_rc = _ext_sql_rc = None
+
+def sql_tokenizer(sql, standard_quoting = False, ignore_whitespace = False):
+    """Parser SQL to tokens.
+
+    Iterator, returns (toktype, tokstr) tuples.
+    """
+    global _std_sql_rc, _ext_sql_rc
+    if not _std_sql_rc:
+        _std_sql_rc = re.compile(_std_sql, re.X | re.I | re.S)
+        _ext_sql_rc = re.compile(_ext_sql, re.X | re.I | re.S)
 
-def _sql_tokenizer(sql):
-    global _sql_token_rc, _copy_from_stdin_rc
-    if not _sql_token_rc:
-        _sql_token_rc = re.compile(_sql_token_re, re.X | re.I)
-        _copy_from_stdin_rc = re.compile(_copy_from_stdin_re, re.X | re.I)
-    rc = _sql_token_rc
+    if standard_quoting:
+        rc = _std_sql_rc
+    else:
+        rc = _ext_sql_rc
 
     pos = 0
     while 1:
@@ -220,38 +205,26 @@ def _sql_tokenizer(sql):
         if not m:
             break
         pos = m.end()
-        tok = m.group(1)
-        ws = m.start('ws') >= 0 # it tok empty?
-        if tok == "/*":
-            end = sql.find("*/", pos)
-            if end < 0:
-                raise Exception("unterminated c comment")
-            pos = end + 2
-            tok = sql[ m.start() : pos]
-        elif len(tok) > 1 and tok[0] == "$" and tok[-1] == "$":
-            end = sql.find(tok, pos)
-            if end < 0:
-                raise Exception("unterminated dollar string")
-            pos = end + len(tok)
-            tok = sql[ m.start() : pos]
-        yield (ws, tok)
+        typ = m.lastgroup
+        if not ignore_whitespace or typ != "ws":
+            yield (m.lastgroup, m.group())
 
+_copy_from_stdin_re = "copy.*from\s+stdin"
+_copy_from_stdin_rc = None
 def parse_statements(sql):
     """Parse multi-statement string into separate statements.
 
     Returns list of statements.
     """
 
-    tk = _sql_tokenizer(sql)
+    global _copy_from_stdin_rc
+    if not _copy_from_stdin_rc:
+        _copy_from_stdin_rc = re.compile(_copy_from_stdin_re, re.X | re.I)
     tokens = []
     pcount = 0 # '(' level
-    while 1:
-        try:
-            ws, t = tk.next()
-        except StopIteration:
-            break
+    for typ, t in _sql_tokenizer(sql):
         # skip whitespace and comments before statement
-        if len(tokens) == 0 and ws:
+        if len(tokens) == 0 and typ == "ws":
             continue
         # keep the rest
         tokens.append(t)