londiste: initial implementation of shard-ed Dispatch handler
authormartinko <gamato@users.sf.net>
Thu, 18 Apr 2013 14:06:56 +0000 (16:06 +0200)
committermartinko <gamato@users.sf.net>
Thu, 18 Apr 2013 14:06:56 +0000 (16:06 +0200)
python/londiste/handler.py
python/londiste/handlers/dispatch.py
python/londiste/handlers/part.py

index 51fa603ae4cadb692298c44fd555c4fa9b9af6a3..bdf0c40809902d9325a3e0987d3bd3090fac5a30 100644 (file)
@@ -178,9 +178,9 @@ class TableHandler(BaseHandler):
 
         enc = args.get('encoding')
         if enc:
-            self.enc = EncodingValidator(self.log, enc)
+            self.encoding_validator = EncodingValidator(self.log, enc)
         else:
-            self.enc = None
+            self.encoding_validator = None
 
     def process_event(self, ev, sql_queue_func, arg):
         row = self.parse_row_data(ev)
@@ -212,13 +212,13 @@ class TableHandler(BaseHandler):
         if len(ev.type) == 1:
             if not self.allow_sql_event:
                 raise Exception('SQL events not supported by this handler')
-            if self.enc:
-                return self.enc.validate_string(ev.data, self.table_name)
+            if self.encoding_validator:
+                return self.encoding_validator.validate_string(ev.data, self.table_name)
             return ev.data
         else:
             row = skytools.db_urldecode(ev.data)
-            if self.enc:
-                return self.enc.validate_dict(row, self.table_name)
+            if self.encoding_validator:
+                return self.encoding_validator.validate_dict(row, self.table_name)
             return row
 
     def real_copy(self, src_tablename, src_curs, dst_curs, column_list):
@@ -226,9 +226,9 @@ class TableHandler(BaseHandler):
         copied
         """
 
-        if self.enc:
+        if self.encoding_validator:
             def _write_hook(obj, data):
-                return self.enc.validate_copy(data, column_list, src_tablename)
+                return self.encoding_validator.validate_copy(data, column_list, src_tablename)
         else:
             _write_hook = None
         condition = self.get_copy_condition(src_curs, dst_curs)
index 758034c7ed33b1f2decd00b50b0a99de423f0d77..b50b95a20500a94f3b9a8d93f84350bd3e64b403 100644 (file)
@@ -153,17 +153,20 @@ creating or coping initial data to destination table.  --expect-sync and
 --skip-truncate should be used and --create switch is to be avoided.
 """
 
-import sys
-import datetime
 import codecs
+import datetime
 import re
+import sys
+from functools import partial
+
 import skytools
-from londiste.handler import BaseHandler, EncodingValidator
 from skytools import quote_ident, quote_fqident, UsageError
 from skytools.dbstruct import *
 from skytools.utf8 import safe_utf8_decode
-from functools import partial
+
+from londiste.handler import EncodingValidator
 from londiste.handlers import handler_args, update
+from londiste.handlers.part import PartHandler
 
 
 __all__ = ['Dispatcher']
