table name and columns change support for full_copy
authorEgon Valdmees <egon.valdmees@skype.net>
Tue, 8 Feb 2011 13:20:11 +0000 (15:20 +0200)
committerMarko Kreen <markokr@gmail.com>
Wed, 11 May 2011 09:39:35 +0000 (12:39 +0300)
added support for destination table and columns with
different name than source

python/skytools/sqltools.py

index d7b6efbdeca2aeb730105e0aa4651c2e9485ee88..b8a179be9fa830cdea98fa685a3c639301ca4be2 100644 (file)
@@ -304,7 +304,7 @@ def magic_insert(curs, tablename, data, fields = None, use_insert = 0, quoted_ta
     if curs == None and use_insert == 0:
         fmt = "COPY %s (%s) FROM STDIN;\n"
         buf.write(fmt % (qtablename, ",".join(qfields)))
+
     # process data
     for row in data:
         buf.write(row_func(qtablename, row, fields, qfields))
@@ -377,19 +377,37 @@ class CopyPipe(object):
         self.buf.seek(0)
         self.buf.truncate()
 
-def full_copy(tablename, src_curs, dst_curs, column_list = [], condition = None):
+def full_copy(tablename, src_curs, dst_curs, column_list = [], condition = None,
+        dst_tablename = None, dst_column_list = None):
     """COPY table from one db to another."""
 
-    qtable = skytools.quote_fqident(tablename)
-    if column_list:
-        qfields = ",".join([skytools.quote_ident(f) for f in column_list])
-        src = dst = "%s (%s)" % (qtable, qfields)
-    else:
-        qfields = '*'
-        src = dst = qtable
+    # default dst table and dst columns to source ones
+    dst_tablename = dst_tablename or tablename
+    dst_column_list = dst_column_list or column_list[:]
+    if len(dst_column_list) != len(column_list):
+        raise Exception('src and dst column lists must match in length')
+
+    def build_qfields(cols):
+        if cols:
+            return ",".join([skytools.quote_ident(f) for f in cols])
+        else:
+            return "*"
 
+    def build_statement(table, cols):
+        qtable = skytools.quote_fqident(table)
+        if cols:
+            qfields = build_qfields(cols)
+            return "%s (%s)" % (qtable, qfields)
+        else:
+            return qtable
+
+    dst = build_statement(dst_tablename, dst_column_list)
     if condition:
-        src = "(SELECT %s FROM %s WHERE %s)" % (qfields, qtable, condition)
+        src = "(SELECT %s FROM %s WHERE %s)" % (build_qfields(cols),
+                                                skytools.quote_fqident(tablename),
+                                                condition)
+    else:
+        src = build_statement(tablename, column_list)
 
     if hasattr(src_curs, 'copy_expert'):
         sql_to = "COPY %s TO stdout" % src