Add an option to split the key and value fields
authorLukáš Lalinský <lalinsky@gmail.com>
Sun, 16 Oct 2011 11:20:31 +0000 (13:20 +0200)
committerLukáš Lalinský <lalinsky@gmail.com>
Sun, 16 Oct 2011 11:24:21 +0000 (13:24 +0200)
python/skytools/parsing.py

index cf4914171352661e75101bf6468b6b5b001fb564..5f6ab9ade0b4f8ec7e701f647b700b456fd303b7 100644 (file)
@@ -69,7 +69,7 @@ class _logtriga_parser:
         for typ, tok in sql_tokenizer(sql, ignore_whitespace = True):
             yield tok
 
-    def parse_insert(self, tk, fields, values):
+    def parse_insert(self, tk, fields, values, key_fields, key_values):
         """Handler for inserts."""
         # (col1, col2) values ('data', null)
         if tk.next() != "(":
@@ -96,7 +96,7 @@ class _logtriga_parser:
         t = tk.next()
         raise Exception("expected EOF, got " + repr(t))
 
-    def parse_update(self, tk, fields, values):
+    def parse_update(self, tk, fields, values, key_fields, key_values):
         """Handler for updates."""
         # col1 = 'data1', col2 = null where pk1 = 'pk1' and pk2 = 'pk2'
         while 1:
@@ -104,7 +104,6 @@ class _logtriga_parser:
             if tk.next() != "=":
                 raise Exception("syntax error")
             values.append(tk.next())
-            
             t = tk.next()
             if t == ",":
                 continue
@@ -114,30 +113,35 @@ class _logtriga_parser:
                 raise Exception("syntax error, expected WHERE or , got "+repr(t))
         while 1:
             fld = tk.next()
-            fields.append(fld)
+            key_fields.append(fld)
             self.pklist.append(fld)
             if tk.next() != "=":
                 raise Exception("syntax error")
-            values.append(tk.next())
+            key_values.append(tk.next())
             t = tk.next()
             if t.lower() != "and":
                 raise Exception("syntax error, expected AND got "+repr(t))
 
-    def parse_delete(self, tk, fields, values):
+    def parse_delete(self, tk, fields, values, key_fields, key_values):
         """Handler for deletes."""
         # pk1 = 'pk1' and pk2 = 'pk2'
         while 1:
             fld = tk.next()
-            fields.append(fld)
+            key_fields.append(fld)
             self.pklist.append(fld)
             if tk.next() != "=":
                 raise Exception("syntax error")
-            values.append(tk.next())
+            key_values.append(tk.next())
             t = tk.next()
             if t.lower() != "and":
                 raise Exception("syntax error, expected AND, got "+repr(t))
 
-    def parse_sql(self, op, sql, pklist=None):
+    def _create_dbdict(self, fields, values):
+        fields = [skytools.unquote_ident(f) for f in fields]
+        values = [skytools.unquote_literal(f) for f in values]
+        return skytools.dbdict(zip(fields, values))
+
+    def parse_sql(self, op, sql, pklist=None, splitkeys=False):
         """Main entry point."""
         if pklist is None:
             self.pklist = []
@@ -146,26 +150,31 @@ class _logtriga_parser:
         tk = self.tokenizer(sql)
         fields = []
         values = []
+        key_fields = []
+        key_values = []
         try:
             if op == "I":
-                self.parse_insert(tk, fields, values)
+                self.parse_insert(tk, fields, values, key_fields, key_values)
             elif op == "U":
-                self.parse_update(tk, fields, values)
+                self.parse_update(tk, fields, values, key_fields, key_values)
             elif op == "D":
-                self.parse_delete(tk, fields, values)
+                self.parse_delete(tk, fields, values, key_fields, key_values)
             raise Exception("syntax error")
         except StopIteration:
             # last sanity check
-            if len(fields) == 0 or len(fields) != len(values):
+            if (len(fields) + len(key_fields) == 0 or
+                len(fields) != len(values) or
+                len(key_fields) != len(key_values)):
                 raise Exception("syntax error, fields do not match values")
-        fields = [skytools.unquote_ident(f) for f in fields]
-        values = [skytools.unquote_literal(f) for f in values]
-        return skytools.dbdict(zip(fields, values))
+        if splitkeys:
+            return (self._create_dbdict(key_fields, key_values),
+                    self._create_dbdict(fields, values))
+        return self._create_dbdict(fields + key_fields, values + key_values)
 
-def parse_logtriga_sql(op, sql):
-    return parse_sqltriga_sql(op, sql)
+def parse_logtriga_sql(op, sql, splitkeys=False):
+    return parse_sqltriga_sql(op, sql, splitkeys=splitkeys)
 
-def parse_sqltriga_sql(op, sql, pklist=None):
+def parse_sqltriga_sql(op, sql, pklist=None, splitkeys=False):
     """Parse partial SQL used by pgq.sqltriga() back to data values.
 
     Parser has following limitations:
@@ -173,7 +182,7 @@ def parse_sqltriga_sql(op, sql, pklist=None):
      - Does not support dollar quoting.
      - Does not support complex expressions anywhere. (hashtext(col1) = hashtext(val1))
      - WHERE expression must not contain IS (NOT) NULL
-     - Does not support updating pk value.
+     - Does not support updating pk value, unless you use the splitkeys parameter.
 
     Returns dict of col->data pairs.
 
@@ -188,8 +197,24 @@ def parse_sqltriga_sql(op, sql, pklist=None):
     Delete event:
     >>> parse_logtriga_sql('D', "id = 1 and id2 = 'str''val'")
     {'id2': "str'val", 'id': '1'}
+
+    If you set the splitkeys parameter, it will return two dicts, one for key
+    fields and one for data fields.
+
+    Insert event:
+    >>> parse_logtriga_sql('I', '(id, data) values (1, null)', splitkeys=True)
+    ({}, {'data': None, 'id': '1'})
+
+    Update event:
+    >>> parse_logtriga_sql('U', "data='foo' where id = 1", splitkeys=True)
+    ({'id': '1'}, {'data': 'foo'})
+
+    Delete event:
+    >>> parse_logtriga_sql('D', "id = 1 and id2 = 'str''val'", splitkeys=True)
+    ({'id2': "str'val", 'id': '1'}, {})
+
     """
-    return _logtriga_parser().parse_sql(op, sql, pklist)
+    return _logtriga_parser().parse_sql(op, sql, pklist, splitkeys=splitkeys)
 
 
 def parse_tabbed_table(txt):