Skip to content
21 changes: 20 additions & 1 deletion pyhive/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import getpass
import logging
import requests
import datetime
from requests.auth import HTTPBasicAuth

try: # Python 3
Expand All @@ -31,7 +32,25 @@
paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s

_logger = logging.getLogger(__name__)
_escaper = common.ParamEscaper()


class PrestoParamEscaper(common.ParamEscaper):
def escape_item(self, item):
if isinstance(item, datetime.datetime):
return self.escape_datetime(item)
elif isinstance(item, datetime.date):

Check warning on line 41 in pyhive/presto.py

View check run for this annotation

Codecov / codecov/patch

pyhive/presto.py#L41

Added line #L41 was not covered by tests
return self.escape_date(item)
else:

Check warning on line 43 in pyhive/presto.py

View check run for this annotation

Codecov / codecov/patch

pyhive/presto.py#L43

Added line #L43 was not covered by tests
return super(PrestoParamEscaper, self).escape_item(item)

def escape_date(self, item):
return "date '{}'".format(item)

Check warning on line 48 in pyhive/presto.py

View check run for this annotation

Codecov / codecov/patch

pyhive/presto.py#L48

Added line #L48 was not covered by tests
def escape_datetime(self, item):
return "timestamp '{}'".format(item)

Check warning on line 51 in pyhive/presto.py

View check run for this annotation

Codecov / codecov/patch

pyhive/presto.py#L51

Added line #L51 was not covered by tests

_escaper = PrestoParamEscaper()


def connect(*args, **kwargs):
Expand Down
33 changes: 33 additions & 0 deletions pyhive/sqlalchemy_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
from sqlalchemy import exc
from sqlalchemy import types
from sqlalchemy import util

# TODO shouldn't use mysql type
from sqlalchemy.databases import mysql
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql.compiler import SQLCompiler
from sqlalchemy.sql.expression import Alias

from pyhive import presto
from pyhive.common import UniversalSet
Expand Down Expand Up @@ -46,6 +48,37 @@
def visit_char_length_func(self, fn, **kw):
return 'length{}'.format(self.function_argspec(fn, **kw))

def visit_column(self, column, add_to_result_map=None, include_table=True, **kwargs):
sql = super(PrestoCompiler, self).visit_column(
column, add_to_result_map, include_table, **kwargs
)
table = column.table
return self.__add_catalog(sql, table)

def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
fromhints=None, use_schema=True, **kwargs):
sql = super(PrestoCompiler, self).visit_table(
table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs
)
return self.__add_catalog(sql, table)

def __add_catalog(self, sql, table):
if table is None:
return sql

Check warning on line 67 in pyhive/sqlalchemy_presto.py

View check run for this annotation

Codecov / codecov/patch

pyhive/sqlalchemy_presto.py#L67

Added line #L67 was not covered by tests

if isinstance(table, Alias):
return sql

Check warning on line 70 in pyhive/sqlalchemy_presto.py

View check run for this annotation

Codecov / codecov/patch

pyhive/sqlalchemy_presto.py#L70

Added line #L70 was not covered by tests

if (
"presto" not in table.dialect_options
or "catalog" not in table.dialect_options["presto"]._non_defaults
):
return sql

catalog = table.dialect_options["presto"]._non_defaults["catalog"]
sql = "\"{catalog}\".{sql}".format(catalog=catalog, sql=sql)
return sql

Check warning on line 80 in pyhive/sqlalchemy_presto.py

View check run for this annotation

Codecov / codecov/patch

pyhive/sqlalchemy_presto.py#L78-L80

Added lines #L78 - L80 were not covered by tests


class PrestoTypeCompiler(compiler.GenericTypeCompiler):
def visit_CLOB(self, type_, **kw):
Expand Down