python/londiste rewrite for cascading
authorMarko Kreen <markokr@gmail.com>
Fri, 13 Feb 2009 11:14:37 +0000 (13:14 +0200)
committerMarko Kreen <markokr@gmail.com>
Fri, 13 Feb 2009 13:20:32 +0000 (15:20 +0200)
New features:
- Cascading
- 'execute' command for running SQL scripts on nodes
- Parallel COPY
- Partition merge
- Sequences are pushed from root
- Rename 'add' command to 'add-table'
- --create switch to add-seq / add-table

also drop the never-implemented file-based transport classes.

python/londiste.py
python/londiste/__init__.py
python/londiste/compare.py
python/londiste/file_read.py [deleted file]
python/londiste/file_write.py [deleted file]
python/londiste/installer.py [deleted file]
python/londiste/playback.py
python/londiste/repair.py
python/londiste/setup.py
python/londiste/syncer.py
python/londiste/table_copy.py

index 44818d8bb51fec2072ceaa6c13483bdf74b5e971..62f434e9ed8c1f104f790a1cfb537ac3165bb8ab 100755 (executable)
@@ -10,9 +10,9 @@ if os.path.exists(os.path.join(sys.path[0], 'londiste.py')) \
     and not os.path.exists(os.path.join(sys.path[0], 'londiste')):
     del sys.path[0]
 
-import londiste, pgq.setadmin
+import londiste, pgq.cascade.admin
 
-command_usage = pgq.setadmin.command_usage + """
+command_usage = pgq.cascade.admin.command_usage + """
 Replication Daemon:
   worker                replay events to subscriber
 
@@ -31,17 +31,18 @@ Replication Extra:
   fkeys                 print out fkey drop/create commands
   compare [TBL ...]     compare table contents on both sides
   repair [TBL ...]      repair data on subscriber
+  execute [FILE ...]    execute SQL files on set
 
 Internal Commands:
   copy                  copy table logic
 """
 
 cmd_handlers = (
-    (('init-root', 'init-branch', 'init-leaf', 'members', 'tag-dead', 'tag-alive',
+    (('create-root', 'create-branch', 'create-leaf', 'members', 'tag-dead', 'tag-alive',
       'change-provider', 'rename-node', 'status', 'pause', 'resume',
       'switchover', 'failover'), londiste.LondisteSetup),
-    (('add', 'remove', 'add-seq', 'remove-seq', 'tables', 'seqs',
-      'missing', 'resync', 'check', 'fkeys'), londiste.LondisteSetup),
+    (('add-table', 'remove-table', 'add-seq', 'remove-seq', 'tables', 'seqs',
+      'missing', 'resync', 'check', 'fkeys', 'execute'), londiste.LondisteSetup),
     (('worker', 'replay'), londiste.Replicator),
     (('compare',), londiste.Comparator),
     (('repair',), londiste.Repairer),
@@ -53,7 +54,7 @@ class Londiste(skytools.DBScript):
         skytools.DBScript.__init__(self, 'londiste', args)
 
         if len(self.args) < 2:
-            print "need command"
+            print("need command")
             sys.exit(1)
         cmd = self.args[1]
         self.script = None
@@ -62,7 +63,7 @@ class Londiste(skytools.DBScript):
                 self.script = cls(args)
                 break
         if not self.script:
-            print "Unknown command '%s', use --help for help" % cmd
+            print("Unknown command '%s', use --help for help" % cmd)
             sys.exit(1)
 
     def start(self):
@@ -83,13 +84,16 @@ class Londiste(skytools.DBScript):
                 help = "add: keep old data", default=False)
         g.add_option("--provider",
                 help = "init: upstream node temp connect string")
-        g.add_option("--create",
+        g.add_option("--create", action="store_true",
+                help = "add: create table/seq if not exist")
+        g.add_option("--create-only",
                 help = "add: create table/seq if not exist (seq,pkey,full,indexes,fkeys)")
+        g.add_option("--target",
+                help = "switchover: target node")
+        g.add_option("--merge",
+                help = "create-leaf: combined queue name")
         p.add_option_group(g)
-
         return p
-    def opt_create_cb(self, option, opt_str, value, parser):
-        print opt_str, '=', value
 
 if __name__ == '__main__':
     script = Londiste(sys.argv[1:])
index cd76c0df6b5989a3b02a00033c78a4e02ab68d3f..1b06de8a342d6b832ec568826c36d6a3d835d6f1 100644 (file)
@@ -1,31 +1,25 @@
 
 """Replication on top of PgQ."""
 
+__pychecker__ = 'no-miximport'
+
 import londiste.playback
 import londiste.compare
-import londiste.file_read
-import londiste.file_write
 import londiste.setup
 import londiste.table_copy
-import londiste.installer
 import londiste.repair
 
 from londiste.playback import *
 from londiste.compare import *
-from londiste.file_read import *
-from londiste.file_write import *
 from londiste.setup import *
 from londiste.table_copy import *
-from londiste.installer import *
 from londiste.repair import *
 
 __all__ = (
     londiste.playback.__all__ +
     londiste.compare.__all__ +
-    londiste.file_read.__all__ +
-    londiste.file_write.__all__ +
     londiste.setup.__all__ +
     londiste.table_copy.__all__ +
-    londiste.installer.__all__ +
     londiste.repair.__all__ )
 
+
index 8624d7ebd087af1bc7c8836391ea997ff4d7bcab..3369f3a1c1fc04d1db818698eb3b1735d2403c95 100644 (file)
@@ -12,6 +12,9 @@ __all__ = ['Comparator']
 from londiste.syncer import Syncer
 
 class Comparator(Syncer):
+    """Simple checker based in Syncer.
+    When tables are in sync runs simple SQL query on them.
+    """
     def process_sync(self, tbl, src_db, dst_db):
         """Actual comparision."""
 
@@ -22,7 +25,7 @@ class Comparator(Syncer):
 
         q = "select count(1) from only _TABLE_"
         q = self.cf.get('compare_sql', q)
-        q = q.replace('_TABLE_', tbl)
+        q = q.replace('_TABLE_', skytools.quote_fqident(tbl))
 
         self.log.debug("srcdb: " + q)
         src_curs.execute(q)
diff --git a/python/londiste/file_read.py b/python/londiste/file_read.py
deleted file mode 100644 (file)
index 2902bda..0000000
+++ /dev/null
@@ -1,52 +0,0 @@
-
-"""Reads events from file instead of db queue."""
-
-import sys, os, re, skytools
-
-from playback import *
-from table_copy import *
-
-__all__ = ['FileRead']
-
-file_regex = r"^tick_0*([0-9]+)\.sql$"
-file_rc = re.compile(file_regex)
-
-
-class FileRead(CopyTable):
-    """Reads events from file instead of db queue.
-    
-    Incomplete implementation.
-    """
-
-    def __init__(self, args, log = None):
-        CopyTable.__init__(self, args, log, copy_thread = 0)
-
-    def launch_copy(self, tbl):
-        # copy immidiately
-        self.do_copy(t)
-
-    def work(self):
-        last_batch = self.get_last_batch(curs)
-        list = self.get_file_list()
-
-    def get_list(self):
-        """Return list of (first_batch, full_filename) pairs."""
-
-        src_dir = self.cf.get('file_src')
-        list = os.listdir(src_dir)
-        list.sort()
-        res = []
-        for fn in list:
-            m = file_rc.match(fn)
-            if not m:
-                self.log.debug("Ignoring file: %s" % fn)
-                continue
-            full = os.path.join(src_dir, fn)
-            batch_id = int(m.group(1))
-            res.append((batch_id, full))
-        return res
-
-if __name__ == '__main__':
-    script = Replicator(sys.argv[1:])
-    script.start()
-
diff --git a/python/londiste/file_write.py b/python/londiste/file_write.py
deleted file mode 100644 (file)
index 86e16aa..0000000
+++ /dev/null
@@ -1,67 +0,0 @@
-
-"""Writes events into file."""
-
-import sys, os, skytools
-from cStringIO import StringIO
-from playback import *
-
-__all__ = ['FileWrite']
-
-class FileWrite(Replicator):
-    """Writes events into file.
-    
-    Incomplete implementation.
-    """
-
-    last_successful_batch = None
-
-    def load_state(self, batch_id):
-        # maybe check if batch exists on filesystem?
-        self.cur_tick = self.cur_batch_info['tick_id']
-        self.prev_tick = self.cur_batch_info['prev_tick_id']
-        return 1
-
-    def process_batch(self, db, batch_id, ev_list):
-        pass
-
-    def save_state(self, do_commit):
-        # nothing to save
-        pass
-
-    def sync_tables(self, dst_db):
-        # nothing to sync
-        return 1
-
-    def interesting(self, ev):
-        # wants all of them
-        return 1
-
-    def handle_data_event(self, ev):
-        fmt = self.sql_command[ev.type]
-        sql = fmt % (ev.ev_extra1, ev.data)
-        row = "%s -- txid:%d" % (sql, ev.txid)
-        self.sql_list.append(row)
-        ev.tag_done()
-
-    def handle_system_event(self, ev):
-        row = "-- sysevent:%s txid:%d data:%s" % (
-                ev.type, ev.txid, ev.data)
-        self.sql_list.append(row)
-        ev.tag_done()
-
-    def flush_sql(self):
-        self.sql_list.insert(0, "-- tick:%d prev:%s" % (
-                             self.cur_tick, self.prev_tick))
-        self.sql_list.append("-- end_tick:%d\n" % self.cur_tick)
-        # store result
-        dir = self.cf.get("file_dst")
-        fn = os.path.join(dir, "tick_%010d.sql" % self.cur_tick)
-        f = open(fn, "w")
-        buf = "\n".join(self.sql_list)
-        f.write(buf)
-        f.close()
-
-if __name__ == '__main__':
-    script = Replicator(sys.argv[1:])
-    script.start()
-
diff --git a/python/londiste/installer.py b/python/londiste/installer.py
deleted file mode 100644 (file)
index e76cdb2..0000000
+++ /dev/null
@@ -1,27 +0,0 @@
-
-"""Functions to install londiste and its depentencies into database."""
-
-import os, skytools
-
-__all__ = ['install_provider', 'install_subscriber']
-
-provider_object_list = [
-    skytools.DBLanguage("plpgsql"),
-    skytools.DBFunction('txid_current_snapshot', 0, sql_file = "txid.sql"),
-    skytools.DBSchema('pgq', sql_file = "pgq.sql"),
-    skytools.DBSchema('londiste', sql_file = "londiste.sql")
-]
-
-subscriber_object_list = [
-    skytools.DBLanguage("plpgsql"),
-    skytools.DBSchema('londiste', sql_file = "londiste.sql")
-]
-
-def install_provider(curs, log):
-    """Installs needed code into provider db."""
-    skytools.db_install(curs, provider_object_list, log)
-
-def install_subscriber(curs, log):
-    """Installs needed code into subscriber db."""
-    skytools.db_install(curs, subscriber_object_list, log)
-
index a9599691cef804fcecd5db33bb6755029b1be436..040da7a585f7d66ef1ba417d0433a243d2077221 100644 (file)
@@ -3,7 +3,9 @@
 """Basic replication core."""
 
 import sys, os, time
-import skytools, pgq
+import skytools
+
+from pgq.cascade.worker import CascadedWorker
 
 __all__ = ['Replicator', 'TableState',
     'TABLE_MISSING', 'TABLE_IN_COPY', 'TABLE_CATCHING_UP',
@@ -53,13 +55,25 @@ class Counter(object):
 class TableState(object):
     """Keeps state about one table."""
     def __init__(self, name, log):
+        """Init TableState for one table."""
         self.name = name
         self.log = log
-        self.forget()
-        self.changed = 0
+        # same as forget:
+        self.state = TABLE_MISSING
+        self.last_snapshot_tick = None
+        self.str_snapshot = None
+        self.from_snapshot = None
+        self.sync_tick_id = None
+        self.ok_batch_count = 0
+        self.last_tick = 0
         self.skip_truncate = False
+        self.copy_role = None
+        self.dropped_ddl = None
+        # except this
+        self.changed = 0
 
     def forget(self):
+        """Reset all info."""
         self.state = TABLE_MISSING
         self.last_snapshot_tick = None
         self.str_snapshot = None
@@ -71,6 +85,7 @@ class TableState(object):
         self.changed = 1
 
     def change_snapshot(self, str_snapshot, tag_changed = 1):
+        """Set snapshot."""
         if self.str_snapshot == str_snapshot:
             return
         self.log.debug("%s: change_snapshot to %s" % (self.name, str_snapshot))
@@ -86,6 +101,7 @@ class TableState(object):
             self.changed = 1
 
     def change_state(self, state, tick_id = None):
+        """Set state."""
         if self.state == state and self.sync_tick_id == tick_id:
             return
         self.state = state
@@ -138,14 +154,18 @@ class TableState(object):
 
         return state
 
-    def loaded_state(self, merge_state, str_snapshot, skip_truncate):
+    def loaded_state(self, row):
+        """Update object with info from db."""
+
         self.log.debug("loaded_state: %s: %s / %s" % (
-                       self.name, merge_state, str_snapshot))
-        self.change_snapshot(str_snapshot, 0)
-        self.state = self.parse_state(merge_state)
+                       self.name, row['merge_state'], row['custom_snapshot']))
+        self.change_snapshot(row['custom_snapshot'], 0)
+        self.state = self.parse_state(row['merge_state'])
         self.changed = 0
-        self.skip_truncate = skip_truncate
-        if merge_state == "?":
+        self.skip_truncate = row['skip_truncate']
+        self.copy_role = row['copy_role']
+        self.dropped_ddl = row['dropped_ddl']
+        if row['merge_state'] == "?":
             self.changed = 1
 
     def interesting(self, ev, tick_id, copy_thread):
@@ -210,38 +230,7 @@ class TableState(object):
         if self.last_snapshot_tick < prev_tick:
             self.change_snapshot(None)
 
-class SeqCache(object):
-    def __init__(self):
-        self.seq_list = []
-        self.val_cache = {}
-
-    def set_seq_list(self, seq_list):
-        self.seq_list = seq_list
-        new_cache = {}
-        for seq in seq_list:
-            val = self.val_cache.get(seq)
-            if val:
-                new_cache[seq] = val
-        self.val_cache = new_cache
-
-    def resync(self, src_curs, dst_curs):
-        if len(self.seq_list) == 0:
-            return
-        dat = ".last_value, ".join(self.seq_list)
-        dat += ".last_value"
-        q = "select %s from %s" % (dat, ",".join(self.seq_list))
-        src_curs.execute(q)
-        row = src_curs.fetchone()
-        for i in range(len(self.seq_list)):
-            seq = self.seq_list[i]
-            cur = row[i]
-            old = self.val_cache.get(seq)
-            if old != cur:
-                q = "select setval(%s, %s)"
-                dst_curs.execute(q, [seq, cur])
-                self.val_cache[seq] = cur
-
-class Replicator(pgq.SetConsumer):
+class Replicator(CascadedWorker):
     """Replication core."""
 
     sql_command = {
@@ -253,30 +242,40 @@ class Replicator(pgq.SetConsumer):
     # batch info
     cur_tick = 0
     prev_tick = 0
+    copy_table_name = None # filled by Copytable()
+    sql_list = []
 
     def __init__(self, args):
-        pgq.SetConsumer.__init__(self, 'londiste', args)
+        """Replication init."""
+        CascadedWorker.__init__(self, 'londiste', 'db', args)
 
         self.table_list = []
         self.table_map = {}
 
         self.copy_thread = 0
-        self.seq_cache = SeqCache()
+        self.set_name = self.queue_name
 
         self.parallel_copies = self.cf.getint('parallel_copies', 1)
         if self.parallel_copies < 1:
-            raise Excpetion('Bad value for parallel_copies: %d' % self.parallel_copies)
+            raise Exception('Bad value for parallel_copies: %d' % self.parallel_copies)
 
-    def process_set_batch(self, src_db, dst_db, ev_list):
+    def connection_setup(self, dbname, db):
+        if dbname == 'db':
+            curs = db.cursor()
+            curs.execute("set session_replication_role = 'replica'")
+            db.commit()
+
+    def process_remote_batch(self, src_db, tick_id, ev_list, dst_db):
         "All work for a batch.  Entry point from SetConsumer."
 
         # this part can play freely with transactions
 
-        dst_curs = dst_db.cursor()
+        self.sync_database_encodings(src_db, dst_db)
         
-        self.cur_tick = self.src_queue.cur_tick
-        self.prev_tick = self.src_queue.prev_tick
+        self.cur_tick = self._batch_info['tick_id']
+        self.prev_tick = self._batch_info['prev_tick_id']
 
+        dst_curs = dst_db.cursor()
         self.load_table_state(dst_curs)
         self.sync_tables(src_db, dst_db)
 
@@ -289,38 +288,15 @@ class Replicator(pgq.SetConsumer):
         # now the actual event processing happens.
         # they must be done all in one tx in dst side
         # and the transaction must be kept open so that
-        # the SerialConsumer can save last tick and commit.
-
-        self.sync_database_encodings(src_db, dst_db)
-
-        self.handle_seqs(dst_curs)
-
-        q = "select pgq.set_connection_context(%s)"
-        dst_curs.execute(q, [self.set_name])
+        # the cascade-consumer can save last tick and commit.
 
         self.sql_list = []
-        pgq.SetConsumer.process_set_batch(self, src_db, dst_db, ev_list)
+        CascadedWorker.process_remote_batch(self, src_db, tick_id, ev_list, dst_db)
         self.flush_sql(dst_curs)
 
         # finalize table changes
         self.save_table_state(dst_curs)
 
-    def handle_seqs(self, dst_curs):
-        return # FIXME
-        if self.copy_thread:
-            return
-
-        q = "select * from londiste.subscriber_get_seq_list(%s)"
-        dst_curs.execute(q, [self.pgq_queue_name])
-        seq_list = []
-        for row in dst_curs.fetchall():
-            seq_list.append(row[0])
-
-        self.seq_cache.set_seq_list(seq_list)
-
-        src_curs = self.get_database('provider_db').cursor()
-        self.seq_cache.resync(src_curs, dst_curs)
-
     def sync_tables(self, src_db, dst_db):
         """Table sync loop.
         
@@ -337,7 +313,8 @@ class Replicator(pgq.SetConsumer):
 
             if res == SYNC_EXIT:
                 self.log.debug('Sync tables: exit')
-                self.unregister_consumer(src_db.cursor())
+                if not self.copy_thread:
+                    self.unregister_consumer()
                 src_db.commit()
                 sys.exit(0)
             elif res == SYNC_OK:
@@ -374,7 +351,7 @@ class Replicator(pgq.SetConsumer):
             src_db.commit()
             for t in self.get_tables_in_state(TABLE_MISSING):
                 if t.name not in pmap:
-                    self.log.warning("Table %s not availalbe on provider" % t.name)
+                    self.log.warning("Table %s not available on provider" % t.name)
                     continue
                 pt = pmap[t.name]
                 if pt.state != TABLE_OK: # or pt.custom_snapshot: # FIXME: does snapsnot matter?
@@ -402,6 +379,7 @@ class Replicator(pgq.SetConsumer):
         
         return ret
 
+
     def sync_from_copy_thread(self, cnt, src_db, dst_db):
         "Copy thread sync logic."
 
@@ -423,6 +401,11 @@ class Replicator(pgq.SetConsumer):
             # wait for main thread to react
             return SYNC_LOOP
         elif t.state == TABLE_CATCHING_UP:
+
+            # partition merging
+            if t.copy_role == 'wait-replay':
+                return SYNC_LOOP
+
             # is there more work?
             if self.work_state:
                 return SYNC_OK
@@ -442,33 +425,123 @@ class Replicator(pgq.SetConsumer):
             # nothing to do
             return SYNC_EXIT
 
-    def process_set_event(self, dst_curs, ev):
+    def do_copy(self, tbl, src_db, dst_db):
+        """Callback for actual copy implementation."""
+        raise Exception('do_copy not implemented')
+
+    def process_remote_event(self, src_curs, dst_curs, ev):
         """handle one event"""
         self.log.debug("New event: id=%s / type=%s / data=%s / extra1=%s" % (ev.id, ev.type, ev.data, ev.extra1))
         if ev.type in ('I', 'U', 'D'):
             self.handle_data_event(ev, dst_curs)
-        elif ev.type == 'add-table':
+            ev.tag_done()
+        elif ev.type[:2] in ('I:', 'U:', 'D:'):
+            self.handle_urlenc_event(ev, dst_curs)
+            ev.tag_done()
+        elif ev.type == "TRUNCATE":
+            self.flush_sql(dst_curs)
+            self.handle_truncate_event(ev, dst_curs)
+            ev.tag_done()
+        elif ev.type == 'EXECUTE':
+            self.flush_sql(dst_curs)
+            self.handle_execute_event(ev, dst_curs)
+            ev.tag_done()
+        elif ev.type == 'londiste.add-table':
+            self.flush_sql(dst_curs)
             self.add_set_table(dst_curs, ev.data)
-        elif ev.type == 'remove-table':
+            ev.tag_done()
+        elif ev.type == 'londiste.remove-table':
+            self.flush_sql(dst_curs)
             self.remove_set_table(dst_curs, ev.data)
+            ev.tag_done()
+        elif ev.type == 'londiste.remove-seq':
+            self.flush_sql(dst_curs)
+            self.remove_set_seq(dst_curs, ev.data)
+            ev.tag_done()
+        elif ev.type == 'londiste.update-seq':
+            self.flush_sql(dst_curs)
+            self.update_seq(dst_curs, ev)
+            ev.tag_done()
         else:
-            pgq.SetConsumer.process_set_event(self, dst_curs, ev)
+            CascadedWorker.process_remote_event(self, src_curs, dst_curs, ev)
 
     def handle_data_event(self, ev, dst_curs):
+        """handle one data event"""
         t = self.get_table_by_name(ev.extra1)
         if t and t.interesting(ev, self.cur_tick, self.copy_thread):
             # buffer SQL statements, then send them together
+            fqname = skytools.quote_fqident(ev.extra1)
             fmt = self.sql_command[ev.type]
-            sql = fmt % (ev.extra1, ev.data)
-            self.sql_list.append(sql)
-            if len(self.sql_list) > 200:
-                self.flush_sql(dst_curs)
+            sql = fmt % (fqname, ev.data)
+
+            self.apply_sql(sql, dst_curs)
+        else:
+            self.stat_increase('ignored_events')
+
+    def handle_urlenc_event(self, ev, dst_curs):
+        """handle one truncate event"""
+        t = self.get_table_by_name(ev.extra1)
+        if not t or not t.interesting(ev, self.cur_tick, self.copy_thread):
+            self.stat_increase('ignored_events')
+            return
+        
+        # parse event
+        pklist = ev.type[2:].split(',')
+        row = skytools.db_urldecode(ev.data)
+        op = ev.type[0]
+        tbl = ev.extra1
+
+        # generate sql
+        if op == 'I':
+            sql = skytools.mk_insert_sql(row, tbl, pklist)
+        elif op == 'U':
+            sql = skytools.mk_update_sql(row, tbl, pklist)
+        elif op == 'D':
+            sql = skytools.mk_delete_sql(row, tbl, pklist)
         else:
+            raise Exception('bug: bad op')
+
+        self.apply_sql(sql, dst_curs)
+
+    def handle_truncate_event(self, ev, dst_curs):
+        """handle one truncate event"""
+        t = self.get_table_by_name(ev.extra1)
+        if not t or not t.interesting(ev, self.cur_tick, self.copy_thread):
             self.stat_increase('ignored_events')
-        ev.tag_done()
+            return
+
+        fqname = skytools.quote_fqident(ev.extra1)
+        sql = "TRUNCATE %s;" % fqname
+        self.apply_sql(sql, dst_curs)
+
+    def handle_execute_event(self, ev, dst_curs):
+        """handle one EXECUTE event"""
+
+        if self.copy_thread:
+            return
+
+        # parse event
+        fname = ev.extra1
+        sql = ev.data
+
+        # fixme: curs?
+        q = "select * from londiste.execute_start(%s, %s, %s, false)"
+        res = self.exec_cmd(dst_curs, q, [self.queue_name, fname, sql], commit = False)
+        ret = res[0]['ret_code']
+        if ret != 200:
+            return
+        for stmt in skytools.parse_statements(sql):
+            dst_curs.execute(stmt)
+        q = "select * from londiste.execute_finish(%s, %s)"
+        self.exec_cmd(dst_curs, q, [self.queue_name, fname], commit = False)
+
+    def apply_sql(self, sql, dst_curs):
+        self.sql_list.append(sql)
+        if len(self.sql_list) > 200:
+            self.flush_sql(dst_curs)
 
     def flush_sql(self, dst_curs):
-        # send all buffered statements at once
+        """Send all buffered statements to DB."""
 
         if len(self.sql_list) == 0:
             return
@@ -479,6 +552,7 @@ class Replicator(pgq.SetConsumer):
         dst_curs.execute(buf)
 
     def interesting(self, ev):
+        """See if event is interesting."""
         if ev.type not in ('I', 'U', 'D'):
             raise Exception('bug - bad event type in .interesting')
         t = self.get_table_by_name(ev.extra1)
@@ -488,17 +562,26 @@ class Replicator(pgq.SetConsumer):
             return 0
 
     def add_set_table(self, dst_curs, tbl):
-        q = "select londiste.set_add_table(%s, %s)"
+        """There was new table added to root, remember it."""
+
+        q = "select londiste.global_add_table(%s, %s)"
         dst_curs.execute(q, [self.set_name, tbl])
 
     def remove_set_table(self, dst_curs, tbl):
+        """There was table dropped from root, remember it."""
         if tbl in self.table_map:
             t = self.table_map[tbl]
             del self.table_map[tbl]
             self.table_list.remove(t)
-        q = "select londiste.set_remove_table(%s, %s)"
+        q = "select londiste.global_remove_table(%s, %s)"
         dst_curs.execute(q, [self.set_name, tbl])
 
+    def remove_set_seq(self, dst_curs, seq):
+        """There was seq dropped from root, remember it."""
+
+        q = "select londiste.global_remove_seq(%s, %s)"
+        dst_curs.execute(q, [self.set_name, seq])
+
     def load_table_state(self, curs):
         """Load table state from database.
         
@@ -506,17 +589,18 @@ class Replicator(pgq.SetConsumer):
         to load state on every batch.
         """
 
-        q = "select table_name, custom_snapshot, merge_state, skip_truncate"\
-            "  from londiste.node_get_table_list(%s)"
+        q = "select * from londiste.get_table_list(%s)"
         curs.execute(q, [self.set_name])
 
         new_list = []
         new_map = {}
         for row in curs.dictfetchall():
+            if not row['local']:
+                continue
             t = self.get_table_by_name(row['table_name'])
             if not t:
                 t = TableState(row['table_name'], self.log)
-            t.loaded_state(row['merge_state'], row['custom_snapshot'], row['skip_truncate'])
+            t.loaded_state(row)
             new_list.append(t)
             new_map[t.name] = t
 
@@ -526,34 +610,35 @@ class Replicator(pgq.SetConsumer):
     def get_state_map(self, curs):
         """Get dict of table states."""
 
-        q = "select table_name, custom_snapshot, merge_state, skip_truncate"\
-            "  from londiste.node_get_table_list(%s)"
+        q = "select * from londiste.get_table_list(%s)"
         curs.execute(q, [self.set_name])
 
         new_map = {}
         for row in curs.fetchall():
+            if not row['local']:
+                continue
             t = TableState(row['table_name'], self.log)
-            t.loaded_state(row['merge_state'], row['custom_snapshot'], row['skip_truncate'])
+            t.loaded_state(row)
             new_map[t.name] = t
         return new_map
 
     def save_table_state(self, curs):
         """Store changed table state in database."""
 
-        got_changes = 0
         for t in self.table_list:
             if not t.changed:
                 continue
             merge_state = t.render_state()
             self.log.info("storing state of %s: copy:%d new_state:%s" % (
                             t.name, self.copy_thread, merge_state))
-            q = "select londiste.node_set_table_state(%s, %s, %s, %s)"
+            q = "select londiste.local_set_table_state(%s, %s, %s, %s)"
             curs.execute(q, [self.set_name,
                              t.name, t.str_snapshot, merge_state])
             t.changed = 0
-            got_changes = 1
 
     def change_table_state(self, dst_db, tbl, state, tick_id = None):
+        """Chage state for table."""
+
         tbl.change_state(state, tick_id)
         self.save_table_state(dst_db.cursor())
         dst_db.commit()
@@ -569,6 +654,7 @@ class Replicator(pgq.SetConsumer):
                 yield t
 
     def get_table_by_name(self, name):
+        """Returns cached state object."""
         if name.find('.') < 0:
             name = "public.%s" % name
         if name in self.table_map:
@@ -576,6 +662,7 @@ class Replicator(pgq.SetConsumer):
         return None
 
     def launch_copy(self, tbl_stat):
+        """Run paraller worker for copy."""
         self.log.info("Launching copy process")
         script = sys.argv[0]
         conf = self.cf.filename
@@ -587,7 +674,7 @@ class Replicator(pgq.SetConsumer):
         # otherwise new copy will exit immidiately.
         # FIXME: should not happen on per-table pidfile ???
         copy_pidfile = "%s.copy.%s" % (self.pidfile, tbl_stat.name)
-        while os.path.isfile(copy_pidfile):
+        while skytools.signal_pidfile(copy_pidfile, 0):
             self.log.warning("Waiting for existing copy to exit")
             time.sleep(2)
 
@@ -630,28 +717,54 @@ class Replicator(pgq.SetConsumer):
         """Restore fkeys that have both tables on sync."""
         dst_curs = dst_db.cursor()
         # restore fkeys -- one at a time
-        q = "select * from londiste.node_get_valid_pending_fkeys(%s)"
+        q = "select * from londiste.get_valid_pending_fkeys(%s)"
         dst_curs.execute(q, [self.set_name])
