python/skytools update
authorMarko Kreen <markokr@gmail.com>
Fri, 13 Feb 2009 10:03:53 +0000 (12:03 +0200)
committerMarko Kreen <markokr@gmail.com>
Fri, 13 Feb 2009 12:21:01 +0000 (14:21 +0200)
- docstrings
- some preliminary python 3.0 compat (var names, print())
- sync with 2.1-stable

adminscript:
- move exec_cmd function to dbscript

dbstruct:
- support sequnces.  SERIAL columns are not automatically created,
  but the link beteween column and sequence is.

psycopgwrapper:
- drop support for psycopg1
- beginnings of quick DB-API / DictRow description.

quoting:
- new unquote_fqident() function, reverse of quote_fqident()
- quote_statement() accepts both row and dict

dbscript:
- catch startup errors
- use log.exception for exceptions, will result in nicer logs

sqltools:
- exists_sequence()

_pyquoting:
- fix typo in variable name

python/skytools/__init__.py
python/skytools/_pyquoting.py
python/skytools/adminscript.py
python/skytools/config.py
python/skytools/dbstruct.py
python/skytools/installer_config.py.in
python/skytools/parsing.py
python/skytools/psycopgwrapper.py
python/skytools/quoting.py
python/skytools/scripting.py
python/skytools/sqltools.py

index e42f06a306bf4411499bbb9db71f69887a14133d..b958b483d3aa081223eeb7d9ae9f694a977d8065 100644 (file)
@@ -1,6 +1,10 @@
 
 """Tools for Python database scripts."""
 
+__version__ = '3.0'
+
+__pychecker__ = 'no-miximport'
+
 import skytools.quoting
 import skytools.config
 import skytools.psycopgwrapper
index 8f72eb5f40063bb6c9bb115b50fed226b823288a..e5511687a98fe649667bd1e22f64bbeeb121cc99 100644 (file)
@@ -136,6 +136,7 @@ _esc_map = {
 }
 
 def _sub_unescape_c(m):
+    """unescape single escape seq."""
     v = m.group(1)
     if (len(v) == 1) and (v < '0' or v > '7'):
         try:
@@ -152,8 +153,9 @@ def unescape(val):
     return _esc_rc.sub(_sub_unescape_c, val)
 
 _esql_re = r"''|\\([0-7]{1,3}|.)"
-_esql_rc = re.compile(_esc_re)
+_esql_rc = re.compile(_esql_re)
 def _sub_unescape_sqlext(m):
+    """Unescape extended-quoted string."""
     if m.group() == "''":
         return "'"
     v = m.group(1)
index 399e5cd35003f429f467df63aaea128309e6cf2f..ce51ae445ff3abfe97a1233166518d7f474195eb 100644 (file)
@@ -11,15 +11,25 @@ from skytools.quoting import quote_statement
 __all__ = ['AdminScript']
 
 class AdminScript(DBScript):
+    """Contains common admin script tools.
+
+    Second argument (first is .ini file) is takes as command
+    name.  If class method 'cmd_' + arg exists, it is called,
+    otherwise error is given.
+    """
     def __init__(self, service_name, args):
+        """AdminScript init."""
         DBScript.__init__(self, service_name, args)
-        self.pidfile = self.pidfile + ".admin"
+        if self.pidfile:
+            self.pidfile = self.pidfile + ".admin"
 
         if len(self.args) < 2:
             self.log.error("need command")
             sys.exit(1)
 
     def work(self):
+        """Non-looping work function, calls command function."""
+
         self.set_single_loop(1)
 
         cmd = self.args[1]
@@ -47,6 +57,7 @@ class AdminScript(DBScript):
         fn(*cmdargs)
 
     def fetch_list(self, db, sql, args, keycol = None):
+        """Fetch a resultset from db, optionally turnin it info value list."""
         curs = db.cursor()
         curs.execute(sql, args)
         rows = curs.fetchall()
@@ -81,85 +92,25 @@ class AdminScript(DBScript):
         fmt = '%%-%ds' * (len(widths) - 1) + '%%s'
         fmt = fmt % tuple(widths[:-1])
         if desc:
-            print desc
-        print fmt % tuple(fields)
-        print fmt % tuple(['-'*15] * len(fields))
+            print(desc)
+        print(fmt % tuple(fields))
+        print(fmt % tuple(['-'*15] * len(fields)))
 
         for row in rows:
-            print fmt % tuple([row[k] for k in fields])
-        print '\n'
+            print(fmt % tuple([row[k] for k in fields]))
+        print('\n')
         return 1
 
 
-    def _exec_cmd(self, curs, sql, args):
-        self.log.debug("exec_cmd: %s" % quote_statement(sql, args))
-        curs.execute(sql, args)
-        ok = True
-        rows = curs.fetchall()
-        for row in rows:
-            try:
-                code = row['ret_code']
-                msg = row['ret_note']
-            except KeyError:
-                self.log.error("Query does not conform to exec_cmd API:")
-                self.log.error("SQL: %s" % quote_statement(sql, args))
-                self.log.error("Row: %s" % repr(row.copy()))
-                sys.exit(1)
-            level = code / 100
-            if level == 1:
-                self.log.debug("%d %s" % (code, msg))
-            elif level == 2:
-                self.log.info("%d %s" % (code, msg))
-            elif level == 3:
-                self.log.warning("%d %s" % (code, msg))
-            else:
-                self.log.error("%d %s" % (code, msg))
-                self.log.error("Query was: %s" % skytools.quote_statement(sql, args))
-                ok = False
-        return (ok, rows)
-
-    def _exec_cmd_many(self, curs, sql, baseargs, extra_list):
-        ok = True
-        rows = []
-        for a in extra_list:
-            (tmp_ok, tmp_rows) = self._exec_cmd(curs, sql, baseargs + [a])
-            ok = tmp_ok and ok
-            rows += tmp_rows
-        return (ok, rows)
-
-    def exec_cmd(self, db, q, args, commit = True):
-        (ok, rows) = self._exec_cmd(db.cursor(), q, args)
-        if ok:
-            if commit:
-                self.log.info("COMMIT")
-                db.commit()
-            return rows
-        else:
-            self.log.info("ROLLBACK")
-            db.rollback()
-            raise EXception("rollback")
-
-    def exec_cmd_many(self, db, sql, baseargs, extra_list, commit = True):
-        curs = db.cursor()
-        (ok, rows) = self._exec_cmd_many(curs, sql, baseargs, extra_list)
-        if ok:
-            if commit:
-                self.log.info("COMMIT")
-                db.commit()
-            return rows
-        else:
-            self.log.info("ROLLBACK")
-            db.rollback()
-            raise EXception("rollback")
-
-
     def exec_stmt(self, db, sql, args):
