skytools_upgrade: support 3.0-upgrades, walk databases
authorMarko Kreen <markokr@gmail.com>
Fri, 25 Nov 2011 09:17:18 +0000 (11:17 +0200)
committerMarko Kreen <markokr@gmail.com>
Fri, 25 Nov 2011 09:17:18 +0000 (11:17 +0200)
scripts/skytools_upgrade.py
setup_skytools.py

index c3472f154808c486290832c234ac7891b2a97cf9..389f66772a33c019726dfe5b311ae2f4ae94786a 100755 (executable)
@@ -2,41 +2,40 @@
 
 """Upgrade script for versioned schemas."""
 
-import sys, os, re
+usage = """
+    %prog [--user=U] [--host=H] [--port=P] --all
+    %prog [--user=U] [--host=H] [--port=P] DB1 [ DB2 ... ]\
+"""
+
+import sys, os, re, optparse
 
 import pkgloader
 pkgloader.require('skytools', '3.0')
-
 import skytools
+from skytools.natsort import natsort_key
 
-ver_rx = r"(\d+)([.](\d+)([.](\d+))?)?"
-ver_rc = re.compile(ver_rx)
 
-def detect_londiste215(curs):
-    return skytools.exists_table(curs, 'londiste.subscriber_pending_fkeys')
+# schemas, where .upgrade.sql is enough
+AUTO_UPGRADE = ('pgq', 'pgq_node', 'pgq_coop', 'londiste')
 
-version_list = [
- ['pgq', '2.1.5', 'v2.1.5_pgq_core.sql', None],
- # those vers did not have version func
- ['pgq_ext', '2.1.5', 'v2.1.5_pgq_ext.sql', None], # ok to reapply
- ['londiste', '2.1.5', 'v2.1.5_londiste.sql', detect_londiste215], # not ok to reapply
+# fetch list of databases
+DB_LIST = "select datname from pg_database "\
+          " where not datistemplate and datallowconn "\
+          " order by 1"
 
- ['pgq_ext', '2.1.6', 'v2.1.6_pgq_ext.sql', None],
- ['londiste', '2.1.6', 'v2.1.6_londiste.sql', None],
+# dont support upgrade from 2.x (yet?)
+version_list = [
+    ['pgq', '3.0', None, None],
+    ['londiste', '3.0', None, None],
+]
 
- ['pgq', '2.1.7', 'v2.1.7_pgq_core.sql', None],
- ['londiste', '2.1.7', 'v2.1.7_londiste.sql', None],
 
- ['pgq', '2.1.8', 'v2.1.8_pgq_core.sql', None],
-]
+def is_version_ge(a, b):
+    """Return True if a is greater or equal than b."""
+    va = natsort_key(a)
+    vb = natsort_key(b)
+    return va >= vb
 
-def parse_ver(ver):
-    m = ver_rc.match(ver)
-    if not ver: return 0
-    v0 = int(m.group(1) or "0")
-    v1 = int(m.group(3) or "0")
-    v2 = int(m.group(5) or "0")
-    return ((v0 * 100) + v1) * 100 + v2
 
 def check_version(curs, schema, new_ver_str, recheck_func=None):
     funcname = "%s.version" % schema
@@ -48,38 +47,115 @@ def check_version(curs, schema, new_ver_str, recheck_func=None):
     q = "select %s()" % funcname
     curs.execute(q)
     old_ver_str = curs.fetchone()[0]
-    new_ver = parse_ver(new_ver_str)
-    old_ver = parse_ver(old_ver_str)
-    return old_ver >= new_ver
-    
+    return is_version_ge(old_ver_str, new_ver_str)
+
 
 class DbUpgrade(skytools.DBScript):
-    def upgrade(self, db):
+    """Upgrade all Skytools schemas in Postgres cluster."""
+
+    def upgrade(self, dbname, db):
+        """Upgrade all schemas in single db."""
+
         curs = db.cursor()
-        for schema, ver, sql, recheck_fn in version_list:
+        for schema, ver, fn, recheck_fn in version_list:
             if not skytools.exists_schema(curs, schema):
                 continue
 
             if check_version(curs, schema, ver, recheck_fn):
                 continue
 
-            fn = "upgrade/final/%s" % sql
+            if fn is None:
+                self.log.info('%s: Cannot upgrade %s, too old version', dbname, schema)
+                continue
+
+            curs = db.cursor()
+            curs.execute('begin')
             skytools.installer_apply_file(db, fn, self.log)
