more checker
authorMarko Kreen <markokr@gmail.com>
Thu, 23 Sep 2010 00:52:16 +0000 (17:52 -0700)
committerMarko Kreen <markokr@gmail.com>
Wed, 6 Oct 2010 15:53:32 +0000 (18:53 +0300)
python/skytools/checker.py

index 765afd1428947a230e8a3584adde3273c66cd9a9..dc89f721da136b87abaf86a1d3163c0cd0ac2b14 100644 (file)
 """Catch moment when tables are in sync on master and slave.
 """
 
-import sys, time, os
+import sys, time, os, subprocess
 
 import pkgloader
 pkgloader.require('skytools', '3.0')
 import skytools
 
-CONFDB = "dbname=confdb host=confdb.service user=replicator"
-
-def unescape(s):
-    """Remove copy escapes."""
-    return skytools.unescape_copy(s)
-
-def get_pkey_list(curs, tbl):
-    """Get list of pkey fields in right order."""
-
-    oid = skytools.get_table_oid(curs, tbl)
-    q = """SELECT k.attname FROM pg_index i, pg_attribute k
-            WHERE i.indrelid = %s AND k.attrelid = i.indexrelid
-              AND i.indisprimary AND k.attnum > 0 AND NOT k.attisdropped
-            ORDER BY k.attnum"""
-    curs.execute(q, [oid])
-    list = []
-    for row in curs.fetchall():
-        list.append(row[0])
-    return list
-
-def get_column_list(curs, tbl):
-    """Get list of columns in right order."""
-
-    oid = skytools.get_table_oid(curs, tbl)
-    q = """SELECT a.attname FROM pg_attribute a
-            WHERE a.attrelid = %s
-              AND a.attnum > 0 AND NOT a.attisdropped
-            ORDER BY a.attnum"""
-    curs.execute(q, [oid])
-    list = []
-    for row in curs.fetchall():
-        list.append(row[0])
-    return list
-
-class Checker(skytools.DBScript):
+class TableRepair:
     """Checks that tables in two databases are in sync."""
