move common db patterns to skytools.adminscript module
authorMarko Kreen <markokr@gmail.com>
Tue, 22 Apr 2008 12:42:47 +0000 (12:42 +0000)
committerMarko Kreen <markokr@gmail.com>
Tue, 22 Apr 2008 12:42:47 +0000 (12:42 +0000)
python/skytools/__init__.py
python/skytools/adminscript.py [new file with mode: 0644]

index dc71872b6818dcaca6995dda800a3599dceecee3..e42f06a306bf4411499bbb9db71f69887a14133d 100644 (file)
@@ -9,6 +9,7 @@ import skytools.gzlog
 import skytools.scripting
 import skytools.parsing
 import skytools.dbstruct
+import skytools.adminscript
 
 from skytools.psycopgwrapper import *
 from skytools.config import *
@@ -18,6 +19,7 @@ from skytools.scripting import *
 from skytools.sqltools import *
 from skytools.quoting import *
 from skytools.parsing import *
+from skytools.adminscript import *
 
 __all__ = (skytools.psycopgwrapper.__all__
         + skytools.config.__all__
@@ -26,5 +28,6 @@ __all__ = (skytools.psycopgwrapper.__all__
         + skytools.scripting.__all__
         + skytools.sqltools.__all__
         + skytools.quoting.__all__
+        + skytools.adminscript.__all__
         + skytools.parsing.__all__)
 
diff --git a/python/skytools/adminscript.py b/python/skytools/adminscript.py
new file mode 100644 (file)
index 0000000..15aa13e
--- /dev/null
@@ -0,0 +1,137 @@
+#! /usr/bin/env python
+
+"""Admin scripting.
+"""
+
+import sys, os, skytools
+
+from skytools.scripting import DBScript
+
+__all__ = ['AdminScript']
+
+class AdminScript(DBScript):
+    def __init__(self, service_name, args):
+        DBScript.__init__(self, service_name, args)
+        self.pidfile = self.pidfile + ".admin"
+
+        if len(self.args) < 2:
+            self.log.error("need command")
+            sys.exit(1)
+
+    def work(self):
+        self.set_single_loop(1)
+        cmd = self.args[1]
+        fname = "cmd_" + cmd.replace('-', '_')
+        if hasattr(self, fname):
+            getattr(self, fname)(self.args[2:])
+        else:
+            self.log.error('bad subcommand, see --help for usage')
+            sys.exit(1)
+
+    def fetch_list(self, curs, sql, args, keycol = None):
+        curs.execute(sql, args)
+        rows = curs.dictfetchall()
+        if not keycol:
+            res = rows
+        else:
+            res = [r[keycol] for r in rows]
+        return res
+
+    def display_table(self, desc, curs, sql, args = [], fields = []):
+        """Display multirow query as a table."""
+
+        curs.execute(sql, args)
+        rows = curs.fetchall()
+        if len(rows) == 0:
+            return 0
+
+        if not fields:
+            fields = [f[0] for f in curs.description]
+        
+        widths = [15] * len(fields)
+        for row in rows:
+            for i, k in enumerate(fields):
+                rlen = row[k] and len(row) or 0
+                widths[i] = widths[i] > rlen and widths[i] or rlen
+        widths = [w + 2 for w in widths]
+
+        fmt = '%%-%ds' * (len(widths) - 1) + '%%s'
+        fmt = fmt % tuple(widths[:-1])
+        if desc:
+            print desc
+        print fmt % tuple(fields)
+        print fmt % tuple(['-'*15] * len(fields))
+            
+        for row in rows:
+            print fmt % tuple([row[k] for k in fields])
+        print '\n'
+        return 1
+
+    def db_display_table(self, db, desc, sql, args = [], fields = []):
+        curs = db.cursor()
+        res = self.display_table(desc, curs, sql, args, fields)
+        db.commit()
+        return res
+        
+
+    def exec_checked(self, curs, sql, args):
+        curs.execute(sql, args)
+        ok = True
+        for row in curs.fetchall():
+            level = row['ret_code'] / 100
+            if level == 1:
+                self.log.debug("%d %s" % (row[0], row[1]))
+            elif level == 2:
+                self.log.info("%d %s" % (row[0], row[1]))
+            elif level == 3:
+                self.log.warning("%d %s" % (row[0], row[1]))
+            else:
+                self.log.error("%d %s" % (row[0], row[1]))
+                ok = False
+        return ok
+
+    def exec_many(self, curs, sql, baseargs, extra_list):
+        ok = True
+        for a in extra_list:
+            tmp = self.exec_checked(curs, sql, baseargs + [a])
+            ok = tmp and ok
+        return ok
+
+    def db_cmd(self, db, q, args, commit = True):
+        ok = self.exec_checked(db.cursor(), q, args)
+        if ok:
+            if commit:
+                self.log.info("COMMIT")
+                db.commit()
+        else:
+            self.log.info("ROLLBACK")
+            db.rollback()
+            raise EXception("rollback")
+
+    def db_cmd_many(self, db, sql, baseargs, extra_list, commit = True):
+        curs = db.cursor()
+        ok = self.exec_many(curs, sql, baseargs, extra_list)
+        if ok:
+            if commit:
+                self.log.info("COMMIT")
+                db.commit()
+        else:
+            self.log.info("ROLLBACK")
+            db.rollback()
+
+
+    def exec_sql(self, db, q, args):
+        self.log.debug(q)
+        curs = db.cursor()
+        curs.execute(q, args)
+        db.commit()
+
+    def exec_query(self, db, q, args):
+        self.log.debug(q)
+        curs = db.cursor()
+        curs.execute(q, args)
+        res = curs.dictfetchall()
+        db.commit()
+        return res
+
+