+        """Run regular non-query SQL on db."""
         self.log.debug("exec_stmt: %s" % quote_statement(sql, args))
         curs = db.cursor()
         curs.execute(sql, args)
         db.commit()
 
     def exec_query(self, db, sql, args):
+        """Run regular query SQL on db."""
         self.log.debug("exec_query: %s" % quote_statement(sql, args))
         curs = db.cursor()
         curs.execute(sql, args)
index a0e96ec59cbca3a0c9e82b744d4e0efb65b38b2a..fbe0b6ed1a2f3a190deb0cf085a3d809e8755040 100644 (file)
@@ -1,7 +1,7 @@
 
 """Nicer config class."""
 
-import sys, os, ConfigParser, socket
+import os, ConfigParser, socket
 
 __all__ = ['Config']
 
@@ -52,7 +52,7 @@ class Config(object):
         """Reads string value, if not set then default."""
         try:
             return self.cf.get(self.main_section, key)
-        except ConfigParser.NoOptionError, det:
+        except ConfigParser.NoOptionError:
             if default == None:
                 raise Exception("Config value not set: " + key)
             return default
@@ -61,7 +61,7 @@ class Config(object):
         """Reads int value, if not set then default."""
         try:
             return self.cf.getint(self.main_section, key)
-        except ConfigParser.NoOptionError, det:
+        except ConfigParser.NoOptionError:
             if default == None:
                 raise Exception("Config value not set: " + key)
             return default
@@ -70,7 +70,7 @@ class Config(object):
         """Reads boolean value, if not set then default."""
         try:
             return self.cf.getboolean(self.main_section, key)
-        except ConfigParser.NoOptionError, det:
+        except ConfigParser.NoOptionError:
             if default == None:
                 raise Exception("Config value not set: " + key)
             return default
@@ -79,7 +79,7 @@ class Config(object):
         """Reads float value, if not set then default."""
         try:
             return self.cf.getfloat(self.main_section, key)
-        except ConfigParser.NoOptionError, det:
+        except ConfigParser.NoOptionError:
             if default == None:
                 raise Exception("Config value not set: " + key)
             return default
@@ -94,7 +94,7 @@ class Config(object):
             for v in s.split(","):
                 res.append(v.strip())
             return res
-        except ConfigParser.NoOptionError, det:
+        except ConfigParser.NoOptionError:
             if default == None:
                 raise Exception("Config value not set: " + key)
             return default
@@ -129,7 +129,7 @@ class Config(object):
         for key in keys:
             try:
                 return self.cf.get(self.main_section, key)
-            except ConfigParser.NoOptionError, det:
+            except ConfigParser.NoOptionError:
                 pass
 
         if default == None:
index 1c7741c5dc27ffbd769cf1ee737e725773ab4e13..2a1c8e47b78cebcb5ecd66c95d3155c6cc8bbbf5 100644 (file)
@@ -1,14 +1,15 @@
 """Find table structure and allow CREATE/DROP elements from it.
 """
 
-import sys, re
+import re
 
 from skytools.sqltools import fq_name_parts, get_table_oid
-from skytools.quoting import quote_ident, quote_fqident
+from skytools.quoting import quote_ident, quote_fqident, quote_literal, unquote_fqident
 
-__all__ = ['TableStruct',
+__all__ = ['TableStruct', 'SeqStruct',
     'T_TABLE', 'T_CONSTRAINT', 'T_INDEX', 'T_TRIGGER',
-    'T_RULE', 'T_GRANT', 'T_OWNER', 'T_PKEY', 'T_ALL']
+    'T_RULE', 'T_GRANT', 'T_OWNER', 'T_PKEY', 'T_ALL',
+    'T_SEQUENCE']
 
 T_TABLE       = 1 << 0
 T_CONSTRAINT  = 1 << 1
@@ -17,8 +18,9 @@ T_TRIGGER     = 1 << 3
 T_RULE        = 1 << 4
 T_GRANT       = 1 << 5
 T_OWNER       = 1 << 6
+T_SEQUENCE    = 1 << 7
 T_PKEY        = 1 << 20 # special, one of constraints