-    cnt_insert = 0
-    cnt_update = 0
-    cnt_delete = 0
-    total_src = 0
-    total_dst = 0
-    pkey_list = []
-    common_fields = []
-
-    def __init__(self, args):
-        """Checker init."""
-        skytools.DBScript.__init__(self, 'cross_mover', args)
-        self.set_single_loop(1)
-        self.log.info('Checker starting %s' % str(args))
-       # compat names
-        self.queue_name = self.cf.get("pgq_queue_name", '')
-        self.consumer_name = self.cf.get('pgq_consumer_id', '')
-        # good names
-        if not self.queue_name:
-            self.queue_name = self.cf.get("queue_name")
-        if not self.consumer_name:
-            self.consumer_name = self.cf.get('consumer_name', self.job_name)
-        self.lock_timeout = self.cf.getfloat('lock_timeout', 10)
-        # get tables to be compared
-        if not self.options.table_list:
-            self.log.error("--table is required")
-        # create temp pidfile 
-        if self.pidfile:
-            self.pidfile += ".repair"
-
-    def set_lock_timeout(self, curs):
-        ms = int(1000 * self.lock_timeout)
-        if ms > 0:
-            q = "SET LOCAL statement_timeout = %d" % ms
-            self.log.debug(q)
-            curs.execute(q)
-
-    def init_optparse(self, p=None):
-        """ Initialize cmdline switches.
-        """
-        p = skytools.DBScript.init_optparse(self, p)
-        p.add_option("--table", dest='table_list', help="space separated list of table names")
-        p.add_option("--part_expr", dest='part_expr', help="table partitioning expression")
-
-        return p
-
-    def check_consumer(self, setup_curs):
-        """ Before locking anything check if consumer is working ok.
-        """
-        self.log.info("Queue: %s Consumer: %s" % (self.queue_name, self.consumer_name)) 
-        # get ticker lag
-        q = "select extract(epoch from ticker_lag) from pgq.get_queue_info(%s);"
-        setup_curs.execute(q, [self.queue_name])
-        ticker_lag = setup_curs.fetchone()[0]
-        self.log.info("Ticker lag: %s" % ticker_lag)
-        # get consumer lag
-        q = "select extract(epoch from lag) from pgq.get_consumer_info(%s, %s);"
-        setup_curs.execute(q, [self.queue_name, self.consumer_name])
-        res = setup_curs.fetchall()
-        if len(res) == 0:
-            self.log.error('No such consumer')
-            sys.exit(1)
-        consumer_lag = res[0][0]
-        self.log.info("Consumer lag: %s" % consumer_lag) 
-        # check that lag is acceptable
-        if consumer_lag > ticker_lag + 10:
-            self.log.error('Consumer lagging too much, cannot proceed')
-            sys.exit(1)
-
-    def work(self):
-        """Syncer main function."""
-        # get sourcedb connection and slots provided there
-        setup_db = self.get_database('setup_db', autocommit = 1, connstr = self.cf.get('src_db'))
-        setup_curs = setup_db.cursor()
-        setup_curs.execute("select hostname(), current_database();")
-        r_source = setup_curs.fetchone()
-        self.log.info("Source: %s" % str(r_source))
-        
-        # get proxy db name and host (used to find out target cluster target partitons and their respective slots)
-        proxy_db = self.get_database('dst_db', autocommit = 1)
-        proxy_curs = proxy_db.cursor()
-        proxy_curs.execute("select hostname(), current_database();")
-        r_proxy = proxy_curs.fetchone()
-        self.log.info("Proxy: %s" % str(r_proxy))
-        
-        # get target partitions from confdb and do also some sanity checks
-        conf_db = self.get_database('conf_db', autocommit = 1, connstr = CONFDB)
-        conf_curs = conf_db.cursor()
-        q = "select db_name, hostname, slots, max_slot from dba.get_cross_targets(%s, %s, %s, %s)"
-        conf_curs.execute(q, r_source + r_proxy)
-        targets = conf_curs.fetchall()
-        
-        # get special purpose connections for magic locking
-        lock_db = self.get_database('lock_db', connstr = self.cf.get('src_db'))
-        src_db = self.get_database('src_db', isolation_level = skytools.I_SERIALIZABLE)
-        
-        # check that consumer is up and running 
-        self.check_consumer(setup_curs)
-        
-        # loop over all tables and all targets
-        mismatch_count = 0
-        for tbl in self.options.table_list.split():
-            self.log.info("Checking table: %s" % tbl)
-            tbl = skytools.fq_name(tbl)
-            for target in targets:
-                self.log.info("Target: %s" % str(target))
-                connstr = "dbname=%s host=%s user=replicator" % (target[0], target[1])
-                fn = "%s.%s" % (target[1], target[0])
-                dst_db = self.get_database(target[0], isolation_level = skytools.I_SERIALIZABLE, connstr = connstr)
-                where = "%s & %s in (%s)" % (self.options.part_expr, target[3],target[2])
-                if not self.check_table(tbl, lock_db, src_db, dst_db, setup_curs, where, fn):
-                    mismatch_count += 1
-                lock_db.commit()
-                src_db.commit()
-                dst_db.commit()
-        if mismatch_count > 0:
-            self.log.error("%s mismatching tables found" % mismatch_count)
-            sys.exit(1)
 