@@ -618,7 +621,7 @@ ROW_HANDLERS = {'plain': RowHandler,
 #------------------------------------------------------------------------------
 
 
-class Dispatcher(BaseHandler):
+class Dispatcher (PartHandler):
     """Partitioned loader.
     Splits events into partitions, if requested.
     Then applies them without further processing.
@@ -630,7 +633,7 @@ class Dispatcher(BaseHandler):
         # compat for dest-table
         dest_table = args.get('table', dest_table)
 
-        BaseHandler.__init__(self, table_name, args, dest_table)
+        super(Dispatcher, self).__init__(table_name, args, dest_table)
 
         # show args
         self.log.debug("dispatch.init: table_name=%r, args=%r", table_name, args)
@@ -641,11 +644,6 @@ class Dispatcher(BaseHandler):
         self.conf = self.get_config()
         hdlr_cls = ROW_HANDLERS[self.conf.row_mode]
         self.row_handler = hdlr_cls(self.log)
-        if self.conf.encoding:
-            self.encoding_validator = EncodingValidator(self.log,
-                                                        self.conf.encoding)
-        else:
-            self.encoding_validator = None
 
     def _parse_args_from_doc (self):
         doc = __doc__
@@ -717,8 +715,6 @@ class Dispatcher(BaseHandler):
                     conf.field_map[tmp[0]] = tmp[0]
                 else:
                     conf.field_map[tmp[0]] = tmp[1]
-        # encoding validator
-        conf.encoding = self.args.get('encoding')
         return conf
 
     def get_arg(self, name, value_list, default = None):
@@ -728,17 +724,21 @@ class Dispatcher(BaseHandler):
             raise Exception('Bad argument %s value %r' % (name, val))
         return val
 
+    def _validate_key(self):
+        pass
+
     def reset(self):
         """Called before starting to process a batch.
         Should clean any pending data."""
-        BaseHandler.reset(self)
+        super(Dispatcher, self).reset()
 
     def prepare_batch(self, batch_info, dst_curs):
         """Called on first event for this table in current batch."""
         if self.conf.table_mode != 'ignore':
             self.batch_info = batch_info
             self.dst_curs = dst_curs
-        #BaseHandler.prepare_batch(self, batch_info, dst_curs)
+        if self.key is not None:
+            super(Dispatcher, self).prepare_batch(batch_info, dst_curs)
 
     def filter_data(self, data):
         """Process with fields skip and map"""
@@ -763,7 +763,7 @@ class Dispatcher(BaseHandler):
             pkeys = [fmap[p] for p in pkeys if p in fmap]
         return pkeys
 
-    def process_event(self, ev, sql_queue_func, arg):
+    def _process_event(self, ev, sql_queue_func, arg):
         """Process a event.
         Event should be added to sql_queue or executed directly.
         """
@@ -798,13 +798,12 @@ class Dispatcher(BaseHandler):
             self.row_handler.add_table(dst, LOADERS[self.conf.load_mode],
                                     self.pkeys, self.conf)
         self.row_handler.process(dst, op, data)
-        #BaseHandler.process_event(self, ev, sql_queue_func, arg)
 
     def finish_batch(self, batch_info, dst_curs):
         """Called when batch finishes."""
         if self.conf.table_mode != 'ignore':
             self.row_handler.flush(dst_curs)
-        #BaseHandler.finish_batch(self, batch_info, dst_curs)
+        #super(Dispatcher, self).finish_batch(batch_info, dst_curs)
 
     def get_part_name(self):
         # if custom part name template given, use it
@@ -918,12 +917,19 @@ class Dispatcher(BaseHandler):
         if res:
             self.log.info("Dropped tables: %s", ", ".join(res))
 
+    def get_copy_condition(self, src_curs, dst_curs):
+        """ Prepare where condition for copy and replay filtering.
+        """
+        if self.key is not None:
+            return super(Dispatcher, self).get_copy_condition(src_curs, dst_curs)
+        return ''
+
     def real_copy(self, tablename, src_curs, dst_curs, column_list):
         """do actual table copy and return tuple with number of bytes and rows
         copied
         """
         _src_cols = _dst_cols = column_list
-        condition = ''
+        condition = self.get_copy_condition (src_curs, dst_curs)
 
         if self.conf.skip_fields:
             _src_cols = [col for col in column_list
@@ -940,7 +946,8 @@ class Dispatcher(BaseHandler):
         else:
             _write_hook = None
 
-        return skytools.full_copy(tablename, src_curs, dst_curs, _src_cols, condition,
+        return skytools.full_copy(tablename, src_curs, dst_curs,
+                                  _src_cols, condition,
                                   dst_tablename = self.dest_table,
                                   dst_column_list = _dst_cols,
                                   write_hook = _write_hook)
index 247256e467d1711b4f7305f041f1d8f20cdad70c..366675a3991acdb256db232398a336fd1a96c7d1 100644 (file)
@@ -39,8 +39,7 @@ class PartHandler(TableHandler):
 
         # primary key columns
         self.key = args.get('key')
-        if self.key is None:
-            raise Exception('Specify key field as key argument')
+        self._validate_key()
 
         # hash function & full expression
         hashfunc = args.get('hashfunc', self.DEFAULT_HASHFUNC)
@@ -49,6 +48,10 @@ class PartHandler(TableHandler):
                 skytools.quote_ident(self.key))
         self.hashexpr = args.get('hashexpr', self.hashexpr)
 
+    def _validate_key(self):
+        if self.key is None:
+            raise Exception('Specify key field as key argument')
+
     def reset(self):
         """Forget config info."""
         self.max_part = None
@@ -57,7 +60,6 @@ class PartHandler(TableHandler):
 
     def add(self, trigger_arg_list):
         """Let trigger put hash into extra3"""
-
         arg = "ev_extra3='hash='||%s" % self.hashexpr
         trigger_arg_list.append(arg)
         TableHandler.add(self, trigger_arg_list)
@@ -70,13 +72,16 @@ class PartHandler(TableHandler):
 
     def process_event(self, ev, sql_queue_func, arg):
         """Filter event by hash in extra3, apply only local part."""
-        if ev.extra3:
+        if ev.extra3 and self.key is not None:
             meta = skytools.db_urldecode(ev.extra3)
             self.log.debug('part.process_event: hash=%d, max_part=%s, local_part=%d',
                            int(meta['hash']), self.max_part, self.local_part)
             if (int(meta['hash']) & self.max_part) != self.local_part:
                 self.log.debug('part.process_event: not my event')
                 return
+        self._process_event(ev, sql_queue_func, arg)
+
+    def _process_event(self, ev, sql_queue_func, arg):
         self.log.debug('part.process_event: my event, processing')
         TableHandler.process_event(self, ev, sql_queue_func, arg)