-T_ALL = (  T_TABLE | T_CONSTRAINT | T_INDEX
+T_ALL = (  T_TABLE | T_CONSTRAINT | T_INDEX | T_SEQUENCE
          | T_TRIGGER | T_RULE | T_GRANT | T_OWNER )
 
 #
@@ -63,7 +65,7 @@ class TElem(object):
     """Keeps info about one metadata object."""
     SQL = ""
     type = 0
-    def get_create_sql(self, curs):
+    def get_create_sql(self, curs, new_name = None):
         """Return SQL statement for creating or None if not supported."""
         return None
     def get_drop_sql(self, curs):
@@ -78,6 +80,7 @@ class TConstraint(TElem):
           FROM pg_constraint WHERE conrelid = %(oid)s AND contype != 'f'
     """
     def __init__(self, table_name, row):
+        """Init constraint."""
         self.table_name = table_name
         self.name = row['name']
         self.defn = row['def']
@@ -88,6 +91,7 @@ class TConstraint(TElem):
             self.type += T_PKEY
 
     def get_create_sql(self, curs, new_table_name=None):
+        """Generate creation SQL."""
         fmt = "ALTER TABLE ONLY %s ADD CONSTRAINT %s %s;"
         if new_table_name:
             name = self.name
@@ -102,6 +106,7 @@ class TConstraint(TElem):
         return sql
 
     def get_drop_sql(self, curs):
+        """Generate removal sql."""
         fmt = "ALTER TABLE ONLY %s DROP CONSTRAINT %s;"
         sql = fmt % (quote_fqident(self.table_name), quote_ident(self.name))
         return sql
@@ -126,6 +131,7 @@ class TIndex(TElem):
         self.defn = row['defn'] + ';'
 
     def get_create_sql(self, curs, new_table_name = None):
+        """Generate creation SQL."""
         if not new_table_name:
             return self.defn
         # fixme: seems broken
@@ -151,9 +157,10 @@ class TRule(TElem):
         self.defn = row['def']
 
     def get_create_sql(self, curs, new_table_name = None):
+        """Generate creation SQL."""
         if not new_table_name:
             return self.defn
-        # fixme: broken
+        # fixme: broken / quoting
         rx = r"\bTO[ ][a-z0-9._]+[ ]DO[ ]"
         pnew = "TO %s DO " % new_table_name
         return rx_replace(rx, self.defn, pnew)
@@ -161,11 +168,12 @@ class TRule(TElem):
     def get_drop_sql(self, curs):
         return 'DROP RULE %s ON %s' % (quote_ident(self.name), quote_fqident(self.table_name))
 
+
 class TTrigger(TElem):
     """Info about trigger."""
     type = T_TRIGGER
     SQL = """
-        SELECT tgname as name, pg_get_triggerdef(oid) as def 
+        SELECT tgname as name, pg_get_triggerdef(oid) as def
           FROM  pg_trigger
          WHERE tgrelid = %(oid)s AND NOT tgisconstraint
     """
@@ -175,9 +183,11 @@ class TTrigger(TElem):
         self.defn = row['def'] + ';'
 
     def get_create_sql(self, curs, new_table_name = None):
+        """Generate creation SQL."""
         if not new_table_name:
             return self.defn
-        # fixme: broken
+
+        # fixme: broken / quoting
         rx = r"\bON[ ][a-z0-9._]+[ ]"
         pnew = "ON %s " % new_table_name
         return rx_replace(rx, self.defn, pnew)
@@ -198,6 +208,7 @@ class TOwner(TElem):
         self.owner = row['owner']
 
     def get_create_sql(self, curs, new_name = None):
+        """Generate creation SQL."""
         if not new_name:
             new_name = self.table_name
         return 'ALTER TABLE %s OWNER TO %s;' % (quote_fqident(new_name), quote_ident(self.owner))
@@ -217,38 +228,40 @@ class TGrant(TElem):
         return ", ".join([ self.acl_map[c] for c in acl ])
 
     def parse_relacl(self, relacl):
+        """Parse ACL to tuple of (user, acl, who)"""
         if relacl is None:
             return []
         if len(relacl) > 0 and relacl[0] == '{' and relacl[-1] == '}':
             relacl = relacl[1:-1]
-        list = []
+        tup_list = []
         for f in relacl.split(','):
             user, tmp = f.strip('"').split('=')
             acl, who = tmp.split('/')
-            list.append((user, acl, who))
-        return list
+            tup_list.append((user, acl, who))
+        return tup_list
 
     def __init__(self, table_name, row, new_name = None):
         self.name = table_name
         self.acl_list = self.parse_relacl(row['relacl'])
 
     def get_create_sql(self, curs, new_name = None):
+        """Generate creation SQL."""
         if not new_name:
             new_name = self.name
 
-        list = []
+        sql_list = []
         for user, acl, who in self.acl_list:
             astr = self.acl_to_grants(acl)
             sql = "GRANT %s ON %s TO %s;" % (astr, quote_fqident(new_name), quote_ident(user))
-            list.append(sql)
-        return "\n".join(list)
+            sql_list.append(sql)
+        return "\n".join(sql_list)
 
     def get_drop_sql(self, curs):
-        list = []
+        sql_list = []
         for user, acl, who in self.acl_list:
             sql = "REVOKE ALL FROM %s ON %s;" % (quote_ident(user), quote_fqident(self.name))
-            list.append(sql)
-        return "\n".join(list)
+            sql_list.append(sql)
+        return "\n".join(sql_list)
 
 class TColumn(TElem):
     """Info about table column."""
@@ -257,8 +270,9 @@ class TColumn(TElem):
             a.attname || ' '
                 || format_type(a.atttypid, a.atttypmod)
                 || case when a.attnotnull then ' not null' else '' end
-                || case when a.atthasdef then ' ' || d.adsrc else '' end
-            as def
+                || case when a.atthasdef then ' default ' || d.adsrc else '' end
+            as def,
+            pg_get_serial_sequence(%(fq2name)s, a.attname) as seqname
           from pg_attribute a left join pg_attrdef d
             on (d.adrelid = a.attrelid and d.adnum = a.attnum)
          where a.attrelid = %(oid)s
@@ -266,9 +280,13 @@ class TColumn(TElem):
            and a.attnum > 0
          order by a.attnum;
     """
+    seqname = None
     def __init__(self, table_name, row):
         self.name = row['name']
         self.column_def = row['def']
+        self.sequence = None
+        if row['seqname']:
+            self.seqname = unquote_fqident(row['seqname'])
 
 class TTable(TElem):
     """Info about table only (columns)."""
@@ -278,6 +296,7 @@ class TTable(TElem):
         self.col_list = col_list
 
     def get_create_sql(self, curs, new_name = None):
+        """Generate creation SQL."""
         if not new_name:
             new_name = self.name
         sql = "create table %s (" % quote_fqident(new_name)
@@ -287,53 +306,88 @@ class TTable(TElem):
             sep = ",\n\t"
         sql += "\n);"
         return sql
-    
+
     def get_drop_sql(self, curs):
         return "DROP TABLE %s;" % quote_fqident(self.name)
 
+class TSeq(TElem):
+    """Info about sequence."""
+    type = T_SEQUENCE
+    SQL = """SELECT *, %(owner)s as "owner" from %(fqname)s """
+    def __init__(self, seq_name, row):
+        self.name = seq_name
+        defn = ''
+        self.owner = row['owner']
+        if row['increment_by'] != 1:
+            defn += ' INCREMENT BY %d' % row['increment_by']
+        if row['min_value'] != 1:
+            defn += ' MINVALUE %d' % row['min_value']
+        if row['max_value'] != 9223372036854775807:
+            defn += ' MAXVALUE %d' % row['max_value']
+        last_value = row['last_value']
+        if row['is_called']:
+            last_value += row['increment_by']
+            if last_value >= row['max_value']:
+                raise Exception('duh, seq passed max_value')
+        if last_value != 1:
+            defn += ' START %d' % last_value
+        if row['cache_value'] != 1:
+            defn += ' CACHE %d' % row['cache_value']
+        if row['is_cycled']:
+            defn += ' CYCLE '
+        if self.owner:
+            defn += ' OWNED BY %s' % self.owner
+        self.defn = defn
+
+    def get_create_sql(self, curs, new_seq_name = None):
+        """Generate creation SQL."""
+
+        # we are in table def, forget full def
+        if self.owner:
+            sql = "ALTER SEQUENCE %s OWNED BY %s" % (
+                    quote_fqident(self.name), self.owner )
+            return sql
+
+        name = self.name
+        if new_seq_name:
+            name = new_seq_name
+        sql = 'CREATE SEQUENCE %s %s;' % (quote_fqident(name), self.defn)
+        return sql
+
+    def get_drop_sql(self, curs):
+        if self.owner:
+            return ''
+        return 'DROP SEQUENCE %s;' % quote_fqident(self.name)
+
 #
 # Main table object, loads all the others
 #
 
-class TableStruct(object):
-    """Collects and manages all info about table.
+class BaseStruct(object):
+    """Collects and manages all info about a higher-level db object.
 
     Allow to issue CREATE/DROP statements about any
     group of elements.
     """
-    def __init__(self, curs, table_name):
+    object_list = []
+    def __init__(self, curs, name):
         """Initializes class by loading info about table_name from database."""
 
-        self.table_name = table_name
-
-        # fill args
-        schema, name = fq_name_parts(table_name)
-        args = {
-            'schema': schema,
-            'table': name,
-            'oid': get_table_oid(curs, table_name),
-            'pg_class_oid': get_table_oid(curs, 'pg_catalog.pg_class'),
-        }
-        
-        # load table struct
-        self.col_list = self._load_elem(curs, args, TColumn)
-        self.object_list = [ TTable(table_name, self.col_list) ]
-
-        # load additional objects
-        to_load = [TConstraint, TIndex, TTrigger, TRule, TGrant, TOwner]
-        for eclass in to_load:
-            self.object_list += self._load_elem(curs, args, eclass)
+        self.name = name
+        self.fqname = quote_fqident(name)
 
-    def _load_elem(self, curs, args, eclass):
-        list = []
+    def _load_elem(self, curs, name, args, eclass):
+        """Fetch element(s) from db."""
+        elem_list = []
+        #print "Loading %s, name=%s, args=%s" % (repr(eclass), repr(name), repr(args))
         curs.execute(eclass.SQL % args)
         for row in curs.dictfetchall():
-            list.append(eclass(self.table_name, row))
-        return list
+            elem_list.append(eclass(name, row))
+        return elem_list
 
     def create(self, curs, objs, new_table_name = None, log = None):
         """Issues CREATE statements for requested set of objects.
-        
+
         If new_table_name is giver, creates table under that name
         and also tries to rename all indexes/constraints that conflict
         with existing table.
@@ -361,6 +415,57 @@ class TableStruct(object):
                     log.debug(sql)
                 curs.execute(sql)
 
+    def get_create_sql(self, objs):
+        res = []
+        for o in self.object_list:
+            if o.type & objs:
+                sql = o.get_create_sql(None, None)
+                if sql:
+                    res.append(sql)
+        return "".join(res)
+
+class TableStruct(BaseStruct):
+    """Collects and manages all info about table.
+
+    Allow to issue CREATE/DROP statements about any
+    group of elements.
+    """
+    def __init__(self, curs, table_name):
+        """Initializes class by loading info about table_name from database."""
+
+        BaseStruct.__init__(self, curs, table_name)
+
+        self.table_name = table_name
+
+        # fill args
+        schema, name = fq_name_parts(table_name)
+        args = {
+            'schema': schema,
+            'table': name,
+            'fqname': self.fqname,
+            'fq2name': quote_literal(self.fqname),
+            'oid': get_table_oid(curs, table_name),
+            'pg_class_oid': get_table_oid(curs, 'pg_catalog.pg_class'),
+        }
+
+        # load table struct
+        self.col_list = self._load_elem(curs, self.name, args, TColumn)
+        self.object_list = [ TTable(table_name, self.col_list) ]
+        self.seq_list = []
+
+        # load seqs
+        for col in self.col_list:
+            if col.seqname:
+                owner = self.fqname + '.' + quote_ident(col.name)
+                seq_args = { 'fqname': col.seqname, 'owner': quote_literal(owner) }
+                self.seq_list += self._load_elem(curs, col.seqname, seq_args, TSeq)
+        self.object_list += self.seq_list
+
+        # load additional objects
+        to_load = [TConstraint, TIndex, TTrigger, TRule, TGrant, TOwner]
+        for eclass in to_load:
+            self.object_list += self._load_elem(curs, self.name, args, eclass)
+
     def get_column_list(self):
         """Returns list of column names the table has."""
 
@@ -369,11 +474,28 @@ class TableStruct(object):
             res.append(c.name)
         return res
 
+class SeqStruct(BaseStruct):
+    """Collects and manages all info about sequence.
+
+    Allow to issue CREATE/DROP statements about any
+    group of elements.
+    """
+    def __init__(self, curs, seq_name):
+        """Initializes class by loading info about table_name from database."""
+
+        BaseStruct.__init__(self, curs, seq_name)
+
+        # fill args
+        args = { 'fqname': self.fqname, 'owner': 'null' }
+
+        # load table struct
+        self.object_list = self._load_elem(curs, seq_name, args, TSeq)
+
 def test():
     from skytools import connect_database
     db = connect_database("dbname=fooz")
     curs = db.cursor()
-    
+
     s = TableStruct(curs, "public.data1")
 
     s.drop(curs, T_ALL)
index 06c9b9560ef531711246188ea0bc393037406e24..a01621f7b71b333cf8ce136280c05eb499998294 100644 (file)
@@ -1,4 +1,8 @@
 
+"""SQL script locations."""
+
+__all__ = ['sql_locations']
+
 sql_locations = [
     "@SQLDIR@",
 ]
index 3aa94991e065851acfa5304646a1992976e9cc4d..36b17d537c9a7008284294559e04b512559ae256 100644 (file)
@@ -42,11 +42,14 @@ def parse_pgarray(array):
 #
 
 class _logtriga_parser:
+    """Parses logtriga/sqltriga partial SQL to values."""
     def tokenizer(self, sql):
+        """Token generator."""
         for typ, tok in sql_tokenizer(sql, ignore_whitespace = True):
             yield tok
 
     def parse_insert(self, tk, fields, values):
+        """Handler for inserts."""
         # (col1, col2) values ('data', null)
         if tk.next() != "(":
             raise Exception("syntax error")
@@ -73,6 +76,7 @@ class _logtriga_parser:
         raise Exception("expected EOF, got " + repr(t))
 
     def parse_update(self, tk, fields, values):
+        """Handler for updates."""
         # col1 = 'data1', col2 = null where pk1 = 'pk1' and pk2 = 'pk2'
         while 1:
             fields.append(tk.next())
@@ -97,6 +101,7 @@ class _logtriga_parser:
                 raise Exception("syntax error, expected AND got "+repr(t))
 
     def parse_delete(self, tk, fields, values):
+        """Handler for deletes."""
         # pk1 = 'pk1' and pk2 = 'pk2'
         while 1:
             fields.append(tk.next())
@@ -108,6 +113,7 @@ class _logtriga_parser:
                 raise Exception("syntax error, expected AND, got "+repr(t))
 
     def parse_sql(self, op, sql):
+        """Main entry point."""
         tk = self.tokenizer(sql)
         fields = []
         values = []
index 9d23e5068a2e4e1e2ac29b7cf0f9b400fd2e4ef4..7bce0c337caa0fda28447c32d63b8ac292a8a278 100644 (file)
@@ -1,16 +1,60 @@
 
-"""Wrapper around psycopg1/2.
+"""Wrapper around psycopg2.
 
-Preferred is psycopg2, fallback to psycopg1.
+Database connection provides regular DB-API 2.0 interface.
 
-Interface provided is psycopg1:
-    - dict* methods.
-    - new columns can be assigned to row.
+Connection object methods::
 
-"""
+    .cursor()
+
+    .commit()
+
+    .rollback()
+
+    .close()
+
+Cursor methods::
+
+    .execute(query[, args])
+
+    .fetchone()
+
+    .fetchall()
+
+
+Sample usage::
 
-import sys
+    db = self.get_database('somedb')
+    curs = db.cursor()
+
+    # query arguments as array
+    q = "select * from table where id = %s and name = %s"
+    curs.execute(q, [1, 'somename'])
+
+    # query arguments as dict
+    q = "select id, name from table where id = %(id)s and name = %(name)s"
+    curs.execute(q, {'id': 1, 'name': 'somename'})
+
+    # loop over resultset
+    for row in curs.fetchall():
+
+        # columns can be asked by index:
+        id = row[0]
+        name = row[1]
+
+        # and by name:
+        id = row['id']
+        name = row['name']
+
+    # now commit the transaction
+    db.commit()
+
+Deprecated interface:  .dictfetchall/.dictfetchone functions on cursor.
+Plain .fetchall() / .fetchone() give exact same result.
+
+"""
 
+# no exports
 __all__ = []
 
 ##from psycopg2.psycopg1 import connect as _pgconnect
@@ -54,6 +98,7 @@ class _CompatCursor(psycopg2.extras.DictCursor):
 
 class _CompatConnection(psycopg2.extensions.connection):
     """Connection object that uses _CompatCursor."""
+    my_name = '?'
     def cursor(self):
         return psycopg2.extensions.connection.cursor(self, cursor_factory = _CompatCursor)
 
index 8225c7b06127b5602fef612e5674f0232217f512..9d281254ac9d47e4820c5fb3280895e18f8eef47 100644 (file)
@@ -12,7 +12,7 @@ __all__ = [
     # local
     "quote_bytea_literal", "quote_bytea_copy", "quote_statement",
     "quote_ident", "quote_fqident", "quote_json", "unescape_copy",
-    "unquote_ident",
+    "unquote_ident", "unquote_fqident",
 ]
 
 try:
@@ -34,15 +34,19 @@ def quote_bytea_copy(s):
 
     return quote_copy(quote_bytea_raw(s))
 
-def quote_statement(sql, dict):
+def quote_statement(sql, dict_or_list):
     """Quote whole statement.
 
-    Data values are taken from dict.
+    Data values are taken from dict or list or tuple.
     """
-    xdict = {}
-    for k, v in dict.items():
-        xdict[k] = quote_literal(v)
-    return sql % xdict
+    if hasattr(dict_or_list, 'items'):
+        qdict = {}
+        for k, v in dict_or_list.items():
+            qdict[k] = quote_literal(v)
+        return sql % qdict
+    else:
+        qvals = [quote_literal(v) for v in dict_or_list]
+        return sql % tuple(qvals)
 
 # reserved keywords
 _ident_kwmap = {
@@ -58,6 +62,8 @@ _ident_kwmap = {
 "primary":1, "references":1, "returning":1, "select":1, "session_user":1,
 "some":1, "symmetric":1, "table":1, "then":1, "to":1, "trailing":1, "true":1,
 "union":1, "unique":1, "user":1, "using":1, "when":1, "where":1,
+# greenplum?
+"errors":1,
 }
 
 _ident_bad = re.compile(r"[^a-z0-9_]")
@@ -90,6 +96,7 @@ _jsmap = { "\b": "\\b", "\f": "\\f", "\n": "\\n", "\r": "\\r",
 }
 
 def _json_quote_char(m):
+    """Quote single char."""
     c = m.group(0)
     try:
         return _jsmap[c]
@@ -114,3 +121,11 @@ def unquote_ident(val):
         return val[1:-1].replace('""', '"')
     return val
 
+def unquote_fqident(val):
+    """Unquotes fully-qualified possibly quoted SQL identifier.
+
+    It must be prefixed schema, which does not contain dots.
+    """
+    tmp = val.split('.', 1)
+    return "%s.%s" % (unquote_ident(tmp[0]), unquote_ident(tmp[1]))
+
index b2c26bc03e451ff891a9a63651c704cb506513d1..a61154de48bd7034a9a56031ad8e6b3e6360e0fd 100644 (file)
@@ -1,13 +1,29 @@
 
 """Useful functions and classes for database scripts."""
 
-import sys, os, signal, optparse, traceback, time, errno
+import sys, os, signal, optparse, time, errno
 import logging, logging.handlers, logging.config
 
 from skytools.config import *
 from skytools.psycopgwrapper import connect_database
+from skytools.quoting import quote_statement
 import skytools.skylog
 
+__pychecker__ = 'no-badexcept'
+
+#: how old connections need to be closed
+DEF_CONN_AGE = 20*60  # 20 min
+
+#: isolation level not set
+I_DEFAULT = -1
+
+#: isolation level constant for AUTOCOMMIT
+I_AUTOCOMMIT = 0
+#: isolation level constant for READ COMMITTED
+I_READ_COMMITTED = 1
+#: isolation level constant for SERIALIZABLE
+I_SERIALIZABLE = 2
+
 __all__ = ['DBScript', 'I_AUTOCOMMIT', 'I_READ_COMMITTED', 'I_SERIALIZABLE',
            'signal_pidfile']
 #__all__ += ['daemonize', 'run_single_process']
@@ -80,10 +96,10 @@ def run_single_process(runnable, daemon, pidfile):
     # check if another process is running
     if pidfile and os.path.isfile(pidfile):
         if signal_pidfile(pidfile, 0):
-            print "Pidfile exists, another process running?"
+            print("Pidfile exists, another process running?")
             sys.exit(1)
         else:
-            print "Ignoring stale pidfile"
+            print("Ignoring stale pidfile")
 
     # daemonize if needed and write pidfile
     if daemon:
@@ -122,8 +138,8 @@ def _init_log(job_name, service_name, cf, log_level):
         skytools.skylog.set_service_name(service_name)
 
         # load general config
-        list = ['skylog.ini', '~/.skylog.ini', '/etc/skylog.ini']
-        for fn in list:
+        flist = ['skylog.ini', '~/.skylog.ini', '/etc/skylog.ini']
+        for fn in flist:
             fn = os.path.expanduser(fn)
             if os.path.isfile(fn):
                 defs = {'job_name': job_name, 'service_name': service_name}
@@ -163,33 +179,24 @@ def _init_log(job_name, service_name, cf, log_level):
 
     return log
 
-#: how old connections need to be closed
-DEF_CONN_AGE = 20*60  # 20 min
-
-#: isolation level not set
-I_DEFAULT = -1
-
-#: isolation level constant for AUTOCOMMIT
-I_AUTOCOMMIT = 0
-#: isolation level constant for READ COMMITTED
-I_READ_COMMITTED = 1
-#: isolation level constant for SERIALIZABLE
-I_SERIALIZABLE = 2
-
 class DBCachedConn(object):
     """Cache a db connection."""
-    def __init__(self, name, loc, max_age = DEF_CONN_AGE):
+    def __init__(self, name, loc, max_age = DEF_CONN_AGE, verbose = False, setup_func=None):
         self.name = name
         self.loc = loc
         self.conn = None
         self.conn_time = 0
         self.max_age = max_age
         self.autocommit = -1
-        self.isolation_level = -1
+        self.isolation_level = I_DEFAULT
+        self.verbose = verbose
+        self.setup_func = setup_func
 
-    def get_connection(self, autocommit = 0, isolation_level = -1):
+    def get_connection(self, autocommit = 0, isolation_level = I_DEFAULT):
         # autocommit overrider isolation_level
         if autocommit:
+            if isolation_level == I_SERIALIZABLE:
+                raise Exception('autocommit is not compatible with I_SERIALIZABLE')
             isolation_level = I_AUTOCOMMIT
 
         # default isolation_level is READ COMMITTED
@@ -200,9 +207,12 @@ class DBCachedConn(object):
         if not self.conn:
             self.isolation_level = isolation_level
             self.conn = connect_database(self.loc)
+            self.conn.my_name = self.name
 
             self.conn.set_isolation_level(isolation_level)
             self.conn_time = time.time()
+            if self.setup_func:
+                self.setup_func(self.name, self.conn)
         else:
             if self.isolation_level != isolation_level:
                 raise Exception("Conflict in isolation_level")
@@ -250,6 +260,8 @@ class DBScript(object):
     job_name = None
     cf = None
     log = None
+    pidfile = None
+    loop_delay = 1
 
     def __init__(self, service_name, args):
         """Script setup.
@@ -286,7 +298,7 @@ class DBScript(object):
         if self.options.verbose:
             self.log_level = logging.DEBUG
         if len(self.args) < 1:
-            print "need config file"
+            print("need config file")
             sys.exit(1)
 
         # read config file
@@ -305,6 +317,8 @@ class DBScript(object):
             self.send_signal(signal.SIGHUP)
 
     def load_config(self):
+        """Loads and returns skytools.Config instance."""
+
         conf_file = self.args[0]
         return Config(self.service_name, conf_file)
 
@@ -369,7 +383,21 @@ class DBScript(object):
             if not self.pidfile:
                 self.log.error("Daemon needs pidfile")
                 sys.exit(1)
-        run_single_process(self, self.go_daemon, self.pidfile)
+
+        try:
+            run_single_process(self, self.go_daemon, self.pidfile)
+        except KeyboardInterrupt:
+            raise
+        except SystemExit:
+            raise
+        except Exception:
+            # catch startup errors
+            exc, msg, tb = sys.exc_info()
+            self.log.exception("Job %s crashed on startup: %s: %s" % (
+                       self.job_name, str(exc), str(msg).rstrip()))
+            del tb
+            sys.exit(1)
+
 
     def stop(self):
         """Safely stops processing loop."""
@@ -386,9 +414,15 @@ class DBScript(object):
         "Internal SIGHUP handler.  Minimal code here."
         self.need_reload = 1
 
+    last_sigint = 0
     def hook_sigint(self, sig, frame):
         "Internal SIGINT handler.  Minimal code here."
         self.stop()
+        t = time.time()
+        if t - self.last_sigint < 1:
+            self.log.warning("Double ^C, fast exit")
+            sys.exit(1)
+        self.last_sigint = t
 
     def stat_add(self, key, value):
         """Old, deprecated function."""
@@ -419,6 +453,9 @@ class DBScript(object):
         self.log.info(logmsg)
         self.stat_dict = {}
 
+    def connection_setup(self, dbname, conn):
+        pass
+
     def get_database(self, dbname, autocommit = 0, isolation_level = -1,
                      cache = None, connstr = None):
         """Load cached database connection.
@@ -435,7 +472,7 @@ class DBScript(object):
         else:
             if not connstr:
                 connstr = self.cf.get(dbname)
-            dbc = DBCachedConn(cache, connstr, max_age)
+            dbc = DBCachedConn(cache, connstr, max_age, setup_func = self.connection_setup)
             self.db_cache[cache] = dbc
 
         return dbc.get_connection(autocommit, isolation_level)
@@ -462,15 +499,14 @@ class DBScript(object):
         # run startup, safely
         try:
             self.startup()
-        except KeyboardInterrupt, det:
+        except KeyboardInterrupt:
             raise
-        except SystemExit, det:
+        except SystemExit:
             raise
-        except Exception, det:
+        except Exception:
             exc, msg, tb = sys.exc_info()
-            self.log.fatal("Job %s crashed: %s: '%s' (%s: %s)" % (
-                       self.job_name, str(exc), str(msg).rstrip(),
-                       str(tb), repr(traceback.format_tb(tb))))
+            self.log.exception("Job %s crashed: %s: %s" % (
+                       self.job_name, str(exc), str(msg).rstrip()))
             del tb
             self.reset()
             sys.exit(1)
@@ -523,10 +559,9 @@ class DBScript(object):
         except Exception, d:
             self.send_stats()
             exc, msg, tb = sys.exc_info()
-            self.log.fatal("Job %s crashed: %s: '%s' (%s: %s)" % (
-                       self.job_name, str(exc), str(msg).rstrip(),
-                       str(tb), repr(traceback.format_tb(tb))))
             del tb
+            self.log.exception("Job %s crashed: %s: %s" % (
+                       self.job_name, str(exc), str(msg).rstrip()))
             self.reset()
             if self.looping and not self.do_single_loop:
                 time.sleep(20)
@@ -553,4 +588,82 @@ class DBScript(object):
         signal.signal(signal.SIGHUP, self.hook_sighup)
         signal.signal(signal.SIGINT, self.hook_sigint)
 
+    def _exec_cmd(self, curs, sql, args, quiet = False):
+        """Internal tool: Run SQL on cursor."""
+        self.log.debug("exec_cmd: %s" % quote_statement(sql, args))
+        curs.execute(sql, args)
+        ok = True
+        rows = curs.fetchall()
+        for row in rows:
+            try:
+                code = row['ret_code']
+                msg = row['ret_note']
+            except KeyError:
+                self.log.error("Query does not conform to exec_cmd API:")
+                self.log.error("SQL: %s" % quote_statement(sql, args))
+                self.log.error("Row: %s" % repr(row.copy()))
+                sys.exit(1)
+            level = code / 100
+            if level == 1:
+                self.log.debug("%d %s" % (code, msg))
+            elif level == 2:
+                if quiet:
+                    self.log.debug("%d %s" % (code, msg))
+                else:
+                    self.log.info("%s" % (msg,))
+            elif level == 3:
+                self.log.warning("%s" % (msg,))
+            else:
+                self.log.error("%s" % (msg,))
+                self.log.error("Query was: %s" % quote_statement(sql, args))
+                ok = False
+        return (ok, rows)
+
+    def _exec_cmd_many(self, curs, sql, baseargs, extra_list, quiet = False):
+        """Internal tool: Run SQL on cursor multiple times."""
+        ok = True
+        rows = []
+        for a in extra_list:
+            (tmp_ok, tmp_rows) = self._exec_cmd(curs, sql, baseargs + [a], quiet=quiet)
+            if not tmp_ok:
+                ok = False
+            rows += tmp_rows
+        return (ok, rows)
+
+    def exec_cmd(self, db_or_curs, q, args, commit = True, quiet = False):
+        """Run SQL on db with code/value error handling."""
+        if hasattr(db_or_curs, 'cursor'):
+            db = db_or_curs
+            curs = db.cursor()
+        else:
+            db = None
+            curs = db_or_curs
+        (ok, rows) = self._exec_cmd(curs, q, args, quiet = quiet)
+        if ok:
+            if commit and db:
+                db.commit()
+            return rows
+        else:
+            if db:
+                db.rollback()
+            raise Exception("db error")
+
+    def exec_cmd_many(self, db_or_curs, sql, baseargs, extra_list, commit = True, quiet = False):
+        """Run SQL on db multiple times."""
+        if hasattr(db_or_curs, 'cursor'):
+            db = db_or_curs
+            curs = db.cursor()
+        else:
+            db = None
+            curs = db_or_curs
+        (ok, rows) = self._exec_cmd_many(curs, sql, baseargs, extra_list, quiet=quiet)
+        if ok:
+            if commit and db:
+                db.commit()
+            return rows
+        else:
+            if db:
+                db.rollback()
+            raise Exception("db error")
+
 
index 883bbe8b6115a092e0291024d55872dffbfe5e50..fd4fb85519e4a7f9e8e287fee9142ea58be0d4b0 100644 (file)
@@ -9,6 +9,7 @@ import skytools.installer_config
 __all__ = [
     "fq_name_parts", "fq_name", "get_table_oid", "get_table_pkeys",
     "get_table_columns", "exists_schema", "exists_table", "exists_type",
+    "exists_sequence",
     "exists_function", "exists_language", "Snapshot", "magic_insert",
     "CopyPipe", "full_copy", "DBObject", "DBSchema", "DBTable", "DBFunction",
     "DBLanguage", "db_install", "installer_find_file", "installer_apply_file",
@@ -19,9 +20,15 @@ class dbdict(dict):
     """Wrapper on actual dict that allows
     accessing dict keys as attributes."""
     # obj.foo access
-    def __getattr__(self, k):       return self[k]
-    def __setattr__(self, k, v):    self[k] = v
-    def __delattr__(self, k):       del self[k]
+    def __getattr__(self, k):
+        "Return attribute."
+        return self[k]
+    def __setattr__(self, k, v):
+        "Set attribute."
+        self[k] = v
+    def __delattr__(self, k):
+        "Remove attribute"
+        del self[k]
 
 #
 # Fully qualified table name
@@ -46,6 +53,7 @@ def fq_name(tbl):
 # info about table
 #
 def get_table_oid(curs, table_name):
+    """Find Postgres OID for table."""
     schema, name = fq_name_parts(table_name)
     q = """select c.oid from pg_namespace n, pg_class c
            where c.relnamespace = n.oid
@@ -57,6 +65,7 @@ def get_table_oid(curs, table_name):
     return res[0][0]
 
 def get_table_pkeys(curs, tbl):
+    """Return list of pkey column names."""
     oid = get_table_oid(curs, tbl)
     q = "SELECT k.attname FROM pg_index i, pg_attribute k"\
         " WHERE i.indrelid = %s AND k.attrelid = i.indexrelid"\
@@ -66,6 +75,7 @@ def get_table_pkeys(curs, tbl):
     return map(lambda x: x[0], curs.fetchall())
 
 def get_table_columns(curs, tbl):
+    """Return list of column names for table."""
     oid = get_table_oid(curs, tbl)
     q = "SELECT k.attname FROM pg_attribute k"\
         " WHERE k.attrelid = %s"\
@@ -78,12 +88,14 @@ def get_table_columns(curs, tbl):
 # exist checks
 #
 def exists_schema(curs, schema):
+    """Does schema exists?"""
     q = "select count(1) from pg_namespace where nspname = %s"
     curs.execute(q, [schema])
     res = curs.fetchone()
     return res[0]
 
 def exists_table(curs, table_name):
+    """Does table exists?"""
     schema, name = fq_name_parts(table_name)
     q = """select count(1) from pg_namespace n, pg_class c
            where c.relnamespace = n.oid and c.relkind = 'r'
@@ -92,7 +104,18 @@ def exists_table(curs, table_name):
     res = curs.fetchone()
     return res[0]
 
+def exists_sequence(curs, seq_name):
+    """Does sequence exists?"""
+    schema, name = fq_name_parts(seq_name)
+    q = """select count(1) from pg_namespace n, pg_class c
+           where c.relnamespace = n.oid and c.relkind = 'S'
+             and n.nspname = %s and c.relname = %s"""
+    curs.execute(q, [schema, name])
+    res = curs.fetchone()
+    return res[0]
+
 def exists_type(curs, type_name):
+    """Does type exists?"""
     schema, name = fq_name_parts(type_name)
     q = """select count(1) from pg_namespace n, pg_type t
            where t.typnamespace = n.oid
@@ -102,6 +125,7 @@ def exists_type(curs, type_name):
     return res[0]
 
 def exists_function(curs, function_name, nargs):
+    """Does function exists?"""
     # this does not check arg types, so may match several functions
     schema, name = fq_name_parts(function_name)
     q = """select count(1) from pg_namespace n, pg_proc p
@@ -118,6 +142,7 @@ def exists_function(curs, function_name, nargs):
     return res[0]
 
 def exists_language(curs, lang_name):
+    """Does PL exists?"""
     q = """select count(1) from pg_language
            where lanname = %s"""
     curs.execute(q, [lang_name])
@@ -331,11 +356,13 @@ class DBObject(object):
     sql = None
     sql_file = None
     def __init__(self, name, sql = None, sql_file = None):
+        """Generic dbobject init."""
         self.name = name
         self.sql = sql
         self.sql_file = sql_file
 
     def create(self, curs, log = None):
+        """Create a dbobject."""
         if log:
             log.info('Installing %s' % self.name)
         if self.sql:
@@ -352,13 +379,14 @@ class DBObject(object):
             curs.execute(stmt)
 
     def find_file(self):
+        """Find install script file."""
         full_fn = None
         if self.sql_file[0] == "/":
             full_fn = self.sql_file
         else:
             dir_list = skytools.installer_config.sql_locations
-            for dir in dir_list:
-                fn = os.path.join(dir, self.sql_file)
+            for fdir in dir_list:
+                fn = os.path.join(fdir, self.sql_file)
                 if os.path.isfile(fn):
                     full_fn = fn
                     break
@@ -370,26 +398,32 @@ class DBObject(object):
 class DBSchema(DBObject):
     """Handles db schema."""
     def exists(self, curs):
+        """Does schema exists."""
         return exists_schema(curs, self.name)
 
 class DBTable(DBObject):
     """Handles db table."""
     def exists(self, curs):
+        """Does table exists."""
         return exists_table(curs, self.name)
 
 class DBFunction(DBObject):
     """Handles db function."""
     def __init__(self, name, nargs, sql = None, sql_file = None):
+        """Function object - number of args is significant."""
         DBObject.__init__(self, name, sql, sql_file)
         self.nargs = nargs
     def exists(self, curs):
+        """Does function exists."""
         return exists_function(curs, self.name, self.nargs)
 
 class DBLanguage(DBObject):
     """Handles db language."""
     def __init__(self, name):
+        """PL object - creation happens with CREATE LANGUAGE."""
         DBObject.__init__(self, name, sql = "create language %s" % name)
     def exists(self, curs):
+        """Does PL exists."""
         return exists_language(curs, self.name)
 
 def db_install(curs, list, log = None):
@@ -402,14 +436,15 @@ def db_install(curs, list, log = None):
                 log.info('%s is installed' % obj.name)
 
 def installer_find_file(filename):
+    """Find SQL script from pre-defined paths."""
     full_fn = None
     if filename[0] == "/":
         if os.path.isfile(filename):
             full_fn = filename
     else:
         dir_list = ["."] + skytools.installer_config.sql_locations
-        for dir in dir_list:
-            fn = os.path.join(dir, filename)
+        for fdir in dir_list:
+            fn = os.path.join(fdir, filename)
             if os.path.isfile(fn):
                 full_fn = fn
                 break
@@ -419,6 +454,7 @@ def installer_find_file(filename):
     return full_fn
 
 def installer_apply_file(db, filename, log):
+    """Find SQL file and apply it to db, statement-by-statement."""
     fn = installer_find_file(filename)
     sql = open(fn, "r").read()
     if log: