#! /usr/bin/env python
 
 """Londiste setup and sanity checker.
-
 """
+
 import sys, os, skytools
-from installer import *
 
 __all__ = ['LondisteSetup']
 
-def find_column_types(curs, table):
-    table_oid = skytools.get_table_oid(curs, table)
-    if table_oid == None:
-        return None
-
-    key_sql = """
-        SELECT k.attname FROM pg_index i, pg_attribute k
-         WHERE i.indrelid = %d AND k.attrelid = i.indexrelid
-           AND i.indisprimary AND k.attnum > 0 AND NOT k.attisdropped
-        """ % table_oid
-
-    # find columns
-    q = """
-        SELECT a.attname as name,
-               CASE WHEN k.attname IS NOT NULL
-                    THEN 'k' ELSE 'v' END AS type
-          FROM pg_attribute a LEFT JOIN (%s) k ON (k.attname = a.attname)
-         WHERE a.attrelid = %d AND a.attnum > 0 AND NOT a.attisdropped
-         ORDER BY a.attnum
-         """ % (key_sql, table_oid)
-    curs.execute(q)
-    rows = curs.dictfetchall()
-    return rows
-
-def make_type_string(col_rows):
-    res = map(lambda x: x['type'], col_rows)
-    return "".join(res)
-
-def convertGlobs(s):
-    return s.replace('?', '.').replace('*', '.*')
-
-def glob2regex(gpat):
-    plist = [convertGlobs(s) for s in gpat.split('.')]
-    return '^%s$' % '[.]'.join(plist)
-
-class CommonSetup(skytools.DBScript):
+class LondisteSetup(skytools.DBScript):
     def __init__(self, args):
         skytools.DBScript.__init__(self, 'londiste', args)
         self.set_single_loop(1)
         self.pidfile = self.pidfile + ".setup"
 
-        self.pgq_queue_name = self.cf.get("pgq_queue_name")
+        self.set_name = self.cf.get("set_name")
         self.consumer_id = self.cf.get("pgq_consumer_id", self.job_name)
-        self.fake = self.cf.getint('fake', 0)
 
         if len(self.args) < 3:
             self.log.error("need subcommand")
             sys.exit(1)
 
     def run(self):
-        self.admin()
-
-    def fetch_provider_table_list(self, curs, pattern='*'):
-        q = """select table_name, trigger_name
-                 from londiste.provider_get_table_list(%s)
-                 where table_name ~ %s"""
-        curs.execute(q, [self.pgq_queue_name, glob2regex(pattern)])
-        return curs.dictfetchall()
+        cmd = self.args[2]
+        fname = "cmd_" + cmd.replace('-', '_')
+        if hasattr(self, fname):
+            getattr(self, fname)(self, self.args[3:])
+        else:
+            self.log.error('bad subcommand')
+            sys.exit(1)
 
-    def get_provider_table_list(self, pattern='*'):
-        src_db = self.get_database('provider_db')
-        src_curs = src_db.cursor()
-        list = self.fetch_provider_table_list(src_curs, pattern)
-        src_db.commit()
-        res = []
-        for row in list:
-            res.append(row['table_name'])
+    def fetch_list(self, curs, sql, args, keycol = None):
+        curs.execute(sql, args)
+        rows = curs.dictfetchall()
+        if not keycol:
+            res = rows
+        else:
+            res = [r[keycol] for r in rows]
         return res
 
-    def get_provider_seqs(self):
-        src_db = self.get_database('provider_db')
-        src_curs = src_db.cursor()
-        q = """SELECT * from londiste.provider_get_seq_list(%s)"""
-        src_curs.execute(q, [self.pgq_queue_name])
-        src_db.commit()
-        res = []
-        for row in src_curs.fetchall():
-            res.append(row[0])
+    def db_fetch_list(self, sql, args, keycol = None):
+        db = self.get_database('node_db')
+        curs = db.cursor()
+        res = self.fetch_list(curs, sql, keycol)
+        db.commit()
         return res
 
-    def get_all_seqs(self, curs):
-        q = """SELECT n.nspname || '.'|| c.relname
-                 from pg_class c, pg_namespace n
-                where n.oid = c.relnamespace 
-                  and c.relkind = 'S'
-                order by 1"""
-        curs.execute(q)
-        res = []
-        for row in curs.fetchall():
-            res.append(row[0])
-        return res
+    def display_table(self, desc, curs, sql, args = [], fields = []):
+        """Display multirow query as a table."""
 
-    def check_provider_queue(self):
-        src_db = self.get_database('provider_db')
-        src_curs = src_db.cursor()
-        q = "select count(1) from pgq.get_queue_info(%s)"
-        src_curs.execute(q, [self.pgq_queue_name])
-        ok = src_curs.fetchone()[0]
-        src_db.commit()
+        curs.execute(sql, args)
+        rows = curs.fetchall()
+        if len(rows) == 0:
+            return 0
+
+        if not fields:
+            fields = [f[0] for f in curs.description]
         
-        if not ok:
-            self.log.error('Event queue does not exist yet')
-            sys.exit(1)
+        widths = [15] * len(fields)
+        for row in rows:
+            for i, k in enumerate(fields):
+                widths[i] = widths[i] > len(row[k]) and widths[i] or len(row[k])
+        widths = [w + 2 for w in widths]
 
-    def fetch_subscriber_tables(self, curs, pattern = '*'):
-        q = "select * from londiste.subscriber_get_table_list(%s) where table_name ~ %s"
-        curs.execute(q, [self.pgq_queue_name, glob2regex(pattern)])
-        return curs.dictfetchall()
+        fmt = '%%-%ds' * (len(widths) - 1) + '%%s'
+        fmt = fmt % tuple(widths[:-1])
+        if desc:
+            print desc
+        print fmt % tuple(fields)
+        print fmt % tuple(['-'*15] * len(fields))
+            
+        for row in rows:
+            print fmt % tuple([row[k] for k in fields])
+        print '\n'
+        return 1
 
-    def get_subscriber_table_list(self, pattern = '*'):
-        dst_db = self.get_database('subscriber_db')
-        dst_curs = dst_db.cursor()
-        list = self.fetch_subscriber_tables(dst_curs, pattern)
-        dst_db.commit()
-        res = []
-        for row in list:
-            res.append(row['table_name'])
+    def db_display_table(self, desc, sql, args = [], fields = []):
+        db = self.get_database('node_db')
+        curs = db.cursor()
+        res = self.display_table(desc, curs, sql, args, fields)
+        db.commit()
         return res
+        
 
     def init_optparse(self, parser=None):
         p = skytools.DBScript.init_optparse(self, parser)
                     help="include all tables", default=False)
         return p
 
+    def exec_checked(self, curs, sql, args):
+        curs.execute(sql, args)
+        ok = True
+        for row in curs.fetchall():
+            if (res[0] % 100) == 2:
+                self.log.info("%d %s" % (res[0], res[1]))
+            else:
+                self.log.error("%d %s" % (res[0], res[1]))
+                ok = False
+        return ok
+
+    def exec_many(self, curs, sql, baseargs, extra_list):
+        res = True
+        for a in extra_list:
+            ok = self.exec_checked(curs, sql, baseargs + [a])
+            if not ok:
+                res = False
+        return res
 
-#
-# Provider commands
-#
+    def db_exec_many(self, sql, baseargs, extra_list):
+        db = self.get_database('node_db')
+        curs = db.cursor()
+        ok = self.exec_many(curs, sql, baseargs, extra_list)
+        if ok:
+            self.log.info("COMMIT")
+            db.commit()
+        else:
+            self.log.info("ROLLBACK")
+            db.rollback()
 
-class LondisteSetup(CommonSetup):
+    def cmd_add(self, args = []):
+        q = "select londiste.node_add_table(%s, %s)"
+        self.db_exec_many(q, [self.set_name], args)
 
-    def admin(self):
-        cmd = self.args[2]
-        if cmd == "tables":
-            self.provider_show_tables()
-        elif cmd == "add":
-            self.provider_add_tables(self.args[3:])
-        elif cmd == "remove":
-            self.provider_remove_tables(self.args[3:])
-        elif cmd == "add-seq":
-            self.provider_add_seq_list(self.args[3:])
-        elif cmd == "remove-seq":
-            self.provider_remove_seq_list(self.args[3:])
-        elif cmd == "install":
-            self.provider_install()
-        elif cmd == "seqs":
-            self.provider_list_seqs()
-        else:
-            self.log.error('bad subcommand')
-            sys.exit(1)
+    def cmd_remove(self, args = []):
+        q = "select londiste.node_remove_table(%s, %s)"
+        self.db_exec_many(q, [self.set_name], args)
 
-    def provider_list_seqs(self):
-        list = self.get_provider_seqs()
-        for seq in list:
-            print seq
+    def cmd_add_seq(self, args = []):
+        q = "select londiste.node_add_seq(%s, %s)"
+        self.db_exec_many(q, [self.set_name], args)
 
-    def provider_get_all_seqs(self):
-        src_db = self.get_database('provider_db')
-        src_curs = src_db.cursor()
-        list = self.get_all_seqs(src_curs)
-        src_db.commit()
-        return list
+    def cmd_remove_seq(self, args = []):
+        q = "select londiste.node_remove_seq(%s, %s)"
+        self.db_exec_many(q, [self.set_name], args)
 
-    def provider_add_seq_list(self, seq_list):
-        if not seq_list and self.options.all:
-            seq_list = self.provider_get_all_seqs()
+    def cmd_resync(self, args = []):
+        q = "select londiste.node_resync_table(%s, %s)"
+        self.db_exec_many(q, [self.set_name], args)
 
-        for seq in self.args[3:]:
-            self.provider_add_seq(seq)
-        self.provider_notify_change()
+    def cmd_tables(self, args = []):
+        q = "select table_name, merge_state from londiste.node_get_table_list(%s)"
+        self.db_display_table("Tables on node", q, [self.set_name])
 
-    def provider_remove_seq_list(self, seq_list):
-        if not seq_list and self.options.all:
-            seq_list = self.get_provider_seqs()
+    def cmd_seqs(self, args = []):
+        q = "select seq_namefrom londiste.node_get_seq_list(%s)"
+        self.db_display_table("Sequences on node", q, [self.set_name])
 
-        for seq in seq_list:
-            self.provider_remove_seq(seq)
-        self.provider_notify_change()
+    def cmd_missing(self, args = []):
+        q = "select * from londiste.node_show_missing(%s)"
+        self.db_display_table("MIssing objects on node", q, [self.set_name])
 
-    def provider_install(self):
-        src_db = self.get_database('provider_db')
-        src_curs = src_db.cursor()
-        install_provider(src_curs, self.log)
+    def cmd_check(self, args = []):
+        pass
+    def cmd_fkeys(self, args = []):
+        pass
+    def cmd_triggers(self, args = []):
+        pass
 
-        # create event queue
-        q = "select pgq.create_queue(%s)"
-        self.exec_provider(q, [self.pgq_queue_name])
+#
+# Old commands
+#
+
+class LondisteSetup_tmp:
 
     def find_missing_provider_tables(self, pattern='*'):
         src_db = self.get_database('provider_db')
             list.append(row[0])
         return list
                 
-    def provider_add_tables(self, table_list):
-        self.check_provider_queue()
-
-        if self.options.all and not table_list:
-            table_list = ['*.*']
-
-        cur_list = self.get_provider_table_list()
-        for tbl in table_list:
-            tbls = self.find_missing_provider_tables(skytools.fq_name(tbl))
-            
-            for tbl in tbls:
-                if tbl not in cur_list:
-                    self.log.info('Adding %s' % tbl)
-                    self.provider_add_table(tbl)
-                else:
-                    self.log.info("Table %s already added" % tbl)
-        self.provider_notify_change()
-
-    def provider_remove_tables(self, table_list):
-        self.check_provider_queue()
-
-        cur_list = self.get_provider_table_list()
-        if not table_list and self.options.all:
-            table_list = cur_list
-
-        for tbl in table_list:
-            tbls = self.get_provider_table_list(skytools.fq_name(tbl))
-            for tbl in tbls:
-                if tbl not in cur_list:
-                    self.log.info('%s already removed' % tbl)
-                else:
-                    self.log.info("Removing %s" % tbl)
-                    self.provider_remove_table(tbl)
-        self.provider_notify_change()
-
-    def provider_add_table(self, tbl):
-        q = "select londiste.provider_add_table(%s, %s)"
-        self.exec_provider(q, [self.pgq_queue_name, tbl])
-
-    def provider_remove_table(self, tbl):
-        q = "select londiste.provider_remove_table(%s, %s)"
-        self.exec_provider(q, [self.pgq_queue_name, tbl])
-
-    def provider_show_tables(self):
-        self.check_provider_queue()
-        list = self.get_provider_table_list()
-        for tbl in list:
-            print tbl
-
-    def provider_notify_change(self):
-        q = "select londiste.provider_notify_change(%s)"
-        self.exec_provider(q, [self.pgq_queue_name])
-
-    def provider_add_seq(self, seq):
-        seq = skytools.fq_name(seq)
-        q = "select londiste.provider_add_seq(%s, %s)"
-        self.exec_provider(q, [self.pgq_queue_name, seq])
-
-    def provider_remove_seq(self, seq):
-        seq = skytools.fq_name(seq)
-        q = "select londiste.provider_remove_seq(%s, %s)"
-        self.exec_provider(q, [self.pgq_queue_name, seq])
-
-    def exec_provider(self, sql, args):
-        src_db = self.get_database('provider_db')
-        src_curs = src_db.cursor()
-
-        src_curs.execute(sql, args)
-
-        if self.fake:
-            src_db.rollback()
-        else:
-            src_db.commit()
-
-#
-# Subscriber commands
-#
-
-class SubscriberSetup(CommonSetup):
-
     def admin(self):
         cmd = self.args[2]
         if cmd == "tables":
             self.display_table(desc, dst_curs, fields, sql)
         dst_db.commit()
 
-    def display_table(self, desc, curs, fields, sql, args = []):
-        """Display multirow query as a table."""
-
-        curs.execute(sql, args)
-        rows = curs.dictfetchall()
-        if len(rows) == 0:
-            return 0
-        
-        widths = [15] * len(fields)
-        for row in rows:
-            for i, k in enumerate(fields):
-                widths[i] = widths[i] > len(row[k]) and widths[i] or len(row[k])
-        widths = [w + 2 for w in widths]
-
-        fmt = '%%-%ds' * (len(widths) - 1) + '%%s'
-        fmt = fmt % tuple(widths[:-1])
-        print desc
-        print fmt % tuple(fields)
-        print fmt % tuple(['-'*15] * len(fields))
-            
-        for row in rows:
-            print fmt % tuple([row[k] for k in fields])
-        print '\n'
-        return 1
-
-    def clean_subscriber_tables(self, table_list):
-        """Returns fully-quelifies table list of tables
-        that are registered on subscriber.
-        """
-        subscriber_tables = self.get_subscriber_table_list()
-        if not table_list and self.options.all:
-            table_list = subscriber_tables
-        else:
-            newlist = []
-            for tbl in table_list:
-                tbl = skytools.fq_name(tbl)
-                if tbl in subscriber_tables:
-                    newlist.append(tbl)
-                else:
-                    #self.log.warning("table %s not subscribed" % tbl)
-                    pass
-            table_list = newlist
-        return table_list
-
     def check_tables(self, table_list):
         src_db = self.get_database('provider_db')
         src_curs = src_db.cursor()
 
         return err
 
-    def subscriber_install(self):
-        dst_db = self.get_database('subscriber_db')
-        dst_curs = dst_db.cursor()
-
-        install_subscriber(dst_curs, self.log)
-
-        if self.fake:
-            self.log.debug('rollback')
-            dst_db.rollback()
-        else:
-            self.log.debug('commit')
-            dst_db.commit()
-
-    def subscriber_register(self):
-        src_db = self.get_database('provider_db')
-        src_curs = src_db.cursor()
-        src_curs.execute("select pgq.register_consumer(%s, %s)",
-            [self.pgq_queue_name, self.consumer_id])
-        src_db.commit()
-
-    def subscriber_unregister(self):
-        q = "select londiste.subscriber_set_table_state(%s, %s, NULL, NULL)"
-
-        dst_db = self.get_database('subscriber_db')
-        dst_curs = dst_db.cursor()
-        tbl_rows = self.fetch_subscriber_tables(dst_curs)
-        for row in tbl_rows:
-            dst_curs.execute(q, [self.pgq_queue_name, row['table_name']])
-        dst_db.commit()
-
-        src_db = self.get_database('provider_db')
-        src_curs = src_db.cursor()
-        src_curs.execute("select pgq.unregister_consumer(%s, %s)",
-            [self.pgq_queue_name, self.consumer_id])
-        src_db.commit()
-
-    def subscriber_show_tables(self):
-        """print out subscriber table list, with state and snapshot"""
-        dst_db = self.get_database('subscriber_db')
-        dst_curs = dst_db.cursor()
-        list = self.fetch_subscriber_tables(dst_curs)
-        dst_db.commit()
-
-        format = "%-30s   %20s"
-        print format % ("Table", "State")
-        for tbl in list:
-            print format % (tbl['table_name'],
-                            tbl['merge_state'] or '-')
-
-    def subscriber_missing_tables(self):
-        provider_tables = self.get_provider_table_list()
-        subscriber_tables = self.get_subscriber_table_list()
-        for tbl in provider_tables:
-            if tbl not in subscriber_tables:
-                print "Table: %s" % tbl
-        provider_seqs = self.get_provider_seqs()
-        subscriber_seqs = self.get_subscriber_seq_list()
-        for seq in provider_seqs:
-            if seq not in subscriber_seqs:
-                print "Sequence: %s" % seq
-
     def find_missing_subscriber_tables(self, pattern='*'):
         src_db = self.get_database('subscriber_db')
         src_curs = src_db.cursor()
             list.append(row[0])
         return list
 
-    def subscriber_add_tables(self, table_list):
-        provider_tables = self.get_provider_table_list()
-        subscriber_tables = self.get_subscriber_table_list()
-
-        if not table_list and self.options.all:
-            table_list = ['*.*']
-            for tbl in provider_tables:
-                if tbl not in subscriber_tables:
-                    table_list.append(tbl)
-        
-        tbls = []
-        for tbl in table_list:
-            more = self.find_missing_subscriber_tables(skytools.fq_name(tbl))
-            if more == []:
-                self.log.info("No tables found that match %s" % tbl)
-            tbls.extend(more)
-        tbls = list(set(tbls))
-
-        err = 0
-        table_list = []
-        for tbl in tbls:
-            if tbl not in provider_tables:
-                err = 1
-                self.log.error("Table %s not attached to queue" % tbl)
-                if not self.options.force:
-                    continue
-            table_list.append(tbl)
-                
-        if err:
-            if self.options.force:
-                self.log.warning('--force used, ignoring errors')
-
-        err = self.check_tables(table_list)
-        if err:
-            if self.options.force:
-                self.log.warning('--force used, ignoring errors')
-            else:
-                sys.exit(1)
-
-        dst_db = self.get_database('subscriber_db')
-        dst_curs = dst_db.cursor()
-        for tbl in table_list:
-            if tbl in subscriber_tables:
-                self.log.info("Table %s already added" % tbl)
-            else:
-                self.log.info("Adding %s" % tbl)
-                self.subscriber_add_one_table(dst_curs, tbl)
-            dst_db.commit()
-
-    def subscriber_remove_tables(self, table_list):
-        subscriber_tables = self.get_subscriber_table_list()
-        if not table_list and self.options.all:
-            table_list = ['*.*']
-            
-        for tbl in table_list:
-            tbls = self.get_subscriber_table_list(skytools.fq_name(tbl))
-            for tbl in tbls:
-                if tbl in subscriber_tables:
-                    self.log.info("Removing: %s" % tbl)
-                    self.subscriber_remove_one_table(tbl)
-                else:
-                    self.log.info("Table %s already removed" % tbl)
-
-    def subscriber_resync_tables(self, table_list):
-        dst_db = self.get_database('subscriber_db')
-        dst_curs = dst_db.cursor()
-        list = self.fetch_subscriber_tables(dst_curs)
-
-        if not table_list and self.options.all:
-            table_list = self.get_subscriber_table_list()
-
-        for tbl in table_list:
-            tbl = skytools.fq_name(tbl)
-            tbl_row = None
-            for row in list:
-                if row['table_name'] == tbl:
-                    tbl_row = row
-                    break
-            if not tbl_row:
-                self.log.warning("Table %s not found" % tbl)
-            elif tbl_row['merge_state'] != 'ok':
-                self.log.warning("Table %s is not in stable state" % tbl)
-            else:
-                self.log.info("Resyncing %s" % tbl)
-                q = "select londiste.subscriber_set_table_state(%s, %s, NULL, NULL)"
-                dst_curs.execute(q, [self.pgq_queue_name, tbl])
-        dst_db.commit()
-
-    def subscriber_add_one_table(self, dst_curs, tbl):
-        q_add = "select londiste.subscriber_add_table(%s, %s)"
-        q_triggers = "select londiste.subscriber_drop_all_table_triggers(%s)"
-
-        if self.options.expect_sync and self.options.skip_truncate:
-            self.log.error("Too many options: --expect-sync and --skip-truncate")
-            sys.exit(1)
-
-        dst_curs.execute(q_add, [self.pgq_queue_name, tbl])
-        if self.options.expect_sync:
-            q = "select londiste.subscriber_set_table_state(%s, %s, null, 'ok')"
-            dst_curs.execute(q, [self.pgq_queue_name, tbl])
-            return
-
-        dst_curs.execute(q_triggers, [tbl])
-        if self.options.skip_truncate:
-            q = "select londiste.subscriber_set_skip_truncate(%s, %s, true)"
-            dst_curs.execute(q, [self.pgq_queue_name, tbl])
-
-    def subscriber_remove_one_table(self, tbl):
-        q_remove = "select londiste.subscriber_remove_table(%s, %s)"
-        q_triggers = "select londiste.subscriber_restore_all_table_triggers(%s)"
-
-        dst_db = self.get_database('subscriber_db')
-        dst_curs = dst_db.cursor()
-        dst_curs.execute(q_remove, [self.pgq_queue_name, tbl])
-        dst_curs.execute(q_triggers, [tbl])
-        dst_db.commit()
-
-    def get_subscriber_seq_list(self):
-        dst_db = self.get_database('subscriber_db')
-        dst_curs = dst_db.cursor()
-        q = "SELECT * from londiste.subscriber_get_seq_list(%s)"
-        dst_curs.execute(q, [self.pgq_queue_name])
-        list = dst_curs.fetchall()
-        dst_db.commit()
-        res = []
-        for row in list:
-            res.append(row[0])
-        return res
-
-    def subscriber_list_seqs(self):
-        list = self.get_subscriber_seq_list()
-        for seq in list:
-            print seq
-
-    def subscriber_add_seq(self, seq_list):
-        src_db = self.get_database('provider_db')
-        src_curs = src_db.cursor()
-        dst_db = self.get_database('subscriber_db')
-        dst_curs = dst_db.cursor()
-        
-        prov_list = self.get_provider_seqs()
-
-        full_list = self.get_all_seqs(dst_curs)
-        cur_list = self.get_subscriber_seq_list()
-
-        if not seq_list and self.options.all:
-            seq_list = prov_list
-        
-        for seq in seq_list:
-            seq = skytools.fq_name(seq)
-            if seq not in prov_list:
-                self.log.error('Seq %s does not exist on provider side' % seq)
-                continue
-            if seq not in full_list:
-                self.log.error('Seq %s does not exist on subscriber side' % seq)
-                continue
-            if seq in cur_list:
-                self.log.info('Seq %s already subscribed' % seq)
-                continue
-
-            self.log.info('Adding sequence: %s' % seq)
-            q = "select londiste.subscriber_add_seq(%s, %s)"
-            dst_curs.execute(q, [self.pgq_queue_name, seq])
-
-        dst_db.commit()
-
-    def subscriber_remove_seq(self, seq_list):
-        dst_db = self.get_database('subscriber_db')
-        dst_curs = dst_db.cursor()
-        cur_list = self.get_subscriber_seq_list()
-
-        if not seq_list and self.options.all:
-            seq_list = cur_list
-
-        for seq in seq_list:
-            seq = skytools.fq_name(seq)
-            if seq not in cur_list:
-                self.log.warning('Seq %s not subscribed')
-            else:
-                self.log.info('Removing sequence: %s' % seq)
-                q = "select londiste.subscriber_remove_seq(%s, %s)"
-                dst_curs.execute(q, [self.pgq_queue_name, seq])
-        dst_db.commit()
-