bring new quoting & parsing code to head
authorMarko Kreen <markokr@gmail.com>
Thu, 28 Feb 2008 09:27:25 +0000 (09:27 +0000)
committerMarko Kreen <markokr@gmail.com>
Thu, 28 Feb 2008 09:27:25 +0000 (09:27 +0000)
python/modules/cquoting.c [new file with mode: 0644]
python/skytools/__init__.py
python/skytools/_pyquoting.py [new file with mode: 0644]
python/skytools/parsing.py [new file with mode: 0644]
python/skytools/quoting.py
setup.py

diff --git a/python/modules/cquoting.c b/python/modules/cquoting.c
new file mode 100644 (file)
index 0000000..7d22e1e
--- /dev/null
@@ -0,0 +1,637 @@
+
+#define PY_SSIZE_T_CLEAN
+#include <Python.h>
+
+#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
+typedef int Py_ssize_t;
+#define PY_SSIZE_T_MAX INT_MAX
+#define PY_SSIZE_T_MIN INT_MIN
+#endif
+
+typedef enum { false = 0, true = 1 } bool;
+
+/*
+ * Common buffer management.
+ */
+
+struct Buf {
+       unsigned char *ptr;
+       unsigned long pos;
+       unsigned long alloc;
+};
+
+static unsigned char *buf_init(struct Buf *buf, unsigned init_size)
+{
+       if (init_size < 256)
+               init_size = 256;
+       buf->ptr = PyMem_Malloc(init_size);
+       if (buf->ptr) {
+               buf->pos = 0;
+               buf->alloc = init_size;
+       }
+       return buf->ptr;
+}
+
+/* return new pos */
+static unsigned char *buf_enlarge(struct Buf *buf, unsigned need_room)
+{
+       unsigned alloc = buf->alloc;
+       unsigned need_size = buf->pos + need_room;
+       unsigned char *ptr;
+
+       /* no alloc needed */
+       if (need_size < alloc)
+               return buf->ptr + buf->pos;
+
+       if (alloc <= need_size / 2)
+               alloc = need_size;
+       else
+               alloc = alloc * 2;
+
+       ptr = PyMem_Realloc(buf->ptr, alloc);
+       if (!ptr)
+               return NULL;
+
+       buf->ptr = ptr;
+       buf->alloc = alloc;
+       return buf->ptr + buf->pos;
+}
+
+static void buf_free(struct Buf *buf)
+{
+       PyMem_Free(buf->ptr);
+       buf->ptr = NULL;
+       buf->pos = buf->alloc = 0;
+}
+
+static inline unsigned char *buf_get_target_for(struct Buf *buf, unsigned len)
+{
+       if (buf->pos + len <= buf->alloc)
+               return buf->ptr + buf->pos;
+       else
+               return buf_enlarge(buf, len);
+}
+
+static inline void buf_set_target(struct Buf *buf, unsigned char *newpos)
+{
+       assert(buf->ptr + buf->pos <= newpos);
+       assert(buf->ptr + buf->alloc >= newpos);
+
+       buf->pos = newpos - buf->ptr;
+}
+
+static inline int buf_put(struct Buf *buf, unsigned char c)
+{
+       if (buf->pos < buf->alloc) {
+               buf->ptr[buf->pos++] = c;
+               return 1;
+       } else if (buf_enlarge(buf, 1)) {
+               buf->ptr[buf->pos++] = c;
+               return 1;
+       }
+       return 0;
+}
+
+static PyObject *buf_pystr(struct Buf *buf, unsigned start_pos, unsigned char *newpos)
+{
+       PyObject *res;
+       if (newpos)
+               buf_set_target(buf, newpos);
+       res = PyString_FromStringAndSize((char *)buf->ptr + start_pos, buf->pos - start_pos);
+       buf_free(buf);
+       return res;
+}
+
+/*
+ * Get string data
+ */
+
+static Py_ssize_t get_buffer(PyObject *obj, unsigned char **buf_p, PyObject **tmp_obj_p)
+{
+       PyBufferProcs *bfp;
+       PyObject *str = NULL;
+       Py_ssize_t res;
+
+       /* check for None */
+       if (obj == Py_None) {
+               PyErr_Format(PyExc_TypeError, "None is not allowed here");
+               return -1;
+       }
+
+       /* is string or unicode ? */
+       if (PyString_Check(obj) || PyUnicode_Check(obj)) {
+               if (PyString_AsStringAndSize(obj, (char**)buf_p, &res) < 0)
+                       return -1;
+               return res;
+       }
+
+       /* try to get buffer */
+       bfp = obj->ob_type->tp_as_buffer;
+       if (bfp && bfp->bf_getsegcount(obj, NULL) == 1)
+               return bfp->bf_getreadbuffer(obj, 0, (void**)buf_p);
+
+       /*
+        * Not a string-like object, run str() or it.
+        */
+
+       /* are we in recursion? */
+       if (tmp_obj_p == NULL) {
+               PyErr_Format(PyExc_TypeError, "Cannot convert to string - get_buffer() recusively failed");
+               return -1;
+       }
+
+       /* do str() then */
+       str = PyObject_Str(obj);
+       res = -1;
+       if (str != NULL) {
+               res = get_buffer(str, buf_p, NULL);
+               if (res >= 0) {
+                       *tmp_obj_p = str;
+               } else {
+                       Py_CLEAR(str);
+               }
+       }
+       return res;
+}
+
+/*
+ * Common argument parsing.
+ */
+
+typedef PyObject *(*quote_fn)(unsigned char *src, Py_ssize_t src_len);
+
+static PyObject *common_quote(PyObject *args, quote_fn qfunc)
+{
+       unsigned char *src = NULL;
+        Py_ssize_t src_len = 0;
+       PyObject *arg, *res, *strtmp = NULL;
+        if (!PyArg_ParseTuple(args, "O", &arg))
+                return NULL;
+       if (arg != Py_None) {
+               src_len = get_buffer(arg, &src, &strtmp);
+               if (src_len < 0)
+                       return NULL;
+       }
+       res = qfunc(src, src_len);
+       Py_CLEAR(strtmp);
+       return res;
+}
+
+/*
+ * Simple quoting functions.
+ */
+
+static const char doc_quote_literal[] =
+"Quote a literal value for SQL.\n"
+"\n"
+"If string contains '\\', it is quoted and result is prefixed with E.\n"
+"Input value of None results in string \"null\" without quotes.\n"
+"\n"
+"C implementation.\n";
+
+static PyObject *quote_literal_body(unsigned char *src, Py_ssize_t src_len)
+{
+       struct Buf buf;
+       unsigned char *esc, *dst, *src_end = src + src_len;
+       unsigned int start_ofs = 1;
+
+       if (src == NULL)
+               return PyString_FromString("null");
+
+       esc = dst = buf_init(&buf, src_len * 2 + 2 + 1);
+        if (!dst)
+               return NULL;
+
+       *dst++ = ' ';
+       *dst++ = '\'';
+        while (src < src_end) {
+               if (*src == '\\') {
+                       *dst++ = '\\';
+                       start_ofs = 0;
+               } else if (*src == '\'') {
+                       *dst++ = '\'';
+               }
+               *dst++ = *src++;
+        }
+       *dst++ = '\'';
+       if (start_ofs == 0)
+               *esc = 'E';
+       return buf_pystr(&buf, start_ofs, dst);
+}
+
+static PyObject *quote_literal(PyObject *self, PyObject *args)
+{
+       return common_quote(args, quote_literal_body);
+}
+
+/* COPY field */
+static const char doc_quote_copy[] =
+"Quoting for COPY data.  None is converted to \\N.\n\n"
+"C implementation.";
+
+static PyObject *quote_copy_body(unsigned char *src, Py_ssize_t src_len)
+{
+       unsigned char *dst, *src_end = src + src_len;
+       struct Buf buf;
+
+       if (src == NULL)
+               return PyString_FromString("\\N");
+
+       dst = buf_init(&buf, src_len * 2);
+        if (!dst)
+               return NULL;
+
+        while (src < src_end) {
+               switch (*src) {
+               case '\t': *dst++ = '\\'; *dst++ = 't'; src++; break;
+               case '\n': *dst++ = '\\'; *dst++ = 'n'; src++; break;
+               case '\r': *dst++ = '\\'; *dst++ = 'r'; src++; break;
+               case '\\': *dst++ = '\\'; *dst++ = '\\'; src++; break;
+               default: *dst++ = *src++; break;
+               }
+        }
+       return buf_pystr(&buf, 0, dst);
+}
+
+static PyObject *quote_copy(PyObject *self, PyObject *args)
+{
+       return common_quote(args, quote_copy_body);
+}
+
+/* raw bytea for byteain() */
+static const char doc_quote_bytea_raw[] =
+"Quoting for bytea parser.  Returns None as None.\n"
+"\n"
+"C implementation.";
+
+static PyObject *quote_bytea_raw_body(unsigned char *src, Py_ssize_t src_len)
+{
+       unsigned char *dst, *src_end = src + src_len;
+       struct Buf buf;
+
+       if (src == NULL) {
+               Py_INCREF(Py_None);
+               return Py_None;
+       }
+
+       dst = buf_init(&buf, src_len * 4);
+        if (!dst)
+               return NULL;
+
+        while (src < src_end) {
+               if (*src < 0x20 || *src >= 0x7F) {
+                       *dst++ = '\\';
+                       *dst++ = '0' + (*src >> 6);
+                       *dst++ = '0' + ((*src >> 3) & 7);
+                       *dst++ = '0' + (*src & 7);
+                       src++;
+               } else {
+                       if (*src == '\\')
+                               *dst++ = '\\';
+                       *dst++ = *src++;
+               }
+        }
+       return buf_pystr(&buf, 0, dst);
+}
+
+static PyObject *quote_bytea_raw(PyObject *self, PyObject *args)
+{
+       return common_quote(args, quote_bytea_raw_body);
+}
+
+/* C unescape */
+static const char doc_unescape[] =
+"Unescape C-style escaped string.\n\n"
+"C implementation.";
+
+static PyObject *unescape_body(unsigned char *src, Py_ssize_t src_len)
+{
+       unsigned char *dst, *src_end = src + src_len;
+       struct Buf buf;
+
+       if (src == NULL) {
+               PyErr_Format(PyExc_TypeError, "None not allowed");
+               return NULL;
+       }
+
+       dst = buf_init(&buf, src_len);
+        if (!dst)
+               return NULL;
+
+        while (src < src_end) {
+               if (*src != '\\') {
+                       *dst++ = *src++;
+                       continue;
+               }
+               if (++src >= src_end)
+                       goto failed;
+               switch (*src) {
+               case 't': *dst++ = '\t'; src++; break;
+               case 'n': *dst++ = '\n'; src++; break;
+               case 'r': *dst++ = '\r'; src++; break;
+               case 'a': *dst++ = '\a'; src++; break;
+               case 'b': *dst++ = '\b'; src++; break;
+               default:
+                       if (*src >= '0' && *src <= '7') {
+                               unsigned char c = *src++ - '0';
+                               if (src < src_end && *src >= '0' && *src <= '7') {
+                                       c = (c << 3) | ((*src++) - '0');
+                                       if (src < src_end && *src >= '0' && *src <= '7')
+                                               c = (c << 3) | ((*src++) - '0');
+                               }
+                               *dst++ = c;
+                       } else {
+                               *dst++ = *src++;
+                       }
+               }
+        }
+       return buf_pystr(&buf, 0, dst);
+failed:
+       PyErr_Format(PyExc_ValueError, "Broken string - \\ at the end");
+       return NULL;
+}
+
+static PyObject *unescape(PyObject *self, PyObject *args)
+{
+       return common_quote(args, unescape_body);
+}
+
+/*
+ * urlencode of dict
+ */
+
+static bool urlenc(struct Buf *buf, PyObject *obj)
+{
+       Py_ssize_t len;
+       unsigned char *src, *dst;
+       PyObject *strtmp = NULL;
+       static const unsigned char hextbl[] = "0123456789abcdef";
+       bool ok = false;
+
+       len = get_buffer(obj, &src, &strtmp);
+       if (len < 0)
+               goto failed;
+
+       dst = buf_get_target_for(buf, len * 3);
+       if (!dst)
+               goto failed;
+
+       while (len--) {
+               if ((*src >= 'a' && *src <= 'z') ||
+                   (*src >= 'A' && *src <= 'Z') ||
+                   (*src >= '0' && *src <= '9') ||
+                   (*src == '.' || *src == '_' || *src == '-'))
+               {
+                       *dst++ = *src++;
+               } else if (*src == ' ') {
+                       *dst++ = '+'; src++;
+               } else {
+                       *dst++ = '%';
+                       *dst++ = hextbl[*src >> 4];
+                       *dst++ = hextbl[*src & 0xF];
+                       src++;
+               }
+       }
+       buf_set_target(buf, dst);
+       ok = true;
+failed:
+       Py_CLEAR(strtmp);
+       return ok;
+}
+
+/* urlencode key+val pair.  val can be None */
+static bool urlenc_keyval(struct Buf *buf, PyObject *key, PyObject *value, bool needAmp)
+{
+       if (needAmp && !buf_put(buf, '&'))
+               return false;
+       if (!urlenc(buf, key))
+               return false;
+       if (value != Py_None) {
+               if (!buf_put(buf, '='))
+                       return false;
+               if (!urlenc(buf, value))
+                       return false;
+       }
+       return true;
+}
+
+/* encode native dict using PyDict_Next */
+static PyObject *encode_dict(PyObject *data)
+{
+       PyObject *key, *value;
+       Py_ssize_t pos = 0;
+       bool needAmp = false;
+       struct Buf buf;
+       if (!buf_init(&buf, 1024))
+               return NULL;
+       while (PyDict_Next(data, &pos, &key, &value)) {
+               if (!urlenc_keyval(&buf, key, value, needAmp))
+                       goto failed;
+               needAmp = true;
+       }
+       return buf_pystr(&buf, 0, NULL);
+failed:
+       buf_free(&buf);
+       return NULL;
+}
+
+/* encode custom object using .iteritems() */
+static PyObject *encode_dictlike(PyObject *data)
+{
+       PyObject *key = NULL, *value = NULL, *tup, *iter;
+       struct Buf buf;
+       bool needAmp = false;
+       
+       if (!buf_init(&buf, 1024))
+               return NULL;
+
+       iter = PyObject_CallMethod(data, "iteritems", NULL);
+       if (iter == NULL) {
+               buf_free(&buf);
+               return NULL;
+       }
+
+       while ((tup = PyIter_Next(iter))) {
+               key = PySequence_GetItem(tup, 0);
+               value = key ? PySequence_GetItem(tup, 1) : NULL;
+               Py_CLEAR(tup);
+               if (!key || !value)
+                       goto failed;
+
+               if (!urlenc_keyval(&buf, key, value, needAmp))
+                       goto failed;
+               needAmp = true;
+
+               Py_CLEAR(key);
+               Py_CLEAR(value);
+       }
+       /* allow error from iterator */
+       if (PyErr_Occurred())
+               goto failed;
+
+       Py_CLEAR(iter);
+       return buf_pystr(&buf, 0, NULL);
+failed:
+       buf_free(&buf);
+       Py_CLEAR(iter);
+       Py_CLEAR(key);
+       Py_CLEAR(value);
+       return NULL;
+}
+
+static const char doc_db_urlencode[] =
+"Urlencode for database records.\n"
+"If a value is None the key is output without '='.\n"
+"\n"
+"C implementation.";
+
+static PyObject *db_urlencode(PyObject *self, PyObject *args)
+{
+       PyObject *data;
+        if (!PyArg_ParseTuple(args, "O", &data))
+                return NULL;
+       if (PyDict_Check(data)) {
+               return encode_dict(data);
+       } else {
+               return encode_dictlike(data);
+       }
+}
+
+/*
+ * urldecode to dict
+ */
+
+static inline int gethex(unsigned char c)
+{
+       if (c >= '0' && c <= '9') return c - '0';
+       c |= 0x20;
+       if (c >= 'a' && c <= 'f') return c - 'a' + 10;
+       return -1;
+}
+
+static PyObject *get_elem(unsigned char *buf, unsigned char **src_p, unsigned char *src_end)
+{
+       int c1, c2;
+       unsigned char *src = *src_p;
+       unsigned char *dst = buf;
+
+       while (src < src_end) {
+               switch (*src) {
+               case '%':
+                       if (++src + 2 > src_end)
+                               goto hex_incomplete;
+                        if ((c1 = gethex(*src++)) < 0)
+                               goto hex_invalid;
+                        if ((c2 = gethex(*src++)) < 0)
+                               goto hex_invalid;
+                       *dst++ = (c1 << 4) | c2;
+                       break;
+               case '+':
+                       *dst++ = ' '; src++;
+                       break;
+               case '&':
+               case '=':
+                       goto gotit;
+               default:
+                       *dst++ = *src++;
+               }
+       }
+gotit:
+       *src_p = src;
+       return PyString_FromStringAndSize((char *)buf, dst - buf);
+
+hex_incomplete:
+       PyErr_Format(PyExc_ValueError, "Incomplete hex code");
+       return NULL;
+hex_invalid:
+       PyErr_Format(PyExc_ValueError, "Invalid hex code");
+       return NULL;
+}
+
+static const char doc_db_urldecode[] =
+"Urldecode from string to dict.\n"
+"NULL are detected by missing '='.\n"
+"Duplicate keys are ignored - only latest is kept.\n"
+"\n"
+"C implementation.";
+
+static PyObject *db_urldecode(PyObject *self, PyObject *args)
+{
+       unsigned char *src, *src_end;
+       Py_ssize_t src_len;
+       PyObject *dict = NULL, *key = NULL, *value = NULL;
+       struct Buf buf;
+
+        if (!PyArg_ParseTuple(args, "t#", &src, &src_len))
+                return NULL;
+       if (!buf_init(&buf, src_len))
+               return NULL;
+
+       dict = PyDict_New();
+       if (!dict) {
+               buf_free(&buf);
+               return NULL;
+       }
+
+       src_end = src + src_len;
+       while (src < src_end) {
+                if (*src == '&') {
+                    src++;
+                    continue;
+                }
+
+               key = get_elem(buf.ptr, &src, src_end);
+               if (!key)
+                       goto failed;
+
+               if (src < src_end && *src == '=') {
+                       src++;
+                       value = get_elem(buf.ptr, &src, src_end);
+                       if (value == NULL)
+                               goto failed;
+               } else {
+                       Py_INCREF(Py_None);
+                       value = Py_None;
+               }
+
+               /* lessen memory usage by intering */
+               PyString_InternInPlace(&key);
+
+               if (PyDict_SetItem(dict, key, value) < 0)
+                       goto failed;
+               Py_CLEAR(key);
+               Py_CLEAR(value);
+       }
+       buf_free(&buf);
+       return dict;
+failed:
+       buf_free(&buf);
+       Py_CLEAR(key);
+       Py_CLEAR(value);
+       Py_CLEAR(dict);
+       return NULL;
+}
+
+/*
+ * Module initialization
+ */
+
+static PyMethodDef
+cquoting_methods[] = {
+       { "quote_literal", quote_literal, METH_VARARGS, doc_quote_literal },
+       { "quote_copy", quote_copy, METH_VARARGS, doc_quote_copy },
+       { "quote_bytea_raw", quote_bytea_raw, METH_VARARGS, doc_quote_bytea_raw },
+       { "unescape", unescape, METH_VARARGS, doc_unescape },
+       { "db_urlencode", db_urlencode, METH_VARARGS, doc_db_urlencode },
+       { "db_urldecode", db_urldecode, METH_VARARGS, doc_db_urldecode },
+       { NULL }
+};
+
+PyMODINIT_FUNC
+init_cquoting(void)
+{
+       PyObject *module;
+       module = Py_InitModule("_cquoting", cquoting_methods);
+       PyModule_AddStringConstant(module, "__doc__", "fast quoting for skytools");
+}
+
index 7b7dd1268d8d644d1091a317dccb6be3a9d522b4..898840956830093a135ed1070339db6be79ecdcf 100644 (file)
@@ -9,6 +9,7 @@ from gzlog import *
 from scripting import *
 from sqltools import *
 from quoting import *