-    def force_tick(self, setup_curs):
-        """ Force tick into source queue so that consumer can move on faster 
-        """
-        q = "select pgq.force_tick(%s)"
-        setup_curs.execute(q, [self.queue_name])
-        res = setup_curs.fetchone()
-        cur_pos = res[0]
-
-        start = time.time()
-        while 1:
-            time.sleep(0.5)
-            setup_curs.execute(q, [self.queue_name])
-            res = setup_curs.fetchone()
-            if res[0] != cur_pos:
-                # new pos
-                return res[0]
-
-            # dont loop more than 10 secs
-            dur = time.time() - start
-            if dur > 10 and not self.options.force:
-                raise Exception("Ticker seems dead")
-
-    def check_table(self, tbl, lock_db, src_db, dst_db, setup_curs, where, target):
-        """ Get transaction to same state, then process.
-        """
-        lock_curs = lock_db.cursor()
-        src_curs = src_db.cursor()
-        dst_curs = dst_db.cursor()
-
-        if not skytools.exists_table(src_curs, tbl):
-            self.log.warning("Table %s does not exist on provider side" % tbl)
-            return
-        if not skytools.exists_table(dst_curs, tbl):
-            self.log.warning("Table %s does not exist on subscriber side" % tbl)
-            return
-
-        # lock table in separate connection
-        self.log.info('Locking %s' % tbl)
-        lock_db.commit()
-        self.set_lock_timeout(lock_curs)
-        lock_time = time.time()
-        lock_curs.execute("LOCK TABLE %s IN SHARE MODE" % skytools.quote_fqident(tbl))
-
-        # now wait until consumer has updated target table until locking
-        self.log.info('Syncing %s' % tbl)
-
-        # consumer must get further than this tick
-        tick_id = self.force_tick(setup_curs)
-        # try to force second tick also
-        self.force_tick(setup_curs)
-
-        # take server time
-        setup_curs.execute("select to_char(now(), 'YYYY-MM-DD HH24:MI:SS.MS')")
-        tpos = setup_curs.fetchone()[0]
-        # now wait
-        while 1:
-            time.sleep(0.5)
-
-            q = "select now() - lag > timestamp %s, now(), lag from pgq.get_consumer_info(%s, %s)"
-            setup_curs.execute(q, [tpos, self.queue_name, self.consumer_name])
-            res = setup_curs.fetchall()
-            if len(res) == 0:
-                raise Exception('No such consumer')
-            row = res[0]
-            self.log.debug("tpos=%s now=%s lag=%s ok=%s" % (tpos, row[1], row[2], row[0]))
-            if row[0]:
-                break
-
-            # limit lock time
-            if time.time() > lock_time + self.lock_timeout:
-                self.log.error('Consumer lagging too much, exiting')
-                lock_db.rollback()
-                sys.exit(1)
-        
-        # take snapshot on provider side
-        src_db.commit()
-        src_curs.execute("SELECT 1")
+    def __init__(self, table_name, log):
+        self.table_name = table_name
+        self.fq_table_name = skytools.quote_fqident(table_name)
+        self.log = log
+        self.reset()
 
-        # take snapshot on subscriber side
-        dst_db.commit()
-        dst_curs.execute("SELECT 1")
-
-        # release lock
-        lock_db.commit()
-
-        # do work
-        result = self.do_compare(tbl, src_db, dst_db, where)
-        if not result:
-            self.do_repair(tbl, src_db, dst_db, where, target)
-        # done
-        src_db.commit()
-        dst_db.commit()
-
-        return result
+    def reset(self):
+        self.cnt_insert = 0
+        self.cnt_update = 0
+        self.cnt_delete = 0
+        self.total_src = 0
+        self.total_dst = 0
+        self.pkey_list = []
+        self.common_fields = []
 
-    def do_compare(self, tbl, src_db, dst_db, where):
+    def do_repair(self, src_db, dst_db, where, pfx = 'repair', apply_fixes = False):
         """Actual comparision."""
 