-        list = dst_curs.dictfetchall()
-        for row in list:
+        fkey_list = dst_curs.dictfetchall()
+        for row in fkey_list:
             self.log.info('Creating fkey: %(fkey_name)s (%(from_table)s --> %(to_table)s)' % row)
             q2 = "select londiste.restore_table_fkey(%(from_table)s, %(fkey_name)s)"
             dst_curs.execute(q2, row)
             dst_db.commit()
     
     def drop_fkeys(self, dst_db, table_name):
-        # drop all foreign keys to and from this table
-        # they need to be dropped one at a time to avoid deadlocks with user code
+        """Drop all foreign keys to and from this table.
+
+        They need to be dropped one at a time to avoid deadlocks with user code.
+        """
+
         dst_curs = dst_db.cursor()
         q = "select * from londiste.find_table_fkeys(%s)"
         dst_curs.execute(q, [table_name])
-        list = dst_curs.dictfetchall()
-        for row in list:
+        fkey_list = dst_curs.dictfetchall()
+        for row in fkey_list:
             self.log.info('Dropping fkey: %s' % row['fkey_name'])
             q2 = "select londiste.drop_table_fkey(%(from_table)s, %(fkey_name)s)"
             dst_curs.execute(q2, row)
             dst_db.commit()
         
+    def process_root_node(self, dst_db):
+        """On root node send seq changes to queue."""
+
+        CascadedWorker.process_root_node(self, dst_db)
+
+        q = "select * from londiste.root_check_seqs(%s)"
+        self.exec_cmd(dst_db, q, [self.queue_name])
+
+    def update_seq(self, dst_curs, ev):
+        if self.copy_thread:
+            return
+
+        val = int(ev.data)
+        seq = ev.extra1
+        q = "select * from londiste.global_update_seq(%s, %s, %s)"
+        self.exec_cmd(dst_curs, q, [self.queue_name, seq, val])
+
+    def copy_event(self, dst_curs, ev):
+        # send only data events down (skipping seqs also)
+        if ev.type[:9] in ('londiste.', 'EXECUTE', 'TRUNCATE'):
+            return
+        CascadedWorker.copy_event(self, dst_curs, ev)
+
 if __name__ == '__main__':
     script = Replicator(sys.argv[1:])
     script.start()
index 5fb1cf22b078dbbca402526207ab032772366663..414425a506001e2b90f4f9b3ffcab98709846ecf 100644 (file)
@@ -12,6 +12,7 @@ from syncer import Syncer
 __all__ = ['Repairer']
 
 def unescape(s):
+    """Remove copy escapes."""
     return skytools.unescape_copy(s)
 
 def get_pkey_list(curs, tbl):
@@ -45,6 +46,13 @@ def get_column_list(curs, tbl):
 class Repairer(Syncer):
     """Walks tables in primary key order and checks if data matches."""
 
+    cnt_insert = 0
+    cnt_update = 0
+    cnt_delete = 0
+    total_src = 0
+    total_dst = 0
+    pkey_list = []
+    common_fields = []
 
     def process_sync(self, tbl, src_db, dst_db):
         """Actual comparision."""
@@ -89,6 +97,7 @@ class Repairer(Syncer):
         os.unlink(dump_dst + ".sorted")
 
     def gen_copy_tbl(self, tbl, src_curs, dst_curs):
+        """Create COPY expession from common fields."""
         self.pkey_list = get_pkey_list(src_curs, tbl)
         dst_pkey = get_pkey_list(dst_curs, tbl)
         if dst_pkey != self.pkey_list:
@@ -108,20 +117,24 @@ class Repairer(Syncer):
 
         self.common_fields = field_list
 
-        tbl_expr = "%s (%s)" % (tbl, ",".join(field_list))
+        fqlist = [skytools.quote_ident(col) for col in field_list]
+
+        tbl_expr = "%s (%s)" % (skytools.quote_ident(tbl), ",".join(fqlist))
 
         self.log.debug("using copy expr: %s" % tbl_expr)
 
         return tbl_expr
 
     def dump_table(self, tbl, copy_tbl, curs, fn):
+        """Dump table to disk."""
         f = open(fn, "w", 64*1024)
         curs.copy_to(f, copy_tbl)
         size = f.tell()
         f.close()
-        self.log.info('Got %d bytes' % size)
+        self.log.info('%s: Got %d bytes' % (tbl, size))
 
     def get_row(self, ln):
+        """Parse a row into dict."""
         if not ln:
             return None
         t = ln[:-1].split('\t')
@@ -131,6 +144,7 @@ class Repairer(Syncer):
         return row
 
     def dump_compare(self, tbl, src_fn, dst_fn):
+        """Dump + compare single table."""
         self.log.info("Comparing dumps: %s" % tbl)
         self.cnt_insert = 0
         self.cnt_update = 0
@@ -154,12 +168,12 @@ class Repairer(Syncer):
                 src_row = self.get_row(src_ln)
                 dst_row = self.get_row(dst_ln)
 
-                cmp = self.cmp_keys(src_row, dst_row)
-                if cmp > 0:
+                diff = self.cmp_keys(src_row, dst_row)
+                if diff > 0:
                     # src > dst
                     self.got_missed_delete(tbl, dst_row)
                     keep_src = 1
-                elif cmp < 0:
+                elif diff < 0:
                     # src < dst
                     self.got_missed_insert(tbl, src_row)
                     keep_dst = 1
@@ -180,55 +194,63 @@ class Repairer(Syncer):
                 self.cnt_insert, self.cnt_update, self.cnt_delete))
 
     def got_missed_insert(self, tbl, src_row):
+        """Create sql for missed insert."""
         self.cnt_insert += 1
         fld_list = self.common_fields
+        fq_list = []
         val_list = []
         for f in fld_list:
+            fq_list.append(skytools.quote_ident(f))
             v = unescape(src_row[f])
             val_list.append(skytools.quote_literal(v))
         q = "insert into %s (%s) values (%s);" % (
-                tbl, ", ".join(fld_list), ", ".join(val_list))
+                tbl, ", ".join(fq_list), ", ".join(val_list))
         self.show_fix(tbl, q, 'insert')
 
     def got_missed_update(self, tbl, src_row, dst_row):
+        """Create sql for missed update."""
         self.cnt_update += 1
         fld_list = self.common_fields
         set_list = []
         whe_list = []
         for f in self.pkey_list:
-            self.addcmp(whe_list, f, unescape(src_row[f]))
+            self.addcmp(whe_list, skytools.quote_ident(f), unescape(src_row[f]))
         for f in fld_list:
             v1 = src_row[f]
             v2 = dst_row[f]
             if self.cmp_value(v1, v2) == 0:
                 continue
 
-            self.addeq(set_list, f, unescape(v1))
-            self.addcmp(whe_list, f, unescape(v2))
+            self.addeq(set_list, skytools.quote_ident(f), unescape(v1))
+            self.addcmp(whe_list, skytools.quote_ident(f), unescape(v2))
 
         q = "update only %s set %s where %s;" % (
                 tbl, ", ".join(set_list), " and ".join(whe_list))
         self.show_fix(tbl, q, 'update')
 
     def got_missed_delete(self, tbl, dst_row):
+        """Create sql for missed delete."""
         self.cnt_delete += 1
         whe_list = []
         for f in self.pkey_list:
-            self.addcmp(whe_list, f, unescape(dst_row[f]))
-        q = "delete from only %s where %s;" % (tbl, " and ".join(whe_list))
+            self.addcmp(whe_list, skytools.quote_ident(f), unescape(dst_row[f]))
+        q = "delete from only %s where %s;" % (skytools.quote_fqident(tbl), " and ".join(whe_list))
         self.show_fix(tbl, q, 'delete')
 
     def show_fix(self, tbl, q, desc):
-        #self.log.warning("missed %s: %s" % (desc, q))
+        """Print/write/apply repair sql."""
+        self.log.debug("missed %s: %s" % (desc, q))
         fn = "fix.%s.sql" % tbl
         open(fn, "a").write("%s\n" % q)
 
     def addeq(self, list, f, v):
+        """Add quoted SET."""
         vq = skytools.quote_literal(v)
         s = "%s = %s" % (f, vq)
         list.append(s)
 
     def addcmp(self, list, f, v):
+        """Add quoted comparision."""
         if v is None:
             s = "%s is null" % f
         else:
@@ -237,6 +259,7 @@ class Repairer(Syncer):
         list.append(s)
 
     def cmp_data(self, src_row, dst_row):
+        """Compare data field-by-field."""
         for k in self.common_fields:
             v1 = src_row[k]
             v2 = dst_row[k]
@@ -245,6 +268,7 @@ class Repairer(Syncer):
         return 0
 
     def cmp_value(self, v1, v2):
+        """Compare single field, tolerates tz vs notz dates."""
         if v1 == v2:
             return 0
 
index fea3e154b92b74e62ae1d2061e3170ba81fb61c3..30854de24880c55831fa287a11d7ac9ae5d673f7 100644 (file)
@@ -5,19 +5,37 @@
 
 import sys, os, skytools
 
-import pgq.setadmin
+from pgq.cascade.admin import CascadeAdmin
 
 __all__ = ['LondisteSetup']
 
-class LondisteSetup(pgq.setadmin.SetAdmin):
+class LondisteSetup(CascadeAdmin):
+    """Londiste-specific admin commands."""
     initial_db_name = 'node_db'
     extra_objs = [ skytools.DBSchema("londiste", sql_file="londiste.sql") ]
+    provider_location = None
     def __init__(self, args):
-        pgq.setadmin.SetAdmin.__init__(self, 'londiste', args)
-        self.set_name = self.cf.get("set_name")
+        """Londiste setup init."""
+        CascadeAdmin.__init__(self, 'londiste', 'db', args, worker_setup = True)
+
+        # compat
+        self.queue_name = self.cf.get('pgq_queue_name', '')
+        # real
+        if not self.queue_name:
+            self.queue_name = self.cf.get('queue_name')
+
+        self.set_name = self.queue_name
+
+    def connection_setup(self, dbname, db):
+        if dbname == 'db':
+            curs = db.cursor()
+            curs.execute("set session_replication_role = 'replica'")
+            db.commit()
 
     def init_optparse(self, parser=None):
-        p = pgq.setadmin.SetAdmin.init_optparse(self, parser)
+        """Add londiste switches to cascadeadmin ones."""
+
+        p = CascadeAdmin.init_optparse(self, parser)
         p.add_option("--expect-sync", action="store_true", dest="expect_sync",
                     help = "no copy needed", default=False)
         p.add_option("--skip-truncate", action="store_true", dest="skip_truncate",
@@ -26,266 +44,425 @@ class LondisteSetup(pgq.setadmin.SetAdmin):
                     help="force", default=False)
         p.add_option("--all", action="store_true",
                     help="include all tables", default=False)
+        p.add_option("--create", action="store_true",
+                    help="include all tables", default=False)
+        p.add_option("--create-only",
+                    help="pkey,fkeys,indexes")
         return p
 
     def extra_init(self, node_type, node_db, provider_db):
+        """Callback from CascadeAdmin init."""
         if not provider_db:
             return
         pcurs = provider_db.cursor()
         ncurs = node_db.cursor()
-        q = "select table_name from londiste.set_get_table_list(%s)"
+
+        # sync tables
+        q = "select table_name from londiste.get_table_list(%s)"
         pcurs.execute(q, [self.set_name])
         for row in pcurs.fetchall():
             tbl = row['table_name']
-            q = "select * from londiste.set_add_table(%s, %s)"
+            q = "select * from londiste.global_add_table(%s, %s)"
             ncurs.execute(q, [self.set_name, tbl])
+
+        # sync seqs
+        q = "select seq_name, last_value from londiste.get_seq_list(%s)"
+        pcurs.execute(q, [self.set_name])
+        for row in pcurs.fetchall():
+            seq = row['seq_name']
+            val = row['last_value']
+            q = "select * from londiste.update_seq(%s, %s, %s)"
+            ncurs.execute(q, [self.set_name, seq, val])
+
+        # done
         node_db.commit()
         provider_db.commit()
 
-    def cmd_add(self, *args):
-        q = "select * from londiste.node_add_table(%s, %s)"
-        db = self.get_database('node_db')
-        self.exec_cmd_many(db, q, [self.set_name], args)
+    def cmd_add_table(self, *args):
+        """Attach table(s) to local node."""
 
-    def cmd_remove(self, *args):
-        q = "select * from londiste.node_remove_table(%s, %s)"
-        db = self.get_database('node_db')
+        dst_db = self.get_database('db')
+        dst_curs = dst_db.cursor()
+        src_db = self.get_provider_db()
+        src_curs = src_db.cursor()
+
+        src_tbls = self.fetch_set_tables(src_curs)
+        dst_tbls = self.fetch_set_tables(dst_curs)
+        src_db.commit()
+        self.sync_table_list(dst_curs, src_tbls, dst_tbls)
+        dst_db.commit()
+
+        # dont check for exist/not here (root handling)
+        problems = False
+        for tbl in args:
+            tbl = skytools.fq_name(tbl)
+            if (tbl in src_tbls) and not src_tbls[tbl]:
+                self.log.error("Table %s does not exist on provider, need to switch to different provider" % tbl)
+                problems = True
+        if problems:
+            self.log.error("Problems, canceling operation")
+            sys.exit(1)
+
+        # pick proper create flags
+        create = self.options.create_only
+        if not create and self.options.create:
+            create = 'full'
+
+        fmap = {
+            "full": skytools.T_ALL,
+            "pkey": skytools.T_PKEY,
+        }
+        create_flags = 0
+        if create:
+            for f in create.split(','):
+                if f not in fmap:
+                    raise Exception("bad --create-only flag: " + f)
+            create_flags += fmap[f]
+
+        # seems ok
+        for tbl in args:
+            tbl = skytools.fq_name(tbl)
+            self.add_table(src_db, dst_db, tbl, create_flags)
+
+    def add_table(self, src_db, dst_db, tbl, create_flags):
+        src_curs = src_db.cursor()
+        dst_curs = dst_db.cursor()
+        if create_flags:
+            if skytools.exists_table(dst_curs, tbl):
+                self.log.info('Table %s already exist, not touching' % tbl)
+            else:
+                s = skytools.TableStruct(src_curs, tbl)
+                src_db.commit()
+                s.create(dst_curs, create_flags, log = self.log)
+        q = "select * from londiste.local_add_table(%s, %s)"
+        self.exec_cmd(dst_curs, q, [self.set_name, tbl])
+        dst_db.commit()
+    
+    def sync_table_list(self, dst_curs, src_tbls, dst_tbls):
+        for tbl in src_tbls.keys():
+            q = "select * from londiste.global_add_table(%s, %s)"
+            if tbl not in dst_tbls:
+                self.log.info("Table %s info missing from subscriber, adding")
+                self.exec_cmd(dst_curs, q, [self.set_name, tbl])
+                dst_tbls[tbl] = False
+        for tbl in dst_tbls.keys():
+            q = "select * from londiste.global_remove_table(%s, %s)"
+            if tbl not in src_tbls:
+                self.log.info("Table %s gone but exists on subscriber, removing")
+                self.exec_cmd(dst_curs, q, [self.set_name, tbl])
+                del dst_tbls[tbl]
+
+    def fetch_set_tables(self, curs):
+        q = "select table_name, local from londiste.get_table_list(%s)"
+        curs.execute(q, [self.set_name])
+        res = {}
+        for row in curs.fetchall():
+            res[row[0]] = row[1]
+        return res
+
+    def cmd_remove_table(self, *args):
+        """Detach table(s) from local node."""
+        q = "select * from londiste.local_remove_table(%s, %s)"
+        db = self.get_database('db')
         self.exec_cmd_many(db, q, [self.set_name], args)
 
     def cmd_add_seq(self, *args):
-        q = "select * from londiste.node_add_seq(%s, %s)"
-        db = self.get_database('node_db')
-        self.exec_cmd_many(db, q, [self.set_name], args)
+        """Attach seqs(s) to local node."""
+        dst_db = self.get_database('db')
+        dst_curs = dst_db.cursor()
+        src_db = self.get_provider_db()
+        src_curs = src_db.cursor()
+
+        src_seqs = self.fetch_seqs(src_curs)
+        dst_seqs = self.fetch_seqs(dst_curs)
+        src_db.commit()
+        self.sync_seq_list(dst_curs, src_seqs, dst_seqs)
+        dst_db.commit()
+
+        # pick proper create flags
+        create = self.options.create_only
+        if not create and self.options.create:
+            create = 'full'
+
+        fmap = {
+            "full": skytools.T_SEQUENCE,
+        }
+        create_flags = 0
+        if create:
+            for f in create.split(','):
+                if f not in fmap:
+                    raise Exception("bad --create-only flag: " + f)
+            create_flags += fmap[f]
+
+        # seems ok
+        for seq in args:
+            seq = skytools.fq_name(seq)
+            self.add_seq(src_db, dst_db, seq, create_flags)
+        dst_db.commit()
+
+    def add_seq(self, src_db, dst_db, seq, create_flags):
+        src_curs = src_db.cursor()
+        dst_curs = dst_db.cursor()
+        if create_flags:
+            if skytools.exists_sequence(dst_curs, seq):
+                self.log.info('Sequence %s already exist, not creating' % seq)
+            else:
+                s = skytools.SeqStruct(src_curs, seq)
+                src_db.commit()
+                s.create(dst_curs, create_flags, log = self.log)
+        q = "select * from londiste.local_add_seq(%s, %s)"
+        self.exec_cmd(dst_curs, q, [self.set_name, seq])
+
+    def fetch_seqs(self, curs):
+        q = "select seq_name, last_value, local from londiste.get_seq_list(%s)"
+        curs.execute(q, [self.set_name])
+        res = {}
+        for row in curs.fetchall():
+            res[row[0]] = row
+        return res
+
+    def sync_seq_list(self, dst_curs, src_seqs, dst_seqs):
+        for seq in src_seqs.keys():
+            q = "select * from londiste.global_update_seq(%s, %s, %s)"
+            if seq not in dst_seqs:
+                self.log.info("Sequence %s info missing from subscriber, adding")
+                self.exec_cmd(dst_curs, q, [self.set_name, seq, src_seqs[seq]['last_value']])
+                tmp = src_seqs[seq].copy()
+                tmp['local'] = False
+                dst_seqs[seq] = tmp
+        for seq in dst_seqs.keys():
+            q = "select * from londiste.global_remove_seq(%s, %s)"
+            if seq not in src_seqs:
+                self.log.info("Sequence %s gone but exists on subscriber, removing")
+                self.exec_cmd(dst_curs, q, [self.set_name, seq])
+                del dst_seqs[seq]
 
     def cmd_remove_seq(self, *args):
-        q = "select * from londiste.node_remove_seq(%s, %s)"
-        db = self.get_database('node_db')
+        """Detach seqs(s) from local node."""
+        q = "select * from londiste.local_remove_seq(%s, %s)"
+        db = self.get_database('db')
         self.exec_cmd_many(db, q, [self.set_name], args)
 
     def cmd_resync(self, *args):
+        """Reload data from provider node.."""
+        # fixme
         q = "select * from londiste.node_resync_table(%s, %s)"
-        db = self.get_database('node_db')
+        db = self.get_database('db')
         self.exec_cmd_many(db, q, [self.set_name], args)
 
     def cmd_tables(self):
-        q = "select table_name, merge_state from londiste.node_get_table_list(%s)"
-        db = self.get_database('node_db')
-        self.db_display_table(db, "Tables on node", q, [self.set_name])
+        """Show attached tables."""
+        q = "select table_name, local, merge_state from londiste.get_table_list(%s)"
+        db = self.get_database('db')
+        self.display_table(db, "Tables on node", q, [self.set_name])
 
     def cmd_seqs(self):
-        q = "select seq_namefrom londiste.node_get_seq_list(%s)"
-        db = self.get_database('node_db')
-        self.db_display_table(db, "Sequences on node", q, [self.set_name])
+        """Show attached seqs."""
+        q = "select seq_name, local, last_value from londiste.get_seq_list(%s)"
+        db = self.get_database('db')
+        self.display_table(db, "Sequences on node", q, [self.set_name])
 
     def cmd_missing(self):
+        """Show missing tables on local node."""
+        # fixme
         q = "select * from londiste.node_show_missing(%s)"
-        db = self.get_database('node_db')
-        self.db_display_table(db, "Missing objects on node", q, [self.set_name])
+        db = self.get_database('db')
+        self.display_table(db, "Missing objects on node", q, [self.set_name])
 
     def cmd_check(self):
+        """TODO: check if structs match"""
         pass
     def cmd_fkeys(self):
+        """TODO: show removed fkeys."""
         pass
     def cmd_triggers(self):
+        """TODO: show removed triggers."""
         pass
 
+    def cmd_execute(self, *files):
+        db = self.get_database('db')
+        curs = db.cursor()
+        for fn in files:
+            fname = os.path.basename(fn)
+            sql = open(fn, "r").read()
+            q = "select * from londiste.execute_start(%s, %s, %s, true)"
+            self.exec_cmd(db, q, [self.queue_name, fname, sql], commit = False)
+            for stmt in skytools.parse_statements(sql):
+                curs.execute(stmt)
+            q = "select * from londiste.execute_finish(%s, %s)"
+            self.exec_cmd(db, q, [self.queue_name, fname], commit = False)
+        db.commit()
+
+    def get_provider_db(self):
+        if not self.provider_location:
+            db = self.get_database('db')
+            q = 'select * from pgq_node.get_node_info(%s)'
+            res = self.exec_cmd(db, q, [self.queue_name], quiet = True)
+            self.provider_location = res[0]['provider_location']
+        return self.get_database('provider_db', connstr = self.provider_location)
+
 #
 # Old commands
 #
 
-class LondisteSetup_tmp:
-
-    def find_missing_provider_tables(self, pattern='*'):
-        src_db = self.get_database('provider_db')
-        src_curs = src_db.cursor()
-        q = """select schemaname || '.' || tablename as full_name from pg_tables
-                where schemaname not in ('pgq', 'londiste', 'pg_catalog', 'information_schema')
-                  and schemaname !~ 'pg_.*'
-                  and (schemaname || '.' || tablename) ~ %s
-                except select table_name from londiste.provider_get_table_list(%s)"""
-        src_curs.execute(q, [glob2regex(pattern), self.pgq_queue_name])
-        rows = src_curs.fetchall()
-        src_db.commit()
-        list = []
-        for row in rows:
-            list.append(row[0])
-        return list
-                
-    def admin(self):
-        cmd = self.args[2]
-        if cmd == "tables":
-            self.subscriber_show_tables()
-        elif cmd == "missing":
-            self.subscriber_missing_tables()
-        elif cmd == "add":
-            self.subscriber_add_tables(self.args[3:])
-        elif cmd == "remove":
-            self.subscriber_remove_tables(self.args[3:])
-        elif cmd == "resync":
-            self.subscriber_resync_tables(self.args[3:])
-        elif cmd == "register":
-            self.subscriber_register()
-        elif cmd == "unregister":
-            self.subscriber_unregister()
-        elif cmd == "install":
-            self.subscriber_install()
-        elif cmd == "check":
-            self.check_tables(self.get_provider_table_list())
-        elif cmd in ["fkeys", "triggers"]:
-            self.collect_meta(self.get_provider_table_list(), cmd, self.args[3:])
-        elif cmd == "seqs":
-            self.subscriber_list_seqs()
-        elif cmd == "add-seq":
-            self.subscriber_add_seq(self.args[3:])
-        elif cmd == "remove-seq":
-            self.subscriber_remove_seq(self.args[3:])
-        elif cmd == "restore-triggers":
-            self.restore_triggers(self.args[3], self.args[4:])
-        else:
-            self.log.error('bad subcommand: ' + cmd)
-            sys.exit(1)
-
-    def collect_meta(self, table_list, meta, args):
-        """Display fkey/trigger info."""
-
-        if args == []:
-            args = ['pending', 'active']
-            
-        field_map = {'triggers': ['table_name', 'trigger_name', 'trigger_def'],
-                     'fkeys': ['from_table', 'to_table', 'fkey_name', 'fkey_def']}
-        
-        query_map = {'pending': "select %s from londiste.subscriber_get_table_pending_%s(%%s)",
-                     'active' : "select %s from londiste.find_table_%s(%%s)"}
-
-        table_list = self.clean_subscriber_tables(table_list)
-        if len(table_list) == 0:
-            self.log.info("No tables, no fkeys")
-            return
-
-        dst_db = self.get_database('subscriber_db')
-        dst_curs = dst_db.cursor()
-
-        for which in args:
-            union_list = []
-            fields = field_map[meta]
-            q = query_map[which] % (",".join(fields), meta)
-            for tbl in table_list:
-                union_list.append(q % skytools.quote_literal(tbl))
-
-            # use union as fkey may appear in duplicate
-            sql = " union ".join(union_list) + " order by 1"
-            desc = "%s %s" % (which, meta)
-            self.display_table(desc, dst_curs, fields, sql)
-        dst_db.commit()
-
-    def check_tables(self, table_list):
-        src_db = self.get_database('provider_db')
-        src_curs = src_db.cursor()
-        dst_db = self.get_database('subscriber_db')
-        dst_curs = dst_db.cursor()
-
-        failed = 0
-        for tbl in table_list:
-            self.log.info('Checking %s' % tbl)
-            if not skytools.exists_table(src_curs, tbl):
-                self.log.error('Table %s missing from provider side' % tbl)
-                failed += 1
-            elif not skytools.exists_table(dst_curs, tbl):
-                self.log.error('Table %s missing from subscriber side' % tbl)
-                failed += 1
-            else:
-                failed += self.check_table_columns(src_curs, dst_curs, tbl)
-
-        src_db.commit()
-        dst_db.commit()
-
-        return failed
-
-    def restore_triggers(self, tbl, triggers=None):
-        tbl = skytools.fq_name(tbl)
-        if tbl not in self.get_subscriber_table_list():
-            self.log.error("Table %s is not in the subscriber queue." % tbl)
-            sys.exit(1)
-            
-        dst_db = self.get_database('subscriber_db')
-        dst_curs = dst_db.cursor()
-        
-        if not triggers:
-            q = "select count(1) from londiste.subscriber_get_table_pending_triggers(%s)"
-            dst_curs.execute(q, [tbl])
-            if not dst_curs.fetchone()[0]:
-                self.log.info("No pending triggers found for %s." % tbl)
-            else:
-                q = "select londiste.subscriber_restore_all_table_triggers(%s)"
-                dst_curs.execute(q, [tbl])
-        else:
-            for trigger in triggers:
-                q = "select count(1) from londiste.find_table_triggers(%s) where trigger_name=%s"
-                dst_curs.execute(q, [tbl, trigger])
-                if dst_curs.fetchone()[0]:
-                    self.log.info("Trigger %s on %s is already active." % (trigger, tbl))
-                    continue
-                    
-                q = "select count(1) from londiste.subscriber_get_table_pending_triggers(%s) where trigger_name=%s"
-                dst_curs.execute(q, [tbl, trigger])
-                if not dst_curs.fetchone()[0]:
-                    self.log.info("Trigger %s not found on %s" % (trigger, tbl))
-                    continue
-                    
-                q = "select londiste.subscriber_restore_table_trigger(%s, %s)"
-                dst_curs.execute(q, [tbl, trigger])
-        dst_db.commit()
-
-    def check_table_columns(self, src_curs, dst_curs, tbl):
-        src_colrows = find_column_types(src_curs, tbl)
-        dst_colrows = find_column_types(dst_curs, tbl)
-
-        src_cols = make_type_string(src_colrows)
-        dst_cols = make_type_string(dst_colrows)
-        if src_cols.find('k') < 0:
-            self.log.error('provider table %s has no primary key (%s)' % (
-                             tbl, src_cols))
-            return 1
-        if dst_cols.find('k') < 0:
-            self.log.error('subscriber table %s has no primary key (%s)' % (
-                             tbl, dst_cols))
-            return 1
-
-        if src_cols != dst_cols:
-            self.log.warning('table %s structure is not same (%s/%s)'\
-                 ', trying to continue' % (tbl, src_cols, dst_cols))
-
-        err = 0
-        for row in src_colrows:
-            found = 0
-            for row2 in dst_colrows:
-                if row2['name'] == row['name']:
-                    found = 1
-                    break
-            if not found:
-                err = 1
-                self.log.error('%s: column %s on provider not on subscriber'
-                                    % (tbl, row['name']))
-            elif row['type'] != row2['type']:
-                err = 1
-                self.log.error('%s: pk different on column %s'
-                                    % (tbl, row['name']))
-
-        return err
-
-    def find_missing_subscriber_tables(self, pattern='*'):
-        src_db = self.get_database('subscriber_db')
-        src_curs = src_db.cursor()
-        q = """select schemaname || '.' || tablename as full_name from pg_tables
-                where schemaname not in ('pgq', 'londiste', 'pg_catalog', 'information_schema')
-                  and schemaname !~ 'pg_.*'
-                  and schemaname || '.' || tablename ~ %s
-                except select table_name from londiste.provider_get_table_list(%s)"""
-        src_curs.execute(q, [glob2regex(pattern), self.pgq_queue_name])
-        rows = src_curs.fetchall()
-        src_db.commit()
-        list = []
-        for row in rows:
-            list.append(row[0])
-        return list
-
+#class LondisteSetup_tmp(LondisteSetup):
+#
+#    def find_missing_provider_tables(self, pattern='*'):
+#        src_db = self.get_database('provider_db')
+#        src_curs = src_db.cursor()
+#        q = """select schemaname || '.' || tablename as full_name from pg_tables
+#                where schemaname not in ('pgq', 'londiste', 'pg_catalog', 'information_schema')
+#                  and schemaname !~ 'pg_.*'
+#                  and (schemaname || '.' || tablename) ~ %s
+#                except select table_name from londiste.provider_get_table_list(%s)"""
+#        src_curs.execute(q, [glob2regex(pattern), self.queue_name])
+#        rows = src_curs.fetchall()
+#        src_db.commit()
+#        list = []
+#        for row in rows:
+#            list.append(row[0])
+#        return list
+#                
+#    def admin(self):
+#        cmd = self.args[2]
+#        if cmd == "tables":
+#            self.subscriber_show_tables()
+#        elif cmd == "missing":
+#            self.subscriber_missing_tables()
+#        elif cmd == "add":
+#            self.subscriber_add_tables(self.args[3:])
+#        elif cmd == "remove":
+#            self.subscriber_remove_tables(self.args[3:])
+#        elif cmd == "resync":
+#            self.subscriber_resync_tables(self.args[3:])
+#        elif cmd == "register":
+#            self.subscriber_register()
+#        elif cmd == "unregister":
+#            self.subscriber_unregister()
+#        elif cmd == "install":
+#            self.subscriber_install()
+#        elif cmd == "check":
+#            self.check_tables(self.get_provider_table_list())
+#        elif cmd in ["fkeys", "triggers"]:
+#            self.collect_meta(self.get_provider_table_list(), cmd, self.args[3:])
+#        elif cmd == "seqs":
+#            self.subscriber_list_seqs()
+#        elif cmd == "add-seq":
+#            self.subscriber_add_seq(self.args[3:])
+#        elif cmd == "remove-seq":
+#            self.subscriber_remove_seq(self.args[3:])
+#        elif cmd == "restore-triggers":
+#            self.restore_triggers(self.args[3], self.args[4:])
+#        else:
+#            self.log.error('bad subcommand: ' + cmd)
+#            sys.exit(1)
+#
+#    def collect_meta(self, table_list, meta, args):
+#        """Display fkey/trigger info."""
+#
+#        if args == []:
+#            args = ['pending', 'active']
+#            
+#        field_map = {'triggers': ['table_name', 'trigger_name', 'trigger_def'],
+#                     'fkeys': ['from_table', 'to_table', 'fkey_name', 'fkey_def']}
+#        
+#        query_map = {'pending': "select %s from londiste.subscriber_get_table_pending_%s(%%s)",
+#                     'active' : "select %s from londiste.find_table_%s(%%s)"}
+#
+#        table_list = self.clean_subscriber_tables(table_list)
+#        if len(table_list) == 0:
+#            self.log.info("No tables, no fkeys")
+#            return
+#
+#        dst_db = self.get_database('subscriber_db')
+#        dst_curs = dst_db.cursor()
+#
+#        for which in args:
+#            union_list = []
+#            fields = field_map[meta]
+#            q = query_map[which] % (",".join(fields), meta)
+#            for tbl in table_list:
+#                union_list.append(q % skytools.quote_literal(tbl))
+#
+#            # use union as fkey may appear in duplicate
+#            sql = " union ".join(union_list) + " order by 1"
+#            desc = "%s %s" % (which, meta)
+#            self.display_table(desc, dst_curs, fields, sql)
+#        dst_db.commit()
+#
+#    def check_tables(self, table_list):
+#        src_db = self.get_database('provider_db')
+#        src_curs = src_db.cursor()
+#        dst_db = self.get_database('subscriber_db')
+#        dst_curs = dst_db.cursor()
+#
+#        failed = 0
+#        for tbl in table_list:
+#            self.log.info('Checking %s' % tbl)
+#            if not skytools.exists_table(src_curs, tbl):
+#                self.log.error('Table %s missing from provider side' % tbl)
+#                failed += 1
+#            elif not skytools.exists_table(dst_curs, tbl):
+#                self.log.error('Table %s missing from subscriber side' % tbl)
+#                failed += 1
+#            else:
+#                failed += self.check_table_columns(src_curs, dst_curs, tbl)
+#
+#        src_db.commit()
+#        dst_db.commit()
+#
+#        return failed
+#
+#    def check_table_columns(self, src_curs, dst_curs, tbl):
+#        src_colrows = find_column_types(src_curs, tbl)
+#        dst_colrows = find_column_types(dst_curs, tbl)
+#
+#        src_cols = make_type_string(src_colrows)
+#        dst_cols = make_type_string(dst_colrows)
+#        if src_cols.find('k') < 0:
+#            self.log.error('provider table %s has no primary key (%s)' % (
+#                             tbl, src_cols))
+#            return 1
+#        if dst_cols.find('k') < 0:
+#            self.log.error('subscriber table %s has no primary key (%s)' % (
+#                             tbl, dst_cols))
+#            return 1
+#
+#        if src_cols != dst_cols:
+#            self.log.warning('table %s structure is not same (%s/%s)'\
+#                 ', trying to continue' % (tbl, src_cols, dst_cols))
+#
+#        err = 0
+#        for row in src_colrows:
+#            found = 0
+#            for row2 in dst_colrows:
+#                if row2['name'] == row['name']:
+#                    found = 1
+#                    break
+#            if not found:
+#                err = 1
+#                self.log.error('%s: column %s on provider not on subscriber'
+#                                    % (tbl, row['name']))
+#            elif row['type'] != row2['type']:
+#                err = 1
+#                self.log.error('%s: pk different on column %s'
+#                                    % (tbl, row['name']))
+#
+#        return err
+#
+#    def find_missing_subscriber_tables(self, pattern='*'):
+#        src_db = self.get_database('subscriber_db')
+#        src_curs = src_db.cursor()
+#        q = """select schemaname || '.' || tablename as full_name from pg_tables
+#                where schemaname not in ('pgq', 'londiste', 'pg_catalog', 'information_schema')
+#                  and schemaname !~ 'pg_.*'
+#                  and schemaname || '.' || tablename ~ %s
+#                except select table_name from londiste.provider_get_table_list(%s)"""
+#        src_curs.execute(q, [glob2regex(pattern), self.queue_name])
+#        rows = src_curs.fetchall()
+#        src_db.commit()
+#        list = []
+#        for row in rows:
+#            list.append(row[0])
+#        return list
+#
index 4d089a2f1abc84882e1602620920207d41942028..44ee92bd9005f1d8b06badcabb9a553cb44c53a5 100644 (file)
@@ -8,28 +8,47 @@ class Syncer(skytools.DBScript):
     """Walks tables in primary key order and checks if data matches."""
 
     def __init__(self, args):
+        """Syncer init."""
         skytools.DBScript.__init__(self, 'londiste', args)
         self.set_single_loop(1)
 
-        self.pgq_queue_name = self.cf.get("pgq_queue_name")
-        self.pgq_consumer_id = self.cf.get('pgq_consumer_id', self.job_name)
+        # compat names
+        self.queue_name = self.cf.get("pgq_queue_name", '')
+        self.consumer_name = self.cf.get('pgq_consumer_id', '')
+
+        # good names
+        if not self.queue_name:
+            self.queue_name = self.cf.get("queue_name")
+        if not self.consumer_name:
+            self.consumer_name = self.cf.get('consumer_name', self.job_name)
+
+        self.lock_timeout = self.cf.getfloat('lock_timeout', 10)
 
         if self.pidfile:
             self.pidfile += ".repair"
 
+    def set_lock_timeout(self, curs):
+        ms = int(1000 * self.lock_timeout)
+        if ms > 0:
+            q = "SET LOCAL statement_timeout = %d" % ms
+            self.log.debug(q)
+            curs.execute(q)
+
     def init_optparse(self, p=None):
+        """Initialize cmdline switches."""
         p = skytools.DBScript.init_optparse(self, p)
         p.add_option("--force", action="store_true", help="ignore lag")
         return p
 
     def check_consumer(self, setup_curs):
-        # before locking anything check if consumer is working ok
+        """Before locking anything check if consumer is working ok."""
+
         q = "select extract(epoch from ticker_lag) from pgq.get_queue_info(%s)"
-        setup_curs.execute(q, [self.pgq_queue_name])
+        setup_curs.execute(q, [self.queue_name])
         ticker_lag = setup_curs.fetchone()[0]
         q = "select extract(epoch from lag)"\
             " from pgq.get_consumer_info(%s, %s)"
-        setup_curs.execute(q, [self.pgq_queue_name, self.pgq_consumer_id])
+        setup_curs.execute(q, [self.queue_name, self.consumer_name])
         res = setup_curs.fetchall()
 
         if len(res) == 0:
@@ -42,15 +61,16 @@ class Syncer(skytools.DBScript):
             sys.exit(1)
 
     def get_subscriber_table_state(self, dst_db):
+        """Load table states from subscriber."""
         dst_curs = dst_db.cursor()
         q = "select * from londiste.subscriber_get_table_list(%s)"
-        dst_curs.execute(q, [self.pgq_queue_name])
+        dst_curs.execute(q, [self.queue_name])
         res = dst_curs.dictfetchall()
         dst_db.commit()
         return res
 
     def work(self):
-        src_loc = self.cf.get('provider_db')
+        """Syncer main function."""
         lock_db = self.get_database('provider_db', cache='lock_db')
         setup_db = self.get_database('provider_db', cache='setup_db', autocommit = 1)
 
@@ -92,14 +112,14 @@ class Syncer(skytools.DBScript):
 
     def force_tick(self, setup_curs):
         q = "select pgq.force_tick(%s)"
-        setup_curs.execute(q, [self.pgq_queue_name])
+        setup_curs.execute(q, [self.queue_name])
         res = setup_curs.fetchone()
         cur_pos = res[0]
 
         start = time.time()
         while 1:
             time.sleep(0.5)
-            setup_curs.execute(q, [self.pgq_queue_name])
+            setup_curs.execute(q, [self.queue_name])
             res = setup_curs.fetchone()
             if res[0] != cur_pos:
                 # new pos
@@ -128,8 +148,9 @@ class Syncer(skytools.DBScript):
         # lock table in separate connection
         self.log.info('Locking %s' % tbl)
         lock_db.commit()
-        lock_curs.execute("LOCK TABLE %s IN SHARE MODE" % tbl)
+        self.set_lock_timeout(lock_curs)
         lock_time = time.time()
+        lock_curs.execute("LOCK TABLE %s IN SHARE MODE" % skytools.quote_fqident(tbl))
 
         # now wait until consumer has updated target table until locking
         self.log.info('Syncing %s' % tbl)
@@ -148,7 +169,7 @@ class Syncer(skytools.DBScript):
 
             q = "select now() - lag > timestamp %s, now(), lag"\
                 " from pgq.get_consumer_info(%s, %s)"
-            setup_curs.execute(q, [tpos, self.pgq_queue_name, self.pgq_consumer_id])
+            setup_curs.execute(q, [tpos, self.queue_name, self.consumer_name])
             res = setup_curs.fetchall()
 
             if len(res) == 0:
@@ -159,8 +180,8 @@ class Syncer(skytools.DBScript):
             if row[0]:
                 break
 
-            # loop max 10 secs
-            if time.time() > lock_time + 10 and not self.options.force:
+            # limit lock time
+            if time.time() > lock_time + self.lock_timeout and not self.options.force:
                 self.log.error('Consumer lagging too much, exiting')
                 lock_db.rollback()
                 sys.exit(1)
index b02b789fb46bb2f9767853809b40809f31ebdf34..d90e00ae47b5120ebf9aaa0f8eb2a92d186f84ac 100644 (file)
@@ -5,7 +5,7 @@
 For internal usage.
 """
 
-import sys, os, skytools
+import sys, os, time, skytools
 
 from skytools.dbstruct import *
 from londiste.playback import *
@@ -13,30 +13,60 @@ from londiste.playback import *
 __all__ = ['CopyTable']
 
 class CopyTable(Replicator):
+    """Table copy thread implementation."""
+
+    reg_ok = False
+
     def __init__(self, args, copy_thread = 1):
+        """Initializer.  copy_thread arg shows if the copy process is separate
+        from main Playback thread or not.  copy_thread=0 means copying happens
+        in same process.
+        """
+
         Replicator.__init__(self, args)
 
         if not copy_thread:
             raise Exception("Combined copy not supported")
 
-        if len(self.args):
-            print "londiste copy requires table name"
+        if len(self.args) != 3:
+            self.log.error("londiste copy requires table name")
+            sys.exit(1)
         self.copy_table_name = self.args[2]
 
-        self.pidfile += ".copy.%s" % self.copy_table_name
-        self.consumer_name += "_copy_%s" % self.copy_table_name
+        sfx = self.get_copy_suffix(self.copy_table_name)
+        self.old_consumer_name = self.consumer_name
+        self.pidfile += sfx
+        self.consumer_name += sfx
         self.copy_thread = 1
         self.main_worker = False
 
+    def get_copy_suffix(self, tblname):
+        return ".copy.%s" % tblname
+
+    def reload_table_stat(self, dst_curs, tblname):
+        self.load_table_state(dst_curs)
+        t = self.table_map[tblname]
+        return t
+
     def do_copy(self, tbl_stat, src_db, dst_db):
+        """Entry point into copying logic."""
 
         dst_db.commit()
 
+        src_curs = src_db.cursor()
+        dst_curs = dst_db.cursor()
+
+        while tbl_stat.copy_role == 'wait-copy':
+            self.log.info('waiting for first partition to initialize copy')
+            time.sleep(10)
+            tbl_stat = self.reload_table_stat(dst_curs, tbl_stat.name)
+            dst_db.commit()
+
         while 1:
             pmap = self.get_state_map(src_db.cursor())
             src_db.commit()
             if tbl_stat.name not in pmap:
-                raise Excpetion("table %s not available on provider" % tbl_stat.name)
+                raise Exception("table %s not available on provider" % tbl_stat.name)
             pt = pmap[tbl_stat.name]
             if pt.state == TABLE_OK:
                 break
@@ -44,6 +74,14 @@ class CopyTable(Replicator):
             self.log.warning("table %s not in sync yet on provider, waiting" % tbl_stat.name)
             time.sleep(10)
 
+        # 0 - dont touch
+        # 1 - single tx
+        # 2 - multi tx
+        cmode = 1
+        if tbl_stat.copy_role == 'lead':
+            cmode = 2
+        elif tbl_stat.copy_role:
+            cmode = 0
 
         # change to SERIALIZABLE isolation level
         src_db.set_isolation_level(skytools.I_SERIALIZABLE)
@@ -51,24 +89,12 @@ class CopyTable(Replicator):
 
         self.sync_database_encodings(src_db, dst_db)
 
-        # initial sync copy
-        src_curs = src_db.cursor()
-        dst_curs = dst_db.cursor()
-
         self.log.info("Starting full copy of %s" % tbl_stat.name)
 
         # just in case, drop all fkeys (in case "replay" was skipped)
         # !! this may commit, so must be done before anything else !!
         self.drop_fkeys(dst_db, tbl_stat.name)
 
-        # drop own triggers
-        q_node_trg = "select * from londiste.node_disable_triggers(%s, %s)"
-        dst_curs.execute(q_node_trg, [self.set_name, tbl_stat.name])
-
-        # drop rest of the triggers
-        q_triggers = "select londiste.drop_all_table_triggers(%s)"
-        dst_curs.execute(q_triggers, [tbl_stat.name])
-
         # find dst struct
         src_struct = TableStruct(src_curs, tbl_stat.name)
         dst_struct = TableStruct(dst_curs, tbl_stat.name)
@@ -89,8 +115,24 @@ class CopyTable(Replicator):
                                  % (tbl_stat.name, c))
 
         # drop unnecessary stuff
-        objs = T_CONSTRAINT | T_INDEX | T_RULE
-        dst_struct.drop(dst_curs, objs, log = self.log)
+        if cmode > 0:
+            # drop indexes
+            objs = T_CONSTRAINT | T_INDEX | T_RULE # | T_TRIGGER
+            dst_struct.drop(dst_curs, objs, log = self.log)
+
+            # drop data
+            if tbl_stat.skip_truncate:
+                self.log.info("%s: skipping truncate" % tbl_stat.name)
+            else:
+                self.log.info("%s: truncating" % tbl_stat.name)
+                dst_curs.execute("truncate " + skytools.quote_fqident(tbl_stat.name))
+
+            if cmode == 2 and tbl_stat.dropped_ddl is None:
+                ddl = dst_struct.get_create_sql(objs)
+                q = "select * from londiste.local_set_table_struct(%s, %s, %s)"
+                self.exec_cmd(dst_curs, q, [self.queue_name, tbl_stat.name, ddl])
+                dst_db.commit()
+                tbl_stat.dropped_ddl = ddl
 
         # do truncate & copy
         self.real_copy(src_curs, dst_curs, tbl_stat, common_cols)
@@ -104,13 +146,25 @@ class CopyTable(Replicator):
         src_db.set_isolation_level(1)
         src_db.commit()
 
-        # restore own triggers
-        q_node_trg = "select * from londiste.node_refresh_triggers(%s, %s)"
-        dst_curs.execute(q_node_trg, [self.set_name, tbl_stat.name])
-
         # create previously dropped objects
-        dst_struct.create(dst_curs, objs, log = self.log)
-        dst_db.commit()
+        if cmode == 1:
+            dst_struct.create(dst_curs, objs, log = self.log)
+        elif cmode == 2:
+            dst_db.commit()
+            while tbl_stat.copy_role == 'lead':
+                self.log.info('waiting for other partitions to finish copy')
+                time.sleep(10)
+                tbl_stat = self.reload_table_stat(dst_curs, tbl_stat.name)
+                dst_db.commit()
+
+            if tbl_stat.dropped_ddl is not None:
+                for ddl in skytools.parse_statements(tbl_stat.dropped_ddl):
+                    self.log.info(ddl)
+                    dst_curs.execute(ddl)
+                q = "select * from londiste.local_set_table_struct(%s, %s, NULL)"
+                self.exec_cmd(dst_curs, q, [self.queue_name, tbl_stat.name])
+                tbl_stat.dropped_ddl = None
+            dst_db.commit()
 
         # set state
         if self.copy_thread:
@@ -121,17 +175,26 @@ class CopyTable(Replicator):
         self.save_table_state(dst_curs)
         dst_db.commit()
 
+        # copy finished
+        if tbl_stat.copy_role == 'wait-replay':
+            return
+
+        # analyze
+        self.log.info("%s: analyze" % tbl_stat.name)
+        dst_curs.execute("analyze " + skytools.quote_fqident(tbl_stat.name))
+        dst_db.commit()
+
+        # if copy done, request immidiate tick from pgqadm,
+        # to make state juggling faster.  on mostly idle db-s
+        # each step may take tickers idle_timeout secs, which is pain.
+        q = "select pgq.force_tick(%s)"
+        src_curs.execute(q, [self.queue_name])
+        src_db.commit()
+
     def real_copy(self, srccurs, dstcurs, tbl_stat, col_list):
-        "Main copy logic."
+        "Actual copy."
 
         tablename = tbl_stat.name
-        # drop data
-        if tbl_stat.skip_truncate:
-            self.log.info("%s: skipping truncate" % tablename)
-        else:
-            self.log.info("%s: truncating" % tablename)
-            dstcurs.execute("truncate " + tablename)
-
         # do copy
         self.log.info("%s: start copy" % tablename)
         stats = skytools.full_copy(tablename, srccurs, dstcurs, col_list)
@@ -139,6 +202,23 @@ class CopyTable(Replicator):
             self.log.info("%s: copy finished: %d bytes, %d rows" % (
                           tablename, stats[0], stats[1]))
 
+    def work(self):
+        if not self.reg_ok:
+            # check if needed? (table, not existing reg)
+            self.register_copy_consumer()
+            self.reg_ok = True
+        return Replicator.work(self)
+
+    def register_copy_consumer(self):
+        # fetch parent consumer state
+        dst_db = self.get_database('db')
+        q = "select * from pgq_node.get_consumer_state(%s, %s)"
+        rows = self.exec_cmd(dst_db, q, [ self.queue_name, self.old_consumer_name ])
+        state = rows[0]
+        loc = state['provider_location']
+
+        self.register_consumer(loc)
+
 if __name__ == '__main__':
     script = CopyTable(sys.argv[1:])
     script.start()