skytools: Separate generic scripting from DBScript
authorMarko Kreen <markokr@gmail.com>
Thu, 7 Apr 2011 12:57:28 +0000 (15:57 +0300)
committerMarko Kreen <markokr@gmail.com>
Fri, 15 Apr 2011 10:20:21 +0000 (13:20 +0300)
python/pgq/cascade/consumer.py
python/skytools/scripting.py

index fdc8a48b948056f33fe6c2b366f738795d313bad..798ee7e74414890ecaeec00eafffc044fc414763 100644 (file)
@@ -274,7 +274,7 @@ class CascadedConsumer(Consumer):
         q = "select * from pgq_node.set_consumer_completed(%s, %s, %s)"
         self.exec_cmd(dst_db, q, [ self.queue_name, self.consumer_name, tick_id ])
 
-    def exception_hook(self, det, emsg, cname):
+    def exception_hook(self, det, emsg):
         try:
             dst_db = self.get_database(self.target_db)
             q = "select * from pgq_node.set_consumer_error(%s, %s, %s)"
@@ -282,4 +282,5 @@ class CascadedConsumer(Consumer):
         except:
             self.log.warning("Failure to call pgq_node.set_consumer_error()")
         self.reset()
+        Consumer.exception_hook(self, det, emsg)
 
index 298c1efa648c01d253220d16751289084ae3912e..a8baffd69c9924a7739aa9e7c55b80e9a21fb658 100644 (file)
@@ -17,22 +17,8 @@ except ImportError:
 
 __pychecker__ = 'no-badexcept'
 
-#: how old connections need to be closed
-DEF_CONN_AGE = 20*60  # 20 min
-
-#: isolation level not set
-I_DEFAULT = -1
-
-#: isolation level constant for AUTOCOMMIT
-I_AUTOCOMMIT = 0
-#: isolation level constant for READ COMMITTED
-I_READ_COMMITTED = 1
-#: isolation level constant for SERIALIZABLE
-I_SERIALIZABLE = 2
-
-__all__ = ['DBScript', 'I_AUTOCOMMIT', 'I_READ_COMMITTED', 'I_SERIALIZABLE',
-           'signal_pidfile', 'UsageError']
-#__all__ += ['daemonize', 'run_single_process']
+__all__ = ['BaseScript', 'signal_pidfile', 'UsageError', 'daemonize',
+            'DBScript', 'I_AUTOCOMMIT', 'I_READ_COMMITTED', 'I_SERIALIZABLE']
 
 class UsageError(Exception):
     """User induced error."""
@@ -197,97 +183,10 @@ def _init_log(job_name, service_name, cf, log_level, is_daemon):
 
     return log
 
-class DBCachedConn(object):
-    """Cache a db connection."""
-    def __init__(self, name, loc, max_age = DEF_CONN_AGE, verbose = False, setup_func=None, channels=[]):
-        self.name = name
-        self.loc = loc
-        self.conn = None
-        self.conn_time = 0
-        self.max_age = max_age
-        self.autocommit = -1
-        self.isolation_level = I_DEFAULT
-        self.verbose = verbose
-        self.setup_func = setup_func
-        self.listen_channel_list = []
-
-    def fileno(self):
-        if not self.conn:
-            return None
-        return self.conn.cursor().fileno()
-
-    def get_connection(self, autocommit = 0, isolation_level = I_DEFAULT, listen_channel_list = []):
-        # autocommit overrider isolation_level
-        if autocommit:
-            if isolation_level == I_SERIALIZABLE:
-                raise Exception('autocommit is not compatible with I_SERIALIZABLE')
-            isolation_level = I_AUTOCOMMIT
-
-        # default isolation_level is READ COMMITTED
-        if isolation_level < 0:
-            isolation_level = I_READ_COMMITTED
-
-        # new conn?
-        if not self.conn:
-            self.isolation_level = isolation_level
-            self.conn = skytools.connect_database(self.loc)
-            self.conn.my_name = self.name
-
-            self.conn.set_isolation_level(isolation_level)
-            self.conn_time = time.time()
-            if self.setup_func:
-                self.setup_func(self.name, self.conn)
-        else:
-            if self.isolation_level != isolation_level:
-                raise Exception("Conflict in isolation_level")
 
-        self._sync_listen(listen_channel_list)
 
-        # done
-        return self.conn
-
-    def _sync_listen(self, new_clist):
-        if not new_clist and not self.listen_channel_list:
-            return
-        curs = self.conn.cursor()
-        for ch in self.listen_channel_list:
-            if ch not in new_clist:
-                curs.execute("UNLISTEN %s" % skytools.quote_ident(ch))
-        for ch in new_clist:
-            if ch not in self.listen_channel_list:
-                curs.execute("LISTEN %s" % skytools.quote_ident(ch))
-        if self.isolation_level != I_AUTOCOMMIT:
-            self.conn.commit()
-        self.listen_channel_list = new_clist[:]
-
-    def refresh(self):
-        if not self.conn:
-            return
-        #for row in self.conn.notifies():
-        #    if row[0].lower() == "reload":
-        #        self.reset()
-        #        return
-        if not self.max_age:
-            return
-        if time.time() - self.conn_time >= self.max_age:
-            self.reset()
-
-    def reset(self):
-        if not self.conn:
-            return
-
-        # drop reference
-        conn = self.conn
-        self.conn = None
-        self.listen_channel_list = []
-
-        # close
-        try:
-            conn.close()
-        except: pass
-
-class DBScript(object):
-    """Base class for database scripts.
+class BaseScript(object):
+    """Base class for service scripts.
 
     Handles logging, daemonizing, config, errors.
 
@@ -314,9 +213,6 @@ class DBScript(object):
         #   1 - enabled, unless non-daemon on console (os.isatty())
         #   2 - always enabled
         #use_skylog = 0
-
-        # default lifetime for database connections (in seconds)
-        #connection_lifetime = 1200
     """
     service_name = None
     job_name = None
@@ -353,12 +249,10 @@ class DBScript(object):
         @param args: cmdline args (sys.argv[1:]), but can be overrided
         """
         self.service_name = service_name
-        self.db_cache = {}
         self.go_daemon = 0
         self.need_reload = 0
         self.stat_dict = {}
         self.log_level = logging.INFO
-        self._listen_map = {} # dbname: channel_list
 
         # parse command line
         parser = self.init_optparse()
@@ -591,50 +485,9 @@ class DBScript(object):
         self.log.info(logmsg)
         self.stat_dict = {}
 
-    def connection_hook(self, dbname, conn):
-        pass
-
-    def get_database(self, dbname, autocommit = 0, isolation_level = -1,
-                     cache = None, connstr = None):
-        """Load cached database connection.
-        
-        User must not store it permanently somewhere,
-        as all connections will be invalidated on reset.
-        """
-
-        max_age = self.cf.getint('connection_lifetime', DEF_CONN_AGE)
-        if not cache:
-            cache = dbname
-        if cache in self.db_cache:
-            dbc = self.db_cache[cache]
-        else:
-            if not connstr:
-                connstr = self.cf.get(dbname)
-            self.log.debug("Connect '%s' to '%s'" % (cache, connstr))
-            dbc = DBCachedConn(cache, connstr, max_age, setup_func = self.connection_hook)
-            self.db_cache[cache] = dbc
-
-        clist = []
-        if cache in self._listen_map:
-            clist = self._listen_map[cache]
-
-        return dbc.get_connection(autocommit, isolation_level, clist)
-
-    def close_database(self, dbname):
-        """Explicitly close a cached connection.
-        
-        Next call to get_database() will reconnect.
-        """
-        if dbname in self.db_cache:
-            dbc = self.db_cache[dbname]
-            dbc.reset()
-            del self.db_cache[dbname]
-
     def reset(self):
-        "Something bad happened, reset all connections."
-        for dbc in self.db_cache.values():
-            dbc.reset()
-        self.db_cache = {}
+        "Something bad happened, reset all state."
+        pass
 
     def run(self):
         "Thread main loop."
@@ -651,13 +504,6 @@ class DBScript(object):
             # do some work
             work = self.run_once()
 
-            # send stats that was added
-            self.send_stats()
-
-            # reconnect if needed
-            for dbc in self.db_cache.values():
-                dbc.refresh()
-
             if not self.looping or self.loop_delay < 0:
                 break
 
@@ -673,7 +519,12 @@ class DBScript(object):
                     break
 
     def run_once(self):
-        return self.run_func_safely(self.work, True)
+        state = self.run_func_safely(self.work, True)
+
+        # send stats that was added
+        self.send_stats()
+
+        return state
 
     def run_func_safely(self, func, prefer_looping = False):
         "Run users work function, safely."
@@ -702,31 +553,13 @@ class DBScript(object):
                 self.log.info("got KeyboardInterrupt, exiting")
             self.reset()
             sys.exit(1)
-        except skytools.DBError, d:
-            self.send_stats()
-            if d.cursor and d.cursor.connection:
-                cname = d.cursor.connection.my_name
-                dsn = d.cursor.connection.dsn
-                sql = d.cursor.query
-                if len(sql) > 200: # avoid logging londiste huge batched queries 
-                    sql = sql[:60] + " ..."
-                emsg = str(d).strip()
-                self.log.exception("Job %s got error on connection '%s': %s.   Query: %s" % (
-                    self.job_name, cname, emsg, sql))
-            else:
-                n = "psycopg2.%s" % d.__class__.__name__
-                emsg = str(d).rstrip()
-                self.log.exception("Job %s crashed: %s: %s" % (
-                       self.job_name, n, emsg))
         except Exception, d:
             self.send_stats()
             emsg = str(d).rstrip()
