Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit aa74a22

Browse files
authored
Merge pull request #253 from datafold/refactor_tests_oct2022
Refactor tests oct2022
2 parents 6ab6f5f + 1fb79da commit aa74a22

17 files changed

+662
-655
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ jobs:
4646
env:
4747
DATADIFF_SNOWFLAKE_URI: '${{ secrets.DATADIFF_SNOWFLAKE_URI }}'
4848
DATADIFF_PRESTO_URI: '${{ secrets.DATADIFF_PRESTO_URI }}'
49-
DATADIFF_TRINO_URI: '${{ secrets.DATADIFF_TRINO_URI }}'
5049
DATADIFF_CLICKHOUSE_URI: 'clickhouse://clickhouse:Password1@localhost:9000/clickhouse'
5150
DATADIFF_VERTICA_URI: 'vertica://vertica:Password1@localhost:5433/vertica'
5251
run: |

data_diff/databases/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import datetime
12
import math
23
import sys
34
import logging
@@ -120,6 +121,10 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list):
120121
compiler = Compiler(self)
121122
if isinstance(sql_ast, Generator):
122123
sql_code = ThreadLocalInterpreter(compiler, sql_ast)
124+
elif isinstance(sql_ast, list):
125+
for i in sql_ast[:-1]:
126+
self.query(i)
127+
return self.query(sql_ast[-1], res_type)
123128
else:
124129
sql_code = compiler.compile(sql_ast)
125130
if sql_code is SKIP:
@@ -249,7 +254,7 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe
249254
if not text_columns:
250255
return
251256

252-
fields = [self.normalize_uuid(c, String_UUID()) for c in text_columns]
257+
fields = [self.normalize_uuid(self.quote(c), String_UUID()) for c in text_columns]
253258
samples_by_row = self.query(table(*table_path).select(*fields).where(where or SKIP).limit(sample_size), list)
254259
if not samples_by_row:
255260
raise ValueError(f"Table {table_path} is empty.")
@@ -329,6 +334,7 @@ def type_repr(self, t) -> str:
329334
str: "VARCHAR",
330335
bool: "BOOLEAN",
331336
float: "FLOAT",
337+
datetime: "TIMESTAMP",
332338
}[t]
333339

334340
def _query_cursor(self, c, sql_code: str):

data_diff/databases/bigquery.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,6 @@ def is_autocommit(self) -> bool:
109109

110110
def type_repr(self, t) -> str:
111111
try:
112-
return {
113-
str: "STRING",
114-
}[t]
112+
return {str: "STRING", float: "FLOAT64"}[t]
115113
except KeyError:
116114
return super().type_repr(t)

