skytools: support for psycopg2
authorMarko Kreen <markokr@gmail.com>
Mon, 16 Jul 2007 14:07:35 +0000 (14:07 +0000)
committerMarko Kreen <markokr@gmail.com>
Mon, 16 Jul 2007 14:07:35 +0000 (14:07 +0000)
python/skytools/__init__.py
python/skytools/dbstruct.py
python/skytools/quoting.py
python/skytools/scripting.py
python/skytools/skylog.py
python/skytools/sqltools.py

index ed2b39bcc9fbd009bb109ebb6b57c4239ac21d93..f1ef7197cf7ca1d8927d5e7313cb1eb29487cdfa 100644 (file)
@@ -4,7 +4,7 @@
 from config import *
 from dbstruct import *
 from gzlog import *
-from quoting import *
 from scripting import *
 from sqltools import *
+from quoting import *
 
index 223334299a3a5b4ca4b9851a82ed776c81924bd2..ece0ea5454956e94523e0f6d9f18c98c841153f3 100644 (file)
@@ -364,8 +364,8 @@ class TableStruct(object):
         return res
 
 def test():
-    import psycopg
-    db = psycopg.connect("dbname=fooz")
+    from skytools import connect_database
+    db = connect_database("dbname=fooz")
     curs = db.cursor()
     
     s = TableStruct(curs, "public.data1")
index 0ce500803624f4b4a9f7d70c9139468f89cd6215..7df56fee5068756033c0fbb7a843804306a26216 100644 (file)
@@ -2,7 +2,12 @@
 
 """Various helpers for string quoting/unquoting."""
 
-import psycopg, urllib, re
+import urllib, re
+
+try:
+    from psycopg2.extensions import QuotedString
+except:
+    from psycopg import QuotedString
 
 __all__ = [
     "quote_literal", "quote_copy", "quote_bytea_raw",
@@ -23,7 +28,7 @@ def quote_literal(s):
 
     if s == None:
         return "null"
-    s = psycopg.QuotedString(str(s))
+    s = QuotedString(str(s))
     return str(s)
 
 def quote_copy(s):
index 456fff69fbb08562b674df1ea895baae9b2381ab..136c89ef41583e434b6bc6e33c63197dc9e006bd 100644 (file)
@@ -1,10 +1,11 @@
 
 """Useful functions and classes for database scripts."""
 
-import sys, os, signal, psycopg, optparse, traceback, time
+import sys, os, signal, optparse, traceback, time
 import logging, logging.handlers, logging.config
 
 from skytools.config import *
+from skytools.sqltools import connect_database
 import skytools.skylog
 
 __all__ = ['daemonize', 'run_single_process', 'DBScript',
@@ -178,7 +179,7 @@ class DBCachedConn(object):
         # new conn?
         if not self.conn:
             self.isolation_level = isolation_level
-            self.conn = psycopg.connect(self.loc)
+            self.conn = connect_database(self.loc)
 
             self.conn.set_isolation_level(isolation_level)
             self.conn_time = time.time()
index 2f6344aedf8c121c83da8ccbc94b9a4ccb5f7028..bde6f1534a766b58337db030eb5ac43057c11c3f 100644 (file)
@@ -1,10 +1,12 @@
 """Our log handlers for Python's logging package.
 """
 
-import sys, os, time, socket, psycopg
+import sys, os, time, socket
 import logging, logging.handlers
 
-from quoting import quote_json
+from skytools.quoting import quote_json
+from skytools.sqltools import connect_database
+
 
 # configurable file logger
 class EasyRotatingFileHandler(logging.handlers.RotatingFileHandler):
@@ -93,9 +95,9 @@ class LogDBHandler(logging.handlers.SocketHandler):
 
     def makeSocket(self):
         """Create server connection.
-        In this case its not socket but psycopg conection."""
+        In this case its not socket but database connection."""
 
-        db = psycopg.connect(self.connect_string)
+        db = connect_database(self.connect_string)
         db.autocommit(1)
         return db
 
index f9144031d7800b535f15db86ee0656dccddfaacb..e0d53bf89d635ed9b4572cc864e38c3bb7264cb4 100644 (file)
@@ -12,9 +12,33 @@ __all__ = [
     "exists_function", "exists_language", "Snapshot", "magic_insert",
     "db_copy_from_dict", "db_copy_from_list", "CopyPipe", "full_copy",
     "DBObject", "DBSchema", "DBTable", "DBFunction", "DBLanguage",
-    "db_install"
+    "db_install", "connect_database"
 ]
 
+
+try:
+    ##from psycopg2.psycopg1 import connect as _pgconnect
+    # psycopg2.psycopg1.cursor is too backwards compatible,
+    # to the point of avoiding optimized access.
+
+    ## only backwards compat thing we need is dict* methods
+    import psycopg2.extensions, psycopg2.extras
+    class _CompatCursor(psycopg2.extras.DictCursor):
+        """Regular psycopg2 DictCursor with dict* methods."""
+        dictfetchone = psycopg2.extras.DictCursor.fetchone
+        dictfetchall = psycopg2.extras.DictCursor.fetchall
+        dictfetchmany = psycopg2.extras.DictCursor.fetchmany
+    class _CompatConnection(psycopg2.extensions.connection):
+        """Connection object that uses _CompatCursor."""
+        def cursor(self):
+            return psycopg2.extensions.connection.cursor(self, cursor_factory = _CompatCursor)
+    def _pgconnect(cstr):
+        """Create a psycopg2 connection."""
+        return _CompatConnection(cstr)
+except ImportError:
+    # use psycopg 1
+    from psycopg import connect as _pgconnect
+
 #
 # Fully qualified table name
 #
@@ -412,3 +436,16 @@ def db_install(curs, list, log = None):
             if log:
                 log.info('%s is installed' % obj.name)
 
+def connect_database(connstr):
+    """Create a db connection with connect_timeout option.
+    
+    Default connect_timeout is 15, to change put it directly into dsn.
+    """
+
+    # allow override
+    if connstr.find("connect_timeout") < 0:
+        connstr += " connect_timeout=15"
+
+    # create connection
+    return _pgconnect(connstr)
+