londiste: globbing, --all
authorMarko Kreen <markokr@gmail.com>
Mon, 1 Jun 2009 12:51:54 +0000 (15:51 +0300)
committerMarko Kreen <markokr@gmail.com>
Mon, 1 Jun 2009 13:01:45 +0000 (16:01 +0300)
python/londiste/setup.py

index 9171f59cee952657af72f18e1a26ceb4fa73ca12..ae8567cd48e998bed270f388ac29e3c91f9f6efa 100644 (file)
@@ -3,9 +3,10 @@
 """Londiste setup and sanity checker.
 """
 
-import sys, os, skytools
+import sys, os, re, skytools
 
 from pgq.cascade.admin import CascadeAdmin
+from skytools.scripting import UsageError
 
 __all__ = ['LondisteSetup']
 
@@ -92,10 +93,11 @@ class LondisteSetup(CascadeAdmin):
         self.sync_table_list(dst_curs, src_tbls, dst_tbls)
         dst_db.commit()
 
+        args = self.expand_arg_list(dst_db, 'r', False, args)
+
         # dont check for exist/not here (root handling)
         problems = False
         for tbl in args:
-            tbl = skytools.fq_name(tbl)
             if (tbl in src_tbls) and not src_tbls[tbl]:
                 self.log.error("Table %s does not exist on provider, need to switch to different provider" % tbl)
                 problems = True
@@ -150,7 +152,7 @@ class LondisteSetup(CascadeAdmin):
         q = "select * from londiste.local_add_table(%s, %s)"
         self.exec_cmd(dst_curs, q, [self.set_name, tbl])
         dst_db.commit()
-    
+
     def sync_table_list(self, dst_curs, src_tbls, dst_tbls):
         for tbl in src_tbls.keys():
             q = "select * from londiste.global_add_table(%s, %s)"
@@ -175,8 +177,9 @@ class LondisteSetup(CascadeAdmin):
 
     def cmd_remove_table(self, *args):
         """Detach table(s) from local node."""
-        q = "select * from londiste.local_remove_table(%s, %s)"
         db = self.get_database('db')
+        args = self.expand_arg_list(db, 'r', True, args)
+        q = "select * from londiste.local_remove_table(%s, %s)"
         self.exec_cmd_many(db, q, [self.set_name], args)
 
     def cmd_add_seq(self, *args):
@@ -192,6 +195,8 @@ class LondisteSetup(CascadeAdmin):
         self.sync_seq_list(dst_curs, src_seqs, dst_seqs)
         dst_db.commit()
 
+        args = self.expand_arg_list(dst_db, 'S', False, args)
+
         # pick proper create flags
         create = self.options.create_only
         if not create and self.options.create:
@@ -263,11 +268,13 @@ class LondisteSetup(CascadeAdmin):
         """Detach seqs(s) from local node."""
         q = "select * from londiste.local_remove_seq(%s, %s)"
         db = self.get_database('db')
+        args = self.expand_arg_list(db, 'S', True, args)
         self.exec_cmd_many(db, q, [self.set_name], args)
 
     def cmd_resync(self, *args):
         """Reload data from provider node.."""
         db = self.get_database('db')
+        args = self.expand_arg_list(db, 'r', True, args)
         q = "select * from londiste.local_set_table_state(%s, %s, null, null)"
         self.exec_cmd_many(db, q, [self.set_name], args)
 
@@ -325,168 +332,78 @@ class LondisteSetup(CascadeAdmin):
             self.provider_location = res[0]['provider_location']
         return self.get_database('provider_db', connstr = self.provider_location)
 
