python/skytools: add doctest-based regtests to few non-sql functions
authorMarko Kreen <markokr@gmail.com>
Wed, 14 Oct 2009 13:37:17 +0000 (16:37 +0300)
committerMarko Kreen <markokr@gmail.com>
Wed, 14 Oct 2009 13:37:17 +0000 (16:37 +0300)
Seems to be better testing method than ad-hoc scripts.  They will
serve as examples too.

Also fix few minor problems found in the process:
- parse_pgarray:  check if str ends with }
- parse_pgarray: support NULL
- quote_fqident: add 'public.' schema to idents without schema
- fq_name_parts: return always list

python/skytools/parsing.py
python/skytools/quoting.py
python/skytools/sqltools.py

index d50b14c0dabf5447ebdf864fdbb2dbdf04ef0e97..4b92306e0e62a2cc6e5e4df534a1eafa4798a47d 100644 (file)
@@ -14,11 +14,18 @@ _rc_listelem = re.compile(r'( [^,"}]+ | ["] ( [^"\\]+ | [\\]. )* ["] )', re.X)
 
 # _parse_pgarray
 def parse_pgarray(array):
-    """ Parse Postgres array and return list of items inside it
-        Used to deserialize data recived from service layer parameters
+    r"""Parse Postgres array and return list of items inside it.
+
+    Examples:
+    >>> parse_pgarray('{}')
+    []
+    >>> parse_pgarray('{a,b,null,"null"}')
+    ['a', 'b', None, 'null']
+    >>> parse_pgarray(r'{"a,a","b\"b","c\\c"}')
+    ['a,a', 'b"b', 'c\\c']
     """
-    if not array or array[0] != "{":
-        raise Exception("bad array format: must start with {")
+    if not array or array[0] != "{" or array[-1] != '}':
+        raise Exception("bad array format: must be surrounded with {}")
     res = []
     pos = 1
     while 1:
@@ -27,16 +34,19 @@ def parse_pgarray(array):
             break
         pos2 = m.end()
         item = array[pos:pos2]
-        if len(item) > 0 and item[0] == '"':
-            item = item[1:-1]
-        item = unescape(item)
-        res.append(item)
+        if len(item) == 4 and item.upper() == "NULL":
+            val = None
+        else:
+            if len(item) > 0 and item[0] == '"':
+                item = item[1:-1]
+            val = unescape(item)
+        res.append(val)
 
         pos = pos2 + 1
         if array[pos2] == "}":
             break
         elif array[pos2] != ",":
-            raise Exception("bad array format: expected ,} got " + array[pos2])
+            raise Exception("bad array format: expected ,} got " + repr(array[pos2]))
     return res
 
 #
@@ -136,26 +146,45 @@ class _logtriga_parser:
         return dbdict(zip(fields, values))
 
 def parse_logtriga_sql(op, sql):
-    """Parse partial SQL used by logtriga() back to data values.
+    return parse_sqltriga_sql(op, sql)
+
+def parse_sqltriga_sql(op, sql):
+    """Parse partial SQL used by pgq.sqltriga() back to data values.
 
     Parser has following limitations:
      - Expects standard_quoted_strings = off
      - Does not support dollar quoting.
      - Does not support complex expressions anywhere. (hashtext(col1) = hashtext(val1))
      - WHERE expression must not contain IS (NOT) NULL
-     - Does not support updateing pk value.
+     - Does not support updating pk value.
 
     Returns dict of col->data pairs.
+
+    Insert event:
+    >>> parse_logtriga_sql('I', '(id, data) values (1, null)')
+    {'data': None, 'id': '1'}
+
+    Update event:
+    >>> parse_logtriga_sql('U', "data='foo' where id = 1")
+    {'data': 'foo', 'id': '1'}
+
+    Delete event:
+    >>> parse_logtriga_sql('D', "id = 1 and id2 = 'str''val'")
+    {'id2': "str'val", 'id': '1'}
     """
     return _logtriga_parser().parse_sql(op, sql)
 
 
 def parse_tabbed_table(txt):