-        src_curs = src_db.cursor()
-        dst_curs = dst_db.cursor()
-
-        self.log.info('Counting %s' % tbl)
-
-        q = "select count(1) as cnt, sum(hashtext(t.*::text)) as chksum from only _TABLE_ t where %s;" %  where
-        q = self.cf.get('compare_sql', q)
-        q = q.replace('_TABLE_', skytools.quote_fqident(tbl))
-
-        f = "%(cnt)d rows, checksum=%(chksum)s"
-        f = self.cf.get('compare_fmt', f)
-
-        self.log.debug("srcdb: " + q)
-        src_curs.execute(q)
-        src_row = src_curs.fetchone()
-        src_str = f % src_row
-        self.log.info("srcdb: %s" % src_str)
-
-        self.log.debug("dstdb: " + q)
-        dst_curs.execute(q)
-        dst_row = dst_curs.fetchone()
-        dst_str = f % dst_row
-        self.log.info("dstdb: %s" % dst_str)
-
-        if src_str != dst_str:
-            self.log.warning("%s: Results do not match!" % tbl)
-            return False
-        else:
-            self.log.info("%s: OK!" % tbl)
-            return True
-
-    def do_repair(self, tbl, src_db, dst_db, where, target):
-        """Actual comparision."""
+        self.reset()
 
         src_curs = src_db.cursor()
         dst_curs = dst_db.cursor()
 
-        self.log.info('Checking %s' % tbl)
+        self.log.info('Checking %s' % self.table_name)
 
-        self.common_fields = []
-        self.pkey_list = []
-        copy_tbl = self.gen_copy_tbl(tbl, src_curs, dst_curs, where)
+        copy_tbl = self.gen_copy_tbl(src_curs, dst_curs, where)
 
-        dump_src = target + "__" + tbl + ".src"
-        dump_dst = target + "__" + tbl + ".dst"
+        dump_src = "%s.%s.src" % (pfx, self.table_name)
+        dump_dst = "%s.%s.dst" % (pfx, self.table_name)
+        fix = "%s.%s.fix" % (pfx, self.table_name)
 
-        self.log.info("Dumping src table: %s" % tbl)
-        self.dump_table(tbl, copy_tbl, src_curs, dump_src)
+        self.log.info("Dumping src table: %s" % self.table_name)
+        self.dump_table(copy_tbl, src_curs, dump_src)
         src_db.commit()
-        self.log.info("Dumping dst table: %s" % tbl)
-        self.dump_table(tbl, copy_tbl, dst_curs, dump_dst)
+        self.log.info("Dumping dst table: %s" % self.table_name)
+        self.dump_table(copy_tbl, dst_curs, dump_dst)
         dst_db.commit()
         
-        self.log.info("Sorting src table: %s" % tbl)
+        self.log.info("Sorting src table: %s" % self.table_name)
+        self.do_sort(dump_src, dump_src + '.sorted')
 
-        s_in, s_out = os.popen4("sort --version")
-        s_ver = s_out.read()
-        del s_in, s_out
-        if s_ver.find("coreutils") > 0:
-            args = "-S 30%"
-        else:
-            args = ""
-        os.system("sort %s -T . -o %s.sorted %s" % (args, dump_src, dump_src))
-        self.log.info("Sorting dst table: %s" % tbl)
-        os.system("sort %s -T . -o %s.sorted %s" % (args, dump_dst, dump_dst))
+        self.log.info("Sorting dst table: %s" % self.table_name)
+        self.do_sort(dump_dst, dump_dst + '.sorted')
 
-        self.dump_compare(tbl, dump_src + ".sorted", dump_dst + ".sorted", target)
+        self.dump_compare(dump_src + ".sorted", dump_dst + ".sorted", fix)
 
         os.unlink(dump_src)
         os.unlink(dump_dst)
         os.unlink(dump_src + ".sorted")
         os.unlink(dump_dst + ".sorted")
 