data_diff/databases/clickhouse.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,7 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:
150150
)
151151
"""
152152
return value
153+
154+
@property
155+
def is_autocommit(self) -> bool:
156+
return True

data_diff/databases/databricks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,7 @@ def parse_table_name(self, name: str) -> DbPath:
151151

152152
def close(self):
153153
self._conn.close()
154+
155+
@property
156+
def is_autocommit(self) -> bool:
157+
return True

data_diff/databases/presto.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,9 @@ def is_autocommit(self) -> bool:
153153

154154
def explain_as_text(self, query: str) -> str:
155155
return f"EXPLAIN (FORMAT TEXT) {query}"
156+
157+
def type_repr(self, t) -> str:
158+
try:
159+
return {float: "REAL"}[t]
160+
except KeyError:
161+
return super().type_repr(t)

data_diff/joindiff_tables.py

Lines changed: 5 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
from runtype import dataclass
1212

13-
from data_diff.databases.database_types import DbPath, NumericType
14-
from data_diff.databases.base import QueryError
13+
from .databases.database_types import DbPath, NumericType
14+
from .query_utils import append_to_table, drop_table
1515

1616

1717
from .utils import safezip
@@ -48,7 +48,7 @@ def sample(table_expr):
4848
return table_expr.order_by(Random()).limit(10)
4949

5050

51-
def create_temp_table(c: Compiler, path: TablePath, expr: Expr):
51+
def create_temp_table(c: Compiler, path: TablePath, expr: Expr) -> str:
5252
db = c.database
5353
if isinstance(db, BigQuery):
5454
return f"create table {c.compile(path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}"
@@ -60,42 +60,6 @@ def create_temp_table(c: Compiler, path: TablePath, expr: Expr):
6060
return f"create temporary table {c.compile(path)} as {c.compile(expr)}"
6161

6262

63-
def drop_table_oracle(name: DbPath):
64-
t = table(name)
65-
# Experience shows double drop is necessary
66-
with suppress(QueryError):
67-
yield t.drop()
68-
yield t.drop()
69-
yield commit
70-
71-
72-
def drop_table(name: DbPath):
73-
t = table(name)
74-
yield t.drop(if_exists=True)
75-
yield commit
76-
77-
78-
def append_to_table_oracle(path: DbPath, expr: Expr):
79-
"""See append_to_table"""
80-
assert expr.schema, expr
81-
t = table(path, schema=expr.schema)
82-
with suppress(QueryError):
83-
yield t.create() # uses expr.schema
84-
yield commit
85-
yield t.insert_expr(expr)
86-
yield commit
87-
88-
89-
def append_to_table(path: DbPath, expr: Expr):
90-
"""Append to table"""
91-
assert expr.schema, expr
92-
t = table(path, schema=expr.schema)
93-
yield t.create(if_not_exists=True) # uses expr.schema
94-
yield commit
95-
yield t.insert_expr(expr)
96-
yield commit
97-
98-
9963
def bool_to_int(x):
10064
return if_(x, 1, 0)
10165

@@ -170,10 +134,7 @@ def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult
170134

171135
bg_funcs = [partial(self._test_duplicate_keys, table1, table2)] if self.validate_unique_key else []
172136
if self.materialize_to_table:
173-
if isinstance(db, Oracle):
174-
db.query(drop_table_oracle(self.materialize_to_table))
175-
else:
176-
db.query(drop_table(self.materialize_to_table))
137+
drop_table(db, self.materialize_to_table)
177138

178139
with self._run_in_background(*bg_funcs):
179140

@@ -348,6 +309,5 @@ def exclusive_rows(expr):
348309
def _materialize_diff(self, db, diff_rows, segment_index=None):
349310
assert self.materialize_to_table
350311

351-
f = append_to_table_oracle if isinstance(db, Oracle) else append_to_table
352-
db.query(f(self.materialize_to_table, diff_rows.limit(self.write_limit)))
312+
append_to_table(db, self.materialize_to_table, diff_rows.limit(self.write_limit))
353313
logger.info("Materialized diff to table '%s'.", ".".join(self.materialize_to_table))

data_diff/queries/api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from typing import Optional
2+
3+
from data_diff.utils import CaseAwareMapping, CaseSensitiveDict
24
from .ast_classes import *
35
from .base import args_as_tuple
46

@@ -30,11 +32,14 @@ def cte(expr: Expr, *, name: Optional[str] = None, params: Sequence[str] = None)
3032
return Cte(expr, name, params)
3133

3234

33-
def table(*path: str, schema: Schema = None) -> TablePath:
35+
def table(*path: str, schema: Union[dict, CaseAwareMapping] = None) -> TablePath:
3436
if len(path) == 1 and isinstance(path[0], tuple):
3537
(path,) = path
3638
if not all(isinstance(i, str) for i in path):
3739
raise TypeError(f"All elements of table path must be of type 'str'. Got: {path}")
40+
if schema and not isinstance(schema, CaseAwareMapping):
41+
assert isinstance(schema, dict)
42+
schema = CaseSensitiveDict(schema)
3843
return TablePath(path, schema)
3944

4045

data_diff/queries/ast_classes.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import field
22
from datetime import datetime
3-
from typing import Any, Generator, Optional, Sequence, Tuple, Union
3+
from typing import Any, Generator, List, Optional, Sequence, Tuple, Union
4+
from uuid import UUID
45

56
from runtype import dataclass
67

@@ -298,18 +299,29 @@ class TablePath(ExprNode, ITable):
298299
path: DbPath
299300
schema: Optional[Schema] = field(default=None, repr=False)
300301

301-
def create(self, if_not_exists=False):
302-
if not self.schema:
303-
raise ValueError("Schema must have a value to create table")
304-
return CreateTable(self, if_not_exists=if_not_exists)
302+
def create(self, source_table: ITable = None, *, if_not_exists=False):
303+
if source_table is None and not self.schema:
304+
raise ValueError("Either schema or source table needed to create table")
305+
if isinstance(source_table, TablePath):
306+
source_table = source_table.select()
307+
return CreateTable(self, source_table, if_not_exists=if_not_exists)
305308

306309
def drop(self, if_exists=False):
307310
return DropTable(self, if_exists=if_exists)
308311

309-
def insert_values(self, rows):
310-
raise NotImplementedError()
312+
def truncate(self):
313+
return TruncateTable(self)
314+
315+
def insert_rows(self, rows, *, columns=None):
316+
rows = list(rows)
317+
return InsertToTable(self, ConstantTable(rows), columns=columns)
318+
319+
def insert_row(self, *values, columns=None):
320+
return InsertToTable(self, ConstantTable([values]), columns=columns)
311321

312322
def insert_expr(self, expr: Expr):
323+
if isinstance(expr, TablePath):
324+
expr = expr.select()
313325
return InsertToTable(self, expr)
314326

315327
@property
@@ -592,6 +604,29 @@ def compile(self, c: Compiler) -> str:
592604
return c.database.random()
593605

594606

607+
@dataclass
608+
class ConstantTable(ExprNode):
609+
rows: Sequence[Sequence]
610+
611+
def compile(self, c: Compiler) -> str:
612+
raise NotImplementedError()
613+
614+
def _value(self, v):
615+
if v is None:
616+
return "NULL"
617+
elif isinstance(v, str):
618+
return f"'{v}'"
619+
elif isinstance(v, datetime):
620+
return f"timestamp '{v}'"
621+
elif isinstance(v, UUID):
622+
return f"'{v}'"
623+
return repr(v)
624+
625+
def compile_for_insert(self, c: Compiler):
626+
values = ", ".join("(%s)" % ", ".join(self._value(v) for v in row) for row in self.rows)
627+
return f"VALUES {values}"
628+
629+
595630
@dataclass
596631
class Explain(ExprNode):
597632
select: Select
@@ -610,11 +645,15 @@ class Statement(Compilable):
610645
@dataclass
611646
class CreateTable(Statement):
612647
path: TablePath
648+
source_table: Expr = None
613649
if_not_exists: bool = False
614650

615651
def compile(self, c: Compiler) -> str:
616-
schema = ", ".join(f"{k} {c.database.type_repr(v)}" for k, v in self.path.schema.items())
617652
ne = "IF NOT EXISTS " if self.if_not_exists else ""
653+
if self.source_table:
654+
return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}"
655+
656+
schema = ", ".join(f"{c.database.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items())
618657
return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})"
619658

620659

@@ -628,14 +667,30 @@ def compile(self, c: Compiler) -> str:
628667
return f"DROP TABLE {ie}{c.compile(self.path)}"
629668

630669

670+
@dataclass
671+
class TruncateTable(Statement):
672+
path: TablePath
673+
674+
def compile(self, c: Compiler) -> str:
675+
return f"TRUNCATE TABLE {c.compile(self.path)}"
676+
677+
631678
@dataclass
632679
class InsertToTable(Statement):
633680
# TODO Support insert for only some columns
634681
path: TablePath
635682
expr: Expr
683+
columns: List[str] = None
636684

637685
def compile(self, c: Compiler) -> str:
638-
return f"INSERT INTO {c.compile(self.path)} {c.compile(self.expr)}"
686+
if isinstance(self.expr, ConstantTable):
687+
expr = self.expr.compile_for_insert(c)
688+
else:
689+
expr = c.compile(self.expr)
690+
691+
columns = f"(%s)" % ", ".join(map(c.quote, self.columns)) if self.columns is not None else ""
692+
693+
return f"INSERT INTO {c.compile(self.path)}{columns} {expr}"
639694

640695

641696
@dataclass

data_diff/query_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"Module for query utilities that didn't make it into the query-builder (yet)"
2+
3+
from contextlib import suppress
4+
5+
from data_diff.databases.database_types import DbPath
6+
from data_diff.databases.base import QueryError
7+
8+
from .databases import Oracle
9+
from .queries import table, commit, Expr
10+
11+
def _drop_table_oracle(name: DbPath):
12+
t = table(name)
13+
# Experience shows double drop is necessary
14+
with suppress(QueryError):
15+
yield t.drop()
16+
yield t.drop()
17+
yield commit
18+
19+
20+
def _drop_table(name: DbPath):
21+
t = table(name)
22+
yield t.drop(if_exists=True)
23+
yield commit
24+
25+
26+
def drop_table(db, tbl):
27+
if isinstance(db, Oracle):
28+
db.query(_drop_table_oracle(tbl))
29+
else:
30+
db.query(_drop_table(tbl))
31+
32+
33+
def _append_to_table_oracle(path: DbPath, expr: Expr):
34+
"""See append_to_table"""
35+
assert expr.schema, expr
36+
t = table(path, schema=expr.schema)
37+
with suppress(QueryError):
38+
yield t.create() # uses expr.schema
39+
yield commit
40+
yield t.insert_expr(expr)
41+
yield commit
42+
43+
44+
def _append_to_table(path: DbPath, expr: Expr):
45+
"""Append to table"""
46+
assert expr.schema, expr
47+
t = table(path, schema=expr.schema)
48+
yield t.create(if_not_exists=True) # uses expr.schema
49+
yield commit
50+
yield t.insert_expr(expr)
51+
yield commit
52+
53+
def append_to_table(db, path, expr):
54+
f = _append_to_table_oracle if isinstance(db, Oracle) else _append_to_table
55+
db.query(f(path, expr))

0 commit comments

Comments
 (0)