skytools.checker: generalize more
authorMarko Kreen <markokr@gmail.com>
Tue, 12 Oct 2010 12:08:18 +0000 (15:08 +0300)
committerMarko Kreen <markokr@gmail.com>
Tue, 12 Oct 2010 12:08:18 +0000 (15:08 +0300)
python/skytools/checker.py

index dc89f721da136b87abaf86a1d3163c0cd0ac2b14..b378f2ad67825d9987e852c8df95ccfd80a48841 100644 (file)
@@ -1,8 +1,9 @@
+#! /usr/bin/env python
 
 """Catch moment when tables are in sync on master and slave.
 """
 
-import sys, time, os, subprocess
+import sys, time, os
 
 import pkgloader
 pkgloader.require('skytools', '3.0')
@@ -48,7 +49,7 @@ class TableRepair:
         self.log.info("Dumping dst table: %s" % self.table_name)
         self.dump_table(copy_tbl, dst_curs, dump_dst)
         dst_db.commit()
-        
+
         self.log.info("Sorting src table: %s" % self.table_name)
         self.do_sort(dump_src, dump_src + '.sorted')
 
@@ -292,3 +293,287 @@ class TableRepair:
                 return 1
         return 0
 
+
+class Syncer(skytools.DBScript):
+    """Checks that tables in two databases are in sync."""
+    lock_timeout = 10
+    ticker_lag_limit = 20
+    consumer_lag_limit = 20
+
+    def sync_table(self, cstr1, cstr2, queue_name, consumer_name, table_name):
+        """Syncer main function.
+
+        Returns (src_db, dst_db) that are in transaction
+        where table should be in sync.
+        """
+
+        setup_db = self.get_database('setup_db', connstr = cstr1, autocommit = 1)
+        lock_db = self.get_database('lock_db', connstr = cstr1)
+
+        src_db = self.get_database('src_db', connstr = cstr1,
+                isolation_level = skytools.I_SERIALIZABLE)
+        dst_db = self.get_database('dst_db', connstr = cstr2,
+                isolation_level = skytools.I_SERIALIZABLE)
+
+        lock_curs = lock_db.cursor()
+        setup_curs = setup_db.cursor()
+        src_curs = src_db.cursor()
+        dst_curs = dst_db.cursor()
+
+        self.check_consumer(setup_curs, queue_name, consumer_name)
+
+        # lock table in separate connection
+        self.log.info('Locking %s' % table_name)
+        self.set_lock_timeout(lock_curs)
+        lock_time = time.time()
+        lock_curs.execute("LOCK TABLE %s IN SHARE MODE" % skytools.quote_fqident(table_name))
+
+        # now wait until consumer has updated target table until locking
+        self.log.info('Syncing %s' % table_name)
+
+        # consumer must get further than this tick
+        tick_id = self.force_tick(setup_curs, queue_name)
+        # try to force second tick also
+        self.force_tick(setup_curs, queue_name)
+
+        # take server time
+        setup_curs.execute("select to_char(now(), 'YYYY-MM-DD HH24:MI:SS.MS')")
+        tpos = setup_curs.fetchone()[0]
+        # now wait
+        while 1:
+            time.sleep(0.5)
+
+            q = "select now() - lag > timestamp %s, now(), lag from pgq.get_consumer_info(%s, %s)"
+            setup_curs.execute(q, [tpos, queue_name, consumer_name])
+            res = setup_curs.fetchall()
+            if len(res) == 0:
+                raise Exception('No such consumer: %s/%s' % (queue_name, consumer_name))
+            row = res[0]
+            self.log.debug("tpos=%s now=%s lag=%s ok=%s" % (tpos, row[1], row[2], row[0]))
+            if row[0]:
+                break
+
+            # limit lock time
+            if time.time() > lock_time + self.lock_timeout:
+                self.log.error('Consumer lagging too much, exiting')
+                lock_db.rollback()
+                sys.exit(1)
+
+        # take snapshot on provider side
+        src_db.commit()
+        src_curs.execute("SELECT 1")
+
+        # take snapshot on subscriber side
+        dst_db.commit()
+        dst_curs.execute("SELECT 1")
+
+        # release lock
+        lock_db.commit()
+
+        self.close_database('setup_db')
+        self.close_database('lock_db')
+
+        return (src_db, dst_db)
+
+    def set_lock_timeout(self, curs):
+        ms = int(1000 * self.lock_timeout)
+        if ms > 0:
+            q = "SET LOCAL statement_timeout = %d" % ms
+            self.log.debug(q)
+            curs.execute(q)
+
+    def check_consumer(self, curs, queue_name, consumer_name):
+        """ Before locking anything check if consumer is working ok.
+        """
+        self.log.info("Queue: %s Consumer: %s" % (queue_name, consumer_name))
+
+        curs.execute('select current_database()')
+        self.log.info('Actual db: %s' % curs.fetchone()[0])
+
+        # get ticker lag
+        q = "select extract(epoch from ticker_lag) from pgq.get_queue_info(%s);"
+        curs.execute(q, [queue_name])
+        ticker_lag = curs.fetchone()[0]
+        self.log.info("Ticker lag: %s" % ticker_lag)
+        # get consumer lag
+        q = "select extract(epoch from lag) from pgq.get_consumer_info(%s, %s);"
+        curs.execute(q, [queue_name, consumer_name])
+        res = curs.fetchall()
+        if len(res) == 0:
+            self.log.error('check_consumer: No such consumer: %s/%s' % (queue_name, consumer_name))
+            sys.exit(1)
+        consumer_lag = res[0][0]
+
+        # check that lag is acceptable
+        self.log.info("Consumer lag: %s" % consumer_lag)
+        if consumer_lag > ticker_lag + 10:
+            self.log.error('Consumer lagging too much, cannot proceed')
+            sys.exit(1)
+
+    def force_tick(self, curs, queue_name):
+        """ Force tick into source queue so that consumer can move on faster
+        """
+        q = "select pgq.force_tick(%s)"
+        curs.execute(q, [queue_name])
+        res = curs.fetchone()
+        cur_pos = res[0]
+
+        start = time.time()
+        while 1:
+            time.sleep(0.5)
+            curs.execute(q, [queue_name])
+            res = curs.fetchone()
+            if res[0] != cur_pos:
+                # new pos
+                return res[0]
+
+            # dont loop more than 10 secs
+            dur = time.time() - start
+            if dur > 10 and not self.options.force:
+                raise Exception("Ticker seems dead")
+
+
+class Checker(Syncer):
+    """Checks that tables in two databases are in sync.
+    
+    Config options::
+
+        ## data_checker ##
+        confdb = dbname=confdb host=confdb.service
+
+        extra_connstr = user=marko
+
+        # random params used in queries
+        cluster_name =
+        instance_name =
+        proxy_host =
+        proxy_db =
+
+        # list of tables to be compared
+        table_list = foo, bar, baz
+
+        where_expr = (hashtext(key_user_name) & %%(max_slots)s) in (%%(slots)s)
+
+        # gets no args
+        source_query =
+         select h.hostname, d.db_name
+           from dba.cluster c
+                join dba.cluster_host ch on (ch.key_cluster = c.id_cluster)
+                join conf.host h on (h.id_host = ch.key_host)
+                join dba.database d on (d.key_host = ch.key_host)
+          where c.db_name = '%(cluster_name)s'
+            and c.instance_name = '%(instance_name)s'
+            and d.mk_db_type = 'partition'
+            and d.mk_db_status = 'active'
+          order by d.db_name, h.hostname
+
+
+        target_query =
+            select db_name, hostname, slots, max_slot
+              from dba.get_cross_targets(%%(hostname)s, %%(db_name)s, '%(proxy_host)s', '%(proxy_db)s')
+
+        consumer_query =
+            select q.queue_name, c.consumer_name
+              from conf.host h
+              join dba.database d on (d.key_host = h.id_host)
+              join dba.pgq_queue q on (q.key_database = d.id_database)
+              join dba.pgq_consumer c on (c.key_queue = q.id_queue)
+             where h.hostname = %%(hostname)s
+               and d.db_name = %%(db_name)s
+               and q.queue_name like 'xm%%%%'
+    """
+
+    def __init__(self, args):
+        """Checker init."""
+        Syncer.__init__(self, 'data_checker', args)
+        self.set_single_loop(1)
+        self.log.info('Checker starting %s' % str(args))
+
+        self.lock_timeout = self.cf.getfloat('lock_timeout', 10)
+
+        self.table_list = self.cf.getlist('table_list')
+
+    def work(self):
+        """Syncer main function."""
+
+        source_query = self.cf.get('source_query')
+        target_query = self.cf.get('target_query')
+        consumer_query = self.cf.get('consumer_query')
+        hash_expr = self.cf.get('hash_expr')
+        extra_connstr = self.cf.get('extra_connstr')
+
+        confdb = self.get_database('confdb', autocommit=1)
+        curs = confdb.cursor()
+
+        curs.execute(source_query)
+        for src_row in curs.fetchall():
+            s_host = src_row['hostname']
+            s_db = src_row['db_name']
+
+            curs.execute(consumer_query, src_row)
+            r = curs.fetchone()
+            consumer_name = r['consumer_name']
+            queue_name = r['queue_name']
+
+            curs.execute(target_query, src_row)
+            for dst_row in curs.fetchall():
+                d_db = dst_row['db_name']
+                d_host = dst_row['hostname']
+                slots = dst_row['slots']
+                max_slot = dst_row['max_slot']
+
+                self.log.info('Source: db=%s host=%s queue=%s consumer=%s' % (
+                              s_db, s_host, queue_name, consumer_name))
+                self.log.info('Target: db=%s host=%s slots=%s' % (d_db, d_host, slots))
+
+                cstr1 = "dbname=%s host=%s %s" % (s_db, s_host, extra_connstr)
+                cstr2 = "dbname=%s host=%s %s" % (d_db, d_host, extra_connstr)
+                where = "(%s & %d) in (%s)" % (hash_expr, max_slot, slots)
+
+                for tbl in self.table_list:
+                    src_db, dst_db = self.sync_table(cstr1, cstr2, queue_name, consumer_name, tbl)
+                    self.do_compare(tbl, src_db, dst_db, where)
+                    self.reset()
+
+    def do_compare(self, tbl, src_db, dst_db, where):
+        """Actual comparision."""
+
+        src_curs = src_db.cursor()
+        dst_curs = dst_db.cursor()
+
+        self.log.info('Counting %s' % tbl)
+
+        q = "select count(1) as cnt, sum(hashtext(t.*::text)) as chksum from only _TABLE_ t where %s;" %  where
+        q = self.cf.get('compare_sql', q)
+        q = q.replace('_TABLE_', skytools.quote_fqident(tbl))
+
+        f = "%(cnt)d rows, checksum=%(chksum)s"
+        f = self.cf.get('compare_fmt', f)
+
+        self.log.debug("srcdb: " + q)
+        src_curs.execute(q)
+        src_row = src_curs.fetchone()
+        src_str = f % src_row
+        self.log.info("srcdb: %s" % src_str)
+
+        self.log.debug("dstdb: " + q)
+        dst_curs.execute(q)
+        dst_row = dst_curs.fetchone()
+        dst_str = f % dst_row
+        self.log.info("dstdb: %s" % dst_str)
+
+        src_db.commit()
+        dst_db.commit()
+
+        if src_str != dst_str:
+            self.log.warning("%s: Results do not match!" % tbl)
+            return False
+        else:
+            self.log.info("%s: OK!" % tbl)
+            return True
+
+
+if __name__ == '__main__':
+    script = Checker(sys.argv[1:])
+    script.start()
+