-    def gen_copy_tbl(self, tbl, src_curs, dst_curs, where):
+        if apply_fixes:
+            pass
+
+    def do_sort(self, src, dst):
+        p = subprocess.Popen(["sort", "--version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+        s_ver = p.communicate()[0]
+        del p
+
+        xenv = os.environ.copy()
+        xenv['LANG'] = 'C'
+        xenv['LC_ALL'] = 'C'
+
+        cmdline = ['sort', '-T', '.']
+        if s_ver.find("coreutils") > 0:
+            cmdline.append('-S')
+            cmdline.append('30%')
+        cmdline.append('-o')
+        cmdline.append(dst)
+        cmdline.append(src)
+        p = subprocess.Popen(cmdline, env = xenv)
+        if p.wait() != 0:
+            raise Exception('sort failed')
+
+    def gen_copy_tbl(self, src_curs, dst_curs, where):
         """Create COPY expession from common fields."""
-        self.pkey_list = get_pkey_list(src_curs, tbl)
-        dst_pkey = get_pkey_list(dst_curs, tbl)
+        self.pkey_list = skytools.get_table_pkeys(src_curs, self.table_name)
+        dst_pkey = skytools.get_table_pkeys(dst_curs, self.table_name)
         if dst_pkey != self.pkey_list:
             self.log.error('pkeys do not match')
             sys.exit(1)
 
-        src_cols = get_column_list(src_curs, tbl)
-        dst_cols = get_column_list(dst_curs, tbl)
+        src_cols = skytools.get_table_columns(src_curs, self.table_name)
+        dst_cols = skytools.get_table_columns(dst_curs, self.table_name)
         field_list = []
         for f in self.pkey_list:
             field_list.append(f)
@@ -356,19 +108,21 @@ class Checker(skytools.DBScript):
 
         fqlist = [skytools.quote_ident(col) for col in field_list]
 
-        tbl_expr = "( select %s from %s where %s )" % (",".join(fqlist), skytools.quote_fqident(tbl), where)
+        tbl_expr = "select %s from %s" % (",".join(fqlist), self.fq_table_name)
+        if where:
+            tbl_expr += ' where ' + where
+        tbl_expr = "COPY (%s) TO STDOUT" % tbl_expr
 
         self.log.debug("using copy expr: %s" % tbl_expr)
 
         return tbl_expr
 
-    def dump_table(self, tbl, copy_tbl, curs, fn):
+    def dump_table(self, copy_cmd, curs, fn):
         """Dump table to disk."""
         f = open(fn, "w", 64*1024)
-        curs.copy_to(f, copy_tbl)
-        size = f.tell()
+        curs.copy_expert(f, copy_cmd)
+        self.log.info('%s: Got %d bytes' % (self.table_name, f.tell()))
         f.close()
-        self.log.info('%s: Got %d bytes' % (tbl, size))
 
     def get_row(self, ln):
         """Parse a row into dict."""
@@ -380,14 +134,9 @@ class Checker(skytools.DBScript):
             row[self.common_fields[i]] = t[i]
         return row
 
-    def dump_compare(self, tbl, src_fn, dst_fn, target):
+    def dump_compare(self, src_fn, dst_fn, fix):
         """Dump + compare single table."""
-        self.log.info("Comparing dumps: %s" % tbl)
-        self.cnt_insert = 0
-        self.cnt_update = 0
-        self.cnt_delete = 0
-        self.total_src = 0
-        self.total_dst = 0
+        self.log.info("Comparing dumps: %s" % self.table_name)
         f1 = open(src_fn, "r", 64*1024)
         f2 = open(dst_fn, "r", 64*1024)
         src_ln = f1.readline()
@@ -395,7 +144,6 @@ class Checker(skytools.DBScript):
         if src_ln: self.total_src += 1
         if dst_ln: self.total_dst += 1
 
-        fix = "fix.%s.%s.sql" % (target, tbl)
         if os.path.isfile(fix):
             os.unlink(fix)
 
@@ -408,15 +156,15 @@ class Checker(skytools.DBScript):
                 diff = self.cmp_keys(src_row, dst_row)
                 if diff > 0:
                     # src > dst
-                    self.got_missed_delete(tbl, dst_row, fix)
+                    self.got_missed_delete(dst_row, fix)
                     keep_src = 1
                 elif diff < 0:
                     # src < dst
-                    self.got_missed_insert(tbl, src_row, fix)
+                    self.got_missed_insert(src_row, fix)
                     keep_dst = 1
                 else:
                     if self.cmp_data(src_row, dst_row) != 0:
-                        self.got_missed_update(tbl, src_row, dst_row, fix)
+                        self.got_missed_update(src_row, dst_row, fix)
 
             if not keep_src:
                 src_ln = f1.readline()
@@ -427,10 +175,10 @@ class Checker(skytools.DBScript):
 
         self.log.info("finished %s: src: %d rows, dst: %d rows,"\
                     " missed: %d inserts, %d updates, %d deletes" % (
-                tbl, self.total_src, self.total_dst,
+                self.table_name, self.total_src, self.total_dst,
                 self.cnt_insert, self.cnt_update, self.cnt_delete))
 
-    def got_missed_insert(self, tbl, src_row, fn):
+    def got_missed_insert(self, src_row, fn):
         """Create sql for missed insert."""
         self.cnt_insert += 1
         fld_list = self.common_fields
@@ -438,43 +186,43 @@ class Checker(skytools.DBScript):
         val_list = []
         for f in fld_list:
             fq_list.append(skytools.quote_ident(f))
-            v = unescape(src_row[f])
+            v = skytools.unescape_copy(src_row[f])
             val_list.append(skytools.quote_literal(v))
         q = "insert into %s (%s) values (%s);" % (
-                tbl, ", ".join(fq_list), ", ".join(val_list))
-        self.show_fix(tbl, q, 'insert', fn)
+                self.fq_table_name, ", ".join(fq_list), ", ".join(val_list))
+        self.show_fix(q, 'insert', fn)
 
-    def got_missed_update(self, tbl, src_row, dst_row, fn):
+    def got_missed_update(self, src_row, dst_row, fn):
         """Create sql for missed update."""
         self.cnt_update += 1
         fld_list = self.common_fields
         set_list = []
         whe_list = []
         for f in self.pkey_list:
-            self.addcmp(whe_list, skytools.quote_ident(f), unescape(src_row[f]))
+            self.addcmp(whe_list, skytools.quote_ident(f), skytools.unescape_copy(src_row[f]))
         for f in fld_list:
             v1 = src_row[f]
             v2 = dst_row[f]
             if self.cmp_value(v1, v2) == 0:
                 continue
 
-            self.addeq(set_list, skytools.quote_ident(f), unescape(v1))
-            self.addcmp(whe_list, skytools.quote_ident(f), unescape(v2))
+            self.addeq(set_list, skytools.quote_ident(f), skytools.unescape_copy(v1))
+            self.addcmp(whe_list, skytools.quote_ident(f), skytools.unescape_copy(v2))
 
         q = "update only %s set %s where %s;" % (
-                tbl, ", ".join(set_list), " and ".join(whe_list))
-        self.show_fix(tbl, q, 'update', fn)
+                self.fq_table_name, ", ".join(set_list), " and ".join(whe_list))
+        self.show_fix(q, 'update', fn)
 
-    def got_missed_delete(self, tbl, dst_row, fn):
+    def got_missed_delete(self, dst_row, fn):
         """Create sql for missed delete."""
         self.cnt_delete += 1
         whe_list = []
         for f in self.pkey_list:
-            self.addcmp(whe_list, skytools.quote_ident(f), unescape(dst_row[f]))
-        q = "delete from only %s where %s;" % (skytools.quote_fqident(tbl), " and ".join(whe_list))
-        self.show_fix(tbl, q, 'delete', fn)
+            self.addcmp(whe_list, skytools.quote_ident(f), skytools.unescape_copy(dst_row[f]))
+        q = "delete from only %s where %s;" % (self.fq_table_name, " and ".join(whe_list))
+        self.show_fix(q, 'delete', fn)
 
-    def show_fix(self, tbl, q, desc, fn):
+    def show_fix(self, q, desc, fn):
         """Print/write/apply repair sql."""
         self.log.debug("missed %s: %s" % (desc, q))
         open(fn, "a").write("%s\n" % q)
@@ -544,8 +292,3 @@ class Checker(skytools.DBScript):
                 return 1
         return 0
 
-
-if __name__ == '__main__':
-    script = Checker(sys.argv[1:])
-    script.start()
-