-#
-# Old commands
-#
-
-#class LondisteSetup_tmp(LondisteSetup):
-#
-#    def find_missing_provider_tables(self, pattern='*'):
-#        src_db = self.get_database('provider_db')
-#        src_curs = src_db.cursor()
-#        q = """select schemaname || '.' || tablename as full_name from pg_tables
-#                where schemaname not in ('pgq', 'londiste', 'pg_catalog', 'information_schema')
-#                  and schemaname !~ 'pg_.*'
-#                  and (schemaname || '.' || tablename) ~ %s
-#                except select table_name from londiste.provider_get_table_list(%s)"""
-#        src_curs.execute(q, [glob2regex(pattern), self.queue_name])
-#        rows = src_curs.fetchall()
-#        src_db.commit()
-#        list = []
-#        for row in rows:
-#            list.append(row[0])
-#        return list
-#                
-#    def admin(self):
-#        cmd = self.args[2]
-#        if cmd == "tables":
-#            self.subscriber_show_tables()
-#        elif cmd == "missing":
-#            self.subscriber_missing_tables()
-#        elif cmd == "add":
-#            self.subscriber_add_tables(self.args[3:])
-#        elif cmd == "remove":
-#            self.subscriber_remove_tables(self.args[3:])
-#        elif cmd == "resync":
-#            self.subscriber_resync_tables(self.args[3:])
-#        elif cmd == "register":
-#            self.subscriber_register()
-#        elif cmd == "unregister":
-#            self.subscriber_unregister()
-#        elif cmd == "install":
-#            self.subscriber_install()
-#        elif cmd == "check":
-#            self.check_tables(self.get_provider_table_list())
-#        elif cmd in ["fkeys", "triggers"]:
-#            self.collect_meta(self.get_provider_table_list(), cmd, self.args[3:])
-#        elif cmd == "seqs":
-#            self.subscriber_list_seqs()
-#        elif cmd == "add-seq":
-#            self.subscriber_add_seq(self.args[3:])
-#        elif cmd == "remove-seq":
-#            self.subscriber_remove_seq(self.args[3:])
-#        elif cmd == "restore-triggers":
-#            self.restore_triggers(self.args[3], self.args[4:])
-#        else:
-#            self.log.error('bad subcommand: ' + cmd)
-#            sys.exit(1)
-#
-#    def collect_meta(self, table_list, meta, args):
-#        """Display fkey/trigger info."""
-#
-#        if args == []:
-#            args = ['pending', 'active']
-#            
-#        field_map = {'triggers': ['table_name', 'trigger_name', 'trigger_def'],
-#                     'fkeys': ['from_table', 'to_table', 'fkey_name', 'fkey_def']}
-#        
-#        query_map = {'pending': "select %s from londiste.subscriber_get_table_pending_%s(%%s)",
-#                     'active' : "select %s from londiste.find_table_%s(%%s)"}
-#
-#        table_list = self.clean_subscriber_tables(table_list)
-#        if len(table_list) == 0:
-#            self.log.info("No tables, no fkeys")
-#            return
-#
-#        dst_db = self.get_database('subscriber_db')
-#        dst_curs = dst_db.cursor()
-#
-#        for which in args:
-#            union_list = []
-#            fields = field_map[meta]
-#            q = query_map[which] % (",".join(fields), meta)
-#            for tbl in table_list:
-#                union_list.append(q % skytools.quote_literal(tbl))
-#
-#            # use union as fkey may appear in duplicate
-#            sql = " union ".join(union_list) + " order by 1"
-#            desc = "%s %s" % (which, meta)
-#            self.display_table(desc, dst_curs, fields, sql)
-#        dst_db.commit()
-#
-#    def check_tables(self, table_list):
-#        src_db = self.get_database('provider_db')
-#        src_curs = src_db.cursor()
-#        dst_db = self.get_database('subscriber_db')
-#        dst_curs = dst_db.cursor()
-#
-#        failed = 0
-#        for tbl in table_list:
-#            self.log.info('Checking %s' % tbl)
-#            if not skytools.exists_table(src_curs, tbl):
-#                self.log.error('Table %s missing from provider side' % tbl)
-#                failed += 1
-#            elif not skytools.exists_table(dst_curs, tbl):
-#                self.log.error('Table %s missing from subscriber side' % tbl)
-#                failed += 1
-#            else:
-#                failed += self.check_table_columns(src_curs, dst_curs, tbl)
-#
-#        src_db.commit()
-#        dst_db.commit()
-#
-#        return failed
-#
-#    def check_table_columns(self, src_curs, dst_curs, tbl):
-#        src_colrows = find_column_types(src_curs, tbl)
-#        dst_colrows = find_column_types(dst_curs, tbl)
-#
-#        src_cols = make_type_string(src_colrows)
-#        dst_cols = make_type_string(dst_colrows)
-#        if src_cols.find('k') < 0:
-#            self.log.error('provider table %s has no primary key (%s)' % (
-#                             tbl, src_cols))
-#            return 1
-#        if dst_cols.find('k') < 0:
-#            self.log.error('subscriber table %s has no primary key (%s)' % (
-#                             tbl, dst_cols))
-#            return 1
-#
-#        if src_cols != dst_cols:
-#            self.log.warning('table %s structure is not same (%s/%s)'\
-#                 ', trying to continue' % (tbl, src_cols, dst_cols))
-#
-#        err = 0
-#        for row in src_colrows:
-#            found = 0
-#            for row2 in dst_colrows:
-#                if row2['name'] == row['name']:
-#                    found = 1
-#                    break
-#            if not found:
-#                err = 1
-#                self.log.error('%s: column %s on provider not on subscriber'
-#                                    % (tbl, row['name']))
-#            elif row['type'] != row2['type']:
-#                err = 1
-#                self.log.error('%s: pk different on column %s'
-#                                    % (tbl, row['name']))
-#
-#        return err
-#
-#    def find_missing_subscriber_tables(self, pattern='*'):
-#        src_db = self.get_database('subscriber_db')
-#        src_curs = src_db.cursor()
-#        q = """select schemaname || '.' || tablename as full_name from pg_tables
-#                where schemaname not in ('pgq', 'londiste', 'pg_catalog', 'information_schema')
-#                  and schemaname !~ 'pg_.*'
-#                  and schemaname || '.' || tablename ~ %s
-#                except select table_name from londiste.provider_get_table_list(%s)"""
-#        src_curs.execute(q, [glob2regex(pattern), self.queue_name])
-#        rows = src_curs.fetchall()
-#        src_db.commit()
-#        list = []
-#        for row in rows:
-#            list.append(row[0])
-#        return list
-#
+
+    def expand_arg_list(self, db, kind, existing, args):
+        curs = db.cursor()
+
+        if kind == 'S':
+            q1 = "select seq_name, local from londiste.get_seq_list(%s) where local"
+        elif kind == 'r':
+            q1 = "select table_name, local from londiste.get_table_list(%s) where local"
+        else:
+            raise Exception("bug")
+        q2 = "select obj_name from londiste.local_show_missing(%%s) where obj_kind = '%s'" % kind
+
+        lst_exists = []
+        map_exists = {}
+        curs.execute(q1, [self.set_name])
+        for row in curs.fetchall():
+            lst_exists.append(row[0])
+            map_exists[row[0]] = 1
+
+        lst_missing = []
+        map_missing = {}
+        curs.execute(q2, [self.set_name])
+        for row in curs.fetchall():
+            lst_missing.append(row[0])
+            map_missing[row[0]] = 1
+
+        db.commit()
+
+        if not args and self.options.all:
+            if existing:
+                return lst_exists
+            else:
+                return lst_missing
+
+        if existing:
+            res = self.solve_globbing(args, lst_exists, map_exists, map_missing)
+        else:
+            res = self.solve_globbing(args, lst_missing, map_missing, map_exists)
+        return res
+
+
+    def solve_globbing(self, args, full_list, full_map, reverse_map):
+        def glob2regex(s):
+            s = s.replace('.', '[.]').replace('?', '.').replace('*', '.*')
+            return '^%s$' % s
+
+        res_map = {}
+        res_list = []
+        err = 0
+        for a in args:
+            if a.find('*') >= 0 or a.find('?') >= 0:
+                if a.find('.') < 0:
+                    a = 'public.' + a
+                rc = re.compile(glob2regex(a))
+                for x in full_list:
+                    if rc.match(x):
+                        if not x in res_map:
+                            res_map[x] = 1
+                            res_list.append(x)
+            else:
+                a = skytools.fq_name(a)
+                if a in res_map:
+                    continue
+                elif a in full_map:
+                    res_list.append(a)
+                    res_map[a] = 1
+                elif a in reverse_map:
+                    self.log.info("%s already processed" % a)
+                else:
+                    self.log.warning("%s not available" % a)
+                    err = 1
+        if err:
+            raise UsageError("Cannot proceed")
+        return res_list
+