-            self.log.exception("Job %s crashed: %s" % (
-                       self.job_name, emsg))
-
+            self.reset()
+            self.exception_hook(d, emsg)
         # reset and sleep
         self.reset()
-        self.exception_hook(d, emsg, cname)
         if prefer_looping and self.looping and self.loop_delay > 0:
             self.sleep(20)
             return -1
@@ -734,39 +567,18 @@ class DBScript(object):
 
     def sleep(self, secs):
         """Make script sleep for some amount of time."""
-        fdlist = []
-        for dbname in self._listen_map.keys():
-            if dbname not in self.db_cache:
-                continue
-            fd = self.db_cache[dbname].fileno()
-            if fd is None:
-                continue
-            fdlist.append(fd)
-
-        if not fdlist:
-            return time.sleep(secs)
-
-        try:
-            if hasattr(select, 'poll'):
-                p = select.poll()
-                for fd in fdlist:
-                    p.register(fd, select.POLLIN)
-                p.poll(int(secs * 1000))
-            else:
-                select.select(fdlist, [], [], secs)
-        except select.error, d:
-            self.log.info('wait canceled')
+        time.sleep(secs)
 
-    def exception_hook(self, det, emsg, cname):
+    def exception_hook(self, det, emsg):
         """Called on after exception processing.
 
         Can do additional logging.
 
         @param det: exception details
         @param emsg: exception msg
-        @param cname: connection name or None
         """