-    """Parse a tab-separated table into list of dicts.
+    r"""Parse a tab-separated table into list of dicts.
     
     Expect first row to be column names.
 
     Very primitive.
+
+    Example:
+    >>> parse_tabbed_table('col1\tcol2\nval1\tval2\n')
+    [{'col2': 'val2', 'col1': 'val1'}]
     """
 
     txt = txt.replace("\r\n", "\n")
@@ -194,9 +223,15 @@ _ext_sql = r"""(?: (?P<str> [E]? %s ) | %s )""" % (_extstr, _base_sql)
 _std_sql_rc = _ext_sql_rc = None
 
 def sql_tokenizer(sql, standard_quoting = False, ignore_whitespace = False):
-    """Parser SQL to tokens.
+    r"""Parser SQL to tokens.
 
     Iterator, returns (toktype, tokstr) tuples.
+
+    Example
+    >>> [x for x in sql_tokenizer("select * from a.b", ignore_whitespace=True)]
+    [('ident', 'select'), ('sym', '*'), ('ident', 'from'), ('ident', 'a'), ('sym', '.'), ('ident', 'b')]
+    >>> [x for x in sql_tokenizer("\"c olumn\",'str''val'")]
+    [('ident', '"c olumn"'), ('sym', ','), ('str', "'str''val'")]
     """
     global _std_sql_rc, _ext_sql_rc
     if not _std_sql_rc:
@@ -224,6 +259,9 @@ def parse_statements(sql, standard_quoting = False):
     """Parse multi-statement string into separate statements.
 
     Returns list of statements.
+
+    >>> [sql for sql in parse_statements("begin; select 1; select 'foo'; end;")]
+    ['begin;', 'select 1;', "select 'foo';", 'end;']
     """
 
     global _copy_from_stdin_rc
@@ -252,3 +290,7 @@ def parse_statements(sql, standard_quoting = False):
     if pcount != 0:
         raise Exception("syntax error - unbalanced parenthesis")
 
+if __name__ == '__main__':
+    import doctest
+    doctest.testmod()
+
index ce860040a62bff992b4a1e8e565119a7d26dd110..e83d24954ffb714eef99f5e95e5244ff8055c4d7 100644 (file)
@@ -82,8 +82,17 @@ def quote_fqident(s):
 
     The '.' is taken as namespace separator and
     all parts are quoted separately
