bulk-handler: sync with new api, add GP compat
authorMarko Kreen <markokr@gmail.com>
Fri, 11 Feb 2011 13:36:20 +0000 (15:36 +0200)
committerMarko Kreen <markokr@gmail.com>
Fri, 11 Feb 2011 13:54:22 +0000 (15:54 +0200)
python/londiste/handlers/bulk.py

index 0b45606f7d828e5a771bdbdf89e498ec4b9cf6b5..5f1177d5980f3d574c5e235628e1b08df28e2f33 100644 (file)
@@ -55,28 +55,19 @@ class BulkLoader(BaseHandler):
     """
     handler_name = 'bulk'
     fake_seq = 0
-    def __init__(self, table_name, next, args, log):
+    def __init__(self, table_name, args, log):
         """Init per-batch table data cache."""
 
-        BaseHandler.__init__(self, table_name, next, args, log)
-
-        self.method = DEFAULT_METHOD
+        BaseHandler.__init__(self, table_name, args, log)
 
         self.pkey_list = None
         self.dist_fields = None
         self.col_list = None
 
         self.pkey_ev_map = {}
-
-        for a in args:
-            k, v = a.split('=')
-            if k == 'method':
-                m = int(v)
-                if m not in (0,1,2):
-                    raise Exception('unknown method: %s' % v)
-                self.method = int(v)
-            else:
-                raise Exception('unknown argument: %s' % a)
+        self.method = int(args.get('method', DEFAULT_METHOD))
+        if not self.method in (0,1,2):
+            raise Exception('unknown method: %s' % self.method)
 
         self.log.debug('bulk_init(%s), method=%d' % (repr(args), self.method))
 
@@ -95,12 +86,13 @@ class BulkLoader(BaseHandler):
         if op not in 'IUD':
             raise Exception('Unknown event type: '+ev.ev_type)
         self.log.debug('bulk.process_event: %s/%s' % (ev.ev_type, ev.ev_data))
-        pkey_list = ev.ev_type[2:].split(',')
+        pkey_list = ev.ev_type[2:].split(',')
         data = skytools.db_urldecode(ev.ev_data)
 
         # get pkey value
         if self.pkey_list is None:
-            self.pkey_list = pkey_list
+            #self.pkey_list = pkey_list
+            self.pkey_list = ev.ev_type[2:].split(',')
         if len(self.pkey_list) > 0:
             pk_data = tuple(data[k] for k in self.pkey_list)
         elif op == 'I':
@@ -157,7 +149,7 @@ class BulkLoader(BaseHandler):
 
             # take last event
             ev = ev_list[-1]
-            
+
             # generate needed commands
             if exists_before and exists_after:
                 upd_list.append(ev.data)
@@ -204,7 +196,7 @@ class BulkLoader(BaseHandler):
 
         qtbl = quote_fqident(self.table_name)
         qtemp = quote_ident(temp)
-        
+
         # where expr must have pkey and dist fields
         klist = []
         for pk in key_fields:
@@ -318,9 +310,9 @@ class BulkLoader(BaseHandler):
             if skytools.exists_temp_table(curs, tempname):
                 self.log.debug("bulk: Using existing temp table %s" % tempname)
                 return tempname
-    
+
         # bizgres crashes on delete rows
-        arg = "on commit delete rows"
+        # removed arg = "on commit delete rows"
         arg = "on commit preserve rows"
         # create temp table for loading
         q = "create temp table %s (like %s) %s" % (
@@ -330,12 +322,12 @@ class BulkLoader(BaseHandler):
         return tempname
 
     def find_dist_fields(self, curs):
-        if not skytools.exists_table(curs, "pg_catalog.mpp_distribution_policy"):
+        if not skytools.exists_table(curs, "pg_catalog.gp_distribution_policy"):
             return []
         schema, name = skytools.fq_name_parts(self.table_name)
         q = "select a.attname"\
             "  from pg_class t, pg_namespace n, pg_attribute a,"\
-            "       mpp_distribution_policy p"\
+            "       gp_distribution_policy p"\
             " where n.oid = t.relnamespace"\
             "   and p.localoid = t.oid"\
             "   and a.attrelid = t.oid"\