londiste: new --all switch for provider/subscriber add/remove
authorMarko Kreen <markokr@gmail.com>
Tue, 31 Jul 2007 19:05:07 +0000 (19:05 +0000)
committerMarko Kreen <markokr@gmail.com>
Tue, 31 Jul 2007 19:05:07 +0000 (19:05 +0000)
provider add --all adds all tables that exist in db, otherwise works on registered tables

By Hans-Juergen Schoenig, plus some cleanup from me

python/londiste.py
python/londiste/setup.py

index 916b104f7e4138b132b0e3f03ab7481a7c194b6c..ef0eb9dfaaae08ae465e42ab88f308a8f2dd1d5b 100755 (executable)
@@ -110,6 +110,8 @@ class Londiste(skytools.DBScript):
         p.set_usage(command_usage.strip())
 
         g = optparse.OptionGroup(p, "expert options")
+        g.add_option("--all", action="store_true",
+                help = "add: include add possible tables")
         g.add_option("--force", action="store_true",
                 help = "add: ignore table differences, repair: ignore lag")
         g.add_option("--expect-sync", action="store_true", dest="expect_sync",
index deda3117f5096cd135b2ae6477a3bcd1652e070e..2e20b164ca405ca24a94a925a5cf988ae2bfc568 100644 (file)
@@ -124,6 +124,8 @@ class CommonSetup(skytools.DBScript):
                     help = "dont delete old data", default=False)
         p.add_option("--force", action="store_true",
                     help="force", default=False)
+        p.add_option("--all", action="store_true",
+                    help="include all tables", default=False)
         return p
 
 
@@ -175,9 +177,27 @@ class ProviderSetup(CommonSetup):
         q = "select pgq.create_queue(%s)"
         self.exec_provider(q, [self.pgq_queue_name])
 
+    def find_missing_provider_tables(self):
+        src_db = self.get_database('provider_db')
+        src_curs = src_db.cursor()
+        q = """select schemaname || '.' || tablename as full_name from pg_tables
+                where schemaname not in ('pgq', 'londiste', 'pg_catalog', 'information_schema')
+                  and schemaname !~ 'pg_.*'
+                except select table_name from londiste.provider_get_table_list(%s)"""
+        src_curs.execute(q, [self.pgq_queue_name])
+        rows = src_curs.fetchall()
+        src_db.commit()
+        list = []
+        for row in rows:
+            list.append(row[0])
+        return list
+
     def provider_add_tables(self, table_list):
         self.check_provider_queue()
 
+        if self.options.all and not table_list:
+            table_list = self.find_missing_provider_tables()
+
         cur_list = self.get_provider_table_list()
         for tbl in table_list:
             if tbl.find('.') < 0:
@@ -193,6 +213,10 @@ class ProviderSetup(CommonSetup):
         self.check_provider_queue()
 
         cur_list = self.get_provider_table_list()
+
+        if not table_list and self.options.all:
+            table_list = cur_list
+
         for tbl in table_list:
             if tbl.find('.') < 0:
                 tbl = "public." + tbl
@@ -451,6 +475,12 @@ class SubscriberSetup(CommonSetup):
         provider_tables = self.get_provider_table_list()
         subscriber_tables = self.get_subscriber_table_list()
 
+        if not table_list and self.options.all:
+            table_list = []
+            for tbl in provider_tables:
+                if tbl not in subscriber_tables:
+                    table_list.append(tbl)
+
         err = 0
         for tbl in table_list:
             tbl = skytools.fq_name(tbl)
@@ -483,9 +513,12 @@ class SubscriberSetup(CommonSetup):
 
     def subscriber_remove_tables(self, table_list):
         subscriber_tables = self.get_subscriber_table_list()
+        if not table_list and self.options.all:
+            table_list = subscriber_tables
         for tbl in table_list:
             tbl = skytools.fq_name(tbl)
             if tbl in subscriber_tables:
+                self.log.info("Removing: %s" % tbl)
                 self.subscriber_remove_one_table(tbl)
             else:
                 self.log.info("Table %s already removed" % tbl)
@@ -494,6 +527,10 @@ class SubscriberSetup(CommonSetup):
         dst_db = self.get_database('subscriber_db')
         dst_curs = dst_db.cursor()
         list = self.fetch_subscriber_tables(dst_curs)
+
+        if not table_list and self.options.all:
+            table_list = self.get_subscriber_table_list()
+
         for tbl in table_list:
             tbl = skytools.fq_name(tbl)
             tbl_row = None