"""Upgrade script for versioned schemas."""
-import sys, os, re
+usage = """
+ %prog [--user=U] [--host=H] [--port=P] --all
+ %prog [--user=U] [--host=H] [--port=P] DB1 [ DB2 ... ]\
+"""
+
+import sys, os, re, optparse
import pkgloader
pkgloader.require('skytools', '3.0')
-
import skytools
+from skytools.natsort import natsort_key
-ver_rx = r"(\d+)([.](\d+)([.](\d+))?)?"
-ver_rc = re.compile(ver_rx)
-def detect_londiste215(curs):
- return skytools.exists_table(curs, 'londiste.subscriber_pending_fkeys')
+# schemas, where .upgrade.sql is enough
+AUTO_UPGRADE = ('pgq', 'pgq_node', 'pgq_coop', 'londiste')
-version_list = [
- ['pgq', '2.1.5', 'v2.1.5_pgq_core.sql', None],
- # those vers did not have version func
- ['pgq_ext', '2.1.5', 'v2.1.5_pgq_ext.sql', None], # ok to reapply
- ['londiste', '2.1.5', 'v2.1.5_londiste.sql', detect_londiste215], # not ok to reapply
+# fetch list of databases
+DB_LIST = "select datname from pg_database "\
+ " where not datistemplate and datallowconn "\
+ " order by 1"
- ['pgq_ext', '2.1.6', 'v2.1.6_pgq_ext.sql', None],
- ['londiste', '2.1.6', 'v2.1.6_londiste.sql', None],
+# dont support upgrade from 2.x (yet?)
+version_list = [
+ ['pgq', '3.0', None, None],
+ ['londiste', '3.0', None, None],
+]
- ['pgq', '2.1.7', 'v2.1.7_pgq_core.sql', None],
- ['londiste', '2.1.7', 'v2.1.7_londiste.sql', None],
- ['pgq', '2.1.8', 'v2.1.8_pgq_core.sql', None],
-]
+def is_version_ge(a, b):
+ """Return True if a is greater or equal than b."""
+ va = natsort_key(a)
+ vb = natsort_key(b)
+ return va >= vb
-def parse_ver(ver):
- m = ver_rc.match(ver)
- if not ver: return 0
- v0 = int(m.group(1) or "0")
- v1 = int(m.group(3) or "0")
- v2 = int(m.group(5) or "0")
- return ((v0 * 100) + v1) * 100 + v2
def check_version(curs, schema, new_ver_str, recheck_func=None):
funcname = "%s.version" % schema
q = "select %s()" % funcname
curs.execute(q)
old_ver_str = curs.fetchone()[0]
- new_ver = parse_ver(new_ver_str)
- old_ver = parse_ver(old_ver_str)
- return old_ver >= new_ver
-
+ return is_version_ge(old_ver_str, new_ver_str)
+
class DbUpgrade(skytools.DBScript):
- def upgrade(self, db):
+ """Upgrade all Skytools schemas in Postgres cluster."""
+
+ def upgrade(self, dbname, db):
+ """Upgrade all schemas in single db."""
+
curs = db.cursor()
- for schema, ver, sql, recheck_fn in version_list:
+ for schema, ver, fn, recheck_fn in version_list:
if not skytools.exists_schema(curs, schema):
continue
if check_version(curs, schema, ver, recheck_fn):
continue
- fn = "upgrade/final/%s" % sql
+ if fn is None:
+ self.log.info('%s: Cannot upgrade %s, too old version', dbname, schema)
+ continue
+
+ curs = db.cursor()
+ curs.execute('begin')
skytools.installer_apply_file(db, fn, self.log)
+ curs.execute('commit')
def work(self):
+ """Loop over databases."""
+
self.set_single_loop(1)
- # loop over hosts
- for cstr in self.args:
- db = self.get_database('db', connstr = cstr, autocommit = 1)
- self.upgrade(db)
+ self.load_cur_versions()
+
+ # loop over all dbs
+ dblst = self.args
+ if self.options.all:
+ db = self.connect_db('postgres')
+ curs = db.cursor()
+ curs.execute(DB_LIST)
+ dblst = []
+ for row in curs.fetchall():
+ dblst.append(row[0])
+ self.close_database('db')
+ elif not dblst:
+ raise skytools.UsageError('Give --all or list of database names on command line')
+
+ # loop over connstrs
+ for dbname in dblst:
+ if self.last_sigint:
+ break
+ self.log.info("%s: connecting", dbname)
+ db = self.connect_db(dbname)
+ self.upgrade(dbname, db)
self.close_database('db')
+ def load_cur_versions(self):
+ """Load current version numbers from .upgrade.sql files."""
+
+ vrc = re.compile(r"^ \s+ return \s+ '([0-9.]+)';", re.X | re.I | re.M)
+ for s in AUTO_UPGRADE:
+ fn = '%s.upgrade.sql' % s
+ fqfn = skytools.installer_find_file(fn)
+ try:
+ f = open(fqfn, 'r')
+ except IOError, d:
+ raise skytools.UsageError('%s: cannot find upgrade file: %s [%s]' % (s, fqfn, str(d)))
+
+ sql = f.read()
+ f.close()
+ m = vrc.search(sql)
+ if not m:
+ raise skytools.UsageError('%s: failed to detect version' % fqfn)
+
+ ver = m.group(1)
+ cur = [s, ver, fn, None]
+ self.log.info("Loaded %s %s from %s", s, ver, fqfn)
+ version_list.append(cur)
+
+ def connect_db(self, dbname):
+ """Create connect string, then connect."""
+
+ elems = ["dbname='%s'" % dbname]
+ if self.options.host:
+ elems.append("host='%s'" % self.options.host)
+ if self.options.port:
+ elems.append("port='%s'" % self.options.port)
+ if self.options.user:
+ elems.append("user='%s'" % self.options.user)
+ cstr = ' '.join(elems)
+ return self.get_database('db', connstr = cstr, autocommit = 1)
+
+ def init_optparse(self, parser=None):
+ """Setup commend-line flags."""
+ p = skytools.DBScript.init_optparse(self, parser)
+ p.set_usage(usage)
+ g = optparse.OptionGroup(p, "options for skytools_upgrade")
+ g.add_option("--all", action="store_true", help = 'upgrade all databases')
+ g.add_option("--user", help = 'username to use')
+ g.add_option("--host", help = 'hostname to use')
+ g.add_option("--port", help = 'port to use')
+ p.add_option_group(g)
+ return p
+
def load_config(self):
- return skytools.Config(self.service_name, None,
- user_defs = {'use_skylog': '0', 'job_name': 'db_upgrade'})
+ """Disable config file."""
+ return skytools.Config(self.service_name, None,
+ user_defs = {'use_skylog': '0', 'job_name': 'db_upgrade'})
if __name__ == '__main__':
- script = DbUpgrade('db_upgrade', sys.argv[1:])
+ script = DbUpgrade('skytools_upgrade', sys.argv[1:])
script.start()