+
+    Example:
+    >>> quote_fqident('tbl')
+    'public.tbl'
+    >>> quote_fqident('Baz.Foo.Bar')
+    '"Baz"."Foo.Bar"'
     """
-    return '.'.join(map(quote_ident, s.split('.', 1)))
+    tmp = s.split('.', 1)
+    if len(tmp) == 1:
+        return 'public.' + quote_ident(s)
+    return '.'.join(map(quote_ident, tmp))
 
 #
 # quoting for JSON strings
@@ -110,7 +119,14 @@ def quote_json(s):
     return '"%s"' % _jsre.sub(_json_quote_char, s)
 
 def unescape_copy(val):
-    """Removes C-style escapes, also converts "\N" to None."""
+    r"""Removes C-style escapes, also converts "\N" to None.
+
+    Example:
+    >>> unescape_copy(r'baz\tfo\'o')
+    "baz\tfo'o"
+    >>> unescape_copy(r'\N') is None
+    True
+    """
     if val == r"\N":
         return None
     return unescape(val)
@@ -129,3 +145,6 @@ def unquote_fqident(val):
     tmp = val.split('.', 1)
     return "%s.%s" % (unquote_ident(tmp[0]), unquote_ident(tmp[1]))
 
+if __name__ == '__main__':
+    import doctest
+    doctest.testmod()
index 3037df698a554c1c83b68b69d30d2b3ba0a68425..902bf9a6b99f19ee3d6d1596115de3f03da0d61e 100644 (file)
@@ -44,18 +44,34 @@ class dbdict(dict):
 #
 
 def fq_name_parts(tbl):
-    "Return fully qualified name parts."
+    """Return fully qualified name parts.
+
+    >>> fq_name_parts('tbl')
+    ['public', 'tbl']
+    >>> fq_name_parts('foo.tbl')
+    ['foo', 'tbl']
+    >>> fq_name_parts('foo.tbl.baz')
+    ['foo', 'tbl.baz']
+    """
 
     tmp = tbl.split('.', 1)
     if len(tmp) == 1:
-        return ('public', tbl)
+        return ['public', tbl]
     elif len(tmp) == 2:
         return tmp
     else:
         raise Exception('Syntax error in table name:'+tbl)
 
 def fq_name(tbl):
-    "Return fully qualified name."
+    """Return fully qualified name.
+
+    >>> fq_name('tbl')
+    'public.tbl'
+    >>> fq_name('foo.tbl')
+    'foo.tbl'
+    >>> fq_name('foo.tbl.baz')
+    'foo.tbl.baz'
+    """
     return '.'.join(fq_name_parts(tbl))
 
 #
@@ -171,7 +187,19 @@ def exists_temp_table(curs, tbl):
 #
 
 class Snapshot(object):
-    "Represents a PostgreSQL snapshot."
+    """Represents a PostgreSQL snapshot.
+
+    Example:
+    >>> sn = Snapshot('11:20:11,12,15')
+    >>> sn.contains(9)
+    True
+    >>> sn.contains(11)
+    False
+    >>> sn.contains(17)
+    True
+    >>> sn.contains(20)
+    False
+    """
 
     def __init__(self, str):
         "Create snapshot from string."
@@ -235,11 +263,15 @@ def _gen_list_insert(tbl, row, fields, qfields):
     return fmt % (tbl, ",".join(qfields), ",".join(tmp))
 
 def magic_insert(curs, tablename, data, fields = None, use_insert = 0):
-    """Copy/insert a list of dict/list data to database.
-    
+    r"""Copy/insert a list of dict/list data to database.
+
     If curs == None, then the copy or insert statements are returned
     as string.  For list of dict the field list is optional, as its
     possible to guess them from dict keys.
+
+    Example:
+    >>> magic_insert(None, 'tbl', [[1, '1'], [2, '2']], ['col1', 'col2'])
+    'COPY public.tbl (col1,col2) FROM STDIN;\n1\t1\n2\t2\n\\.\n'
     """
     if len(data) == 0:
         return
@@ -486,7 +518,11 @@ def installer_apply_file(db, filename, log):
 #
 
 def mk_insert_sql(row, tbl, pkey_list = None, field_map = None):
-    """Generate INSERT statement from dict data."""
+    """Generate INSERT statement from dict data.
+
+    >>> mk_insert_sql({'id': '1', 'data': None}, 'tbl')
+    "insert into public.tbl (data, id) values (null, '1');"
+    """
 
     col_list = []
     val_list = []
@@ -504,7 +540,11 @@ def mk_insert_sql(row, tbl, pkey_list = None, field_map = None):
                     quote_fqident(tbl), col_str, val_str)
 
 def mk_update_sql(row, tbl, pkey_list, field_map = None):
-    """Generate UPDATE statement from dict data."""
+    r"""Generate UPDATE statement from dict data.
+
+    >>> mk_update_sql({'id': 0, 'id2': '2', 'data': 'str\\'}, 'Table', ['id', 'id2'])
+    'update only public."Table" set data = E\'str\\\\\' where id = \'0\' and id2 = \'2\';'
+    """
 
     if len(pkey_list) < 1:
         raise Exception("update needs pkeys")
@@ -787,3 +827,7 @@ class PLPyQueryBuilder(QueryBuilder):
             res = [dbdict(r) for r in res]
         return res
 
+if __name__ == '__main__':
+    import doctest
+    doctest.testmod()
+