+from parsing import *
 
 __all__ = (psycopgwrapper.__all__
         + config.__all__
@@ -16,5 +17,6 @@ __all__ = (psycopgwrapper.__all__
         + gzlog.__all__
         + scripting.__all__
         + sqltools.__all__
+        + parsing.__all__
         + quoting.__all__ )
 
diff --git a/python/skytools/_pyquoting.py b/python/skytools/_pyquoting.py
new file mode 100644 (file)
index 0000000..28a5757
--- /dev/null
@@ -0,0 +1,153 @@
+# _pyquoting.py
+
+"""Various helpers for string quoting/unquoting.
+
+Here is pure Python that should match C code in _cquoting.
+"""
+
+import urllib, re
+
+__all__ = [
+    "quote_literal", "quote_copy", "quote_bytea_raw",
+    "db_urlencode", "db_urldecode", "unescape",
+]
+
+# 
+# SQL quoting
+#
+
+def quote_literal(s):
+    """Quote a literal value for SQL.
+
+    If string contains '\\', it is quoted and result is prefixed with E.
+    Input value of None results in string "null" without quotes.
+
+    Python implementation.
+    """
+
+    if s == None:
+        return "null"
+    s = str(s).replace("'", "''")
+    s2 = s.replace("\\", "\\\\")
+    if len(s) != len(s2):
+        return "E'" + s2 + "'"
+    return "'" + s2 + "'"
+
+def quote_copy(s):
+    """Quoting for copy command.  None is converted to \\N.
+    
+    Python implementation.
+    """
+
+    if s == None:
+        return "\\N"
+    s = str(s)
+    s = s.replace("\\", "\\\\")
+    s = s.replace("\t", "\\t")
+    s = s.replace("\n", "\\n")
+    s = s.replace("\r", "\\r")
+    return s
+
+_bytea_map = None
+def quote_bytea_raw(s):
+    """Quoting for bytea parser.  Returns None as None.
+    
+    Python implementation.
+    """
+    global _bytea_map
+    if s == None:
+        return None
+    if 1 and _bytea_map is None:
+        _bytea_map = {}
+        for i in xrange(256):
+            c = chr(i)
+            if i < 0x20 or i >= 0x7F:
+                _bytea_map[c] = "\\%03o" % i
+            elif c == "\\":
+                _bytea_map[c] = r"\\"
+            else:
+                _bytea_map[c] = c
+    return "".join([_bytea_map[c] for c in s])
+    # faster but does not match c code
+    #return s.replace("\\", "\\\\").replace("\0", "\\000")
+
+#
+# Database specific urlencode and urldecode.
+#
+
+def db_urlencode(dict):
+    """Database specific urlencode.
+
+    Encode None as key without '='.  That means that in "foo&bar=",
+    foo is NULL and bar is empty string.
+
+    Python implementation.
+    """
+
+    elem_list = []
+    for k, v in dict.items():
+        if v is None:
+            elem = urllib.quote_plus(str(k))
+        else:
+            elem = urllib.quote_plus(str(k)) + '=' + urllib.quote_plus(str(v))
+        elem_list.append(elem)
+    return '&'.join(elem_list)
+
+def db_urldecode(qs):
+    """Database specific urldecode.
+
+    Decode key without '=' as None.
+    This also does not support one key several times.
+
+    Python implementation.
+    """
+
+    res = {}
+    for elem in qs.split('&'):
+        if not elem:
+            continue
+        pair = elem.split('=', 1)
+        name = urllib.unquote_plus(pair[0])
+
+        # keep only one instance around
+        name = intern(str(name))
+
+        if len(pair) == 1:
+            res[name] = None
+        else:
+            res[name] = urllib.unquote_plus(pair[1])
+    return res
+
+#
+# Remove C-like backslash escapes
+#
+
+_esc_re = r"\\([0-7]{1,3}|.)"
+_esc_rc = re.compile(_esc_re)
+_esc_map = {
+    't': '\t',
+    'n': '\n',
+    'r': '\r',
+    'a': '\a',
+    'b': '\b',
+    "'": "'",
+    '"': '"',
+    '\\': '\\',
+}
+
+def _sub_unescape(m):
+    v = m.group(1)
+    if (len(v) == 1) and (v < '0' or v > '7'):
+        try:
+            return _esc_map[v]
+        except KeyError:
+            return v
+    else:
+        return chr(int(v, 8))
+
+def unescape(val):
+    """Removes C-style escapes from string.
+    Python implementation.
+    """
+    return _esc_rc.sub(_sub_unescape, val)
+
diff --git a/python/skytools/parsing.py b/python/skytools/parsing.py
new file mode 100644 (file)
index 0000000..1f4dd78
--- /dev/null
@@ -0,0 +1,272 @@
+
+"""Various parsers for Postgres-specific data formats."""
+
+import re
+
+from skytools.quoting import unescape
+
+__all__ = ["parse_pgarray", "parse_logtriga_sql", "parse_tabbed_table", "parse_statements"]
+
+_rc_listelem = re.compile(r'( [^,"}]+ | ["] ( [^"\\]+ | [\\]. )* ["] )', re.X)
+
+# _parse_pgarray
+def parse_pgarray(array):
+    """ Parse Postgres array and return list of items inside it
+        Used to deserialize data recived from service layer parameters
+    """
+    if not array or array[0] != "{":
+        raise Exception("bad array format: must start with {")
+    res = []
+    pos = 1
+    while 1:
+        m = _rc_listelem.search(array, pos)
+        if not m:
+            break
+        pos2 = m.end()
+        item = array[pos:pos2]
+        if len(item) > 0 and item[0] == '"':
+            item = item[1:-1]
+        item = unescape(item)
+        res.append(item)
+
+        pos = pos2 + 1
+        if array[pos2] == "}":
+            break
+        elif array[pos2] != ",":
+            raise Exception("bad array format: expected ,} got " + array[pos2])
+    return res
+
+#
+# parse logtriga partial sql
+#
+
+class _logtriga_parser:
+    token_re = r"""
+        [ \t\r\n]*
+        ( [a-z][a-z0-9_]*
+        | ["] ( [^"\\]+ | \\. )* ["]
+        | ['] ( [^'\\]+ | \\. | [']['] )* [']
+        | [^ \t\r\n]
+        )"""
+    token_rc = None
+
+    def tokenizer(self, sql):
+        if not _logtriga_parser.token_rc:
+            _logtriga_parser.token_rc = re.compile(self.token_re, re.X | re.I)
+        rc = self.token_rc
+
+        pos = 0
+        while 1:
+            m = rc.match(sql, pos)
+            if not m:
+                break
+            pos = m.end()
+            yield m.group(1)
+
+    def unquote_data(self, fields, values):
+        # unquote data and column names
+        data = {}
+        for k, v in zip(fields, values):
+            if k[0] == '"':
+                k = unescape(k[1:-1])
+            if len(v) == 4 and v.lower() == "null":
+                v = None
+            elif v[0] == "'":
+                v = unescape(v[1:-1])
+            data[k] = v
+        return data
+
+    def parse_insert(self, tk, fields, values):
+        # (col1, col2) values ('data', null)
+        if tk.next() != "(":
+            raise Exception("syntax error")
+        while 1:
+            fields.append(tk.next())
+            t = tk.next()
+            if t == ")":
+                break
+            elif t != ",":
+                raise Exception("syntax error")
+        if tk.next().lower() != "values":
+            raise Exception("syntax error")
+        if tk.next() != "(":
+            raise Exception("syntax error")
+        while 1:
+            t = tk.next()
+            if t == ")":
+                break
+            if t == ",":
+                continue
+            values.append(t)
+        tk.next()
+
+    def parse_update(self, tk, fields, values):
+        # col1 = 'data1', col2 = null where pk1 = 'pk1' and pk2 = 'pk2'
+        while 1:
+            fields.append(tk.next())
+            if tk.next() != "=":
+                raise Exception("syntax error")
+            values.append(tk.next())
+            
+            t = tk.next()
+            if t == ",":
+                continue
+            elif t.lower() == "where":
+                break
+            else:
+                raise Exception("syntax error")
+        while 1:
+            t = tk.next()
+            fields.append(t)
+            if tk.next() != "=":
+                raise Exception("syntax error")
+            values.append(tk.next())
+            t = tk.next()
+            if t.lower() != "and":
+                raise Exception("syntax error")
+
+    def parse_delete(self, tk, fields, values):
+        # pk1 = 'pk1' and pk2 = 'pk2'
+        while 1:
+            t = tk.next()
+            if t == "and":
+                continue
+            fields.append(t)
+            if tk.next() != "=":
+                raise Exception("syntax error")
+            values.append(tk.next())
+
+    def parse_sql(self, op, sql):
+        tk = self.tokenizer(sql)
+        fields = []
+        values = []
+        try:
+            if op == "I":
+                self.parse_insert(tk, fields, values)
+            elif op == "U":
+                self.parse_update(tk, fields, values)
+            elif op == "D":
+                self.parse_delete(tk, fields, values)
+            raise Exception("syntax error")
+        except StopIteration:
+            # last sanity check
+            if len(fields) == 0 or len(fields) != len(values):
+                raise Exception("syntax error")
+
+        return self.unquote_data(fields, values)
+
+def parse_logtriga_sql(op, sql):
+    """Parse partial SQL used by logtriga() back to data values.
+
+    Parser has following limitations:
+    - Expects standard_quoted_strings = off
+    - Does not support dollar quoting.
+    - Does not support complex expressions anywhere. (hashtext(col1) = hashtext(val1))
+    - WHERE expression must not contain IS (NOT) NULL
+    - Does not support updateing pk value.
+
+    Returns dict of col->data pairs.
+    """
+    return _logtriga_parser().parse_sql(op, sql)
+
+
+def parse_tabbed_table(txt):
+    """Parse a tab-separated table into list of dicts.
+    
+    Expect first row to be column names.
+
+    Very primitive.
+    """
+
+    txt = txt.replace("\r\n", "\n")
+    fields = None
+    data = []
+    for ln in txt.split("\n"):
+        if not ln:
+            continue
+        if not fields:
+            fields = ln.split("\t")
+            continue
+        cols = ln.split("\t")
+        if len(cols) != len(fields):
+            continue
+        row = dict(zip(fields, cols))
+        data.append(row)
+    return data
+
+
+_sql_token_re = r"""
+    ( [a-z][a-z0-9_$]*
+    | ["] ( [^"\\]+ | \\. )* ["]
+    | ['] ( [^'\\]+ | \\. | [']['] )* [']
+    | [$] ([_a-z][_a-z0-9]*)? [$]
+    | (?P<ws> \s+ | [/][*] | [-][-][^\n]* )
+    | .
+    )"""
+_sql_token_rc = None
+_copy_from_stdin_re = "copy.*from\s+stdin"
+_copy_from_stdin_rc = None
+
+def _sql_tokenizer(sql):
+    global _sql_token_rc, _copy_from_stdin_rc
+    if not _sql_token_rc:
+        _sql_token_rc = re.compile(_sql_token_re, re.X | re.I)
+        _copy_from_stdin_rc = re.compile(_copy_from_stdin_re, re.X | re.I)
+    rc = _sql_token_rc
+
+    pos = 0
+    while 1:
+        m = rc.match(sql, pos)
+        if not m:
+            break
+        pos = m.end()
+        tok = m.group(1)
+        ws = m.start('ws') >= 0 # it tok empty?
+        if tok == "/*":
+            end = sql.find("*/", pos)
+            if end < 0:
+                raise Exception("unterminated c comment")
+            pos = end + 2
+            tok = sql[ m.start() : pos]
+        elif len(tok) > 1 and tok[0] == "$" and tok[-1] == "$":
+            end = sql.find(tok, pos)
+            if end < 0:
+                raise Exception("unterminated dollar string")
+            pos = end + len(tok)
+            tok = sql[ m.start() : pos]
+        yield (ws, tok)
+
+def parse_statements(sql):
+    """Parse multi-statement string into separate statements.
+
+    Returns list of statements.
+    """
+
+    tk = _sql_tokenizer(sql)
+    tokens = []
+    pcount = 0 # '(' level
+    while 1:
+        try:
+            ws, t = tk.next()
+        except StopIteration:
+            break
+        # skip whitespace and comments before statement
+        if len(tokens) == 0 and ws:
+            continue
+        # keep the rest
+        tokens.append(t)
+        if t == "(":
+            pcount += 1
+        elif t == ")":
+            pcount -= 1
+        elif t == ";" and pcount == 0:
+            sql = "".join(tokens)
+            if _copy_from_stdin_rc.match(sql):
+                raise Exception("copy from stdin not supported")
+            yield ("".join(tokens))
+            tokens = []
+    if len(tokens) > 0:
+        yield ("".join(tokens))
+    if pcount != 0:
+        raise Exception("syntax error - unbalanced parenthesis")
+
index 594646a4a0283f3789ed096d1034d55af5e148d9..10d4626a5e5ea1576adb0834618e8f3cc2ecfba6 100644 (file)
@@ -4,49 +4,23 @@
 
 import urllib, re
 
-from skytools.psycopgwrapper import QuotedString
-
 __all__ = [
     "quote_literal", "quote_copy", "quote_bytea_raw",
+    "db_urlencode", "db_urldecode", "unescape",
+
     "quote_bytea_literal", "quote_bytea_copy", "quote_statement",
-    "quote_ident", "quote_fqident", "quote_json",
-    "db_urlencode", "db_urldecode", "unescape", "unescape_copy"
+    "quote_ident", "quote_fqident", "quote_json", "unescape_copy"
 ]
 
+try:
+    from _cquoting import *
+except ImportError:
+    from _pyquoting import *
+
 # 
 # SQL quoting
 #
 
-def quote_literal(s):
-    """Quote a literal value for SQL.
-    
-    Surronds it with single-quotes.
-    """
-
-    if s == None:
-        return "null"
-    s = QuotedString(str(s))
-    return str(s)
-
-def quote_copy(s):
-    """Quoting for copy command."""
-
-    if s == None:
-        return "\\N"
-    s = str(s)
-    s = s.replace("\\", "\\\\")
-    s = s.replace("\t", "\\t")
-    s = s.replace("\n", "\\n")
-    s = s.replace("\r", "\\r")
-    return s
-
-def quote_bytea_raw(s):
-    """Quoting for bytea parser."""
-
-    if s == None:
-        return None
-    return s.replace("\\", "\\\\").replace("\0", "\\000")
-
 def quote_bytea_literal(s):
     """Quote bytea for regular SQL."""
 
@@ -125,214 +99,9 @@ def quote_json(s):
         return "null"
     return '"%s"' % _jsre.sub(_json_quote_char, s)
 
-#
-# Database specific urlencode and urldecode.
-#
-
-def db_urlencode(dict):
-    """Database specific urlencode.
-
-    Encode None as key without '='.  That means that in "foo&bar=",
-    foo is NULL and bar is empty string.
-    """
-
-    elem_list = []
-    for k, v in dict.items():
-        if v is None:
-            elem = urllib.quote_plus(str(k))
-        else:
-            elem = urllib.quote_plus(str(k)) + '=' + urllib.quote_plus(str(v))
-        elem_list.append(elem)
-    return '&'.join(elem_list)
-
-def db_urldecode(qs):
-    """Database specific urldecode.
-
-    Decode key without '=' as None.
-    This also does not support one key several times.
-    """
-
-    res = {}
-    for elem in qs.split('&'):
-        if not elem:
-            continue
-        pair = elem.split('=', 1)
-        name = urllib.unquote_plus(pair[0])
-
-        # keep only one instance around
-        name = intern(name)
-
-        if len(pair) == 1:
-            res[name] = None
-        else:
-            res[name] = urllib.unquote_plus(pair[1])
-    return res
-
-#
-# Remove C-like backslash escapes
-#
-
-_esc_re = r"\\([0-7][0-7][0-7]|.)"
-_esc_rc = re.compile(_esc_re)
-_esc_map = {
-    't': '\t',
-    'n': '\n',
-    'r': '\r',
-    'a': '\a',
-    'b': '\b',
-    "'": "'",
-    '"': '"',
-    '\\': '\\',
-}
-
-def _sub_unescape(m):
-    v = m.group(1)
-    if len(v) == 1:
-        return _esc_map[v]
-    else:
-        return chr(int(v, 8))
-
-def unescape(val):
-    """Removes C-style escapes from string."""
-    return _esc_rc.sub(_sub_unescape, val)
-
 def unescape_copy(val):
     """Removes C-style escapes, also converts "\N" to None."""
     if val == r"\N":
         return None
     return unescape(val)
 
-
-#
-# parse logtriga partial sql
-#
-
-class _logtriga_parser:
-    token_re = r"""
-        [ \t\r\n]*
-        ( [a-z][a-z0-9_]*
-        | ["] ( [^"\\]+ | \\. )* ["]
-        | ['] ( [^'\\]+ | \\. | [']['] )* [']
-        | [^ \t\r\n]
-        )"""
-    token_rc = None
-
-    def tokenizer(self, sql):
-        if not _logtriga_parser.token_rc:
-            _logtriga_parser.token_rc = re.compile(self.token_re, re.X | re.I)
-        rc = self.token_rc
-
-        pos = 0
-        while 1:
-            m = rc.match(sql, pos)
-            if not m:
-                break
-            pos = m.end()
-            yield m.group(1)
-
-    def unquote_data(self, fields, values):
-        # unquote data and column names
-        data = {}
-        for k, v in zip(fields, values):
-            if k[0] == '"':
-                k = unescape(k[1:-1])
-            if len(v) == 4 and v.lower() == "null":
-                v = None
-            elif v[0] == "'":
-                v = unescape(v[1:-1])
-            data[k] = v
-        return data
-
-    def parse_insert(self, tk, fields, values):
-        # (col1, col2) values ('data', null)
-        if tk.next() != "(":
-            raise Exception("syntax error")
-        while 1:
-            fields.append(tk.next())
-            t = tk.next()
-            if t == ")":
-                break
-            elif t != ",":
-                raise Exception("syntax error")
-        if tk.next().lower() != "values":
-            raise Exception("syntax error")
-        if tk.next() != "(":
-            raise Exception("syntax error")
-        while 1:
-            t = tk.next()
-            if t == ")":
-                break
-            if t == ",":
-                continue
-            values.append(t)
-        tk.next()
-
-    def parse_update(self, tk, fields, values):
-        # col1 = 'data1', col2 = null where pk1 = 'pk1' and pk2 = 'pk2'
-        while 1:
-            fields.append(tk.next())
-            if tk.next() != "=":
-                raise Exception("syntax error")
-            values.append(tk.next())
-            
-            t = tk.next()
-            if t == ",":
-                continue
-            elif t.lower() == "where":
-                break
-            else:
-                raise Exception("syntax error")
-        while 1:
-            t = tk.next()
-            fields.append(t)
-            if tk.next() != "=":
-                raise Exception("syntax error")
-            values.append(tk.next())
-            t = tk.next()
-            if t.lower() != "and":
-                raise Exception("syntax error")
-
-    def parse_delete(self, tk, fields, values):
-        # pk1 = 'pk1' and pk2 = 'pk2'
-        while 1:
-            t = tk.next()
-            if t == "and":
-                continue
-            fields.append(t)
-            if tk.next() != "=":
-                raise Exception("syntax error")
-            values.append(tk.next())
-
-    def parse_sql(self, op, sql):
-        tk = self.tokenizer(sql)
-        fields = []
-        values = []
-        try:
-            if op == "I":
-                self.parse_insert(tk, fields, values)
-            elif op == "U":
-                self.parse_update(tk, fields, values)
-            elif op == "D":
-                self.parse_delete(tk, fields, values)
-            raise Exception("syntax error")
-        except StopIteration:
-            # last sanity check
-            if len(fields) == 0 or len(fields) != len(values):
-                raise Exception("syntax error")
-
-        return self.unquote_data(fields, values)
-
-def parse_logtriga_sql(op, sql):
-    """Parse partial SQL used by logtriga() back to data values.
-
-    Parser has following limitations:
-    - Expects standard_quoted_strings = off
-    - Does not support dollar quoting.
-    - Does not support complex expressions anywhere. (hashtext(col1) = hashtext(val1))
-    - WHERE expression must not contain IS (NOT) NULL
-    - Does not support updateing pk value.
-
-    Returns dict of col->data pairs.
-    """
-    return _logtriga_parser().parse_sql(op, sql)
-
index 84819f9a7e91fe428f458daa766c1cadc6358209..01deb25c9b1cbc01ca0beafd34e592f3cf2aa76d 100755 (executable)
--- a/setup.py
+++ b/setup.py
@@ -53,5 +53,6 @@ setup(
         'scripts/scriptmgr.ini.templ',
         ]),
       ('share/skytools', share_dup_files)],
+    ext_modules=[Extension("skytools._cquoting", ['python/modules/cquoting.c'])],
 )