-        pass
+        self.log.exception("Job %s crashed: %s" % (
+                   self.job_name, emsg))
 
     def work(self):
         """Here should user's processing happen.
@@ -787,6 +599,152 @@ class DBScript(object):
         signal.signal(signal.SIGHUP, self.hook_sighup)
         signal.signal(signal.SIGINT, self.hook_sigint)
 
+##
+##  DBScript
+##
+
+#: how old connections need to be closed
+DEF_CONN_AGE = 20*60  # 20 min
+
+#: isolation level not set
+I_DEFAULT = -1
+
+#: isolation level constant for AUTOCOMMIT
+I_AUTOCOMMIT = 0
+#: isolation level constant for READ COMMITTED
+I_READ_COMMITTED = 1
+#: isolation level constant for SERIALIZABLE
+I_SERIALIZABLE = 2
+
+
+class DBScript(BaseScript):
+    """Base class for database scripts.
+
+    Handles database connection state.
+
+    Config template::
+
+        ## Parameters for skytools.DBScript ##
+
+        # default lifetime for database connections (in seconds)
+        #connection_lifetime = 1200
+    """
+
+    def __init__(self, service_name, args):
+        """Script setup.
+
+        User class should override work() and optionally __init__(), startup(),
+        reload(), reset() and init_optparse().
+
+        NB: in case of daemon, the __init__() and startup()/work() will be
+        run in different processes.  So nothing fancy should be done in __init__().
+        
+        @param service_name: unique name for script.
+            It will be also default job_name, if not specified in config.
+        @param args: cmdline args (sys.argv[1:]), but can be overrided
+        """
+        self.db_cache = {}
+        self._listen_map = {} # dbname: channel_list
+        BaseScript.__init__(self, service_name, args)
+
+    def connection_hook(self, dbname, conn):
+        pass
+
+    def get_database(self, dbname, autocommit = 0, isolation_level = -1,
+                     cache = None, connstr = None):
+        """Load cached database connection.
+        
+        User must not store it permanently somewhere,
+        as all connections will be invalidated on reset.
+        """
+
+        max_age = self.cf.getint('connection_lifetime', DEF_CONN_AGE)
+        if not cache:
+            cache = dbname
+        if cache in self.db_cache:
+            dbc = self.db_cache[cache]
+        else:
+            if not connstr:
+                connstr = self.cf.get(dbname)
+            self.log.debug("Connect '%s' to '%s'" % (cache, connstr))
+            dbc = DBCachedConn(cache, connstr, max_age, setup_func = self.connection_hook)
+            self.db_cache[cache] = dbc
+
+        clist = []
+        if cache in self._listen_map:
+            clist = self._listen_map[cache]
+
+        return dbc.get_connection(autocommit, isolation_level, clist)
+
+    def close_database(self, dbname):
+        """Explicitly close a cached connection.
+        
+        Next call to get_database() will reconnect.
+        """
+        if dbname in self.db_cache:
+            dbc = self.db_cache[dbname]
+            dbc.reset()
+            del self.db_cache[dbname]
+
+    def reset(self):
+        "Something bad happened, reset all connections."
+        for dbc in self.db_cache.values():
+            dbc.reset()
+        self.db_cache = {}
+        BaseScript.reset(self)
+
+    def run_once(self):
+        state = BaseScript.run_once(self)
+
+        # reconnect if needed
+        for dbc in self.db_cache.values():
+            dbc.refresh()
+
+        return state
+
+    def exception_hook(self, d, emsg):
+        """Log database and query details from exception."""
+        curs = getattr(d, 'cursor', None)
+        conn = getattr(curs, 'connection', None)
+        cname = getattr(conn, 'my_name', None)
+        if cname:
+            # Properly named connection
+            cname = d.cursor.connection.my_name
+            dsn = getattr(conn, 'dsn', '?')
+            sql = getattr(curs, 'query', '?')
+            if len(sql) > 200: # avoid logging londiste huge batched queries 
+                sql = sql[:60] + " ..."
+            emsg = str(d).strip()
+            self.log.exception("Job %s got error on connection '%s': %s.   Query: %s" % (
+                self.job_name, cname, emsg, sql))
+        else:
+            BaseScript.exception_hook(self, d, emsg)
+
+    def sleep(self, secs):
+        """Make script sleep for some amount of time."""
+        fdlist = []
+        for dbname in self._listen_map.keys():
+            if dbname not in self.db_cache:
+                continue
+            fd = self.db_cache[dbname].fileno()
+            if fd is None:
+                continue
+            fdlist.append(fd)
+
+        if not fdlist:
+            return BaseScript.sleep(self, secs)
+
+        try:
+            if hasattr(select, 'poll'):
+                p = select.poll()
+                for fd in fdlist:
+                    p.register(fd, select.POLLIN)
+                p.poll(int(secs * 1000))
+            else:
+                select.select(fdlist, [], [], secs)
+        except select.error, d:
+            self.log.info('wait canceled')
+
     def _exec_cmd(self, curs, sql, args, quiet = False):
         """Internal tool: Run SQL on cursor."""
         self.log.debug("exec_cmd: %s" % skytools.quote_statement(sql, args))
@@ -903,3 +861,95 @@ class DBScript(object):
         except ValueError:
             pass
 
+class DBCachedConn(object):
+    """Cache a db connection."""
+    def __init__(self, name, loc, max_age = DEF_CONN_AGE, verbose = False, setup_func=None, channels=[]):
+        self.name = name
+        self.loc = loc
+        self.conn = None
+        self.conn_time = 0
+        self.max_age = max_age
+        self.autocommit = -1
+        self.isolation_level = I_DEFAULT
+        self.verbose = verbose
+        self.setup_func = setup_func
+        self.listen_channel_list = []
+
+    def fileno(self):
+        if not self.conn:
+            return None
+        return self.conn.cursor().fileno()
+
+    def get_connection(self, autocommit = 0, isolation_level = I_DEFAULT, listen_channel_list = []):
+        # autocommit overrider isolation_level
+        if autocommit:
+            if isolation_level == I_SERIALIZABLE:
+                raise Exception('autocommit is not compatible with I_SERIALIZABLE')
+            isolation_level = I_AUTOCOMMIT
+
+        # default isolation_level is READ COMMITTED
+        if isolation_level < 0:
+            isolation_level = I_READ_COMMITTED
+
+        # new conn?
+        if not self.conn:
+            self.isolation_level = isolation_level
+            self.conn = skytools.connect_database(self.loc)
+            self.conn.my_name = self.name
+
+            self.conn.set_isolation_level(isolation_level)
+            self.conn_time = time.time()
+            if self.setup_func:
+                self.setup_func(self.name, self.conn)
+        else:
+            if self.isolation_level != isolation_level:
+                raise Exception("Conflict in isolation_level")
+
+        self._sync_listen(listen_channel_list)
+
+        # done
+        return self.conn
+
+    def _sync_listen(self, new_clist):
+        if not new_clist and not self.listen_channel_list:
+            return
+        curs = self.conn.cursor()
+        for ch in self.listen_channel_list:
+            if ch not in new_clist:
+                curs.execute("UNLISTEN %s" % skytools.quote_ident(ch))
+        for ch in new_clist:
+            if ch not in self.listen_channel_list:
+                curs.execute("LISTEN %s" % skytools.quote_ident(ch))
+        if self.isolation_level != I_AUTOCOMMIT:
+            self.conn.commit()
+        self.listen_channel_list = new_clist[:]
+
+    def refresh(self):
+        if not self.conn:
+            return
+        #for row in self.conn.notifies():
+        #    if row[0].lower() == "reload":
+        #        self.reset()
+        #        return
+        if not self.max_age:
+            return
+        if time.time() - self.conn_time >= self.max_age:
+            self.reset()
+
+    def reset(self):
+        if not self.conn:
+            return
+
+        # drop reference
+        conn = self.conn
+        self.conn = None
+        self.listen_channel_list = []
+
+        # close
+        try:
+            conn.close()
+        except: pass
+
+
+
+