skytools.checker: fix, add repair-apply mode
authorMarko Kreen <markokr@gmail.com>
Wed, 13 Oct 2010 13:57:37 +0000 (16:57 +0300)
committerMarko Kreen <markokr@gmail.com>
Wed, 13 Oct 2010 13:57:37 +0000 (16:57 +0300)
python/skytools/checker.py [changed mode: 0644->0755]

old mode 100644 (file)
new mode 100755 (executable)
index b378f2a..785cd9d
@@ -3,7 +3,7 @@
 """Catch moment when tables are in sync on master and slave.
 """
 
-import sys, time, os
+import sys, time, os, subprocess
 
 import pkgloader
 pkgloader.require('skytools', '3.0')
@@ -26,6 +26,8 @@ class TableRepair:
         self.total_dst = 0
         self.pkey_list = []
         self.common_fields = []
+        self.apply_fixes = False
+        self.apply_cursor = None
 
     def do_repair(self, src_db, dst_db, where, pfx = 'repair', apply_fixes = False):
         """Actual comparision."""
@@ -35,6 +37,10 @@ class TableRepair:
         src_curs = src_db.cursor()
         dst_curs = dst_db.cursor()
 
+        self.apply_fixes = apply_fixes
+        if apply_fixes:
+            self.apply_cursor = dst_curs
+
         self.log.info('Checking %s' % self.table_name)
 
         copy_tbl = self.gen_copy_tbl(src_curs, dst_curs, where)
@@ -64,7 +70,7 @@ class TableRepair:
         os.unlink(dump_dst + ".sorted")
 
         if apply_fixes:
-            pass
+            dst_db.commit()
 
     def do_sort(self, src, dst):
         p = subprocess.Popen(["sort", "--version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
@@ -121,7 +127,7 @@ class TableRepair:
     def dump_table(self, copy_cmd, curs, fn):
         """Dump table to disk."""
         f = open(fn, "w", 64*1024)
-        curs.copy_expert(f, copy_cmd)
+        curs.copy_expert(copy_cmd, f)
         self.log.info('%s: Got %d bytes' % (self.table_name, f.tell()))
         f.close()
 
@@ -228,6 +234,9 @@ class TableRepair:
         self.log.debug("missed %s: %s" % (desc, q))
         open(fn, "a").write("%s\n" % q)
 
+        if self.apply_fixes:
+            self.apply_cursor.execute(q)
+
     def addeq(self, list, f, v):
         """Add quoted SET."""
         vq = skytools.quote_literal(v)
@@ -443,6 +452,9 @@ class Checker(Syncer):
 
         extra_connstr = user=marko
 
+        # one of: compare, repair, repair-apply
+        check_type = compare
+
         # random params used in queries
         cluster_name =
         instance_name =
@@ -452,7 +464,7 @@ class Checker(Syncer):
         # list of tables to be compared
         table_list = foo, bar, baz
 
-        where_expr = (hashtext(key_user_name) & %%(max_slots)s) in (%%(slots)s)
+        where_expr = (hashtext(key_user_name) & %%(max_slot)s) in (%%(slots)s)
 
         # gets no args
         source_query =
@@ -499,9 +511,11 @@ class Checker(Syncer):
         source_query = self.cf.get('source_query')
         target_query = self.cf.get('target_query')
         consumer_query = self.cf.get('consumer_query')
-        hash_expr = self.cf.get('hash_expr')
+        where_expr = self.cf.get('where_expr')
         extra_connstr = self.cf.get('extra_connstr')
 
+        check = self.cf.get('check_type', 'compare')
+
         confdb = self.get_database('confdb', autocommit=1)
         curs = confdb.cursor()
 
@@ -519,20 +533,27 @@ class Checker(Syncer):
             for dst_row in curs.fetchall():
                 d_db = dst_row['db_name']
                 d_host = dst_row['hostname']
-                slots = dst_row['slots']
-                max_slot = dst_row['max_slot']
-
-                self.log.info('Source: db=%s host=%s queue=%s consumer=%s' % (
-                              s_db, s_host, queue_name, consumer_name))
-                self.log.info('Target: db=%s host=%s slots=%s' % (d_db, d_host, slots))
 
                 cstr1 = "dbname=%s host=%s %s" % (s_db, s_host, extra_connstr)
                 cstr2 = "dbname=%s host=%s %s" % (d_db, d_host, extra_connstr)
-                where = "(%s & %d) in (%s)" % (hash_expr, max_slot, slots)
+                where = where_expr % dst_row
+
+                self.log.info('Source: db=%s host=%s queue=%s consumer=%s' % (
+                                  s_db, s_host, queue_name, consumer_name))
+                self.log.info('Target: db=%s host=%s where=%s' % (d_db, d_host, where))
 
                 for tbl in self.table_list:
                     src_db, dst_db = self.sync_table(cstr1, cstr2, queue_name, consumer_name, tbl)
-                    self.do_compare(tbl, src_db, dst_db, where)
+                    if check == 'compare':
+                        self.do_compare(tbl, src_db, dst_db, where)
+                    elif check == 'repair':
+                        r = TableRepair(tbl, self.log)
+                        r.do_repair(src_db, dst_db, where, 'fix.' + tbl, False)
+                    elif check == 'repair-apply':
+                        r = TableRepair(tbl, self.log)
+                        r.do_repair(src_db, dst_db, where, 'fix.' + tbl, True)
+                    else:
+                        raise Exception('unknown check type')
                     self.reset()
 
     def do_compare(self, tbl, src_db, dst_db, where):