+            curs.execute('commit')
 
     def work(self):
+        """Loop over databases."""
+
         self.set_single_loop(1)
         
-        # loop over hosts
-        for cstr in self.args:
-            db = self.get_database('db', connstr = cstr, autocommit = 1)
-            self.upgrade(db)
+        self.load_cur_versions()
+
+        # loop over all dbs
+        dblst = self.args
+        if self.options.all:
+            db = self.connect_db('postgres')
+            curs = db.cursor()
+            curs.execute(DB_LIST)
+            dblst = []
+            for row in curs.fetchall():
+                dblst.append(row[0])
+            self.close_database('db')
+        elif not dblst:
+            raise skytools.UsageError('Give --all or list of database names on command line')
+
+        # loop over connstrs
+        for dbname in dblst:
+            if self.last_sigint:
+                break
+            self.log.info("%s: connecting", dbname)
+            db = self.connect_db(dbname)
+            self.upgrade(dbname, db)
             self.close_database('db')
 
+    def load_cur_versions(self):
+        """Load current version numbers from .upgrade.sql files."""
+
+        vrc = re.compile(r"^ \s+ return \s+ '([0-9.]+)';", re.X | re.I | re.M)
+        for s in AUTO_UPGRADE:
+            fn = '%s.upgrade.sql' % s
+            fqfn = skytools.installer_find_file(fn)
+            try:
+                f = open(fqfn, 'r')
+            except IOError, d:
+                raise skytools.UsageError('%s: cannot find upgrade file: %s [%s]' % (s, fqfn, str(d)))
+
+            sql = f.read()
+            f.close()
+            m = vrc.search(sql)
+            if not m:
+                raise skytools.UsageError('%s: failed to detect version' % fqfn)
+
+            ver = m.group(1)
+            cur = [s, ver, fn, None]
+            self.log.info("Loaded %s %s from %s", s, ver, fqfn)
+            version_list.append(cur)
+
+    def connect_db(self, dbname):
+        """Create connect string, then connect."""
+
+        elems = ["dbname='%s'" % dbname]
+        if self.options.host:
+            elems.append("host='%s'" % self.options.host)
+        if self.options.port:
+            elems.append("port='%s'" % self.options.port)
+        if self.options.user:
+            elems.append("user='%s'" % self.options.user)
+        cstr = ' '.join(elems)
+        return self.get_database('db', connstr = cstr, autocommit = 1)
+
+    def init_optparse(self, parser=None):
+        """Setup commend-line flags."""
+        p = skytools.DBScript.init_optparse(self, parser)
+        p.set_usage(usage)
+        g = optparse.OptionGroup(p, "options for skytools_upgrade")
+        g.add_option("--all", action="store_true", help = 'upgrade all databases')
+        g.add_option("--user", help = 'username to use')
+        g.add_option("--host", help = 'hostname to use')
+        g.add_option("--port", help = 'port to use')
+        p.add_option_group(g)
+        return p
+
     def load_config(self):
-         return skytools.Config(self.service_name, None,
-                 user_defs = {'use_skylog': '0', 'job_name': 'db_upgrade'})
+        """Disable config file."""
+        return skytools.Config(self.service_name, None,
+                user_defs = {'use_skylog': '0', 'job_name': 'db_upgrade'})
 
 if __name__ == '__main__':
-    script = DbUpgrade('db_upgrade', sys.argv[1:])
+    script = DbUpgrade('skytools_upgrade', sys.argv[1:])
     script.start()
 
index ddd3cc7ea19d2e5584f16794afeefe6b23eced15..ca3e5607bad90211dd7be713c28caf1d1c30eb6e 100755 (executable)
@@ -51,8 +51,16 @@ if not INSTALL_SCRIPTS:
 sql_files = [
    'sql/pgq/pgq.sql',
    'sql/londiste/londiste.sql',
-   'sql/pgq_ext/pgq_ext.sql',
    'sql/pgq_node/pgq_node.sql',
+   'sql/pgq_coop/pgq_coop.sql',
+   'sql/pgq_ext/pgq_ext.sql',
+
+   'sql/pgq/pgq.upgrade.sql',
+   'sql/pgq_node/pgq_node.upgrade.sql',
+   'sql/londiste/londiste.upgrade.sql',
+   'sql/pgq_coop/pgq_coop.upgrade.sql',
+
+   'upgrade/final/v3.0_pgq_core.sql',
    #'sql/txid/txid.sql',
 ]