From 75ed605bb2125dec97f04aae0e68718e07c0d7d2 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 31 Aug 2022 17:29:37 +0200 Subject: [PATCH 01/93] Query builder package (still incomplete) --- data_diff/databases/base.py | 3 + data_diff/databases/database_types.py | 26 +- data_diff/queries/__init__.py | 4 + data_diff/queries/api.py | 58 +++ data_diff/queries/ast_classes.py | 493 ++++++++++++++++++++++++++ data_diff/queries/base.py | 18 + data_diff/queries/compiler.py | 60 ++++ data_diff/queries/extras.py | 61 ++++ data_diff/sql.py | 2 - tests/test_query.py | 130 +++++++ tests/test_sql.py | 1 - 11 files changed, 843 insertions(+), 13 deletions(-) create mode 100644 data_diff/queries/__init__.py create mode 100644 data_diff/queries/api.py create mode 100644 data_diff/queries/ast_classes.py create mode 100644 data_diff/queries/base.py create mode 100644 data_diff/queries/compiler.py create mode 100644 data_diff/queries/extras.py create mode 100644 tests/test_query.py diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index b114937a..fd6ec2c0 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -287,6 +287,9 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: return f"TRIM({value})" return self.to_string(value) + def random(self) -> str: + return "RANDOM()" + class ThreadedDatabase(Database): """Access the database through singleton threads. diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index e93e380e..ca2734fc 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -140,7 +140,7 @@ class UnknownColType(ColType): supported = False -class AbstractDatabase(ABC): +class AbstractDialect(ABC): name: str @abstractmethod @@ -148,11 +148,6 @@ def quote(self, s: str): "Quote SQL name (implementation specific)" ... - @abstractmethod - def to_string(self, s: str) -> str: - "Provide SQL for casting a column to string" - ... - @abstractmethod def concat(self, l: List[str]) -> str: "Provide SQL for concatenating a bunch of column into a string" @@ -163,6 +158,21 @@ def is_distinct_from(self, a: str, b: str) -> str: "Provide SQL for a comparison where NULL = NULL is true" ... + @abstractmethod + def to_string(self, s: str) -> str: + "Provide SQL for casting a column to string" + ... + + @abstractmethod + def random(self) -> str: + "Provide SQL for generating a random number" + + @abstractmethod + def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): + "Provide SQL fragment for limit and offset inside a select" + ... + +class AbstractDatabase(AbstractDialect): @abstractmethod def timestamp_value(self, t: DbTime) -> str: "Provide SQL for the given timestamp value" @@ -173,10 +183,6 @@ def md5_to_int(self, s: str) -> str: "Provide SQL for computing md5 and returning an int" ... - @abstractmethod - def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): - "Provide SQL fragment for limit and offset inside a select" - ... @abstractmethod def _query(self, sql_code: str) -> list: diff --git a/data_diff/queries/__init__.py b/data_diff/queries/__init__.py new file mode 100644 index 00000000..93299b26 --- /dev/null +++ b/data_diff/queries/__init__.py @@ -0,0 +1,4 @@ +from .compiler import Compiler +from .api import this, join, outerjoin, table, SKIP, sum_, avg, min_, max_, cte +from .ast_classes import Expr, ExprNode, Select, Count, BinOp, Explain, In +from .extras import Checksum, NormalizeAsString, ApplyFuncAndNormalizeAsString diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py new file mode 100644 index 00000000..7c617af4 --- /dev/null +++ b/data_diff/queries/api.py @@ -0,0 +1,58 @@ +from typing import Optional +from .ast_classes import * +from .base import args_as_tuple + + +this = This() + + +def join(*tables: ITable): + "Joins each table into a 'struct'" + return Join(tables) + + +def outerjoin(*tables: ITable): + "Outerjoins each table into a 'struct'" + return Join(tables, "FULL OUTER") + + +def cte(expr: Expr, *, name: Optional[str] = None, params: Sequence[str] = None): + return Cte(expr, name, params) + + +def table(*path: str, schema: Schema = None) -> ITable: + assert all(isinstance(i, str) for i in path), path + return TablePath(path, schema) + + +def or_(*exprs: Expr): + exprs = args_as_tuple(exprs) + if len(exprs) == 1: + return exprs[0] + return BinOp("OR", exprs) + +def and_(*exprs: Expr): + exprs = args_as_tuple(exprs) + if len(exprs) == 1: + return exprs[0] + return BinOp("AND", exprs) + + +def sum_(expr: Expr): + return Func("sum", [expr]) + + +def avg(expr: Expr): + return Func("avg", [expr]) + + +def min_(expr: Expr): + return Func("min", [expr]) + + +def max_(expr: Expr): + return Func("max", [expr]) + + +def if_(cond: Expr, then: Expr, else_: Optional[Expr] = None): + return CaseWhen([(cond, then)], else_=else_) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py new file mode 100644 index 00000000..a5c1008c --- /dev/null +++ b/data_diff/queries/ast_classes.py @@ -0,0 +1,493 @@ +from datetime import datetime +from typing import Any, Generator, Sequence, Tuple, Union + +from runtype import dataclass + +from data_diff.utils import ArithString, join_iter + +from .compiler import Compilable, Compiler +from .base import SKIP, CompileError, DbPath, Schema, args_as_tuple + + +class ExprNode(Compilable): + type: Any = None + + def _dfs_values(self): + yield self + for k, vs in dict(self).items(): # __dict__ provided by runtype.dataclass + if k == "source_table": + # Skip data-sources, we're only interested in data-parameters + continue + if not isinstance(vs, (list, tuple)): + vs = [vs] + for v in vs: + if isinstance(v, ExprNode): + yield from v._dfs_values() + + def cast_to(self, to): + return Cast(self, to) + + +Expr = Union[ExprNode, str, bool, int, datetime, ArithString, None] + + +@dataclass +class Alias(ExprNode): + expr: Expr + name: str + + def compile(self, c: Compiler) -> str: + return f"{c.compile(self.expr)} AS {c.quote(self.name)}" + + +def _drop_skips(exprs): + return [e for e in exprs if e is not SKIP] + + +def _drop_skips_dict(exprs_dict): + return {k: v for k, v in exprs_dict.items() if v is not SKIP} + + +class ITable: + source_table: Any + schema: Schema = None + + def select(self, *exprs, **named_exprs): + exprs = args_as_tuple(exprs) + exprs = _drop_skips(exprs) + named_exprs = _drop_skips_dict(named_exprs) + exprs += _named_exprs_as_aliases(named_exprs) + resolve_names(self.source_table, exprs) + return Select.make(self, columns=exprs) + + def where(self, *exprs): + exprs = args_as_tuple(exprs) + exprs = _drop_skips(exprs) + if not exprs: + return self + + resolve_names(self.source_table, exprs) + return Select.make(self, where_exprs=exprs, _concat=True) + + def at(self, *exprs): + # TODO + exprs = _drop_skips(exprs) + if not exprs: + return self + + raise NotImplementedError() + + def join(self, target): + return Join(self, target) + + def group_by(self, *, keys=None, values=None): + # TODO + assert keys or values + raise NotImplementedError() + + def with_schema(self): + # TODO + raise NotImplementedError() + + def _get_column(self, name: str): + if self.schema: + name = self.schema.get_key(name) # Get the actual name. Might be case-insensitive. + return Column(self, name) + + # def __getattr__(self, column): + # return self._get_column(column) + + def __getitem__(self, column): + if not isinstance(column, str): + raise TypeError() + return self._get_column(column) + + def count(self): + return Select(self, [Count()]) + + +@dataclass +class Concat(ExprNode): + args: list + sep: str = None + + def compile(self, c: Compiler) -> str: + # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL + items = [f"coalesce({c.compile(c.database.to_string(expr))}, '')" for expr in self.exprs] + assert items + if len(items) == 1: + return items[0] + + if self.sep: + items = list(join_iter(f"'{self.sep}'", items)) + return c.database.concat(items) + +@dataclass +class Count(ExprNode): + expr: Expr = '*' + distinct: bool = False + + def compile(self, c: Compiler) -> str: + expr = c.compile(self.expr) + if self.distinct: + return f"count(distinct {expr})" + + return f"count({expr})" + + +@dataclass +class Func(ExprNode): + name: str + args: Sequence[Expr] + + def compile(self, c: Compiler) -> str: + args = ", ".join(c.compile(e) for e in self.args) + return f"{self.name}({args})" + + +@dataclass +class CaseWhen(ExprNode): + cases: Sequence[Tuple[Expr, Expr]] + else_: Expr = None + + def compile(self, c: Compiler) -> str: + assert self.cases + when_thens = " ".join(f"WHEN {c.compile(when)} THEN {c.compile(then)}" for when, then in self.cases) + else_ = (" " + c.compile(self.else_)) if self.else_ else "" + return f"CASE {when_thens}{else_} END" + + +class LazyOps: + def __add__(self, other): + return BinOp("+", [self, other]) + + def __gt__(self, other): + return BinOp(">", [self, other]) + + def __ge__(self, other): + return BinOp(">=", [self, other]) + + def __eq__(self, other): + if other is None: + return BinOp("IS", [self, None]) + return BinOp("=", [self, other]) + + def __lt__(self, other): + return BinOp("<", [self, other]) + + def __le__(self, other): + return BinOp("<=", [self, other]) + + def __or__(self, other): + return BinOp("OR", [self, other]) + + def is_distinct_from(self, other): + return IsDistinctFrom(self, other) + + def sum(self): + return Func("SUM", [self]) + + +@dataclass(eq=False, order=False) +class IsDistinctFrom(ExprNode, LazyOps): + a: Expr + b: Expr + + def compile(self, c: Compiler) -> str: + return c.database.is_distinct_from(c.compile(self.a), c.compile(self.b)) + + +@dataclass(eq=False, order=False) +class BinOp(ExprNode, LazyOps): + op: str + args: Sequence[Expr] + + def __post_init__(self): + assert len(self.args) == 2, self.args + + def compile(self, c: Compiler) -> str: + a, b = self.args + return f"({c.compile(a)} {self.op} {c.compile(b)})" + + +@dataclass(eq=False, order=False) +class Column(ExprNode, LazyOps): + source_table: ITable + name: str + + @property + def type(self): + if self.source_table.schema is None: + raise RuntimeError(f"Schema required for table {self.source_table}") + return self.source_table.schema[self.name] + + def compile(self, c: Compiler) -> str: + if c._table_context: + if len(c._table_context) > 1: + aliases = [ + t for t in c._table_context if isinstance(t, TableAlias) and t.source_table is self.source_table + ] + if not aliases: + raise CompileError(f"No aliased table found for column {self.name}") # TODO better error + elif len(aliases) > 1: + raise CompileError(f"Too many aliases for column {self.name}") + (alias,) = aliases + + return f"{c.quote(alias.name)}.{c.quote(self.name)}" + + return c.quote(self.name) + + +@dataclass +class TablePath(ExprNode, ITable): + path: DbPath + schema: Schema = None + + def insert_values(self, rows): + pass + + def insert_query(self, query): + pass + + @property + def source_table(self): + return self + + def compile(self, c: Compiler) -> str: + path = self.path # c.database._normalize_table_path(self.name) + return ".".join(map(c.quote, path)) + + +@dataclass +class TableAlias(ExprNode, ITable): + source_table: ITable + name: str + + def compile(self, c: Compiler) -> str: + return f"{c.compile(self.source_table)} {c.quote(self.name)}" + + +@dataclass +class Join(ExprNode, ITable): + source_tables: Sequence[ITable] + op: str = None + on_exprs: Sequence[Expr] = None + columns: Sequence[Expr] = None + + @property + def source_table(self): + return self # TODO is this right? + + @property + def schema(self): + # TODO combine both tables + return None + + def on(self, *exprs): + if len(exprs) == 1: + (e,) = exprs + if isinstance(e, Generator): + exprs = tuple(e) + + exprs = _drop_skips(exprs) + if not exprs: + return self + + return self.replace(on_exprs=(self.on_exprs or []) + exprs) + + def select(self, *exprs, **named_exprs): + if self.columns is not None: + # join-select already applied + return super().select(*exprs, **named_exprs) + + exprs = _drop_skips(exprs) + named_exprs = _drop_skips_dict(named_exprs) + exprs += _named_exprs_as_aliases(named_exprs) + # resolve_names(self.source_table, exprs) + # TODO Ensure exprs <= self.columns ? + return self.replace(columns=exprs) + + def compile(self, parent_c: Compiler) -> str: + tables = [ + t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in self.source_tables + ] + c = parent_c.add_table_context(*tables) + op = " JOIN " if self.op is None else f" {self.op} JOIN " + joined = op.join(c.compile(t) for t in tables) + + if self.on_exprs: + on = " AND ".join(c.compile(e) for e in self.on_exprs) + res = f"{joined} ON {on}" + else: + res = joined + + columns = "*" if self.columns is None else ", ".join(map(c.compile, self.columns)) + select = f"SELECT {columns} FROM {res}" + + if parent_c.in_select: + select = f"({select}) {c.new_unique_name()}" + return select + + +class GroupBy(ITable): + def having(self): + pass + + +@dataclass +class Select(ExprNode, ITable): + table: Expr = None + columns: Sequence[Expr] = None + where_exprs: Sequence[Expr] = None + order_by_exprs: Sequence[Expr] = None + group_by_exprs: Sequence[Expr] = None + limit_expr: int = None + + @property + def source_table(self): + return self + + @property + def schema(self): + return self.table.schema + + def compile(self, parent_c: Compiler) -> str: + c = parent_c.replace(in_select=True).add_table_context(self.table) + + columns = ", ".join(map(c.compile, self.columns)) if self.columns else "*" + select = f"SELECT {columns}" + + if self.table: + select += " FROM " + c.compile(self.table) + + if self.where_exprs: + select += " WHERE " + " AND ".join(map(c.compile, self.where_exprs)) + + if self.group_by_exprs: + select += " GROUP BY " + ", ".join(map(c.compile, self.group_by_exprs)) + + if self.order_by_exprs: + select += " ORDER BY " + ", ".join(map(c.compile, self.order_by_exprs)) + + if self.limit_expr is not None: + select += " " + c.database.offset_limit(0, self.limit_expr) + + if parent_c.in_select: + select = f"({select})" + return select + + @classmethod + def make(cls, table: ITable, _concat: bool = False, **kwargs): + if not isinstance(table, cls): + return cls(table, **kwargs) + + # Fill in missing attributes, instead of creating a new instance. + for k, v in kwargs.items(): + if getattr(table, k) is not None: + if _concat: + kwargs[k] = getattr(table, k) + v + else: + raise ValueError("...") + + return table.replace(**kwargs) + + +@dataclass +class Cte(ExprNode, ITable): + source_table: Expr + name: str = None + params: Sequence[str] = None + + def compile(self, parent_c: Compiler) -> str: + c = parent_c.replace(_table_context=[], in_select=False) + compiled = c.compile(self.source_table) + + name = self.name or parent_c.new_unique_name() + name_params = f"{name}({', '.join(self.params)})" if self.params else name + parent_c._subqueries[name_params] = compiled + + return name + + @property + def schema(self): + # TODO add cte to schema + return self.source_table.schema + + +def _named_exprs_as_aliases(named_exprs): + return [Alias(expr, name) for name, expr in named_exprs.items()] + + +def resolve_names(source_table, exprs): + i = 0 + for expr in exprs: + # Iterate recursively and update _ResolveColumn with the right expression + if isinstance(expr, ExprNode): + for v in expr._dfs_values(): + if isinstance(v, _ResolveColumn): + v.resolve(source_table._get_column(v.name)) + i += 1 + + +@dataclass(frozen=False, eq=False, order=False) +class _ResolveColumn(ExprNode, LazyOps): + name: str + resolved: Expr = None + + def resolve(self, expr): + assert self.resolved is None + self.resolved = expr + + def compile(self, c: Compiler) -> str: + if self.resolved is None: + raise RuntimeError(f"Column not resolved: {self.name}") + return self.resolved.compile(c) + + @property + def type(self): + if self.resolved is None: + raise RuntimeError(f"Column not resolved: {self.name}") + return self.resolved.type + + +class This: + def __getattr__(self, name): + return _ResolveColumn(name) + + def __getitem__(self, name): + if isinstance(name, list): + return [_ResolveColumn(n) for n in name] + return _ResolveColumn(name) + + +@dataclass +class Explain(ExprNode): + sql: Select + + def compile(self, c: Compiler) -> str: + return f"EXPLAIN {c.compile(self.sql)}" + + +@dataclass +class In(ExprNode): + expr: Expr + list: Sequence[Expr] + + def compile(self, c: Compiler): + elems = ", ".join(map(c.compile, self.list)) + return f"({c.compile(self.expr)} IN ({elems}))" + + +@dataclass +class Cast(ExprNode): + expr: Expr + target_type: Expr + + def compile(self, c: Compiler) -> str: + return f"cast({c.compile(self.expr)} as {c.compile(self.target_type)})" + + +@dataclass +class Random(ExprNode): + def compile(self, c: Compiler) -> str: + return c.database.random() diff --git a/data_diff/queries/base.py b/data_diff/queries/base.py new file mode 100644 index 00000000..50a57e2f --- /dev/null +++ b/data_diff/queries/base.py @@ -0,0 +1,18 @@ +from typing import Generator + +from data_diff.databases.database_types import DbPath, DbKey, Schema + + +SKIP = object() + + +class CompileError(Exception): + pass + + +def args_as_tuple(exprs): + if len(exprs) == 1: + (e,) = exprs + if isinstance(e, Generator): + return tuple(e) + return exprs diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py new file mode 100644 index 00000000..2a37d09f --- /dev/null +++ b/data_diff/queries/compiler.py @@ -0,0 +1,60 @@ +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, Sequence, List + +from runtype import dataclass + +from data_diff.databases.database_types import AbstractDialect + + +@dataclass +class Compiler: + database: AbstractDialect + in_select: bool = False # Compilation + + _table_context: List = [] # List[ITable] + _subqueries: Dict[str, Any] = {} # XXX not thread-safe + root: bool = True + + _counter: List = [0] + + def quote(self, s: str): + return self.database.quote(s) + + def compile(self, elem) -> str: + res = self._compile(elem) + if self.root and self._subqueries: + subq = ", ".join(f"\n {k} AS ({v})" for k, v in self._subqueries.items()) + self._subqueries.clear() + return f"WITH {subq}\n{res}" + return res + + def _compile(self, elem) -> str: + if elem is None: + return "NULL" + elif isinstance(elem, Compilable): + return elem.compile(self.replace(root=False)) + elif isinstance(elem, str): + return elem + elif isinstance(elem, int): + return str(elem) + elif isinstance(elem, datetime): + return self.database.timestamp_value(elem) + elif isinstance(elem, bytes): + return f"b'{elem.decode()}'" + elif isinstance(elem, ArithString): + return f"'{elem}'" + assert False, elem + + def new_unique_name(self, prefix="tmp"): + self._counter[0] += 1 + return f"{prefix}{self._counter[0]}" + + def add_table_context(self, *tables: Sequence): + return self.replace(_table_context=self._table_context + list(tables)) + + +class Compilable(ABC): + @abstractmethod + def compile(self, c: Compiler) -> str: + ... diff --git a/data_diff/queries/extras.py b/data_diff/queries/extras.py new file mode 100644 index 00000000..9b5189e1 --- /dev/null +++ b/data_diff/queries/extras.py @@ -0,0 +1,61 @@ +"Useful AST classes that don't quite fall within the scope of regular SQL" + +from typing import Callable, Sequence +from runtype import dataclass + +from data_diff.databases.database_types import ColType, Native_UUID + +from .compiler import Compiler +from .ast_classes import Expr, ExprNode, Concat + + +@dataclass +class NormalizeAsString(ExprNode): + expr: ExprNode + type: ColType = None + + def compile(self, c: Compiler) -> str: + expr = c.compile(self.expr) + return c.database.normalize_value_by_type(expr, self.type or self.expr.type) + + +@dataclass +class ApplyFuncAndNormalizeAsString(ExprNode): + expr: ExprNode + apply_func: Callable = None + + def compile(self, c: Compiler) -> str: + expr = self.expr + expr_type = expr.type + + if isinstance(expr_type, Native_UUID): + # Normalize first, apply template after (for uuids) + # Needed because min/max(uuid) fails in postgresql + expr = NormalizeAsString(expr, expr_type) + if self.apply_func is not None: + expr = self.apply_func(expr) # Apply template using Python's string formatting + + else: + # Apply template before normalizing (for ints) + if self.apply_func is not None: + expr = self.apply_func(expr) # Apply template using Python's string formatting + expr = NormalizeAsString(expr, expr_type) + + return c.compile(expr) + + +@dataclass +class Checksum(ExprNode): + exprs: Sequence[Expr] + + def compile(self, c: Compiler): + if len(self.exprs) > 1: + exprs = [f"coalesce({c.compile(expr)}, '')" for expr in self.exprs] + # exprs = [c.compile(e) for e in exprs] + expr = Concat(exprs, "|") + else: + # No need to coalesce - safe to assume that key cannot be null + (expr,) = self.exprs + expr = c.compile(expr) + md5 = c.database.md5_to_int(expr) + return f"sum({md5})" diff --git a/data_diff/sql.py b/data_diff/sql.py index 46332797..6240ca9b 100644 --- a/data_diff/sql.py +++ b/data_diff/sql.py @@ -17,8 +17,6 @@ class Sql: SqlOrStr = Union[Sql, str] -CONCAT_SEP = "|" - @dataclass class Compiler: diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 00000000..b6b90394 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,130 @@ +from cmath import exp +from typing import List, Optional +import unittest +from data_diff.databases.database_types import AbstractDialect, CaseInsensitiveDict, CaseSensitiveDict + +from data_diff.queries import this, table, Compiler, outerjoin, cte +from data_diff.queries.ast_classes import Random + + +def normalize_spaces(s: str): + return " ".join(s.split()) + + +class MockDialect(AbstractDialect): + def quote(self, s: str): + return s + + def concat(self, l: List[str]) -> str: + s = ", ".join(l) + return f"concat({s})" + + def to_string(self, s: str) -> str: + return f"cast({s} as varchar)" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"{a} is distinct from {b}" + + def random(self) -> str: + return "random()" + + def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): + x = offset and f"offset {offset}", limit and f"limit {limit}" + return " ".join(filter(None, x)) + + +class TestQuery(unittest.TestCase): + def setUp(self): + pass + + def test_basic(self): + c = Compiler(MockDialect()) + + t = table("point") + t2 = t.select(x=this.x + 1, y=t["y"] + this.x) + assert c.compile(t2) == "SELECT (x + 1) AS x, (y + x) AS y FROM point" + + t = table("point").where(this.x == 1, this.y == 2) + assert c.compile(t) == "SELECT * FROM point WHERE (x = 1) AND (y = 2)" + + t = table("point").select("x", "y") + assert c.compile(t) == "SELECT x, y FROM point" + + def test_outerjoin(self): + c = Compiler(MockDialect()) + + a = table("a") + b = table("b") + keys = ["x", "y"] + cols = ["u", "v"] + + j = outerjoin(a, b).on(a[k] == b[k] for k in keys) + + self.assertEqual( + c.compile(j), "SELECT * FROM a tmp1 FULL OUTER JOIN b tmp2 ON (tmp1.x = tmp2.x) AND (tmp1.y = tmp2.y)" + ) + + # diffed = j.select("*", **{f"is_diff_col_{c}": a[c].is_distinct_from(b[c]) for c in cols}) + + # t = diffed.select( + # **{f"total_diff_col_{c}": diffed[f"is_diff_col_{c}"].sum() for c in cols}, + # total_diff=or_(diffed[f"is_diff_col_{c}"] for c in cols).sum(), + # ) + + # print(c.compile(t)) + + # t.group_by(keys=[this.x], values=[this.py]) + + def test_schema(self): + c = Compiler(MockDialect()) + schema = dict(id="int", comment="varchar") + + t = table("a", schema=CaseInsensitiveDict(schema)) + q = t.select(this.Id, t["COMMENT"]) + assert c.compile(q) == "SELECT id, comment FROM a" + + t = table("a", schema=CaseSensitiveDict(schema)) + self.assertRaises(KeyError, t.__getitem__, "Id") + self.assertRaises(KeyError, t.select, this.Id) + + def test_commutable_select(self): + # c = Compiler(MockDialect()) + + t = table("a") + q1 = t.select("a").where("b") + q2 = t.where("b").select("a") + assert q1 == q2, (q1, q2) + + def test_cte(self): + c = Compiler(MockDialect()) + + t = table("a") + + # single cte + t2 = cte(t.select(this.x)) + t3 = t2.select(this.x) + + expected = "WITH tmp1 AS (SELECT x FROM a) SELECT x FROM tmp1" + assert normalize_spaces(c.compile(t3)) == expected + + # nested cte + c = Compiler(MockDialect()) + t4 = cte(t3).select(this.x) + + expected = "WITH tmp1 AS (SELECT x FROM a), tmp2 AS (SELECT x FROM tmp1) SELECT x FROM tmp2" + assert normalize_spaces(c.compile(t4)) == expected + + # parameterized cte + c = Compiler(MockDialect()) + t2 = cte(t.select(this.x), params=["y"]) + t3 = t2.select(this.y) + + expected = "WITH tmp1(y) AS (SELECT x FROM a) SELECT y FROM tmp1" + assert normalize_spaces(c.compile(t3)) == expected + + def test_funcs(self): + c = Compiler(MockDialect()) + t = table("a") + + q = c.compile(t.order_by(Random()).limit(10)) + assert q == "SELECT * FROM a ORDER BY random() limit 10" diff --git a/tests/test_sql.py b/tests/test_sql.py index bc4828c0..67c5637d 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -5,7 +5,6 @@ from .common import TEST_MYSQL_CONN_STRING - class TestSQL(unittest.TestCase): def setUp(self): self.mysql = connect_to_uri(TEST_MYSQL_CONN_STRING) From 70c595210ab14346eda09f6c2627639b3678bb9b Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 2 Sep 2022 17:42:38 +0200 Subject: [PATCH 02/93] data-diff now uses new 'data_diff.queries' modules instead of 'data_diff.sql' --- data_diff/databases/base.py | 8 +- data_diff/queries/api.py | 1 + data_diff/queries/ast_classes.py | 19 ++- data_diff/queries/compiler.py | 1 + data_diff/sql.py | 196 ------------------------------- data_diff/table_segment.py | 95 ++++----------- tests/test_sql.py | 45 ++++--- 7 files changed, 71 insertions(+), 294 deletions(-) delete mode 100644 data_diff/sql.py diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index fd6ec2c0..181a80e5 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -24,8 +24,10 @@ UnknownColType, Text, DbTime, + DbPath, ) -from data_diff.sql import DbPath, SqlOrStr, Compiler, Explain, Select, TableName + +from data_diff.queries import Expr, Compiler, table, Select, SKIP logger = logging.getLogger("database") @@ -87,7 +89,7 @@ class Database(AbstractDatabase): def name(self): return type(self).__name__ - def query(self, sql_ast: SqlOrStr, res_type: type): + def query(self, sql_ast: Expr, res_type: type): "Query the given SQL code/AST, and attempt to convert the result to type 'res_type'" compiler = Compiler(self) @@ -213,7 +215,7 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe return fields = [self.normalize_uuid(c, String_UUID()) for c in text_columns] - samples_by_row = self.query(Select(fields, TableName(table_path), limit=16, where=where and [where]), list) + samples_by_row = self.query(table(*table_path).select(*fields).where(where or SKIP).limit(16), list) if not samples_by_row: raise ValueError(f"Table {table_path} is empty.") diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index 7c617af4..76aaf5d2 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -31,6 +31,7 @@ def or_(*exprs: Expr): return exprs[0] return BinOp("OR", exprs) + def and_(*exprs: Expr): exprs = args_as_tuple(exprs) if len(exprs) == 1: diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index a5c1008c..019227f7 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -69,6 +69,20 @@ def where(self, *exprs): resolve_names(self.source_table, exprs) return Select.make(self, where_exprs=exprs, _concat=True) + def order_by(self, *exprs): + exprs = _drop_skips(exprs) + if not exprs: + return self + + resolve_names(self.source_table, exprs) + return Select.make(self, order_by_exprs=exprs) + + def limit(self, limit: int): + if limit is SKIP: + return self + + return Select.make(self, limit_expr=limit) + def at(self, *exprs): # TODO exprs = _drop_skips(exprs) @@ -108,7 +122,7 @@ def count(self): @dataclass class Concat(ExprNode): - args: list + exprs: list sep: str = None def compile(self, c: Compiler) -> str: @@ -122,9 +136,10 @@ def compile(self, c: Compiler) -> str: items = list(join_iter(f"'{self.sep}'", items)) return c.database.concat(items) + @dataclass class Count(ExprNode): - expr: Expr = '*' + expr: Expr = "*" distinct: bool = False def compile(self, c: Compiler) -> str: diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index 2a37d09f..8ea0e7a5 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -4,6 +4,7 @@ from runtype import dataclass +from data_diff.utils import ArithString from data_diff.databases.database_types import AbstractDialect diff --git a/data_diff/sql.py b/data_diff/sql.py deleted file mode 100644 index 6240ca9b..00000000 --- a/data_diff/sql.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Provides classes for a pseudo-SQL AST that compiles to SQL code -""" - -from typing import Sequence, Union, Optional -from datetime import datetime - -from runtype import dataclass - -from .utils import join_iter, ArithString - -from .databases.database_types import AbstractDatabase, DbPath - - -class Sql: - pass - - -SqlOrStr = Union[Sql, str] - - -@dataclass -class Compiler: - """Provides a set of utility methods for compiling SQL - - For internal use. - """ - - database: AbstractDatabase - in_select: bool = False # Compilation - - def quote(self, s: str): - return self.database.quote(s) - - def compile(self, elem): - if isinstance(elem, Sql): - return elem.compile(self) - elif isinstance(elem, str): - return elem - elif isinstance(elem, int): - return str(elem) - assert False - - -@dataclass -class TableName(Sql): - name: DbPath - - def compile(self, c: Compiler): - path = c.database._normalize_table_path(self.name) - return ".".join(map(c.quote, path)) - - -@dataclass -class ColumnName(Sql): - name: str - - def compile(self, c: Compiler): - return c.quote(self.name) - - -@dataclass -class Value(Sql): - value: object # Primitive - - def compile(self, c: Compiler): - if isinstance(self.value, bytes): - return f"b'{self.value.decode()}'" - elif isinstance(self.value, str): - return f"'{self.value}'" % self.value - elif isinstance(self.value, ArithString): - return f"'{self.value}'" - return str(self.value) - - -@dataclass -class Select(Sql): - columns: Sequence[SqlOrStr] - table: SqlOrStr = None - where: Sequence[SqlOrStr] = None - order_by: Sequence[SqlOrStr] = None - group_by: Sequence[SqlOrStr] = None - limit: int = None - - def compile(self, parent_c: Compiler): - c = parent_c.replace(in_select=True) - columns = ", ".join(map(c.compile, self.columns)) - select = f"SELECT {columns}" - - if self.table: - select += " FROM " + c.compile(self.table) - - if self.where: - select += " WHERE " + " AND ".join(map(c.compile, self.where)) - - if self.group_by: - select += " GROUP BY " + ", ".join(map(c.compile, self.group_by)) - - if self.order_by: - select += " ORDER BY " + ", ".join(map(c.compile, self.order_by)) - - if self.limit is not None: - select += " " + c.database.offset_limit(0, self.limit) - - if parent_c.in_select: - select = "(%s)" % select - return select - - -@dataclass -class Enum(Sql): - table: DbPath - order_by: SqlOrStr - - def compile(self, c: Compiler): - table = ".".join(map(c.quote, self.table)) - order = c.compile(self.order_by) - return f"(SELECT *, (row_number() over (ORDER BY {order})) as idx FROM {table} ORDER BY {order}) tmp" - - -@dataclass -class Checksum(Sql): - exprs: Sequence[SqlOrStr] - - def compile(self, c: Compiler): - if len(self.exprs) > 1: - compiled_exprs = [f"coalesce({c.compile(expr)}, '')" for expr in self.exprs] - separated = list(join_iter(f"'|'", compiled_exprs)) - expr = c.database.concat(separated) - else: - # No need to coalesce - safe to assume that key cannot be null - (expr,) = self.exprs - expr = c.compile(expr) - md5 = c.database.md5_to_int(expr) - return f"sum({md5})" - - -@dataclass -class Compare(Sql): - op: str - a: SqlOrStr - b: SqlOrStr - - def compile(self, c: Compiler): - return f"({c.compile(self.a)} {self.op} {c.compile(self.b)})" - - -@dataclass -class In(Sql): - expr: SqlOrStr - list: Sequence # List[SqlOrStr] - - def compile(self, c: Compiler): - elems = ", ".join(map(c.compile, self.list)) - return f"({c.compile(self.expr)} IN ({elems}))" - - -@dataclass -class Count(Sql): - column: Optional[SqlOrStr] = None - - def compile(self, c: Compiler): - if self.column: - return f"count({c.compile(self.column)})" - return "count(*)" - - -@dataclass -class Min(Sql): - column: SqlOrStr - - def compile(self, c: Compiler): - return f"min({c.compile(self.column)})" - - -@dataclass -class Max(Sql): - column: SqlOrStr - - def compile(self, c: Compiler): - return f"max({c.compile(self.column)})" - - -@dataclass -class Time(Sql): - time: datetime - - def compile(self, c: Compiler): - return c.database.timestamp_value(self.time) - - -@dataclass -class Explain(Sql): - sql: Select - - def compile(self, c: Compiler): - return f"EXPLAIN {c.compile(self.sql)}" diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 8b95458f..761b3a74 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -4,11 +4,11 @@ from runtype import dataclass -from .utils import ArithString, split_space, ArithAlphanumeric - +from .utils import ArithString, split_space from .databases.base import Database -from .databases.database_types import DbPath, DbKey, DbTime, Native_UUID, Schema, create_schema -from .sql import Select, Checksum, Compare, Count, TableName, Time, Value +from .databases.database_types import DbPath, DbKey, DbTime, Schema, create_schema +from .queries import Count, Checksum, SKIP, table, this, Expr, min_, max_ +from .queries.extras import ApplyFuncAndNormalizeAsString, NormalizeAsString logger = logging.getLogger("table_segment") @@ -66,38 +66,6 @@ def __post_init__(self): f"Error: min_update expected to be smaller than max_update! ({self.min_update} >= {self.max_update})" ) - @property - def _update_column(self): - return self._quote_column(self.update_column) - - def _quote_column(self, c: str) -> str: - if self._schema: - c = self._schema.get_key(c) # Get the actual name. Might be case-insensitive. - return self.database.quote(c) - - def _normalize_column(self, name: str, template: str = None) -> str: - if not self._schema: - raise RuntimeError( - "Cannot compile query when the schema is unknown. Please use TableSegment.with_schema()." - ) - - col_type = self._schema[name] - col = self._quote_column(name) - - if isinstance(col_type, Native_UUID): - # Normalize first, apply template after (for uuids) - # Needed because min/max(uuid) fails in postgresql - col = self.database.normalize_value_by_type(col, col_type) - if template is not None: - col = template % col # Apply template using Python's string formatting - return col - - # Apply template before normalizing (for ints) - if template is not None: - col = template % col # Apply template using Python's string formatting - - return self.database.normalize_value_by_type(col, col_type) - def _with_raw_schema(self, raw_schema: dict) -> "TableSegment": schema = self.database._process_table_schema(self.table_path, raw_schema, self._relevant_columns, self.where) return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive)) @@ -111,37 +79,26 @@ def with_schema(self) -> "TableSegment": def _make_key_range(self): if self.min_key is not None: - yield Compare("<=", Value(self.min_key), self._quote_column(self.key_column)) + yield self.min_key <= this[self.key_column] if self.max_key is not None: - yield Compare("<", self._quote_column(self.key_column), Value(self.max_key)) + yield this[self.key_column] < self.max_key def _make_update_range(self): if self.min_update is not None: - yield Compare("<=", Time(self.min_update), self._update_column) + yield self.min_update <= this[self.update_column] if self.max_update is not None: - yield Compare("<", self._update_column, Time(self.max_update)) - - def _make_select(self, *, table=None, columns=None, where=None, group_by=None, order_by=None): - if columns is None: - columns = [self._normalize_column(self.key_column)] - where = [ - *self._make_key_range(), - *self._make_update_range(), - *([] if where is None else [where]), - *([] if self.where is None else [self.where]), - ] - order_by = None if order_by is None else [order_by] - return Select( - table=table or TableName(self.table_path), - where=where, - columns=columns, - group_by=group_by, - order_by=order_by, - ) + yield this[self.update_column] < self.max_update + + @property + def source_table(self): + return table(*self.table_path, schema=self._schema) + + def _make_select(self): + return self.source_table.where(*self._make_key_range(), *self._make_update_range(), self.where or SKIP) def get_values(self) -> list: "Download all the relevant values of the segment from the database" - select = self._make_select(columns=self._relevant_columns_repr) + select = self._make_select().select(*self._relevant_columns_repr) return self.database.query(select, List[Tuple]) def choose_checkpoints(self, count: int) -> List[DbKey]: @@ -185,19 +142,18 @@ def _relevant_columns(self) -> List[str]: return [self.key_column] + extras @property - def _relevant_columns_repr(self) -> List[str]: - return [self._normalize_column(c) for c in self._relevant_columns] + def _relevant_columns_repr(self) -> List[Expr]: + return [NormalizeAsString(this[c]) for c in self._relevant_columns] def count(self) -> Tuple[int, int]: """Count how many rows are in the segment, in one pass.""" - return self.database.query(self._make_select(columns=[Count()]), int) + return self.database.query(self._make_select().select(Count()), int) def count_and_checksum(self) -> Tuple[int, int]: """Count and checksum the rows in the segment, in one pass.""" start = time.monotonic() - count, checksum = self.database.query( - self._make_select(columns=[Count(), Checksum(self._relevant_columns_repr)]), tuple - ) + q = self._make_select().select(Count(), Checksum(self._relevant_columns_repr)) + count, checksum = self.database.query(q, tuple) duration = time.monotonic() - start if duration > RECOMMENDED_CHECKSUM_DURATION: logger.warning( @@ -212,11 +168,10 @@ def count_and_checksum(self) -> Tuple[int, int]: def query_key_range(self) -> Tuple[int, int]: """Query database for minimum and maximum key. This is used for setting the initial bounds.""" # Normalizes the result (needed for UUIDs) after the min/max computation - select = self._make_select( - columns=[ - self._normalize_column(self.key_column, "min(%s)"), - self._normalize_column(self.key_column, "max(%s)"), - ] + # TODO better error if there is no schema + select = self._make_select().select( + ApplyFuncAndNormalizeAsString(this[self.key_column], min_), + ApplyFuncAndNormalizeAsString(this[self.key_column], max_), ) min_key, max_key = self.database.query(select, tuple) diff --git a/tests/test_sql.py b/tests/test_sql.py index 67c5637d..fe17940b 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -1,10 +1,11 @@ import unittest from data_diff.databases import connect_to_uri -from data_diff.sql import Checksum, Compare, Compiler, Count, Enum, Explain, In, Select, TableName - from .common import TEST_MYSQL_CONN_STRING +from data_diff.queries import Compiler, Count, Explain, Select, table, In, BinOp + + class TestSQL(unittest.TestCase): def setUp(self): self.mysql = connect_to_uri(TEST_MYSQL_CONN_STRING) @@ -17,7 +18,7 @@ def test_compile_int(self): self.assertEqual("1", self.compiler.compile(1)) def test_compile_table_name(self): - self.assertEqual("`marine_mammals`.`walrus`", self.compiler.compile(TableName(("marine_mammals", "walrus")))) + self.assertEqual("`marine_mammals`.`walrus`", self.compiler.compile(table("marine_mammals", "walrus"))) def test_compile_select(self): expected_sql = "SELECT name FROM `marine_mammals`.`walrus`" @@ -25,23 +26,23 @@ def test_compile_select(self): expected_sql, self.compiler.compile( Select( + table("marine_mammals", "walrus"), ["name"], - TableName(("marine_mammals", "walrus")), ) ), ) - def test_enum(self): - expected_sql = "(SELECT *, (row_number() over (ORDER BY id)) as idx FROM `walrus` ORDER BY id) tmp" - self.assertEqual( - expected_sql, - self.compiler.compile( - Enum( - ("walrus",), - "id", - ) - ), - ) + # def test_enum(self): + # expected_sql = "(SELECT *, (row_number() over (ORDER BY id)) as idx FROM `walrus` ORDER BY id) tmp" + # self.assertEqual( + # expected_sql, + # self.compiler.compile( + # Enum( + # ("walrus",), + # "id", + # ) + # ), + # ) # def test_checksum(self): # expected_sql = "SELECT name, sum(cast(conv(substring(md5(concat(cast(id as char), cast(timestamp as char))), 18), 16, 10) as unsigned)) FROM `marine_mammals`.`walrus`" @@ -61,9 +62,9 @@ def test_compare(self): expected_sql, self.compiler.compile( Select( + table("marine_mammals", "walrus"), ["name"], - TableName(("marine_mammals", "walrus")), - [Compare("<=", "id", "1000"), Compare(">", "id", "1")], + [BinOp("<=", ["id", "1000"]), BinOp(">", ["id", "1"])], ) ), ) @@ -72,23 +73,21 @@ def test_in(self): expected_sql = "SELECT name FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" self.assertEqual( expected_sql, - self.compiler.compile(Select(["name"], TableName(("marine_mammals", "walrus")), [In("id", [1, 2, 3])])), + self.compiler.compile(Select(table("marine_mammals", "walrus"), ["name"], [In("id", [1, 2, 3])])), ) def test_count(self): expected_sql = "SELECT count(*) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" self.assertEqual( expected_sql, - self.compiler.compile(Select([Count()], TableName(("marine_mammals", "walrus")), [In("id", [1, 2, 3])])), + self.compiler.compile(Select(table("marine_mammals", "walrus"), [Count()], [In("id", [1, 2, 3])])), ) def test_count_with_column(self): expected_sql = "SELECT count(id) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" self.assertEqual( expected_sql, - self.compiler.compile( - Select([Count("id")], TableName(("marine_mammals", "walrus")), [In("id", [1, 2, 3])]) - ), + self.compiler.compile(Select(table("marine_mammals", "walrus"), [Count("id")], [In("id", [1, 2, 3])])), ) def test_explain(self): @@ -96,6 +95,6 @@ def test_explain(self): self.assertEqual( expected_sql, self.compiler.compile( - Explain(Select([Count("id")], TableName(("marine_mammals", "walrus")), [In("id", [1, 2, 3])])) + Explain(Select(table("marine_mammals", "walrus"), [Count("id")], [In("id", [1, 2, 3])])) ), ) From e5ace37250ca4e8ab341d49c5954bafb9df91c54 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 8 Sep 2022 14:09:56 +0200 Subject: [PATCH 03/93] Join-diff implementation --- data_diff/__init__.py | 3 - data_diff/__main__.py | 19 +++- data_diff/diff_tables.py | 79 ++++++++----- data_diff/joindiff_tables.py | 207 +++++++++++++++++++++++++++++++++++ tests/test_diff_tables.py | 8 +- tests/test_joindiff.py | 168 ++++++++++++++++++++++++++++ 6 files changed, 441 insertions(+), 43 deletions(-) create mode 100644 data_diff/joindiff_tables.py create mode 100644 tests/test_joindiff.py diff --git a/data_diff/__init__.py b/data_diff/__init__.py index bc5677c8..2af199db 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -55,8 +55,6 @@ def diff_tables( # Maximum size of each threadpool. None = auto. Only relevant when threaded is True. # There may be many pools, so number of actual threads can be a lot higher. max_threadpool_size: Optional[int] = 1, - # Enable/disable debug prints - debug: bool = False, ) -> Iterator: """Efficiently finds the diff between table1 and table2. @@ -86,7 +84,6 @@ def diff_tables( differ = TableDiffer( bisection_factor=bisection_factor, bisection_threshold=bisection_threshold, - debug=debug, threaded=threaded, max_threadpool_size=max_threadpool_size, ) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index bccd132f..a5715397 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -192,12 +192,19 @@ def _main( logging.error(f"Error while parsing age expression: {e}") return - differ = TableDiffer( - bisection_factor=bisection_factor, - bisection_threshold=bisection_threshold, - threaded=threaded, - max_threadpool_size=threads and threads * 2, - ) + if algorithm == Algorithm.JOINDIFF: + differ = JoinDiffer( + threaded=threaded, + max_threadpool_size=threads and threads * 2, + ) + else: + assert algorithm == Algorithm.HASHDIFF + differ = TableDiffer( + bisection_factor=bisection_factor, + bisection_threshold=bisection_threshold, + threaded=threaded, + max_threadpool_size=threads and threads * 2, + ) if database1 is None or database2 is None: logging.error( diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 96c6c624..7f67cfc2 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -1,7 +1,9 @@ """Provides classes for performing a table diff """ +from contextlib import contextmanager import time +import threading import os from numbers import Number from operator import attrgetter, methodcaller @@ -44,7 +46,53 @@ def diff_sets(a: set, b: set) -> Iterator: @dataclass -class TableDiffer: +class ThreadBase: + "Provides utility methods for optional threading" + + threaded: bool = True + max_threadpool_size: Optional[int] = 1 + + def _thread_map(self, func, iterable): + if not self.threaded: + return map(func, iterable) + + with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool: + return task_pool.map(func, iterable) + + def _threaded_call(self, func, iterable): + "Calls a method for each object in iterable." + return list(self._thread_map(methodcaller(func), iterable)) + + def _thread_as_completed(self, func, iterable): + if not self.threaded: + yield from map(func, iterable) + return + + with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool: + futures = [task_pool.submit(func, item) for item in iterable] + for future in as_completed(futures): + yield future.result() + + def _threaded_call_as_completed(self, func, iterable): + "Calls a method for each object in iterable. Returned in order of completion." + return self._thread_as_completed(methodcaller(func), iterable) + + def _run_thread(self, threadfunc, *args, daemon=False) -> threading.Thread: + th = threading.Thread(target=threadfunc, args=args) + if daemon: + th.daemon = True + th.start() + return th + + @contextmanager + def _run_in_background(self, threadfunc, *args, daemon=False): + t = self._run_thread(threadfunc, *args, daemon=daemon) + yield t + t.join() + + +@dataclass +class TableDiffer(ThreadBase): """Finds the diff between two SQL tables The algorithm uses hashing to quickly check if the tables are different, and then applies a @@ -62,11 +110,6 @@ class TableDiffer: bisection_factor: int = DEFAULT_BISECTION_FACTOR bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests - threaded: bool = True - max_threadpool_size: Optional[int] = 1 - - # Enable/disable debug prints - debug: bool = False stats: dict = {} @@ -291,27 +334,3 @@ def _diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableS if checksum1 != checksum2: return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max(count1, count2)) - def _thread_map(self, func, iterable): - if not self.threaded: - return map(func, iterable) - - with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool: - return task_pool.map(func, iterable) - - def _threaded_call(self, func, iterable): - "Calls a method for each object in iterable." - return list(self._thread_map(methodcaller(func), iterable)) - - def _thread_as_completed(self, func, iterable): - if not self.threaded: - yield from map(func, iterable) - return - - with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool: - futures = [task_pool.submit(func, item) for item in iterable] - for future in as_completed(futures): - yield future.result() - - def _threaded_call_as_completed(self, func, iterable): - "Calls a method for each object in iterable. Returned in order of completion." - return self._thread_as_completed(methodcaller(func), iterable) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py new file mode 100644 index 00000000..7f684ebb --- /dev/null +++ b/data_diff/joindiff_tables.py @@ -0,0 +1,207 @@ +"""Provides classes for performing a table diff using JOIN + +""" + +from decimal import Decimal +import logging +from contextlib import contextmanager +from typing import Dict, List + +from runtype import dataclass + +from .utils import safezip +from .databases.base import Database +from .table_segment import TableSegment +from .diff_tables import ThreadBase, DiffResult + +from .queries import table, sum_, min_, max_, avg +from .queries.api import and_, if_, or_, outerjoin, this +from .queries.ast_classes import Concat, Count, Expr, Random +from .queries.compiler import Compiler +from .queries.extras import NormalizeAsString + + +logger = logging.getLogger("joindiff_tables") + + +def merge_dicts(dicts): + i = iter(dicts) + res = next(i) + for d in i: + res.update(d) + return res + + +@dataclass(frozen=False) +class Stats: + exclusive_count: int + exclusive_sample: List[tuple] + diff_ratio_by_column: Dict[str, float] + diff_ratio_total: float + metrics: Dict[str, float] + + +def sample(table): + # TODO + return table.order_by(Random()).limit(10) + + +@contextmanager +def temp_table(db: Database, expr: Expr): + c = Compiler(db) + name = c.new_unique_name("tmp_table") + db.query(f"create temporary table {c.quote(name)} as {c.compile(expr)}", None) + try: + yield table(name, schema=expr.source_table.schema) + finally: + db.query(f"drop table {c.quote(name)}", None) + + +def _slice_tuple(t, *sizes): + i = 0 + for size in sizes: + yield t[i : i + size] + i += size + assert i == len(t) + + +def json_friendly_value(v): + if isinstance(v, Decimal): + return float(v) + return v + + +@dataclass +class JoinDifferBase(ThreadBase): + """Finds the diff between two SQL tables using JOINs""" + + stats: dict = {} + + def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: + table1, table2 = self._threaded_call("with_schema", [table1, table2]) + + if table1.database is not table2.database: + raise ValueError("Join-diff only works when both tables are in the same database") + + with self._run_in_background(self._test_null_or_duplicate_keys, table1, table2): + with self._run_in_background(self._collect_stats, 1, table1): + with self._run_in_background(self._collect_stats, 2, table2): + yield from self._outer_join(table1, table2) + + logger.info("Diffing complete") + + def _test_null_or_duplicate_keys(self, table1, table2): + logger.info("Testing for null or duplicate keys") + + # Test null or duplicate keys + for ts in [table1, table2]: + t = table(*ts.table_path, schema=ts._schema) + key_columns = [ts.key_column] # XXX + + q = t.select(total=Count(), total_distinct=Count(Concat(key_columns), distinct=True)) + total, total_distinct = ts.database.query(q, tuple) + if total != total_distinct: + raise ValueError("Duplicate primary keys") + + q = t.select(*key_columns).where(or_(this[k] == None for k in key_columns)) + nulls = ts.database.query(q, list) + if nulls: + raise ValueError(f"NULL values in one or more primary keys: {nulls}") + + logger.debug("Done testing for null or duplicate keys") + + def _collect_stats(self, i, table): + logger.info(f"Collecting stats for table #{i}") + db = table.database + + # Metrics + col_exprs = merge_dicts( + { + f"sum_{c}": sum_(c), + f"avg_{c}": avg(c), + f"min_{c}": min_(c), + f"max_{c}": max_(c), + } + for c in table._relevant_columns + if c == "id" # TODO just if the right type + ) + col_exprs["count"] = Count() + + res = db.query(table._make_select().select(**col_exprs), tuple) + res = dict(zip([f"table{i}_{n}" for n in col_exprs], map(json_friendly_value, res))) + self.stats.update(res) + + logger.debug(f"Done collecting stats for table #{i}") + + # stats.diff_ratio_by_column = diff_stats + # stats.diff_ratio_total = diff_stats['total_diff'] + + +def bool_to_int(x): + return if_(x, 1, 0) + + +class JoinDiffer(JoinDifferBase): + def _outer_join(self, table1, table2): + db = table1.database + if db is not table2.database: + raise ValueError("Joindiff only applies to tables within the same database") + + keys1 = [table1.key_column] # XXX + keys2 = [table2.key_column] # XXX + if len(keys1) != len(keys2): + raise ValueError("The provided key columns are of a different count") + + cols1 = table1._relevant_columns + cols2 = table2._relevant_columns + if len(cols1) != len(cols2): + raise ValueError("The provided columns are of a different count") + + a = table1._make_select() + b = table2._make_select() + + is_diff_cols = { + f"is_diff_col_{c1}": bool_to_int(a[c1].is_distinct_from(b[c2])) for c1, c2 in safezip(cols1, cols2) + } + + a_cols = {f"table1_{c}": NormalizeAsString(a[c]) for c in cols1} + b_cols = {f"table2_{c}": NormalizeAsString(b[c]) for c in cols2} + + diff_rows = ( + outerjoin(a, b) + .on(a[k1] == b[k2] for k1, k2 in safezip(keys1, keys2)) + .select( + is_exclusive_a=and_(b[k] == None for k in keys2), + is_exclusive_b=and_(a[k] == None for k in keys1), + **is_diff_cols, + **a_cols, + **b_cols, + ) + .where(or_(this[c] == 1 for c in is_diff_cols)) + ) + + with self._run_in_background(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols): + with self._run_in_background(self._count_diff_per_column, db, diff_rows, is_diff_cols): + + logger.info("Querying for different rows") + for is_xa, is_xb, *x in db.query(diff_rows, list): + assert not (is_xa and is_xb) # Can't both be exclusive + is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols)) + if not is_xb: + yield "-", tuple(a_row) + if not is_xa: + yield "+", tuple(b_row) + + def _count_diff_per_column(self, db, diff_rows, is_diff_cols): + logger.info("Counting differences per column") + is_diff_cols_counts = db.query(diff_rows.select(sum_(this[c]) for c in is_diff_cols), tuple) + for name, count in safezip(is_diff_cols, is_diff_cols_counts): + self.stats[f"count_{name}"] = count + + def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): + logger.info("Counting and sampling exclusive rows") + exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b) + with temp_table(db, exclusive_rows_query) as exclusive_rows: + self.stats["exclusive_count"] = db.query(exclusive_rows.count(), int) + sample_rows = db.query(sample(exclusive_rows.select(*this[list(a_cols)], *this[list(b_cols)])), list) + self.stats["exclusive_sample"] = sample_rows diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index de0cde5d..3ac37bd0 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -176,7 +176,7 @@ def test_init(self): ) def test_basic(self): - differ = TableDiffer(10, 100) + differ = TableDiffer(bisection_factor=10, bisection_threshold=100) a = TableSegment(self.connection, self.table_src_path, "id", "datetime", case_sensitive=False) b = TableSegment(self.connection, self.table_dst_path, "id", "datetime", case_sensitive=False) assert a.count() == 6 @@ -186,7 +186,7 @@ def test_basic(self): self.assertEqual(len(list(differ.diff_tables(a, b))), 1) def test_offset(self): - differ = TableDiffer(2, 10) + differ = TableDiffer(bisection_factor=2, bisection_threshold=10) sec1 = self.now.shift(seconds=-1).datetime a = TableSegment(self.connection, self.table_src_path, "id", "datetime", max_update=sec1, case_sensitive=False) b = TableSegment(self.connection, self.table_dst_path, "id", "datetime", max_update=sec1, case_sensitive=False) @@ -250,7 +250,7 @@ def setUp(self): self.table = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) self.table2 = TableSegment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) - self.differ = TableDiffer(3, 4) + self.differ = TableDiffer(bisection_factor=3, bisection_threshold=4) def test_properties_on_empty_table(self): table = self.table.with_schema() @@ -287,7 +287,7 @@ def test_diff_small_tables(self): self.assertEqual(1, self.differ.stats["table2_count"]) def test_non_threaded(self): - differ = TableDiffer(3, 4, threaded=False) + differ = TableDiffer(bisection_factor=3, bisection_threshold=4, threaded=False) time = "2022-01-01 00:00:00" time_str = f"timestamp '{time}'" diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py new file mode 100644 index 00000000..72d604cd --- /dev/null +++ b/tests/test_joindiff.py @@ -0,0 +1,168 @@ +from parameterized import parameterized_class + +from data_diff.databases.connect import connect +from data_diff.table_segment import TableSegment, split_space +from data_diff import databases as db +from data_diff.utils import ArithAlphanumeric +from data_diff.joindiff_tables import JoinDiffer + +from .test_diff_tables import TestPerDatabase, _get_float_type, _get_text_type, _commit, _insert_row, _insert_rows + +from .common import ( + str_to_checksum, + CONN_STRINGS, + N_THREADS, +) + +DATABASE_INSTANCES = None +DATABASE_URIS = {k.__name__: v for k, v in CONN_STRINGS.items()} + + +def init_instances(): + global DATABASE_INSTANCES + if DATABASE_INSTANCES is not None: + return + + DATABASE_INSTANCES = {k.__name__: connect(v, N_THREADS) for k, v in CONN_STRINGS.items()} + + +TEST_DATABASES = {x.__name__ for x in (db.PostgreSQL,)} + +_class_per_db_dec = parameterized_class( + ("name", "db_name"), [(name, name) for name in DATABASE_URIS if name in TEST_DATABASES] +) + + +def test_per_database(cls): + return _class_per_db_dec(cls) + + +@test_per_database +class TestJoindiff(TestPerDatabase): + def setUp(self): + super().setUp() + + float_type = _get_float_type(self.connection) + + self.connection.query( + f"create table {self.table_src}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", + None, + ) + self.connection.query( + f"create table {self.table_dst}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", + None, + ) + _commit(self.connection) + + self.table = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) + self.table2 = TableSegment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) + + self.differ = JoinDiffer() + + def test_diff_small_tables(self): + time = "2022-01-01 00:00:00" + time_str = f"timestamp '{time}'" + + cols = "id userid movieid rating timestamp".split() + _insert_rows(self.connection, self.table_src, cols, [[1, 1, 1, 9, time_str], [2, 2, 2, 9, time_str]]) + _insert_rows(self.connection, self.table_dst, cols, [[1, 1, 1, 9, time_str]]) + _commit(self.connection) + diff = list(self.differ.diff_tables(self.table, self.table2)) + expected = [("-", ("2", time + ".000000"))] + self.assertEqual(expected, diff) + self.assertEqual(2, self.differ.stats["table1_count"]) + self.assertEqual(1, self.differ.stats["table2_count"]) + + def test_diff_table_above_bisection_threshold(self): + time = "2022-01-01 00:00:00" + time_str = f"timestamp '{time}'" + + cols = "id userid movieid rating timestamp".split() + _insert_rows( + self.connection, + self.table_src, + cols, + [ + [1, 1, 1, 9, time_str], + [2, 2, 2, 9, time_str], + [3, 3, 3, 9, time_str], + [4, 4, 4, 9, time_str], + [5, 5, 5, 9, time_str], + ], + ) + + _insert_rows( + self.connection, + self.table_dst, + cols, + [ + [1, 1, 1, 9, time_str], + [2, 2, 2, 9, time_str], + [3, 3, 3, 9, time_str], + [4, 4, 4, 9, time_str], + ], + ) + _commit(self.connection) + + diff = list(self.differ.diff_tables(self.table, self.table2)) + expected = [("-", ("5", time + ".000000"))] + self.assertEqual(expected, diff) + self.assertEqual(5, self.differ.stats["table1_count"]) + self.assertEqual(4, self.differ.stats["table2_count"]) + + def test_return_empty_array_when_same(self): + time = "2022-01-01 00:00:00" + time_str = f"timestamp '{time}'" + + cols = "id userid movieid rating timestamp".split() + + _insert_row(self.connection, self.table_src, cols, [1, 1, 1, 9, time_str]) + _insert_row(self.connection, self.table_dst, cols, [1, 1, 1, 9, time_str]) + + diff = list(self.differ.diff_tables(self.table, self.table2)) + self.assertEqual([], diff) + + def test_diff_sorted_by_key(self): + time = "2022-01-01 00:00:00" + time2 = "2021-01-01 00:00:00" + + time_str = f"timestamp '{time}'" + time_str2 = f"timestamp '{time2}'" + + cols = "id userid movieid rating timestamp".split() + + _insert_rows( + self.connection, + self.table_src, + cols, + [ + [1, 1, 1, 9, time_str], + [2, 2, 2, 9, time_str2], + [3, 3, 3, 9, time_str], + [4, 4, 4, 9, time_str2], + [5, 5, 5, 9, time_str], + ], + ) + + _insert_rows( + self.connection, + self.table_dst, + cols, + [ + [1, 1, 1, 9, time_str], + [2, 2, 2, 9, time_str], + [3, 3, 3, 9, time_str], + [4, 4, 4, 9, time_str], + [5, 5, 5, 9, time_str], + ], + ) + _commit(self.connection) + + diff = list(self.differ.diff_tables(self.table, self.table2)) + expected = [ + ("-", ("2", time2 + ".000000")), + ("+", ("2", time + ".000000")), + ("-", ("4", time2 + ".000000")), + ("+", ("4", time + ".000000")), + ] + self.assertEqual(expected, diff) From b6170bf8c64916126173cc25a4c1c67691f1e314 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Sat, 10 Sep 2022 11:51:07 +0300 Subject: [PATCH 04/93] Integrate joindiff into main --- data_diff/__main__.py | 76 ++++++++++++++++++++++++--- data_diff/databases/database_types.py | 2 +- data_diff/diff_tables.py | 10 ++-- 3 files changed, 74 insertions(+), 14 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index a5715397..7ad156ea 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -1,13 +1,17 @@ from copy import deepcopy +from enum import Enum import sys import time import json import logging from itertools import islice +from typing import Optional import rich import click +from data_diff.joindiff_tables import JoinDiffer + from .utils import remove_password_from_url, safezip, match_like from .diff_tables import TableDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR @@ -28,6 +32,12 @@ } +class Algorithm(Enum): + AUTO = "auto" + JOINDIFF = "joindiff" + HASHDIFF = "hashdiff" + + def _remove_passwords_in_dict(d: dict): for k, v in d.items(): if k == "password": @@ -43,13 +53,30 @@ def _get_schema(pair): return db.query_table_schema(table_path) -@click.command() +class MyHelpFormatter(click.HelpFormatter): + def __init__(self, **kwargs): + super().__init__(self, **kwargs) + self.indent_increment = 6 + + def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -> None: + self.write(f"data-diff - efficiently diff rows across database tables.\n\n") + self.write(f"Usage:\n") + self.write(f" * In-db diff: {prog} [OPTIONS]\n") + self.write(f" * Cross-db diff: {prog} [OPTIONS]\n") + self.write(f" * Using config: {prog} --conf PATH [--run NAME] [OPTIONS]\n") + # s = super().write_usage(prog, args, prefix) + + +click.Context.formatter_class = MyHelpFormatter + + +@click.command(no_args_is_help=True) @click.argument("database1", required=False) @click.argument("table1", required=False) @click.argument("database2", required=False) @click.argument("table2", required=False) -@click.option("-k", "--key-column", default=None, help="Name of primary key column. Default='id'.") -@click.option("-t", "--update-column", default=None, help="Name of updated_at/last_updated column") +@click.option("-k", "--key-column", default=None, help="Name of primary key column. Default='id'.", metavar="NAME") +@click.option("-t", "--update-column", default=None, help="Name of updated_at/last_updated column", metavar="NAME") @click.option( "-c", "--columns", @@ -58,13 +85,20 @@ def _get_schema(pair): help="Names of extra columns to compare." "Can be used more than once in the same command. " "Accepts a name or a pattern like in SQL. Example: -c col% -c another_col", + metavar="NAME", +) +@click.option("-l", "--limit", default=None, help="Maximum number of differences to find", metavar="NUM") +@click.option( + "--bisection-factor", + default=None, + help=f"Segments per iteration. Default={DEFAULT_BISECTION_FACTOR}.", + metavar="NUM", ) -@click.option("-l", "--limit", default=None, help="Maximum number of differences to find") -@click.option("--bisection-factor", default=None, help=f"Segments per iteration. Default={DEFAULT_BISECTION_FACTOR}.") @click.option( "--bisection-threshold", default=None, help=f"Minimal bisection threshold. Below it, data-diff will download the data and compare it locally. Default={DEFAULT_BISECTION_THRESHOLD}.", + metavar="NUM", ) @click.option( "--min-age", @@ -72,8 +106,11 @@ def _get_schema(pair): help="Considers only rows older than specified. Useful for specifying replication lag." "Example: --min-age=5min ignores rows from the last 5 minutes. " f"\nValid units: {UNITS_STR}", + metavar="AGE", +) +@click.option( + "--max-age", default=None, help="Considers only rows younger than specified. See --min-age.", metavar="AGE" ) -@click.option("--max-age", default=None, help="Considers only rows younger than specified. See --min-age.") @click.option("-s", "--stats", is_flag=True, help="Print stats instead of a detailed diff") @click.option("-d", "--debug", is_flag=True, help="Print debug info") @click.option("--json", "json_output", is_flag=True, help="Print JSONL output for machine readability") @@ -92,21 +129,39 @@ def _get_schema(pair): help="Number of worker threads to use per database. Default=1. " "A higher number will increase performance, but take more capacity from your database. " "'serial' guarantees a single-threaded execution of the algorithm (useful for debugging).", + metavar="COUNT", +) +@click.option( + "-w", "--where", default=None, help="An additional 'where' expression to restrict the search space.", metavar="EXPR" ) -@click.option("-w", "--where", default=None, help="An additional 'where' expression to restrict the search space.") +@click.option("-a", "--algorithm", default=Algorithm.AUTO.value, type=click.Choice([i.value for i in Algorithm])) @click.option( "--conf", default=None, help="Path to a configuration.toml file, to provide a default configuration, and a list of possible runs.", + metavar="PATH", ) @click.option( "--run", default=None, help="Name of run-configuration to run. If used, CLI arguments for database and table must be omitted.", + metavar="NAME", ) def main(conf, run, **kw): + indb_syntax = False + if kw["table2"] is None and kw["database2"]: + # Use the "database table table" form + kw["table2"] = kw["database2"] + kw["database2"] = kw["database1"] + indb_syntax = True + if conf: kw = apply_config_from_file(conf, run, kw) + + kw["algorithm"] = Algorithm(kw["algorithm"]) + if kw["algorithm"] == Algorithm.AUTO: + kw["algorithm"] = Algorithm.JOINDIFF if indb_syntax else Algorithm.HASHDIFF + return _main(**kw) @@ -119,6 +174,7 @@ def _main( update_column, columns, limit, + algorithm, bisection_factor, bisection_threshold, min_age, @@ -214,7 +270,10 @@ def _main( try: db1 = connect(database1, threads1 or threads) - db2 = connect(database2, threads2 or threads) + if database1 == database2: + db2 = db1 + else: + db2 = connect(database2, threads2 or threads) except Exception as e: logging.error(e) return @@ -277,6 +336,7 @@ def _main( "different_+": plus, "different_-": minus, "total": max_table_count, + "stats": differ.stats, } print(json.dumps(json_output)) else: diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index ca2734fc..1e9c973e 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -172,6 +172,7 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None "Provide SQL fragment for limit and offset inside a select" ... + class AbstractDatabase(AbstractDialect): @abstractmethod def timestamp_value(self, t: DbTime) -> str: @@ -183,7 +184,6 @@ def md5_to_int(self, s: str) -> str: "Provide SQL for computing md5 and returning an int" ... - @abstractmethod def _query(self, sql_code: str) -> list: "Send query to database and return result" diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 7f67cfc2..a98a6508 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -78,11 +78,11 @@ def _threaded_call_as_completed(self, func, iterable): return self._thread_as_completed(methodcaller(func), iterable) def _run_thread(self, threadfunc, *args, daemon=False) -> threading.Thread: - th = threading.Thread(target=threadfunc, args=args) - if daemon: - th.daemon = True - th.start() - return th + th = threading.Thread(target=threadfunc, args=args) + if daemon: + th.daemon = True + th.start() + return th @contextmanager def _run_in_background(self, threadfunc, *args, daemon=False): From becf36c6ddda0fbfc8f7a51c0077b4388657d428 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 14 Sep 2022 11:20:36 +0300 Subject: [PATCH 05/93] Refactor diff_tables.TableDiffer -> hashdiff_tables.HashDiffer --- data_diff/__init__.py | 33 ++-- data_diff/__main__.py | 17 +-- data_diff/diff_tables.py | 282 +--------------------------------- data_diff/hashdiff_tables.py | 283 +++++++++++++++++++++++++++++++++++ tests/common.py | 3 +- tests/test_database_types.py | 7 +- tests/test_diff_tables.py | 33 ++-- tests/test_postgresql.py | 5 +- 8 files changed, 341 insertions(+), 322 deletions(-) create mode 100644 data_diff/hashdiff_tables.py diff --git a/data_diff/__init__.py b/data_diff/__init__.py index 2af199db..f22ab039 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -3,7 +3,10 @@ from .tracking import disable_tracking from .databases.connect import connect from .databases.database_types import DbKey, DbTime, DbPath -from .diff_tables import TableSegment, TableDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR +from .diff_tables import Algorithm +from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR +from .joindiff_tables import JoinDiffer +from .table_segment import TableSegment def connect_to_table( @@ -46,9 +49,11 @@ def diff_tables( # Start/end update_column values, used to restrict the segment min_update: DbTime = None, max_update: DbTime = None, - # Into how many segments to bisect per iteration + # Algorithm + algorithm: Algorithm = Algorithm.HASHDIFF, + # Into how many segments to bisect per iteration (hashdiff only) bisection_factor: int = DEFAULT_BISECTION_FACTOR, - # When should we stop bisecting and compare locally (in row count) + # When should we stop bisecting and compare locally (in row count; hashdiff only) bisection_threshold: int = DEFAULT_BISECTION_THRESHOLD, # Enable/disable threaded diffing. Needed to take advantage of database threads. threaded: bool = True, @@ -81,10 +86,20 @@ def diff_tables( segments = [t.new(**override_attrs) for t in tables] if override_attrs else tables - differ = TableDiffer( - bisection_factor=bisection_factor, - bisection_threshold=bisection_threshold, - threaded=threaded, - max_threadpool_size=max_threadpool_size, - ) + algorithm = Algorithm(algorithm) + if algorithm == Algorithm.HASHDIFF: + differ = HashDiffer( + bisection_factor=bisection_factor, + bisection_threshold=bisection_threshold, + threaded=threaded, + max_threadpool_size=max_threadpool_size, + ) + elif algorithm == Algorithm.JOINDIFF: + differ = JoinDiffer( + threaded=threaded, + max_threadpool_size=max_threadpool_size, + ) + else: + raise ValueError(f"Unknown algorithm: {algorithm}") + return differ.diff_tables(*segments) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 7ad156ea..adb5bee9 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -1,5 +1,4 @@ from copy import deepcopy -from enum import Enum import sys import time import json @@ -10,11 +9,10 @@ import rich import click -from data_diff.joindiff_tables import JoinDiffer - - from .utils import remove_password_from_url, safezip, match_like -from .diff_tables import TableDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR +from .diff_tables import Algorithm +from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR +from .joindiff_tables import JoinDiffer from .table_segment import TableSegment from .databases.database_types import create_schema from .databases.connect import connect @@ -32,12 +30,6 @@ } -class Algorithm(Enum): - AUTO = "auto" - JOINDIFF = "joindiff" - HASHDIFF = "hashdiff" - - def _remove_passwords_in_dict(d: dict): for k, v in d.items(): if k == "password": @@ -64,7 +56,6 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - self.write(f" * In-db diff: {prog} [OPTIONS]\n") self.write(f" * Cross-db diff: {prog} [OPTIONS]\n") self.write(f" * Using config: {prog} --conf PATH [--run NAME] [OPTIONS]\n") - # s = super().write_usage(prog, args, prefix) click.Context.formatter_class = MyHelpFormatter @@ -255,7 +246,7 @@ def _main( ) else: assert algorithm == Algorithm.HASHDIFF - differ = TableDiffer( + differ = HashDiffer( bisection_factor=bisection_factor, bisection_threshold=bisection_threshold, threaded=threaded, diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index a98a6508..04a95fe7 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -1,45 +1,20 @@ """Provides classes for performing a table diff """ +from enum import Enum from contextlib import contextmanager -import time import threading -import os -from numbers import Number -from operator import attrgetter, methodcaller -from collections import defaultdict +from operator import methodcaller from typing import Tuple, Iterator, Optional -import logging from concurrent.futures import ThreadPoolExecutor, as_completed from runtype import dataclass -from .utils import safezip, run_as_daemon -from .thread_utils import ThreadedYielder -from .databases.database_types import IKey, NumericType, PrecisionType, StringType -from .table_segment import TableSegment -from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled -logger = logging.getLogger("diff_tables") - -BENCHMARK = os.environ.get("BENCHMARK", False) -DEFAULT_BISECTION_THRESHOLD = 1024 * 16 -DEFAULT_BISECTION_FACTOR = 32 - - -def diff_sets(a: set, b: set) -> Iterator: - s1 = set(a) - s2 = set(b) - d = defaultdict(list) - - # The first item is always the key (see TableDiffer._relevant_columns) - for i in s1 - s2: - d[i[0]].append(("-", i)) - for i in s2 - s1: - d[i[0]].append(("+", i)) - - for _k, v in sorted(d.items(), key=lambda i: i[0]): - yield from v +class Algorithm(Enum): + AUTO = "auto" + JOINDIFF = "joindiff" + HASHDIFF = "hashdiff" DiffResult = Iterator[Tuple[str, tuple]] # Iterator[Tuple[Literal["+", "-"], tuple]] @@ -89,248 +64,3 @@ def _run_in_background(self, threadfunc, *args, daemon=False): t = self._run_thread(threadfunc, *args, daemon=daemon) yield t t.join() - - -@dataclass -class TableDiffer(ThreadBase): - """Finds the diff between two SQL tables - - The algorithm uses hashing to quickly check if the tables are different, and then applies a - bisection search recursively to find the differences efficiently. - - Works best for comparing tables that are mostly the same, with minor discrepencies. - - Parameters: - bisection_factor (int): Into how many segments to bisect per iteration. - bisection_threshold (Number): When should we stop bisecting and compare locally (in row count). - threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. - max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``. - There may be many pools, so number of actual threads can be a lot higher. - """ - - bisection_factor: int = DEFAULT_BISECTION_FACTOR - bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests - - stats: dict = {} - - def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: - """Diff the given tables. - - Parameters: - table1 (TableSegment): The "before" table to compare. Or: source table - table2 (TableSegment): The "after" table to compare. Or: target table - - Returns: - An iterator that yield pair-tuples, representing the diff. Items can be either - ('-', columns) for items in table1 but not in table2 - ('+', columns) for items in table2 but not in table1 - Where `columns` is a tuple of values for the involved columns, i.e. (id, ...extra) - """ - # Validate options - if self.bisection_factor >= self.bisection_threshold: - raise ValueError("Incorrect param values (bisection factor must be lower than threshold)") - if self.bisection_factor < 2: - raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)") - - if is_tracking_enabled(): - options = dict(self) - event_json = create_start_event_json(options) - run_as_daemon(send_event_json, event_json) - - self.stats["diff_count"] = 0 - start = time.monotonic() - error = None - try: - - # Query and validate schema - table1, table2 = self._threaded_call("with_schema", [table1, table2]) - self._validate_and_adjust_columns(table1, table2) - - key_type = table1._schema[table1.key_column] - key_type2 = table2._schema[table2.key_column] - if not isinstance(key_type, IKey): - raise NotImplementedError(f"Cannot use column of type {key_type} as a key") - if not isinstance(key_type2, IKey): - raise NotImplementedError(f"Cannot use column of type {key_type2} as a key") - assert key_type.python_type is key_type2.python_type - - # Query min/max values - key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2]) - - # Start with the first completed value, so we don't waste time waiting - min_key1, max_key1 = self._parse_key_range_result(key_type, next(key_ranges)) - - table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)] - - logger.info( - f"Diffing tables | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}. " - f"key-range: {table1.min_key}..{table2.max_key}, " - f"size: table1 <= {table1.approximate_size()}, table2 <= {table2.approximate_size()}" - ) - - ti = ThreadedYielder(self.max_threadpool_size) - # Bisect (split) the table into segments, and diff them recursively. - ti.submit(self._bisect_and_diff_tables, ti, table1, table2) - - # Now we check for the second min-max, to diff the portions we "missed". - min_key2, max_key2 = self._parse_key_range_result(key_type, next(key_ranges)) - - if min_key2 < min_key1: - pre_tables = [t.new(min_key=min_key2, max_key=min_key1) for t in (table1, table2)] - ti.submit(self._bisect_and_diff_tables, ti, *pre_tables) - - if max_key2 > max_key1: - post_tables = [t.new(min_key=max_key1, max_key=max_key2) for t in (table1, table2)] - ti.submit(self._bisect_and_diff_tables, ti, *post_tables) - - yield from ti - - except BaseException as e: # Catch KeyboardInterrupt too - error = e - finally: - if is_tracking_enabled(): - runtime = time.monotonic() - start - table1_count = self.stats.get("table1_count") - table2_count = self.stats.get("table2_count") - diff_count = self.stats.get("diff_count") - err_message = str(error)[:20] # Truncate possibly sensitive information. - event_json = create_end_event_json( - error is None, - runtime, - table1.database.name, - table2.database.name, - table1_count, - table2_count, - diff_count, - err_message, - ) - send_event_json(event_json) - - if error: - raise error - - def _parse_key_range_result(self, key_type, key_range): - mn, mx = key_range - cls = key_type.make_value - # We add 1 because our ranges are exclusive of the end (like in Python) - try: - return cls(mn), cls(mx) + 1 - except (TypeError, ValueError) as e: - raise type(e)(f"Cannot apply {key_type} to {mn}, {mx}.") from e - - def _validate_and_adjust_columns(self, table1, table2): - for c1, c2 in safezip(table1._relevant_columns, table2._relevant_columns): - if c1 not in table1._schema: - raise ValueError(f"Column '{c1}' not found in schema for table {table1}") - if c2 not in table2._schema: - raise ValueError(f"Column '{c2}' not found in schema for table {table2}") - - # Update schemas to minimal mutual precision - col1 = table1._schema[c1] - col2 = table2._schema[c2] - if isinstance(col1, PrecisionType): - if not isinstance(col2, PrecisionType): - raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") - - lowest = min(col1, col2, key=attrgetter("precision")) - - if col1.precision != col2.precision: - logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") - - table1._schema[c1] = col1.replace(precision=lowest.precision, rounds=lowest.rounds) - table2._schema[c2] = col2.replace(precision=lowest.precision, rounds=lowest.rounds) - - elif isinstance(col1, NumericType): - if not isinstance(col2, NumericType): - raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") - - lowest = min(col1, col2, key=attrgetter("precision")) - - if col1.precision != col2.precision: - logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") - - table1._schema[c1] = col1.replace(precision=lowest.precision) - table2._schema[c2] = col2.replace(precision=lowest.precision) - - elif isinstance(col1, StringType): - if not isinstance(col2, StringType): - raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") - - for t in [table1, table2]: - for c in t._relevant_columns: - ctype = t._schema[c] - if not ctype.supported: - logger.warning( - f"[{t.database.name}] Column '{c}' of type '{ctype}' has no compatibility handling. " - "If encoding/formatting differs between databases, it may result in false positives." - ) - - def _bisect_and_diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None): - assert table1.is_bounded and table2.is_bounded - - if max_rows is None: - # We can be sure that row_count <= max_rows - max_rows = max(table1.approximate_size(), table2.approximate_size()) - - # If count is below the threshold, just download and compare the columns locally - # This saves time, as bisection speed is limited by ping and query performance. - if max_rows < self.bisection_threshold: - rows1, rows2 = self._threaded_call("get_values", [table1, table2]) - diff = list(diff_sets(rows1, rows2)) - - # Initial bisection_threshold larger than count. Normally we always - # checksum and count segments, even if we get the values. At the - # first level, however, that won't be true. - if level == 0: - self.stats["table1_count"] = len(rows1) - self.stats["table2_count"] = len(rows2) - - self.stats["diff_count"] += len(diff) - - logger.info(". " * level + f"Diff found {len(diff)} different rows.") - self.stats["rows_downloaded"] = self.stats.get("rows_downloaded", 0) + max(len(rows1), len(rows2)) - return diff - - # Choose evenly spaced checkpoints (according to min_key and max_key) - checkpoints = table1.choose_checkpoints(self.bisection_factor - 1) - - # Create new instances of TableSegment between each checkpoint - segmented1 = table1.segment_by_checkpoints(checkpoints) - segmented2 = table2.segment_by_checkpoints(checkpoints) - - # Recursively compare each pair of corresponding segments between table1 and table2 - for i, (t1, t2) in enumerate(safezip(segmented1, segmented2)): - ti.submit(self._diff_tables, ti, t1, t2, max_rows, level + 1, i + 1, len(segmented1), priority=level) - - def _diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): - logger.info( - ". " * level + f"Diffing segment {segment_index}/{segment_count}, " - f"key-range: {table1.min_key}..{table2.max_key}, " - f"size <= {max_rows}" - ) - - # When benchmarking, we want the ability to skip checksumming. This - # allows us to download all rows for comparison in performance. By - # default, data-diff will checksum the section first (when it's below - # the threshold) and _then_ download it. - if BENCHMARK: - if max_rows < self.bisection_threshold: - return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max_rows) - - (count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2]) - - if count1 == 0 and count2 == 0: - # logger.warning( - # f"Uneven distribution of keys detected in segment {table1.min_key}..{table2.max_key}. (big gaps in the key column). " - # "For better performance, we recommend to increase the bisection-threshold." - # ) - assert checksum1 is None and checksum2 is None - return - - if level == 1: - self.stats["table1_count"] = self.stats.get("table1_count", 0) + count1 - self.stats["table2_count"] = self.stats.get("table2_count", 0) + count2 - - if checksum1 != checksum2: - return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max(count1, count2)) - diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py new file mode 100644 index 00000000..f4867a74 --- /dev/null +++ b/data_diff/hashdiff_tables.py @@ -0,0 +1,283 @@ +import os +import time +from numbers import Number +import logging +from collections import defaultdict +from typing import Iterator +from operator import attrgetter + +from runtype import dataclass + +from .utils import safezip, run_as_daemon +from .thread_utils import ThreadedYielder +from .databases.database_types import IKey, NumericType, PrecisionType, StringType +from .table_segment import TableSegment +from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled + +from .diff_tables import ThreadBase, DiffResult + +BENCHMARK = os.environ.get("BENCHMARK", False) + +DEFAULT_BISECTION_THRESHOLD = 1024 * 16 +DEFAULT_BISECTION_FACTOR = 32 + +logger = logging.getLogger("hashdiff_tables") + + +def diff_sets(a: set, b: set) -> Iterator: + s1 = set(a) + s2 = set(b) + d = defaultdict(list) + + # The first item is always the key (see TableDiffer._relevant_columns) + for i in s1 - s2: + d[i[0]].append(("-", i)) + for i in s2 - s1: + d[i[0]].append(("+", i)) + + for _k, v in sorted(d.items(), key=lambda i: i[0]): + yield from v + + +@dataclass +class HashDiffer(ThreadBase): + """Finds the diff between two SQL tables + + The algorithm uses hashing to quickly check if the tables are different, and then applies a + bisection search recursively to find the differences efficiently. + + Works best for comparing tables that are mostly the same, with minor discrepencies. + + Parameters: + bisection_factor (int): Into how many segments to bisect per iteration. + bisection_threshold (Number): When should we stop bisecting and compare locally (in row count). + threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. + max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``. + There may be many pools, so number of actual threads can be a lot higher. + """ + + bisection_factor: int = DEFAULT_BISECTION_FACTOR + bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests + + stats: dict = {} + + def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: + """Diff the given tables. + + Parameters: + table1 (TableSegment): The "before" table to compare. Or: source table + table2 (TableSegment): The "after" table to compare. Or: target table + + Returns: + An iterator that yield pair-tuples, representing the diff. Items can be either + ('-', columns) for items in table1 but not in table2 + ('+', columns) for items in table2 but not in table1 + Where `columns` is a tuple of values for the involved columns, i.e. (id, ...extra) + """ + # Validate options + if self.bisection_factor >= self.bisection_threshold: + raise ValueError("Incorrect param values (bisection factor must be lower than threshold)") + if self.bisection_factor < 2: + raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)") + + if is_tracking_enabled(): + options = dict(self) + event_json = create_start_event_json(options) + run_as_daemon(send_event_json, event_json) + + self.stats["diff_count"] = 0 + start = time.monotonic() + error = None + try: + + # Query and validate schema + table1, table2 = self._threaded_call("with_schema", [table1, table2]) + self._validate_and_adjust_columns(table1, table2) + + key_type = table1._schema[table1.key_column] + key_type2 = table2._schema[table2.key_column] + if not isinstance(key_type, IKey): + raise NotImplementedError(f"Cannot use column of type {key_type} as a key") + if not isinstance(key_type2, IKey): + raise NotImplementedError(f"Cannot use column of type {key_type2} as a key") + assert key_type.python_type is key_type2.python_type + + # Query min/max values + key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2]) + + # Start with the first completed value, so we don't waste time waiting + min_key1, max_key1 = self._parse_key_range_result(key_type, next(key_ranges)) + + table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)] + + logger.info( + f"Diffing tables | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}. " + f"key-range: {table1.min_key}..{table2.max_key}, " + f"size: table1 <= {table1.approximate_size()}, table2 <= {table2.approximate_size()}" + ) + + ti = ThreadedYielder(self.max_threadpool_size) + # Bisect (split) the table into segments, and diff them recursively. + ti.submit(self._bisect_and_diff_tables, ti, table1, table2) + + # Now we check for the second min-max, to diff the portions we "missed". + min_key2, max_key2 = self._parse_key_range_result(key_type, next(key_ranges)) + + if min_key2 < min_key1: + pre_tables = [t.new(min_key=min_key2, max_key=min_key1) for t in (table1, table2)] + ti.submit(self._bisect_and_diff_tables, ti, *pre_tables) + + if max_key2 > max_key1: + post_tables = [t.new(min_key=max_key1, max_key=max_key2) for t in (table1, table2)] + ti.submit(self._bisect_and_diff_tables, ti, *post_tables) + + yield from ti + + except BaseException as e: # Catch KeyboardInterrupt too + error = e + finally: + if is_tracking_enabled(): + runtime = time.monotonic() - start + table1_count = self.stats.get("table1_count") + table2_count = self.stats.get("table2_count") + diff_count = self.stats.get("diff_count") + err_message = str(error)[:20] # Truncate possibly sensitive information. + event_json = create_end_event_json( + error is None, + runtime, + table1.database.name, + table2.database.name, + table1_count, + table2_count, + diff_count, + err_message, + ) + send_event_json(event_json) + + if error: + raise error + + def _parse_key_range_result(self, key_type, key_range): + mn, mx = key_range + cls = key_type.make_value + # We add 1 because our ranges are exclusive of the end (like in Python) + try: + return cls(mn), cls(mx) + 1 + except (TypeError, ValueError) as e: + raise type(e)(f"Cannot apply {key_type} to {mn}, {mx}.") from e + + def _validate_and_adjust_columns(self, table1, table2): + for c1, c2 in safezip(table1._relevant_columns, table2._relevant_columns): + if c1 not in table1._schema: + raise ValueError(f"Column '{c1}' not found in schema for table {table1}") + if c2 not in table2._schema: + raise ValueError(f"Column '{c2}' not found in schema for table {table2}") + + # Update schemas to minimal mutual precision + col1 = table1._schema[c1] + col2 = table2._schema[c2] + if isinstance(col1, PrecisionType): + if not isinstance(col2, PrecisionType): + raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") + + lowest = min(col1, col2, key=attrgetter("precision")) + + if col1.precision != col2.precision: + logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") + + table1._schema[c1] = col1.replace(precision=lowest.precision, rounds=lowest.rounds) + table2._schema[c2] = col2.replace(precision=lowest.precision, rounds=lowest.rounds) + + elif isinstance(col1, NumericType): + if not isinstance(col2, NumericType): + raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") + + lowest = min(col1, col2, key=attrgetter("precision")) + + if col1.precision != col2.precision: + logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") + + table1._schema[c1] = col1.replace(precision=lowest.precision) + table2._schema[c2] = col2.replace(precision=lowest.precision) + + elif isinstance(col1, StringType): + if not isinstance(col2, StringType): + raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") + + for t in [table1, table2]: + for c in t._relevant_columns: + ctype = t._schema[c] + if not ctype.supported: + logger.warning( + f"[{t.database.name}] Column '{c}' of type '{ctype}' has no compatibility handling. " + "If encoding/formatting differs between databases, it may result in false positives." + ) + + def _bisect_and_diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None): + assert table1.is_bounded and table2.is_bounded + + if max_rows is None: + # We can be sure that row_count <= max_rows + max_rows = max(table1.approximate_size(), table2.approximate_size()) + + # If count is below the threshold, just download and compare the columns locally + # This saves time, as bisection speed is limited by ping and query performance. + if max_rows < self.bisection_threshold: + rows1, rows2 = self._threaded_call("get_values", [table1, table2]) + diff = list(diff_sets(rows1, rows2)) + + # Initial bisection_threshold larger than count. Normally we always + # checksum and count segments, even if we get the values. At the + # first level, however, that won't be true. + if level == 0: + self.stats["table1_count"] = len(rows1) + self.stats["table2_count"] = len(rows2) + + self.stats["diff_count"] += len(diff) + + logger.info(". " * level + f"Diff found {len(diff)} different rows.") + self.stats["rows_downloaded"] = self.stats.get("rows_downloaded", 0) + max(len(rows1), len(rows2)) + return diff + + # Choose evenly spaced checkpoints (according to min_key and max_key) + checkpoints = table1.choose_checkpoints(self.bisection_factor - 1) + + # Create new instances of TableSegment between each checkpoint + segmented1 = table1.segment_by_checkpoints(checkpoints) + segmented2 = table2.segment_by_checkpoints(checkpoints) + + # Recursively compare each pair of corresponding segments between table1 and table2 + for i, (t1, t2) in enumerate(safezip(segmented1, segmented2)): + ti.submit(self._diff_tables, ti, t1, t2, max_rows, level + 1, i + 1, len(segmented1), priority=level) + + def _diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): + logger.info( + ". " * level + f"Diffing segment {segment_index}/{segment_count}, " + f"key-range: {table1.min_key}..{table2.max_key}, " + f"size <= {max_rows}" + ) + + # When benchmarking, we want the ability to skip checksumming. This + # allows us to download all rows for comparison in performance. By + # default, data-diff will checksum the section first (when it's below + # the threshold) and _then_ download it. + if BENCHMARK: + if max_rows < self.bisection_threshold: + return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max_rows) + + (count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2]) + + if count1 == 0 and count2 == 0: + # logger.warning( + # f"Uneven distribution of keys detected in segment {table1.min_key}..{table2.max_key}. (big gaps in the key column). " + # "For better performance, we recommend to increase the bisection-threshold." + # ) + assert checksum1 is None and checksum2 is None + return + + if level == 1: + self.stats["table1_count"] = self.stats.get("table1_count", 0) + count1 + self.stats["table2_count"] = self.stats.get("table2_count", 0) + count2 + + if checksum1 != checksum2: + return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max(count1, count2)) diff --git a/tests/common.py b/tests/common.py index 2aad4be1..44a15cf2 100644 --- a/tests/common.py +++ b/tests/common.py @@ -43,7 +43,8 @@ def get_git_revision_short_hash() -> str: level = getattr(logging, os.environ["LOG_LEVEL"].upper()) logging.basicConfig(level=level) -logging.getLogger("diff_tables").setLevel(level) +logging.getLogger("hashdiff_tables").setLevel(level) +logging.getLogger("joindiff_tables").setLevel(level) logging.getLogger("table_segment").setLevel(level) logging.getLogger("database").setLevel(level) diff --git a/tests/test_database_types.py b/tests/test_database_types.py index ce273182..4ac8d5f4 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -15,7 +15,8 @@ from data_diff import databases as db from data_diff.databases import postgresql, oracle from data_diff.utils import number_to_human, accumulate -from data_diff.diff_tables import TableDiffer, TableSegment, DEFAULT_BISECTION_THRESHOLD +from data_diff.hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD +from data_diff.table_segment import TableSegment from .common import ( CONN_STRINGS, N_SAMPLES, @@ -667,7 +668,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego ch_factor = min(max(int(N_SAMPLES / 250_000), 2), 128) if BENCHMARK else 2 ch_threshold = min(DEFAULT_BISECTION_THRESHOLD, int(N_SAMPLES / ch_factor)) if BENCHMARK else 3 ch_threads = N_THREADS - differ = TableDiffer( + differ = HashDiffer( bisection_threshold=ch_threshold, bisection_factor=ch_factor, max_threadpool_size=ch_threads, @@ -688,7 +689,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego dl_factor = max(int(N_SAMPLES / 100_000), 2) if BENCHMARK else 2 dl_threshold = int(N_SAMPLES / dl_factor) + 1 if BENCHMARK else math.inf dl_threads = N_THREADS - differ = TableDiffer( + differ = HashDiffer( bisection_threshold=dl_threshold, bisection_factor=dl_factor, max_threadpool_size=dl_threads ) start = time.monotonic() diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 3ac37bd0..63195efb 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -7,7 +7,7 @@ import arrow # comes with preql from data_diff.databases.connect import connect -from data_diff.diff_tables import TableDiffer +from data_diff.hashdiff_tables import HashDiffer from data_diff.table_segment import TableSegment, split_space from data_diff import databases as db from data_diff.utils import ArithAlphanumeric, numberToAlphanum @@ -176,7 +176,7 @@ def test_init(self): ) def test_basic(self): - differ = TableDiffer(bisection_factor=10, bisection_threshold=100) + differ = HashDiffer(bisection_factor=10, bisection_threshold=100) a = TableSegment(self.connection, self.table_src_path, "id", "datetime", case_sensitive=False) b = TableSegment(self.connection, self.table_dst_path, "id", "datetime", case_sensitive=False) assert a.count() == 6 @@ -186,7 +186,7 @@ def test_basic(self): self.assertEqual(len(list(differ.diff_tables(a, b))), 1) def test_offset(self): - differ = TableDiffer(bisection_factor=2, bisection_threshold=10) + differ = HashDiffer(bisection_factor=2, bisection_threshold=10) sec1 = self.now.shift(seconds=-1).datetime a = TableSegment(self.connection, self.table_src_path, "id", "datetime", max_update=sec1, case_sensitive=False) b = TableSegment(self.connection, self.table_dst_path, "id", "datetime", max_update=sec1, case_sensitive=False) @@ -250,7 +250,7 @@ def setUp(self): self.table = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) self.table2 = TableSegment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) - self.differ = TableDiffer(bisection_factor=3, bisection_threshold=4) + self.differ = HashDiffer(bisection_factor=3, bisection_threshold=4) def test_properties_on_empty_table(self): table = self.table.with_schema() @@ -287,7 +287,7 @@ def test_diff_small_tables(self): self.assertEqual(1, self.differ.stats["table2_count"]) def test_non_threaded(self): - differ = TableDiffer(bisection_factor=3, bisection_threshold=4, threaded=False) + differ = HashDiffer(bisection_factor=3, bisection_threshold=4, threaded=False) time = "2022-01-01 00:00:00" time_str = f"timestamp '{time}'" @@ -384,7 +384,7 @@ def test_diff_sorted_by_key(self): ) _commit(self.connection) - differ = TableDiffer() + differ = HashDiffer() diff = list(differ.diff_tables(self.table, self.table2)) expected = [ ("-", ("2", time2 + ".000000")), @@ -444,7 +444,7 @@ def test_diff_column_names(self): table1 = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) table2 = TableSegment(self.connection, self.table_dst_path, "id2", "timestamp2", case_sensitive=False) - differ = TableDiffer() + differ = HashDiffer() diff = list(differ.diff_tables(table1, table2)) assert diff == [] @@ -480,7 +480,7 @@ def setUp(self): self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_string_keys(self): - differ = TableDiffer() + differ = HashDiffer() diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, [("-", (str(self.new_uuid), "This one is different"))]) @@ -493,7 +493,7 @@ def test_string_keys(self): def test_where_sampling(self): a = self.a.replace(where="1=1") - differ = TableDiffer() + differ = HashDiffer() diff = list(differ.diff_tables(a, self.b)) self.assertEqual(diff, [("-", (str(self.new_uuid), "This one is different"))]) @@ -534,7 +534,7 @@ def setUp(self): def test_alphanum_keys(self): - differ = TableDiffer(bisection_factor=2, bisection_threshold=3) + differ = HashDiffer(bisection_factor=2, bisection_threshold=3) diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, [("-", (str(self.new_alphanum), "This one is different"))]) @@ -590,8 +590,7 @@ def test_varying_alphanum_keys(self): for a in alphanums: assert a - a == 0 - # Test with the differ - differ = TableDiffer(threaded=False) + differ = HashDiffer() diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, [("-", (str(self.new_alphanum), "This one is different"))]) @@ -669,7 +668,7 @@ def setUp(self): self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_uuid_column_with_nulls(self): - differ = TableDiffer() + differ = HashDiffer() diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, [("-", (str(self.null_uuid), None))]) @@ -719,7 +718,7 @@ def test_uuid_columns_with_nulls(self): diff results, but it's not. This test helps to detect such cases. """ - differ = TableDiffer(bisection_factor=2, bisection_threshold=3) + differ = HashDiffer(bisection_factor=2, bisection_threshold=3) diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, [("-", (str(self.null_uuid), None))]) @@ -783,7 +782,7 @@ def test_tables_are_different(self): value, it may lead that concat(pk_i, i, NULL) == concat(pk_i, i-diff, NULL). This test handle such cases. """ - differ = TableDiffer(bisection_factor=2, bisection_threshold=4) + differ = HashDiffer(bisection_factor=2, bisection_threshold=4) diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, self.diffs) @@ -814,7 +813,7 @@ def setUp(self): self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_right_table_empty(self): - differ = TableDiffer() + differ = HashDiffer() self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) def test_left_table_empty(self): @@ -827,5 +826,5 @@ def test_left_table_empty(self): _commit(self.connection) - differ = TableDiffer() + differ = HashDiffer() self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 2feecb02..529de055 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -1,7 +1,6 @@ import unittest -from data_diff.databases.connect import connect -from data_diff import TableSegment, TableDiffer +from data_diff import TableSegment, HashDiffer, connect from .common import TEST_POSTGRESQL_CONN_STRING, random_table_suffix @@ -40,7 +39,7 @@ def test_uuid(self): a = TableSegment(self.connection, (self.table_src,), "id", "comment") b = TableSegment(self.connection, (self.table_dst,), "id", "comment") - differ = TableDiffer() + differ = HashDiffer() diff = list(differ.diff_tables(a, b)) uuid = diff[0][1][0] self.assertEqual(diff, [("-", (uuid, "This one is different"))]) From 74f31e8e2387f374df2afa24dc5347239bb0eae7 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 14 Sep 2022 17:11:04 +0300 Subject: [PATCH 06/93] Adjustments to joindiff implementation --- data_diff/diff_tables.py | 17 +++++------- data_diff/joindiff_tables.py | 52 +++++++++++++++++++++--------------- tests/test_joindiff.py | 26 ++++++++++++++++++ 3 files changed, 63 insertions(+), 32 deletions(-) diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 04a95fe7..430f027e 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -52,15 +52,10 @@ def _threaded_call_as_completed(self, func, iterable): "Calls a method for each object in iterable. Returned in order of completion." return self._thread_as_completed(methodcaller(func), iterable) - def _run_thread(self, threadfunc, *args, daemon=False) -> threading.Thread: - th = threading.Thread(target=threadfunc, args=args) - if daemon: - th.daemon = True - th.start() - return th - @contextmanager - def _run_in_background(self, threadfunc, *args, daemon=False): - t = self._run_thread(threadfunc, *args, daemon=daemon) - yield t - t.join() + def _run_in_background(self, *funcs): + with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool: + futures = [task_pool.submit(f) for f in funcs] + yield futures + for f in futures: + f.result() diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 7f684ebb..52d055d3 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -3,6 +3,7 @@ """ from decimal import Decimal +from functools import partial import logging from contextlib import contextmanager from typing import Dict, List @@ -83,10 +84,12 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: if table1.database is not table2.database: raise ValueError("Join-diff only works when both tables are in the same database") - with self._run_in_background(self._test_null_or_duplicate_keys, table1, table2): - with self._run_in_background(self._collect_stats, 1, table1): - with self._run_in_background(self._collect_stats, 2, table2): - yield from self._outer_join(table1, table2) + with self._run_in_background( + partial(self._test_null_or_duplicate_keys, table1, table2), + partial(self._collect_stats, 1, table1), + partial(self._collect_stats, 2, table2) + ): + yield from self._outer_join(table1, table2) logger.info("Diffing complete") @@ -106,7 +109,7 @@ def _test_null_or_duplicate_keys(self, table1, table2): q = t.select(*key_columns).where(or_(this[k] == None for k in key_columns)) nulls = ts.database.query(q, list) if nulls: - raise ValueError(f"NULL values in one or more primary keys: {nulls}") + raise ValueError(f"NULL values in one or more primary keys") logger.debug("Done testing for null or duplicate keys") @@ -161,7 +164,7 @@ def _outer_join(self, table1, table2): b = table2._make_select() is_diff_cols = { - f"is_diff_col_{c1}": bool_to_int(a[c1].is_distinct_from(b[c2])) for c1, c2 in safezip(cols1, cols2) + f"is_diff_{c1}": bool_to_int(a[c1].is_distinct_from(b[c2])) for c1, c2 in safezip(cols1, cols2) } a_cols = {f"table1_{c}": NormalizeAsString(a[c]) for c in cols1} @@ -180,23 +183,30 @@ def _outer_join(self, table1, table2): .where(or_(this[c] == 1 for c in is_diff_cols)) ) - with self._run_in_background(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols): - with self._run_in_background(self._count_diff_per_column, db, diff_rows, is_diff_cols): - - logger.info("Querying for different rows") - for is_xa, is_xb, *x in db.query(diff_rows, list): - assert not (is_xa and is_xb) # Can't both be exclusive - is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols)) - if not is_xb: - yield "-", tuple(a_row) - if not is_xa: - yield "+", tuple(b_row) - - def _count_diff_per_column(self, db, diff_rows, is_diff_cols): + with self._run_in_background( + partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols), + partial(self._count_diff_per_column, db, diff_rows, cols1, is_diff_cols) + ): + + logger.info("Querying for different rows") + for is_xa, is_xb, *x in db.query(diff_rows, list): + if is_xa and is_xb: + # Can't both be exclusive, meaning a pk is NULL + # This can happen if the explicit null test didn't finish running yet + raise ValueError(f"NULL values in one or more primary keys") + is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols)) + if not is_xb: + yield "-", tuple(a_row) + if not is_xa: + yield "+", tuple(b_row) + + def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): logger.info("Counting differences per column") is_diff_cols_counts = db.query(diff_rows.select(sum_(this[c]) for c in is_diff_cols), tuple) - for name, count in safezip(is_diff_cols, is_diff_cols_counts): - self.stats[f"count_{name}"] = count + diff_counts = {} + for name, count in safezip(cols, is_diff_cols_counts): + diff_counts[name] = count + self.stats['diff_counts'] = diff_counts def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): logger.info("Counting and sampling exclusive rows") diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 72d604cd..d37cea58 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -166,3 +166,29 @@ def test_diff_sorted_by_key(self): ("+", ("4", time + ".000000")), ] self.assertEqual(expected, diff) + + def test_dup_pks(self): + time = "2022-01-01 00:00:00" + time_str = f"timestamp '{time}'" + + cols = "id rating timestamp".split() + + _insert_row(self.connection, self.table_src, cols, [1, 9, time_str]) + _insert_row(self.connection, self.table_src, cols, [1, 10, time_str]) + _insert_row(self.connection, self.table_dst, cols, [1, 9, time_str]) + + x = self.differ.diff_tables(self.table, self.table2) + self.assertRaises(ValueError, list, x) + + + def test_null_pks(self): + time = "2022-01-01 00:00:00" + time_str = f"timestamp '{time}'" + + cols = "id rating timestamp".split() + + _insert_row(self.connection, self.table_src, cols, ['null', 9, time_str]) + _insert_row(self.connection, self.table_dst, cols, [1, 9, time_str]) + + x = self.differ.diff_tables(self.table, self.table2) + self.assertRaises(ValueError, list, x) From 686b1f730e332cb97b054ab1dd22779614185ed5 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 21 Sep 2022 14:48:21 +0300 Subject: [PATCH 07/93] refactor tablediffer --- data_diff/diff_tables.py | 21 ++++++++++++++++++++- data_diff/hashdiff_tables.py | 16 ++-------------- data_diff/joindiff_tables.py | 10 ++++++++-- docs/python-api.rst | 7 +++++-- tests/test_query.py | 2 +- 5 files changed, 36 insertions(+), 20 deletions(-) diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 430f027e..3a92d708 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -1,13 +1,15 @@ """Provides classes for performing a table diff """ +from abc import ABC, abstractmethod from enum import Enum from contextlib import contextmanager -import threading from operator import methodcaller from typing import Tuple, Iterator, Optional from concurrent.futures import ThreadPoolExecutor, as_completed +from .table_segment import TableSegment + from runtype import dataclass @@ -59,3 +61,20 @@ def _run_in_background(self, *funcs): yield futures for f in futures: f.result() + + +class TableDiffer(ThreadBase, ABC): + @abstractmethod + def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: + """Diff the given tables. + + Parameters: + table1 (TableSegment): The "before" table to compare. Or: source table + table2 (TableSegment): The "after" table to compare. Or: target table + + Returns: + An iterator that yield pair-tuples, representing the diff. Items can be either - + ('-', row) for items in table1 but not in table2. + ('+', row) for items in table2 but not in table1. + Where `row` is a tuple of values, corresponding to the diffed columns. + """ diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index f4867a74..0f2e8cb7 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -14,7 +14,7 @@ from .table_segment import TableSegment from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled -from .diff_tables import ThreadBase, DiffResult +from .diff_tables import TableDiffer, DiffResult BENCHMARK = os.environ.get("BENCHMARK", False) @@ -40,7 +40,7 @@ def diff_sets(a: set, b: set) -> Iterator: @dataclass -class HashDiffer(ThreadBase): +class HashDiffer(TableDiffer): """Finds the diff between two SQL tables The algorithm uses hashing to quickly check if the tables are different, and then applies a @@ -62,18 +62,6 @@ class HashDiffer(ThreadBase): stats: dict = {} def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: - """Diff the given tables. - - Parameters: - table1 (TableSegment): The "before" table to compare. Or: source table - table2 (TableSegment): The "after" table to compare. Or: target table - - Returns: - An iterator that yield pair-tuples, representing the diff. Items can be either - ('-', columns) for items in table1 but not in table2 - ('+', columns) for items in table2 but not in table1 - Where `columns` is a tuple of values for the involved columns, i.e. (id, ...extra) - """ # Validate options if self.bisection_factor >= self.bisection_threshold: raise ValueError("Incorrect param values (bisection factor must be lower than threshold)") diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 52d055d3..0099de6e 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -13,7 +13,7 @@ from .utils import safezip from .databases.base import Database from .table_segment import TableSegment -from .diff_tables import ThreadBase, DiffResult +from .diff_tables import TableDiffer, DiffResult from .queries import table, sum_, min_, max_, avg from .queries.api import and_, if_, or_, outerjoin, this @@ -73,7 +73,7 @@ def json_friendly_value(v): @dataclass -class JoinDifferBase(ThreadBase): +class JoinDifferBase(TableDiffer): """Finds the diff between two SQL tables using JOINs""" stats: dict = {} @@ -145,6 +145,12 @@ def bool_to_int(x): class JoinDiffer(JoinDifferBase): + """Finds the diff between two SQL tables in the same database. + + The algorithm uses an OUTER JOIN (or equivalent) with extra checks and statistics. + + """ + def _outer_join(self, table1, table2): db = table1.database if db is not table2.database: diff --git a/docs/python-api.rst b/docs/python-api.rst index d2b18636..f28b18d1 100644 --- a/docs/python-api.rst +++ b/docs/python-api.rst @@ -5,11 +5,14 @@ Python API Reference .. autofunction:: connect -.. autoclass:: TableDiffer +.. autoclass:: HashDiffer + :members: __init__, diff_tables + +.. autoclass:: JoinDiffer :members: __init__, diff_tables .. autoclass:: TableSegment - :members: __init__, get_values, choose_checkpoints, segment_by_checkpoints, count, count_and_checksum, is_bounded, new + :members: __init__, get_values, choose_checkpoints, segment_by_checkpoints, count, count_and_checksum, is_bounded, new, with_schema .. autoclass:: data_diff.databases.database_types.AbstractDatabase :members: diff --git a/tests/test_query.py b/tests/test_query.py index b6b90394..f31f5417 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -12,7 +12,7 @@ def normalize_spaces(s: str): class MockDialect(AbstractDialect): - def quote(self, s: str): + def quote(self, s: str) -> str: return s def concat(self, l: List[str]) -> str: From b830afc8fceba1a0d47dcd237cd2ca1121bf8f3e Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 21 Sep 2022 10:24:59 +0300 Subject: [PATCH 08/93] joindiff now working for all major databases: new: - mysql - bigquery - presto - verica - trino - oracle - redshift --- data_diff/databases/bigquery.py | 3 ++ data_diff/databases/mysql.py | 3 ++ data_diff/databases/oracle.py | 6 ++++ data_diff/databases/presto.py | 32 ++++++++++++++---- data_diff/databases/redshift.py | 3 ++ data_diff/databases/vertica.py | 3 ++ data_diff/joindiff_tables.py | 58 ++++++++++++++++++++++++-------- data_diff/queries/api.py | 10 +++++- data_diff/queries/ast_classes.py | 21 +++++++++++- data_diff/queries/compiler.py | 5 +++ tests/test_diff_tables.py | 2 ++ tests/test_joindiff.py | 3 +- tests/test_query.py | 8 +++++ 13 files changed, 132 insertions(+), 25 deletions(-) diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 411ae795..218c9cb4 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -95,3 +95,6 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: def parse_table_name(self, name: str) -> DbPath: path = parse_table_name(name) return self._normalize_table_path(path) + + def random(self) -> str: + return "RAND()" diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 7e89b184..07c34aaf 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -73,3 +73,6 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: def is_distinct_from(self, a: str, b: str) -> str: return f"not ({a} <=> {b})" + + def random(self) -> str: + return "RAND()" diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 76387010..79f7bf31 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -124,3 +124,9 @@ def timestamp_value(self, t: DbTime) -> str: def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Cast is necessary for correct MD5 (trimming not enough) return f"CAST(TRIM({value}) AS VARCHAR(36))" + + def random(self) -> str: + return "dbms_random.value" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"DECODE({a}, {b}, 1, 0) = 0" diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index 5ee98770..c990e06e 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -1,6 +1,7 @@ import re -from ..utils import match_regexps +from data_diff.utils import match_regexps +from data_diff.queries import ThreadLocalInterpreter from .database_types import * from .base import Database, import_helper @@ -10,6 +11,14 @@ TIMESTAMP_PRECISION_POS, ) +def query_cursor(c, sql_code): + c.execute(sql_code) + if sql_code.lower().startswith("select"): + return c.fetchall() + # Required for the query to actually run 🤯 + if re.match(r"(insert|create|truncate|drop)", sql_code, re.IGNORECASE): + return c.fetchone() + @import_helper("presto") def import_presto(): @@ -63,12 +72,21 @@ def to_string(self, s: str): def _query(self, sql_code: str) -> list: "Uses the standard SQL cursor interface" c = self._conn.cursor() - c.execute(sql_code) - if sql_code.lower().startswith("select"): - return c.fetchall() - # Required for the query to actually run 🤯 - if re.match(r"(insert|create|truncate|drop)", sql_code, re.IGNORECASE): - return c.fetchone() + + if isinstance(sql_code, ThreadLocalInterpreter): + # TODO reuse code from base.py + g = sql_code.interpret() + q = next(g) + while True: + res = query_cursor(c, q) + try: + q = g.send(res) + except StopIteration: + break + return + + return query_cursor(c, sql_code) + def close(self): self._conn.close() diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index a512c123..f11b950c 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -46,3 +46,6 @@ def select_table_schema(self, path: DbPath) -> str: "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM information_schema.columns " f"WHERE table_name = '{table.lower()}' AND table_schema = '{schema.lower()}'" ) + + def is_distinct_from(self, a: str, b: str) -> str: + return f"{a} IS NULL AND NOT {b} IS NULL OR {b} IS NULL OR {a}!={b}" diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index 78a52363..cc606511 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -123,3 +123,6 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Trim doesn't work on CHAR type return f"TRIM(CAST({value} AS VARCHAR))" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"not ({a} <=> {b})" diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 0099de6e..53acd954 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -10,13 +10,15 @@ from runtype import dataclass + from .utils import safezip from .databases.base import Database +from .databases import MySQL, BigQuery, Presto, Oracle from .table_segment import TableSegment from .diff_tables import TableDiffer, DiffResult from .queries import table, sum_, min_, max_, avg -from .queries.api import and_, if_, or_, outerjoin, this +from .queries.api import and_, if_, or_, outerjoin, leftjoin, rightjoin, this, ITable from .queries.ast_classes import Concat, Count, Expr, Random from .queries.compiler import Compiler from .queries.extras import NormalizeAsString @@ -43,18 +45,29 @@ class Stats: def sample(table): - # TODO return table.order_by(Random()).limit(10) @contextmanager def temp_table(db: Database, expr: Expr): c = Compiler(db) - name = c.new_unique_name("tmp_table") - db.query(f"create temporary table {c.quote(name)} as {c.compile(expr)}", None) + + name = c.new_unique_table_name("temp_table") + + if isinstance(db, BigQuery): + name = f"{db.default_schema}.{name}" + db.query(f"create table {c.quote(name)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}", None) + elif isinstance(db, Presto): + db.query(f"create table {c.quote(name)} as {c.compile(expr)}", None) + elif isinstance(db, Oracle): + db.query(f"create global temporary table {c.quote(name)} as {c.compile(expr)}", None) + else: + db.query(f"create temporary table {c.quote(name)} as {c.compile(expr)}", None) + try: yield table(name, schema=expr.source_table.schema) finally: + # Only drops if create table succeeded (meaning, the table didn't already exist) db.query(f"drop table {c.quote(name)}", None) @@ -144,6 +157,28 @@ def bool_to_int(x): return if_(x, 1, 0) +def _outerjoin(db: Database, a: ITable, b: ITable, keys1: List[str], keys2: List[str], select_fields: dict) -> ITable: + on = [a[k1] == b[k2] for k1, k2 in safezip(keys1, keys2)] + + if isinstance(db, Oracle): + is_exclusive_a = and_(bool_to_int(b[k] == None) for k in keys2) + is_exclusive_b = and_(bool_to_int(a[k] == None) for k in keys1) + else: + is_exclusive_a = and_(b[k] == None for k in keys2) + is_exclusive_b = and_(a[k] == None for k in keys1) + + if isinstance(db, MySQL): + # No outer join + l = leftjoin(a, b).on(*on).select(is_exclusive_a=is_exclusive_a, is_exclusive_b=False, **select_fields) + r = rightjoin(a, b).on(*on).select(is_exclusive_a=False, is_exclusive_b=is_exclusive_b, **select_fields) + return l.union(r) + + return ( + outerjoin(a, b).on(*on) + .select(is_exclusive_a=is_exclusive_a, is_exclusive_b=is_exclusive_b, **select_fields) + ) + + class JoinDiffer(JoinDifferBase): """Finds the diff between two SQL tables in the same database. @@ -177,15 +212,7 @@ def _outer_join(self, table1, table2): b_cols = {f"table2_{c}": NormalizeAsString(b[c]) for c in cols2} diff_rows = ( - outerjoin(a, b) - .on(a[k1] == b[k2] for k1, k2 in safezip(keys1, keys2)) - .select( - is_exclusive_a=and_(b[k] == None for k in keys2), - is_exclusive_b=and_(a[k] == None for k in keys1), - **is_diff_cols, - **a_cols, - **b_cols, - ) + _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols}) .where(or_(this[c] == 1 for c in is_diff_cols)) ) @@ -216,7 +243,10 @@ def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): logger.info("Counting and sampling exclusive rows") - exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b) + if isinstance(db, Oracle): + exclusive_rows_query = diff_rows.where((this.is_exclusive_a==1) | (this.is_exclusive_b==1)) + else: + exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b) with temp_table(db, exclusive_rows_query) as exclusive_rows: self.stats["exclusive_count"] = db.query(exclusive_rows.count(), int) sample_rows = db.query(sample(exclusive_rows.select(*this[list(a_cols)], *this[list(b_cols)])), list) diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index 76aaf5d2..136807eb 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -11,8 +11,16 @@ def join(*tables: ITable): return Join(tables) +def leftjoin(*tables: ITable): + "Left-joins each table into a 'struct'" + return Join(tables, "LEFT") + +def rightjoin(*tables: ITable): + "Right-joins each table into a 'struct'" + return Join(tables, "RIGHT") + def outerjoin(*tables: ITable): - "Outerjoins each table into a 'struct'" + "Outer-joins each table into a 'struct'" return Join(tables, "FULL OUTER") diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 019227f7..a3383ad2 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Generator, Sequence, Tuple, Union +from typing import Any, Generator, ItemsView, Sequence, Tuple, Union from runtype import dataclass @@ -119,6 +119,9 @@ def __getitem__(self, column): def count(self): return Select(self, [Count()]) + def union(self, other: 'ITable'): + return Union(self, other) + @dataclass class Concat(ExprNode): @@ -348,6 +351,22 @@ class GroupBy(ITable): def having(self): pass +@dataclass +class Union(ExprNode, ITable): + table1: ITable + table2: ITable + + @property + def source_table(self): + return self # TODO is this right? + + def compile(self, parent_c: Compiler) -> str: + c = parent_c.replace(in_select=False) + union_all = f"{c.compile(self.table1)} UNION {c.compile(self.table2)}" + if parent_c.in_select: + union_all = f"({union_all}) {c.new_unique_name()}" + return union_all + @dataclass class Select(ExprNode, ITable): diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index 8ea0e7a5..64e24650 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -1,3 +1,4 @@ +import random from abc import ABC, abstractmethod from datetime import datetime from typing import Any, Dict, Sequence, List @@ -51,6 +52,10 @@ def new_unique_name(self, prefix="tmp"): self._counter[0] += 1 return f"{prefix}{self._counter[0]}" + def new_unique_table_name(self, prefix="tmp"): + self._counter[0] += 1 + return f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}" + def add_table_context(self, *tables: Sequence): return self.replace(_table_context=self._table_context + list(tables)) diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 63195efb..7668ec61 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -81,6 +81,8 @@ def _get_text_type(conn): def _get_float_type(conn): if isinstance(conn, db.BigQuery): return "FLOAT64" + elif isinstance(conn, db.Presto): + return "REAL" return "float" diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index d37cea58..e8db3167 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -3,7 +3,6 @@ from data_diff.databases.connect import connect from data_diff.table_segment import TableSegment, split_space from data_diff import databases as db -from data_diff.utils import ArithAlphanumeric from data_diff.joindiff_tables import JoinDiffer from .test_diff_tables import TestPerDatabase, _get_float_type, _get_text_type, _commit, _insert_row, _insert_rows @@ -26,7 +25,7 @@ def init_instances(): DATABASE_INSTANCES = {k.__name__: connect(v, N_THREADS) for k, v in CONN_STRINGS.items()} -TEST_DATABASES = {x.__name__ for x in (db.PostgreSQL,)} +TEST_DATABASES = {x.__name__ for x in (db.PostgreSQL, db.Snowflake, db.MySQL, db.BigQuery, db.Presto, db.Vertica, db.Trino, db.Oracle, db.Redshift)} _class_per_db_dec = parameterized_class( ("name", "db_name"), [(name, name) for name in DATABASE_URIS if name in TEST_DATABASES] diff --git a/tests/test_query.py b/tests/test_query.py index f31f5417..4ae4b82e 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -128,3 +128,11 @@ def test_funcs(self): q = c.compile(t.order_by(Random()).limit(10)) assert q == "SELECT * FROM a ORDER BY random() limit 10" + + def test_union_all(self): + c = Compiler(MockDialect()) + a = table("a").select('x') + b = table("b").select('y') + + q = c.compile(a.union(b)) + assert q == "SELECT x FROM a UNION SELECT y FROM b" From 4f441f02caf4040149bff60d16cccedda68492b4 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 22 Sep 2022 17:11:30 +0300 Subject: [PATCH 09/93] Fix in queries --- data_diff/queries/ast_classes.py | 29 ++++++++++++++++------------- data_diff/queries/compiler.py | 3 ++- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index a3383ad2..56213dd3 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -1,5 +1,6 @@ +from dataclasses import field from datetime import datetime -from typing import Any, Generator, ItemsView, Sequence, Tuple, Union +from typing import Any, Generator, ItemsView, Optional, Sequence, Tuple, Union from runtype import dataclass @@ -246,7 +247,7 @@ def compile(self, c: Compiler) -> str: t for t in c._table_context if isinstance(t, TableAlias) and t.source_table is self.source_table ] if not aliases: - raise CompileError(f"No aliased table found for column {self.name}") # TODO better error + return c.quote(self.name) elif len(aliases) > 1: raise CompileError(f"Too many aliases for column {self.name}") (alias,) = aliases @@ -259,7 +260,7 @@ def compile(self, c: Compiler) -> str: @dataclass class TablePath(ExprNode, ITable): path: DbPath - schema: Schema = None + schema: Optional[Schema] = field(default=None, repr=False) def insert_values(self, rows): pass @@ -329,7 +330,7 @@ def compile(self, parent_c: Compiler) -> str: tables = [ t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in self.source_tables ] - c = parent_c.add_table_context(*tables) + c = parent_c.add_table_context(*tables).replace(in_join=True, in_select=False) op = " JOIN " if self.op is None else f" {self.op} JOIN " joined = op.join(c.compile(t) for t in tables) @@ -344,6 +345,8 @@ def compile(self, parent_c: Compiler) -> str: if parent_c.in_select: select = f"({select}) {c.new_unique_name()}" + elif parent_c.in_join: + select = f"({select})" return select @@ -365,34 +368,32 @@ def compile(self, parent_c: Compiler) -> str: union_all = f"{c.compile(self.table1)} UNION {c.compile(self.table2)}" if parent_c.in_select: union_all = f"({union_all}) {c.new_unique_name()}" + elif parent_c.in_join: + union_all = f"({union_all})" return union_all @dataclass class Select(ExprNode, ITable): - table: Expr = None + source_table: Expr = None columns: Sequence[Expr] = None where_exprs: Sequence[Expr] = None order_by_exprs: Sequence[Expr] = None group_by_exprs: Sequence[Expr] = None limit_expr: int = None - @property - def source_table(self): - return self - @property def schema(self): - return self.table.schema + return self.source_table.schema def compile(self, parent_c: Compiler) -> str: - c = parent_c.replace(in_select=True).add_table_context(self.table) + c = parent_c.replace(in_select=True) #.add_table_context(self.table) columns = ", ".join(map(c.compile, self.columns)) if self.columns else "*" select = f"SELECT {columns}" - if self.table: - select += " FROM " + c.compile(self.table) + if self.source_table: + select += " FROM " + c.compile(self.source_table) if self.where_exprs: select += " WHERE " + " AND ".join(map(c.compile, self.where_exprs)) @@ -407,6 +408,8 @@ def compile(self, parent_c: Compiler) -> str: select += " " + c.database.offset_limit(0, self.limit_expr) if parent_c.in_select: + select = f"({select}) {c.new_unique_name()}" + elif parent_c.in_join: select = f"({select})" return select diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index 64e24650..5133301c 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -12,7 +12,8 @@ @dataclass class Compiler: database: AbstractDialect - in_select: bool = False # Compilation + in_select: bool = False # Compilation runtime flag + in_join: bool = False # Compilation runtime flag _table_context: List = [] # List[ITable] _subqueries: Dict[str, Any] = {} # XXX not thread-safe From de26e56dd86579cb903c492c3b9250efea2236fd Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 22 Sep 2022 17:12:59 +0300 Subject: [PATCH 10/93] Joindiff now support tracking and bisection --- data_diff/diff_tables.py | 130 ++++++++++++++++++++++++++- data_diff/hashdiff_tables.py | 164 ++++++++--------------------------- data_diff/joindiff_tables.py | 70 +++++++++++---- data_diff/table_segment.py | 4 + data_diff/utils.py | 8 ++ tests/common.py | 1 + 6 files changed, 228 insertions(+), 149 deletions(-) diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 3a92d708..5cd21302 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -1,6 +1,7 @@ """Provides classes for performing a table diff """ +import time from abc import ABC, abstractmethod from enum import Enum from contextlib import contextmanager @@ -8,10 +9,15 @@ from typing import Tuple, Iterator, Optional from concurrent.futures import ThreadPoolExecutor, as_completed -from .table_segment import TableSegment - from runtype import dataclass +from .utils import run_as_daemon, safezip, getLogger +from .thread_utils import ThreadedYielder +from .table_segment import TableSegment +from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled +from .databases.database_types import IKey + +logger = getLogger(__name__) class Algorithm(Enum): AUTO = "auto" @@ -64,7 +70,8 @@ def _run_in_background(self, *funcs): class TableDiffer(ThreadBase, ABC): - @abstractmethod + bisection_factor = 32 + def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: """Diff the given tables. @@ -78,3 +85,120 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: ('+', row) for items in table2 but not in table1. Where `row` is a tuple of values, corresponding to the diffed columns. """ + + if is_tracking_enabled(): + options = dict(self) + event_json = create_start_event_json(options) + run_as_daemon(send_event_json, event_json) + + self.stats["diff_count"] = 0 + start = time.monotonic() + error = None + try: + + # Query and validate schema + table1, table2 = self._threaded_call("with_schema", [table1, table2]) + self._validate_and_adjust_columns(table1, table2) + + yield from self._diff_tables(table1, table2) + + except BaseException as e: # Catch KeyboardInterrupt too + error = e + finally: + if is_tracking_enabled(): + runtime = time.monotonic() - start + table1_count = self.stats.get("table1_count") + table2_count = self.stats.get("table2_count") + diff_count = self.stats.get("diff_count") + err_message = str(error)[:20] # Truncate possibly sensitive information. + event_json = create_end_event_json( + error is None, + runtime, + table1.database.name, + table2.database.name, + table1_count, + table2_count, + diff_count, + err_message, + ) + send_event_json(event_json) + + if error: + raise error + + def _validate_and_adjust_columns(self, table1: TableSegment, table2: TableSegment) -> DiffResult: + pass + + def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: + return self._bisect_and_diff_tables(table1, table2) + + + @abstractmethod + def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): + ... + + + def _bisect_and_diff_tables(self, table1, table2): + key_type = table1._schema[table1.key_column] + key_type2 = table2._schema[table2.key_column] + if not isinstance(key_type, IKey): + raise NotImplementedError(f"Cannot use column of type {key_type} as a key") + if not isinstance(key_type2, IKey): + raise NotImplementedError(f"Cannot use column of type {key_type2} as a key") + assert key_type.python_type is key_type2.python_type + + # Query min/max values + key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2]) + + # Start with the first completed value, so we don't waste time waiting + min_key1, max_key1 = self._parse_key_range_result(key_type, next(key_ranges)) + + table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)] + + logger.info( + # f"Diffing tables | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}. " + f"Diffing segments at key-range: {table1.min_key}..{table2.max_key}. " + f"size: table1 <= {table1.approximate_size()}, table2 <= {table2.approximate_size()}" + ) + + ti = ThreadedYielder(self.max_threadpool_size) + # Bisect (split) the table into segments, and diff them recursively. + ti.submit(self._bisect_and_diff_segments, ti, table1, table2) + + # Now we check for the second min-max, to diff the portions we "missed". + min_key2, max_key2 = self._parse_key_range_result(key_type, next(key_ranges)) + + if min_key2 < min_key1: + pre_tables = [t.new(min_key=min_key2, max_key=min_key1) for t in (table1, table2)] + ti.submit(self._bisect_and_diff_segments, ti, *pre_tables) + + if max_key2 > max_key1: + post_tables = [t.new(min_key=max_key1, max_key=max_key2) for t in (table1, table2)] + ti.submit(self._bisect_and_diff_segments, ti, *post_tables) + + return ti + + + def _parse_key_range_result(self, key_type, key_range): + mn, mx = key_range + cls = key_type.make_value + # We add 1 because our ranges are exclusive of the end (like in Python) + try: + return cls(mn), cls(mx) + 1 + except (TypeError, ValueError) as e: + raise type(e)(f"Cannot apply {key_type} to {mn}, {mx}.") from e + + + def _bisect_and_diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None): + assert table1.is_bounded and table2.is_bounded + + # Choose evenly spaced checkpoints (according to min_key and max_key) + checkpoints = table1.choose_checkpoints(self.bisection_factor - 1) + + # Create new instances of TableSegment between each checkpoint + segmented1 = table1.segment_by_checkpoints(checkpoints) + segmented2 = table2.segment_by_checkpoints(checkpoints) + + # Recursively compare each pair of corresponding segments between table1 and table2 + for i, (t1, t2) in enumerate(safezip(segmented1, segmented2)): + ti.submit(self._diff_segments, ti, t1, t2, max_rows, level + 1, i + 1, len(segmented1), priority=level) diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index 0f2e8cb7..64b05b67 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -1,5 +1,4 @@ import os -import time from numbers import Number import logging from collections import defaultdict @@ -8,13 +7,12 @@ from runtype import dataclass -from .utils import safezip, run_as_daemon +from .utils import safezip from .thread_utils import ThreadedYielder from .databases.database_types import IKey, NumericType, PrecisionType, StringType from .table_segment import TableSegment -from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled -from .diff_tables import TableDiffer, DiffResult +from .diff_tables import TableDiffer BENCHMARK = os.environ.get("BENCHMARK", False) @@ -61,98 +59,14 @@ class HashDiffer(TableDiffer): stats: dict = {} - def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: + def __post_init__(self): # Validate options if self.bisection_factor >= self.bisection_threshold: raise ValueError("Incorrect param values (bisection factor must be lower than threshold)") if self.bisection_factor < 2: raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)") - if is_tracking_enabled(): - options = dict(self) - event_json = create_start_event_json(options) - run_as_daemon(send_event_json, event_json) - - self.stats["diff_count"] = 0 - start = time.monotonic() - error = None - try: - - # Query and validate schema - table1, table2 = self._threaded_call("with_schema", [table1, table2]) - self._validate_and_adjust_columns(table1, table2) - - key_type = table1._schema[table1.key_column] - key_type2 = table2._schema[table2.key_column] - if not isinstance(key_type, IKey): - raise NotImplementedError(f"Cannot use column of type {key_type} as a key") - if not isinstance(key_type2, IKey): - raise NotImplementedError(f"Cannot use column of type {key_type2} as a key") - assert key_type.python_type is key_type2.python_type - - # Query min/max values - key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2]) - - # Start with the first completed value, so we don't waste time waiting - min_key1, max_key1 = self._parse_key_range_result(key_type, next(key_ranges)) - - table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)] - - logger.info( - f"Diffing tables | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}. " - f"key-range: {table1.min_key}..{table2.max_key}, " - f"size: table1 <= {table1.approximate_size()}, table2 <= {table2.approximate_size()}" - ) - - ti = ThreadedYielder(self.max_threadpool_size) - # Bisect (split) the table into segments, and diff them recursively. - ti.submit(self._bisect_and_diff_tables, ti, table1, table2) - - # Now we check for the second min-max, to diff the portions we "missed". - min_key2, max_key2 = self._parse_key_range_result(key_type, next(key_ranges)) - - if min_key2 < min_key1: - pre_tables = [t.new(min_key=min_key2, max_key=min_key1) for t in (table1, table2)] - ti.submit(self._bisect_and_diff_tables, ti, *pre_tables) - - if max_key2 > max_key1: - post_tables = [t.new(min_key=max_key1, max_key=max_key2) for t in (table1, table2)] - ti.submit(self._bisect_and_diff_tables, ti, *post_tables) - - yield from ti - - except BaseException as e: # Catch KeyboardInterrupt too - error = e - finally: - if is_tracking_enabled(): - runtime = time.monotonic() - start - table1_count = self.stats.get("table1_count") - table2_count = self.stats.get("table2_count") - diff_count = self.stats.get("diff_count") - err_message = str(error)[:20] # Truncate possibly sensitive information. - event_json = create_end_event_json( - error is None, - runtime, - table1.database.name, - table2.database.name, - table1_count, - table2_count, - diff_count, - err_message, - ) - send_event_json(event_json) - - if error: - raise error - - def _parse_key_range_result(self, key_type, key_range): - mn, mx = key_range - cls = key_type.make_value - # We add 1 because our ranges are exclusive of the end (like in Python) - try: - return cls(mn), cls(mx) + 1 - except (TypeError, ValueError) as e: - raise type(e)(f"Cannot apply {key_type} to {mn}, {mx}.") from e + def _validate_and_adjust_columns(self, table1, table2): for c1, c2 in safezip(table1._relevant_columns, table2._relevant_columns): @@ -201,44 +115,8 @@ def _validate_and_adjust_columns(self, table1, table2): "If encoding/formatting differs between databases, it may result in false positives." ) - def _bisect_and_diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None): - assert table1.is_bounded and table2.is_bounded - - if max_rows is None: - # We can be sure that row_count <= max_rows - max_rows = max(table1.approximate_size(), table2.approximate_size()) - - # If count is below the threshold, just download and compare the columns locally - # This saves time, as bisection speed is limited by ping and query performance. - if max_rows < self.bisection_threshold: - rows1, rows2 = self._threaded_call("get_values", [table1, table2]) - diff = list(diff_sets(rows1, rows2)) - - # Initial bisection_threshold larger than count. Normally we always - # checksum and count segments, even if we get the values. At the - # first level, however, that won't be true. - if level == 0: - self.stats["table1_count"] = len(rows1) - self.stats["table2_count"] = len(rows2) - - self.stats["diff_count"] += len(diff) - - logger.info(". " * level + f"Diff found {len(diff)} different rows.") - self.stats["rows_downloaded"] = self.stats.get("rows_downloaded", 0) + max(len(rows1), len(rows2)) - return diff - - # Choose evenly spaced checkpoints (according to min_key and max_key) - checkpoints = table1.choose_checkpoints(self.bisection_factor - 1) - - # Create new instances of TableSegment between each checkpoint - segmented1 = table1.segment_by_checkpoints(checkpoints) - segmented2 = table2.segment_by_checkpoints(checkpoints) - # Recursively compare each pair of corresponding segments between table1 and table2 - for i, (t1, t2) in enumerate(safezip(segmented1, segmented2)): - ti.submit(self._diff_tables, ti, t1, t2, max_rows, level + 1, i + 1, len(segmented1), priority=level) - - def _diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): + def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): logger.info( ". " * level + f"Diffing segment {segment_index}/{segment_count}, " f"key-range: {table1.min_key}..{table2.max_key}, " @@ -251,7 +129,7 @@ def _diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableS # the threshold) and _then_ download it. if BENCHMARK: if max_rows < self.bisection_threshold: - return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max_rows) + return self._bisect_and_diff_segments(ti, table1, table2, level=level, max_rows=max_rows) (count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2]) @@ -268,4 +146,32 @@ def _diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableS self.stats["table2_count"] = self.stats.get("table2_count", 0) + count2 if checksum1 != checksum2: - return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max(count1, count2)) + return self._bisect_and_diff_segments(ti, table1, table2, level=level, max_rows=max(count1, count2)) + + def _bisect_and_diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None): + assert table1.is_bounded and table2.is_bounded + + if max_rows is None: + # We can be sure that row_count <= max_rows + max_rows = max(table1.approximate_size(), table2.approximate_size()) + + # If count is below the threshold, just download and compare the columns locally + # This saves time, as bisection speed is limited by ping and query performance. + if max_rows < self.bisection_threshold: + rows1, rows2 = self._threaded_call("get_values", [table1, table2]) + diff = list(diff_sets(rows1, rows2)) + + # Initial bisection_threshold larger than count. Normally we always + # checksum and count segments, even if we get the values. At the + # first level, however, that won't be true. + if level == 0: + self.stats["table1_count"] = len(rows1) + self.stats["table2_count"] = len(rows2) + + self.stats["diff_count"] += len(diff) + + logger.info(". " * level + f"Diff found {len(diff)} different rows.") + self.stats["rows_downloaded"] = self.stats.get("rows_downloaded", 0) + max(len(rows1), len(rows2)) + return diff + + return super()._bisect_and_diff_segments(ti, table1, table2, level, max_rows) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 53acd954..988cbaec 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -2,6 +2,7 @@ """ +from collections import defaultdict from decimal import Decimal from functools import partial import logging @@ -13,9 +14,10 @@ from .utils import safezip from .databases.base import Database -from .databases import MySQL, BigQuery, Presto, Oracle +from .databases import MySQL, BigQuery, Presto, Oracle, PostgreSQL, Snowflake from .table_segment import TableSegment from .diff_tables import TableDiffer, DiffResult +from .thread_utils import ThreadedYielder from .queries import table, sum_, min_, max_, avg from .queries.api import and_, if_, or_, outerjoin, leftjoin, rightjoin, this, ITable @@ -67,8 +69,10 @@ def temp_table(db: Database, expr: Expr): try: yield table(name, schema=expr.source_table.schema) finally: - # Only drops if create table succeeded (meaning, the table didn't already exist) - db.query(f"drop table {c.quote(name)}", None) + if isinstance(db, (BigQuery, Presto)): + # Only drops if create table succeeded (meaning, the table didn't already exist) + # And if the table won't delete itself + db.query(f"drop table {c.quote(name)}", None) def _slice_tuple(t, *sizes): @@ -90,28 +94,50 @@ class JoinDifferBase(TableDiffer): """Finds the diff between two SQL tables using JOINs""" stats: dict = {} + validate_unique_key: bool = True - def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: - table1, table2 = self._threaded_call("with_schema", [table1, table2]) + def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: + db = table1.database if table1.database is not table2.database: raise ValueError("Join-diff only works when both tables are in the same database") + table1, table2 = self._threaded_call("with_schema", [table1, table2]) + + + bg_funcs = [partial(self._test_duplicate_keys, table1, table2)] if self.validate_unique_key else [] + + with self._run_in_background(*bg_funcs): + if isinstance(db, (Snowflake, BigQuery)): + # Don't segment the table; let the database handling parallelization + yield from self._diff_segments(None, table1, table2, None) + else: + yield from self._bisect_and_diff_tables(table1, table2) + + def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): + assert table1.database is table2.database + + logger.info( + ". " * level + f"Diffing segment {segment_index}/{segment_count}, " + f"key-range: {table1.min_key}..{table2.max_key}, " + f"size <= {max_rows}" + ) + with self._run_in_background( - partial(self._test_null_or_duplicate_keys, table1, table2), - partial(self._collect_stats, 1, table1), - partial(self._collect_stats, 2, table2) - ): + partial(self._collect_stats, 1, table1), + partial(self._collect_stats, 2, table2), + partial(self._test_null_keys, table1, table2), + ): yield from self._outer_join(table1, table2) logger.info("Diffing complete") - def _test_null_or_duplicate_keys(self, table1, table2): - logger.info("Testing for null or duplicate keys") + def _test_duplicate_keys(self, table1, table2): + logger.debug("Testing for duplicate keys") - # Test null or duplicate keys + # Test duplicate keys for ts in [table1, table2]: - t = table(*ts.table_path, schema=ts._schema) + t = ts._make_select() key_columns = [ts.key_column] # XXX q = t.select(total=Count(), total_distinct=Count(Concat(key_columns), distinct=True)) @@ -119,12 +145,19 @@ def _test_null_or_duplicate_keys(self, table1, table2): if total != total_distinct: raise ValueError("Duplicate primary keys") + def _test_null_keys(self, table1, table2): + logger.debug("Testing for null keys") + + # Test null keys + for ts in [table1, table2]: + t = ts._make_select() + key_columns = [ts.key_column] # XXX + q = t.select(*key_columns).where(or_(this[k] == None for k in key_columns)) nulls = ts.database.query(q, list) if nulls: raise ValueError(f"NULL values in one or more primary keys") - logger.debug("Done testing for null or duplicate keys") def _collect_stats(self, i, table): logger.info(f"Collecting stats for table #{i}") @@ -145,7 +178,9 @@ def _collect_stats(self, i, table): res = db.query(table._make_select().select(**col_exprs), tuple) res = dict(zip([f"table{i}_{n}" for n in col_exprs], map(json_friendly_value, res))) - self.stats.update(res) + for k, v in res.items(): + self.stats[k] = self.stats.get(k, 0) + (v or 0) + # self.stats.update(res) logger.debug(f"Done collecting stats for table #{i}") @@ -221,7 +256,7 @@ def _outer_join(self, table1, table2): partial(self._count_diff_per_column, db, diff_rows, cols1, is_diff_cols) ): - logger.info("Querying for different rows") + logger.debug("Querying for different rows") for is_xa, is_xb, *x in db.query(diff_rows, list): if is_xa and is_xb: # Can't both be exclusive, meaning a pk is NULL @@ -238,7 +273,7 @@ def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): is_diff_cols_counts = db.query(diff_rows.select(sum_(this[c]) for c in is_diff_cols), tuple) diff_counts = {} for name, count in safezip(cols, is_diff_cols_counts): - diff_counts[name] = count + diff_counts[name] = diff_counts.get(name, 0) + (count or 0) self.stats['diff_counts'] = diff_counts def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): @@ -247,6 +282,7 @@ def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): exclusive_rows_query = diff_rows.where((this.is_exclusive_a==1) | (this.is_exclusive_b==1)) else: exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b) + with temp_table(db, exclusive_rows_query) as exclusive_rows: self.stats["exclusive_count"] = db.query(exclusive_rows.count(), int) sample_rows = db.query(sample(exclusive_rows.select(*this[list(a_cols)], *this[list(b_cols)])), list) diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 761b3a74..8b51e3f6 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -103,6 +103,10 @@ def get_values(self) -> list: def choose_checkpoints(self, count: int) -> List[DbKey]: "Suggests a bunch of evenly-spaced checkpoints to split by (not including start, end)" + + if self.max_key - self.min_key <= count: + count = 1 + assert self.is_bounded if isinstance(self.min_key, ArithString): assert type(self.min_key) is type(self.max_key) diff --git a/data_diff/utils.py b/data_diff/utils.py index 5911f8f8..2e346fa3 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -1,3 +1,4 @@ +import logging import re import math from typing import Iterable, Tuple, Union, Any, Sequence, Dict @@ -214,6 +215,9 @@ def __setitem__(self, key: str, value: V): def __contains__(self, key: str) -> bool: ... + def __repr__(self): + return repr(dict(self.items())) + class CaseInsensitiveDict(CaseAwareMapping): def __init__(self, initial): @@ -285,3 +289,7 @@ def run_as_daemon(threadfunc, *args): th.daemon = True th.start() return th + + +def getLogger(name): + return logging.getLogger(name.rsplit('.', 1)[-1]) diff --git a/tests/common.py b/tests/common.py index 44a15cf2..5cce3964 100644 --- a/tests/common.py +++ b/tests/common.py @@ -45,6 +45,7 @@ def get_git_revision_short_hash() -> str: logging.basicConfig(level=level) logging.getLogger("hashdiff_tables").setLevel(level) logging.getLogger("joindiff_tables").setLevel(level) +logging.getLogger("diff_tables").setLevel(level) logging.getLogger("table_segment").setLevel(level) logging.getLogger("database").setLevel(level) From 7c7e5bd963e8e94470aa91668dc48eaee2e06881 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 23 Sep 2022 11:48:20 +0300 Subject: [PATCH 11/93] Added diffing schemas (when same db, for mutual columns) --- data_diff/__main__.py | 19 ++++++++++++++++++- data_diff/joindiff_tables.py | 14 +++++++------- data_diff/utils.py | 4 +++- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index adb5bee9..39951b06 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -45,6 +45,19 @@ def _get_schema(pair): return db.query_table_schema(table_path) +def diff_schemas(schema1, schema2, columns): + logging.info('Diffing schemas...') + attrs = 'name', 'type', 'datetime_precision', 'numeric_precision', 'numeric_scale' + for c in columns: + if c is None: # Skip for convenience + continue + diffs = [] + for attr, v1, v2 in safezip(attrs, schema1[c], schema2[c]): + if v1 != v2: + diffs.append(f"{attr}:({v1} != {v2})") + if diffs: + logging.warning(f"Schema mismatch in column '{c}': {', '.join(diffs)}") + class MyHelpFormatter(click.HelpFormatter): def __init__(self, **kwargs): super().__init__(self, **kwargs) @@ -300,7 +313,11 @@ def _main( columns = tuple(expanded_columns - {key_column, update_column}) - logging.info(f"Diffing columns: key={key_column} update={update_column} extra={columns}") + if db1 is db2: + diff_schemas(schema1, schema2, (key_column, update_column,) + columns) + + + logging.info(f"Diffing using columns: key={key_column} update={update_column} extra={columns}") segments = [ TableSegment(db, table_path, key_column, update_column, columns, **options)._with_raw_schema(raw_schema) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 988cbaec..32b47a14 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -113,15 +113,17 @@ def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult yield from self._diff_segments(None, table1, table2, None) else: yield from self._bisect_and_diff_tables(table1, table2) + logger.info("Diffing complete") def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): assert table1.database is table2.database - logger.info( - ". " * level + f"Diffing segment {segment_index}/{segment_count}, " - f"key-range: {table1.min_key}..{table2.max_key}, " - f"size <= {max_rows}" - ) + if segment_index or table1.min_key or max_rows: + logger.info( + ". " * level + f"Diffing segment {segment_index}/{segment_count}, " + f"key-range: {table1.min_key}..{table2.max_key}, " + f"size <= {max_rows}" + ) with self._run_in_background( partial(self._collect_stats, 1, table1), @@ -130,8 +132,6 @@ def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: Tabl ): yield from self._outer_join(table1, table2) - logger.info("Diffing complete") - def _test_duplicate_keys(self, table1, table2): logger.debug("Testing for duplicate keys") diff --git a/data_diff/utils.py b/data_diff/utils.py index 2e346fa3..642a4b7b 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -15,7 +15,9 @@ def safezip(*args): "zip but makes sure all sequences are the same length" - assert len(set(map(len, args))) == 1 + lens = list(map(len, args)) + if len(set(lens)) != 1: + raise ValueError(f"Mismatching lengths in arguments to safezip: {lens}") return zip(*args) From bee5479d59f486efddb7a2571cf88cfdf24326be Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 23 Sep 2022 16:09:07 +0300 Subject: [PATCH 12/93] Joindiff: Added Interpreter; Fixed exclusive_rows to use temp_table in an interpreter. --- data_diff/databases/base.py | 34 +++++++++++++++++----- data_diff/joindiff_tables.py | 55 ++++++++++++++++++----------------- data_diff/queries/__init__.py | 2 +- data_diff/queries/compiler.py | 26 +++++++++++++++-- 4 files changed, 80 insertions(+), 37 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 181a80e5..288bf6fd 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -1,7 +1,7 @@ import math import sys import logging -from typing import Dict, Tuple, Optional, Sequence, Type, List +from typing import Dict, Generator, Tuple, Optional, Sequence, Type, List, Union from functools import wraps from concurrent.futures import ThreadPoolExecutor import threading @@ -27,7 +27,7 @@ DbPath, ) -from data_diff.queries import Expr, Compiler, table, Select, SKIP +from data_diff.queries import Expr, Compiler, table, Select, SKIP, ThreadLocalInterpreter logger = logging.getLogger("database") @@ -66,11 +66,29 @@ def _one(seq): return x -def _query_conn(conn, sql_code: str) -> list: +def _query_cursor(c, sql_code): + try: + c.execute(sql_code) + if sql_code.lower().startswith("select"): + return c.fetchall() + except Exception as e: + logger.exception(e) + raise + +def _query_conn(conn, sql_code: Union[str, ThreadLocalInterpreter]) -> list: c = conn.cursor() - c.execute(sql_code) - if sql_code.lower().startswith("select"): - return c.fetchall() + + if isinstance(sql_code, ThreadLocalInterpreter): + g = sql_code.interpret() + q = next(g) + while True: + res = _query_cursor(c, q) + try: + q = g.send(res) + except StopIteration: + break + else: + return _query_cursor(c, sql_code) class Database(AbstractDatabase): @@ -312,11 +330,11 @@ def set_conn(self): except ModuleNotFoundError as e: self._init_error = e - def _query(self, sql_code: str): + def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): r = self._queue.submit(self._query_in_worker, sql_code) return r.result() - def _query_in_worker(self, sql_code: str): + def _query_in_worker(self, sql_code: Union[str, ThreadLocalInterpreter]): "This method runs in a worker thread" if self._init_error: raise self._init_error diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 32b47a14..2b83201f 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -2,11 +2,9 @@ """ -from collections import defaultdict from decimal import Decimal from functools import partial import logging -from contextlib import contextmanager from typing import Dict, List from runtype import dataclass @@ -49,30 +47,17 @@ class Stats: def sample(table): return table.order_by(Random()).limit(10) - -@contextmanager -def temp_table(db: Database, expr: Expr): - c = Compiler(db) - - name = c.new_unique_table_name("temp_table") - +def create_temp_table(c: Compiler, name: str, expr: Expr): + db = c.database if isinstance(db, BigQuery): name = f"{db.default_schema}.{name}" - db.query(f"create table {c.quote(name)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}", None) + return f"create table {c.quote(name)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}" elif isinstance(db, Presto): - db.query(f"create table {c.quote(name)} as {c.compile(expr)}", None) + return f"create table {c.quote(name)} as {c.compile(expr)}" elif isinstance(db, Oracle): - db.query(f"create global temporary table {c.quote(name)} as {c.compile(expr)}", None) + return f"create global temporary table {c.quote(name)} as {c.compile(expr)}" else: - db.query(f"create temporary table {c.quote(name)} as {c.compile(expr)}", None) - - try: - yield table(name, schema=expr.source_table.schema) - finally: - if isinstance(db, (BigQuery, Presto)): - # Only drops if create table succeeded (meaning, the table didn't already exist) - # And if the table won't delete itself - db.query(f"drop table {c.quote(name)}", None) + return f"create temporary table {c.quote(name)} as {c.compile(expr)}" def _slice_tuple(t, *sizes): @@ -95,6 +80,7 @@ class JoinDifferBase(TableDiffer): stats: dict = {} validate_unique_key: bool = True + sample_exclusive_rows: bool = True def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: db = table1.database @@ -277,13 +263,30 @@ def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): self.stats['diff_counts'] = diff_counts def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): - logger.info("Counting and sampling exclusive rows") if isinstance(db, Oracle): exclusive_rows_query = diff_rows.where((this.is_exclusive_a==1) | (this.is_exclusive_b==1)) else: exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b) - with temp_table(db, exclusive_rows_query) as exclusive_rows: - self.stats["exclusive_count"] = db.query(exclusive_rows.count(), int) - sample_rows = db.query(sample(exclusive_rows.select(*this[list(a_cols)], *this[list(b_cols)])), list) - self.stats["exclusive_sample"] = sample_rows + if not self.sample_exclusive_rows: + logger.info("Counting exclusive rows") + self.stats["exclusive_count"] = db.query(exclusive_rows_query.count(), int) + return + + logger.info("Counting and sampling exclusive rows") + def exclusive_rows(expr): + c = Compiler(db) + name = c.new_unique_table_name("temp_table") + yield create_temp_table(c, name, expr) + exclusive_rows = table(name, schema=expr.source_table.schema) + + count = yield exclusive_rows.count() + self.stats["exclusive_count"] = self.stats.get('exclusive_count', 0) + count[0][0] + sample_rows = yield sample(exclusive_rows.select(*this[list(a_cols)], *this[list(b_cols)])) + self.stats["exclusive_sample"] = self.stats.get('exclusive_sample', []) + sample_rows + + # Only drops if create table succeeded (meaning, the table didn't already exist) + yield f"drop table {c.quote(name)}" + + # Run as a sequence of thread-local queries (compiled into a ThreadLocalInterpreter) + db.query(exclusive_rows(exclusive_rows_query), None) diff --git a/data_diff/queries/__init__.py b/data_diff/queries/__init__.py index 93299b26..64a6e60f 100644 --- a/data_diff/queries/__init__.py +++ b/data_diff/queries/__init__.py @@ -1,4 +1,4 @@ -from .compiler import Compiler +from .compiler import Compiler, ThreadLocalInterpreter from .api import this, join, outerjoin, table, SKIP, sum_, avg, min_, max_, cte from .ast_classes import Expr, ExprNode, Select, Count, BinOp, Explain, In from .extras import Checksum, NormalizeAsString, ApplyFuncAndNormalizeAsString diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index 5133301c..62430880 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -1,7 +1,7 @@ import random from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Dict, Sequence, List +from typing import Any, Dict, Generator, Sequence, List, Union from runtype import dataclass @@ -32,7 +32,7 @@ def compile(self, elem) -> str: return f"WITH {subq}\n{res}" return res - def _compile(self, elem) -> str: + def _compile(self, elem) -> Union[str, 'ThreadLocalInterpreter']: if elem is None: return "NULL" elif isinstance(elem, Compilable): @@ -47,6 +47,8 @@ def _compile(self, elem) -> str: return f"b'{elem.decode()}'" elif isinstance(elem, ArithString): return f"'{elem}'" + elif isinstance(elem, Generator): + return ThreadLocalInterpreter(self, elem) assert False, elem def new_unique_name(self, prefix="tmp"): @@ -65,3 +67,23 @@ class Compilable(ABC): @abstractmethod def compile(self, c: Compiler) -> str: ... + + +class ThreadLocalInterpreter: + """An interpeter used to execute a sequence of queries within the same thread. + + Useful for cursor-sensitive operations, such as creating a temporary table. + """ + + def __init__(self, compiler: Compiler, gen: Generator): + self.gen = gen + self.compiler = compiler + + def interpret(self): + q = next(self.gen) + while True: + try: + res = yield self.compiler.compile(q) + q = self.gen.send(res) + except StopIteration: + break From 4c80e5d48156650a675b95eeb37c4f52899dd6a3 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 30 Sep 2022 15:09:04 +0300 Subject: [PATCH 13/93] Tracking: Errors now provide more info, with truncated values --- data_diff/__main__.py | 2 +- data_diff/diff_tables.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 39951b06..e6d37253 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -307,7 +307,7 @@ def _main( m1 = None if any(match_like(c, schema1.keys())) else f"{db1}/{table1}" m2 = None if any(match_like(c, schema2.keys())) else f"{db2}/{table2}" not_matched = ", ".join(m for m in [m1, m2] if m) - raise ValueError(f"Column {c} not found in: {not_matched}") + raise ValueError(f"Column '{c}' not found in: {not_matched}") expanded_columns |= match diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 5cd21302..78440a95 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -1,6 +1,7 @@ """Provides classes for performing a table diff """ +import re import time from abc import ABC, abstractmethod from enum import Enum @@ -27,6 +28,10 @@ class Algorithm(Enum): DiffResult = Iterator[Tuple[str, tuple]] # Iterator[Tuple[Literal["+", "-"], tuple]] +def truncate_error(error: str): + first_line = error.split('\n', 1)[0] + return re.sub("'(.*?)'", "'***'", first_line) + @dataclass class ThreadBase: @@ -110,7 +115,7 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: table1_count = self.stats.get("table1_count") table2_count = self.stats.get("table2_count") diff_count = self.stats.get("diff_count") - err_message = str(error)[:20] # Truncate possibly sensitive information. + err_message = truncate_error(repr(error)) event_json = create_end_event_json( error is None, runtime, @@ -186,7 +191,7 @@ def _parse_key_range_result(self, key_type, key_range): try: return cls(mn), cls(mx) + 1 except (TypeError, ValueError) as e: - raise type(e)(f"Cannot apply {key_type} to {mn}, {mx}.") from e + raise type(e)(f"Cannot apply {key_type} to '{mn}', '{mx}'.") from e def _bisect_and_diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None): From 179ce547d3167f2e167a8f1e944fc22f52f5d38d Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 30 Sep 2022 17:38:30 +0300 Subject: [PATCH 14/93] Better docs and docstrings --- data_diff/__init__.py | 35 ++++++++++++++++++++++++++++++++--- data_diff/joindiff_tables.py | 10 +++++++++- data_diff/table_segment.py | 3 ++- docs/conf.py | 1 + docs/python-api.rst | 8 ++++++++ docs/requirements.txt | 2 +- 6 files changed, 53 insertions(+), 6 deletions(-) diff --git a/data_diff/__init__.py b/data_diff/__init__.py index f22ab039..3e8451ba 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -15,14 +15,17 @@ def connect_to_table( key_column: str = "id", thread_count: Optional[int] = 1, **kwargs, -): +) -> TableSegment: """Connects to the given database, and creates a TableSegment instance Parameters: db_info: Either a URI string, or a dict of connection options. table_name: Name of the table as a string, or a tuple that signifies the path. key_column: Name of the key column - thread_count: Number of threads for this connection (only if using a threadpooled implementation) + thread_count: Number of threads for this connection (only if using a threadpooled db implementation) + + See Also: + :meth:`connect` """ db = connect(db_info, thread_count=thread_count) @@ -61,13 +64,39 @@ def diff_tables( # There may be many pools, so number of actual threads can be a lot higher. max_threadpool_size: Optional[int] = 1, ) -> Iterator: - """Efficiently finds the diff between table1 and table2. + """Finds the diff between table1 and table2. + + Parameters: + key_column (str): Name of the key column, which uniquely identifies each row (usually id) + update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update). + Used by `min_update` and `max_update`. + extra_columns (Tuple[str, ...], optional): Extra columns to compare + min_key (:data:`DbKey`, optional): Lowest key_column value, used to restrict the segment + max_key (:data:`DbKey`, optional): Highest key_column value, used to restrict the segment + min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment + max_update (:data:`DbTime`, optional): Highest update_column value, used to restrict the segment + algorithm (:class:`Algorithm`): Which diffing algorithm to use (`HASHDIFF` or `JOINDIFF`) + bisection_factor (int): Into how many segments to bisect per iteration. (when algorithm is `HASHDIFF`) + bisection_threshold (Number): When should we stop bisecting and compare locally (when algorithm is `HASHDIFF`; in row count). + threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. + max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``. + There may be many pools, so number of actual threads can be a lot higher. + + Note: + The following parameters are used to override the corresponding attributes of the given :class:`TableSegment` instances: + `key_column`, `update_column`, `extra_columns`, `min_key`, `max_key`. If different values are needed per table, it's + possible to omit them here, and instead set them directly when creating each :class:`TableSegment`. Example: >>> table1 = connect_to_table('postgresql:///', 'Rating', 'id') >>> list(diff_tables(table1, table1)) [] + See Also: + :class:`TableSegment` + :class:`HashDiffer` + :class:`JoinDiffer` + """ tables = [table1, table2] override_attrs = { diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 2b83201f..622d002f 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -76,7 +76,15 @@ def json_friendly_value(v): @dataclass class JoinDifferBase(TableDiffer): - """Finds the diff between two SQL tables using JOINs""" + """Finds the diff between two SQL tables using JOINs + + The two tables must reside in the same database, and their primary keys must be unique and not null. + + Parameters: + threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. + max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``. + There may be many pools, so number of actual threads can be a lot higher. + """ stats: dict = {} validate_unique_key: bool = True diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 8b51e3f6..3a4ddbe4 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -23,7 +23,8 @@ class TableSegment: database (Database): Database instance. See :meth:`connect` table_path (:data:`DbPath`): Path to table in form of a tuple. e.g. `('my_dataset', 'table_name')` key_column (str): Name of the key column, which uniquely identifies each row (usually id) - update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update) + update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update). + Used by `min_update` and `max_update`. extra_columns (Tuple[str, ...], optional): Extra columns to compare min_key (:data:`DbKey`, optional): Lowest key_column value, used to restrict the segment max_key (:data:`DbKey`, optional): Highest key_column value, used to restrict the segment diff --git a/docs/conf.py b/docs/conf.py index ef75ecc0..dc58fb90 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -41,6 +41,7 @@ "recommonmark", "sphinx_markdown_tables", "sphinx_copybutton", + "enum_tools.autoenum", # 'sphinx_gallery.gen_gallery' ] diff --git a/docs/python-api.rst b/docs/python-api.rst index f28b18d1..ada633d1 100644 --- a/docs/python-api.rst +++ b/docs/python-api.rst @@ -5,6 +5,10 @@ Python API Reference .. autofunction:: connect +.. autofunction:: connect_to_table + +.. autofunction:: diff_tables + .. autoclass:: HashDiffer :members: __init__, diff_tables @@ -17,6 +21,10 @@ Python API Reference .. autoclass:: data_diff.databases.database_types.AbstractDatabase :members: +.. autoclass:: data_diff.databases.database_types.AbstractDialect + :members: + .. autodata:: DbKey .. autodata:: DbTime .. autodata:: DbPath +.. autoenum:: Algorithm diff --git a/docs/requirements.txt b/docs/requirements.txt index 0d1d793a..252c7acb 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,6 +4,6 @@ sphinx_markdown_tables sphinx-copybutton sphinx-rtd-theme recommonmark +enum-tools[sphinx] -# Requirements. TODO Use poetry instead of this redundant list data_diff From 073333ce624a8f88dd5cb5972395cf61ed449039 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 3 Oct 2022 10:20:26 +0300 Subject: [PATCH 15/93] Refactor joindiff --- data_diff/joindiff_tables.py | 65 +++++++++++++++++------------------- 1 file changed, 30 insertions(+), 35 deletions(-) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 622d002f..d997730f 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -60,6 +60,32 @@ def create_temp_table(c: Compiler, name: str, expr: Expr): return f"create temporary table {c.quote(name)} as {c.compile(expr)}" +def bool_to_int(x): + return if_(x, 1, 0) + + +def _outerjoin(db: Database, a: ITable, b: ITable, keys1: List[str], keys2: List[str], select_fields: dict) -> ITable: + on = [a[k1] == b[k2] for k1, k2 in safezip(keys1, keys2)] + + if isinstance(db, Oracle): + is_exclusive_a = and_(bool_to_int(b[k] == None) for k in keys2) + is_exclusive_b = and_(bool_to_int(a[k] == None) for k in keys1) + else: + is_exclusive_a = and_(b[k] == None for k in keys2) + is_exclusive_b = and_(a[k] == None for k in keys1) + + if isinstance(db, MySQL): + # No outer join + l = leftjoin(a, b).on(*on).select(is_exclusive_a=is_exclusive_a, is_exclusive_b=False, **select_fields) + r = rightjoin(a, b).on(*on).select(is_exclusive_a=False, is_exclusive_b=is_exclusive_b, **select_fields) + return l.union(r) + + return ( + outerjoin(a, b).on(*on) + .select(is_exclusive_a=is_exclusive_a, is_exclusive_b=is_exclusive_b, **select_fields) + ) + + def _slice_tuple(t, *sizes): i = 0 for size in sizes: @@ -74,10 +100,12 @@ def json_friendly_value(v): return v + @dataclass -class JoinDifferBase(TableDiffer): - """Finds the diff between two SQL tables using JOINs +class JoinDiffer(TableDiffer): + """Finds the diff between two SQL tables in the same database, using JOINs. + The algorithm uses an OUTER JOIN (or equivalent) with extra checks and statistics. The two tables must reside in the same database, and their primary keys must be unique and not null. Parameters: @@ -182,39 +210,6 @@ def _collect_stats(self, i, table): # stats.diff_ratio_total = diff_stats['total_diff'] -def bool_to_int(x): - return if_(x, 1, 0) - - -def _outerjoin(db: Database, a: ITable, b: ITable, keys1: List[str], keys2: List[str], select_fields: dict) -> ITable: - on = [a[k1] == b[k2] for k1, k2 in safezip(keys1, keys2)] - - if isinstance(db, Oracle): - is_exclusive_a = and_(bool_to_int(b[k] == None) for k in keys2) - is_exclusive_b = and_(bool_to_int(a[k] == None) for k in keys1) - else: - is_exclusive_a = and_(b[k] == None for k in keys2) - is_exclusive_b = and_(a[k] == None for k in keys1) - - if isinstance(db, MySQL): - # No outer join - l = leftjoin(a, b).on(*on).select(is_exclusive_a=is_exclusive_a, is_exclusive_b=False, **select_fields) - r = rightjoin(a, b).on(*on).select(is_exclusive_a=False, is_exclusive_b=is_exclusive_b, **select_fields) - return l.union(r) - - return ( - outerjoin(a, b).on(*on) - .select(is_exclusive_a=is_exclusive_a, is_exclusive_b=is_exclusive_b, **select_fields) - ) - - -class JoinDiffer(JoinDifferBase): - """Finds the diff between two SQL tables in the same database. - - The algorithm uses an OUTER JOIN (or equivalent) with extra checks and statistics. - - """ - def _outer_join(self, table1, table2): db = table1.database if db is not table2.database: From c1e171d74e3f1340289f4e2b172b894d93a00d51 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 3 Oct 2022 13:15:49 +0300 Subject: [PATCH 16/93] Queries: Derive schemas (WIP) --- data_diff/queries/ast_classes.py | 56 ++++++++++++++++++++++++-------- data_diff/queries/extras.py | 5 +-- tests/test_query.py | 14 ++++++++ 3 files changed, 60 insertions(+), 15 deletions(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 56213dd3..7c891a96 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -40,6 +40,10 @@ class Alias(ExprNode): def compile(self, c: Compiler) -> str: return f"{c.compile(self.expr)} AS {c.quote(self.name)}" + @property + def type(self): + return self.expr.type + def _drop_skips(exprs): return [e for e in exprs if e is not SKIP] @@ -163,6 +167,10 @@ def compile(self, c: Compiler) -> str: args = ", ".join(c.compile(e) for e in self.args) return f"{self.name}({args})" +def _expr_type(e: Expr): + if isinstance(e, ExprNode): + return e.type + return type(e) @dataclass class CaseWhen(ExprNode): @@ -175,30 +183,40 @@ def compile(self, c: Compiler) -> str: else_ = (" " + c.compile(self.else_)) if self.else_ else "" return f"CASE {when_thens}{else_} END" + @property + def type(self): + when_types = {_expr_type(w) for _c,w in self.cases } + if self.else_: + when_types |= _expr_type(self.else_) + if len(when_types) > 1: + raise RuntimeError(f"Non-matching types in when: {when_types}") + t ,= when_types + return t + class LazyOps: def __add__(self, other): return BinOp("+", [self, other]) def __gt__(self, other): - return BinOp(">", [self, other]) + return BinBoolOp(">", [self, other]) def __ge__(self, other): - return BinOp(">=", [self, other]) + return BinBoolOp(">=", [self, other]) def __eq__(self, other): if other is None: - return BinOp("IS", [self, None]) - return BinOp("=", [self, other]) + return BinBoolOp("IS", [self, None]) + return BinBoolOp("=", [self, other]) def __lt__(self, other): - return BinOp("<", [self, other]) + return BinBoolOp("<", [self, other]) def __le__(self, other): - return BinOp("<=", [self, other]) + return BinBoolOp("<=", [self, other]) def __or__(self, other): - return BinOp("OR", [self, other]) + return BinBoolOp("OR", [self, other]) def is_distinct_from(self, other): return IsDistinctFrom(self, other) @@ -211,6 +229,7 @@ def sum(self): class IsDistinctFrom(ExprNode, LazyOps): a: Expr b: Expr + type = bool def compile(self, c: Compiler) -> str: return c.database.is_distinct_from(c.compile(self.a), c.compile(self.b)) @@ -228,6 +247,9 @@ def compile(self, c: Compiler) -> str: a, b = self.args return f"({c.compile(a)} {self.op} {c.compile(b)})" +class BinBoolOp(BinOp): + type = bool + @dataclass(eq=False, order=False) class Column(ExprNode, LazyOps): @@ -299,8 +321,9 @@ def source_table(self): @property def schema(self): - # TODO combine both tables - return None + assert self.columns # TODO Implement SELECT * + s = self.source_tables[0].schema # XXX + return type(s)({c.name: c.type for c in self.columns}) def on(self, *exprs): if len(exprs) == 1: @@ -375,7 +398,7 @@ def compile(self, parent_c: Compiler) -> str: @dataclass class Select(ExprNode, ITable): - source_table: Expr = None + table: Expr = None columns: Sequence[Expr] = None where_exprs: Sequence[Expr] = None order_by_exprs: Sequence[Expr] = None @@ -384,7 +407,14 @@ class Select(ExprNode, ITable): @property def schema(self): - return self.source_table.schema + s = self.table.schema + if s is None or self.columns is None: + return s + return type(s)({c.name: c.type for c in self.columns}) + + @property + def source_table(self): + return self def compile(self, parent_c: Compiler) -> str: c = parent_c.replace(in_select=True) #.add_table_context(self.table) @@ -392,8 +422,8 @@ def compile(self, parent_c: Compiler) -> str: columns = ", ".join(map(c.compile, self.columns)) if self.columns else "*" select = f"SELECT {columns}" - if self.source_table: - select += " FROM " + c.compile(self.source_table) + if self.table: + select += " FROM " + c.compile(self.table) if self.where_exprs: select += " WHERE " + " AND ".join(map(c.compile, self.where_exprs)) diff --git a/data_diff/queries/extras.py b/data_diff/queries/extras.py index 9b5189e1..bcd426df 100644 --- a/data_diff/queries/extras.py +++ b/data_diff/queries/extras.py @@ -12,11 +12,12 @@ @dataclass class NormalizeAsString(ExprNode): expr: ExprNode - type: ColType = None + expr_type: ColType = None + type = str def compile(self, c: Compiler) -> str: expr = c.compile(self.expr) - return c.database.normalize_value_by_type(expr, self.type or self.expr.type) + return c.database.normalize_value_by_type(expr, self.expr_type or self.expr.type) @dataclass diff --git a/tests/test_query.py b/tests/test_query.py index 4ae4b82e..3e895bc5 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -79,6 +79,7 @@ def test_schema(self): c = Compiler(MockDialect()) schema = dict(id="int", comment="varchar") + # test table t = table("a", schema=CaseInsensitiveDict(schema)) q = t.select(this.Id, t["COMMENT"]) assert c.compile(q) == "SELECT id, comment FROM a" @@ -87,6 +88,19 @@ def test_schema(self): self.assertRaises(KeyError, t.__getitem__, "Id") self.assertRaises(KeyError, t.select, this.Id) + # test select + q = t.select(this.id) + self.assertRaises(KeyError, q.__getitem__, "comment") + + # test join + s = CaseInsensitiveDict({'x': int, 'y': int}) + a = table("a", schema=s) + b = table("b", schema=s) + keys = ["x", "y"] + j = outerjoin(a, b).on(a[k] == b[k] for k in keys).select(a['x'], b['y'], xsum=a['x'] + b['x']) + j['x'], j['y'], j['xsum'] + self.assertRaises(KeyError, j.__getitem__, "ysum") + def test_commutable_select(self): # c = Compiler(MockDialect()) From da6c2df0d7205af13b97a0b6bf7354e02f6bd1c0 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 4 Oct 2022 09:35:44 +0300 Subject: [PATCH 17/93] Queries: DDL initial (drop/create table, insert) --- data_diff/queries/ast_classes.py | 49 +++++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 7c891a96..cc2463f0 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -1,6 +1,6 @@ from dataclasses import field from datetime import datetime -from typing import Any, Generator, ItemsView, Optional, Sequence, Tuple, Union +from typing import Any, Generator, Optional, Sequence, Tuple, Union from runtype import dataclass @@ -10,6 +10,7 @@ from .base import SKIP, CompileError, DbPath, Schema, args_as_tuple + class ExprNode(Compilable): type: Any = None @@ -284,11 +285,16 @@ class TablePath(ExprNode, ITable): path: DbPath schema: Optional[Schema] = field(default=None, repr=False) + def create(self, if_not_exists=False): + if not self.schema: + raise ValueError("Schema must have a value to create table") + return CreateTable(self, if_not_exists=if_not_exists) + def insert_values(self, rows): - pass + raise NotImplementedError() - def insert_query(self, query): - pass + def insert_expr(self, expr: Expr): + return InsertToTable(self, expr) @property def source_table(self): @@ -558,3 +564,38 @@ def compile(self, c: Compiler) -> str: class Random(ExprNode): def compile(self, c: Compiler) -> str: return c.database.random() + + +# DDL + +class Statement(Compilable): + type = None + +def to_sql_type(t): + if isinstance(t, str): + return t + return { + int: "int", + str: "varchar", + bool: "boolean", + }[t] + + +@dataclass +class CreateTable(Statement): + path: TablePath + if_not_exists: bool = False + + def compile(self, c: Compiler) -> str: + schema = ', '.join(f'{k} {to_sql_type(v)}' for k, v in self.path.schema.items()) + ne = 'IF NOT EXISTS ' if self.if_not_exists else '' + return f'CREATE TABLE {ne}{c.compile(self.path)}({schema})' + +@dataclass +class InsertToTable(Statement): + # TODO Support insert for only some columns + path: TablePath + expr: Expr + + def compile(self, c: Compiler) -> str: + return f'INSERT INTO {c.compile(self.path)} {c.compile(self.expr)}' From 9f404a06a1522509f1f26b378401ee28ffec0539 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 5 Oct 2022 11:08:46 +0300 Subject: [PATCH 18/93] Queries: Fix in .type --- data_diff/queries/ast_classes.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index cc2463f0..8a07f4fb 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -32,6 +32,10 @@ def cast_to(self, to): Expr = Union[ExprNode, str, bool, int, datetime, ArithString, None] +def get_type(e: Expr) -> type: + if isinstance(e, ExprNode): + return e.type + return type(e) @dataclass class Alias(ExprNode): @@ -43,7 +47,7 @@ def compile(self, c: Compiler) -> str: @property def type(self): - return self.expr.type + return get_type(self.expr) def _drop_skips(exprs): @@ -392,6 +396,17 @@ class Union(ExprNode, ITable): def source_table(self): return self # TODO is this right? + @property + def type(self): + return self.table1.type + + @property + def schema(self): + s1 = self.table1.schema + s2 = self.table2.schema + assert len(s1) == len(s2) + return s1 + def compile(self, parent_c: Compiler) -> str: c = parent_c.replace(in_select=False) union_all = f"{c.compile(self.table1)} UNION {c.compile(self.table2)}" @@ -576,7 +591,7 @@ def to_sql_type(t): return t return { int: "int", - str: "varchar", + str: "varchar(1024)", bool: "boolean", }[t] From 5cd424dd49c0013cb509d7979d6b4fecda2f5e1b Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 3 Oct 2022 13:16:01 +0300 Subject: [PATCH 19/93] Joindiff: Added support to materialize results as tables (-m) --- data_diff/__main__.py | 14 ++++- data_diff/databases/base.py | 2 +- data_diff/diff_tables.py | 2 +- data_diff/joindiff_tables.py | 89 +++++++++++++++++++++----------- data_diff/queries/ast_classes.py | 29 +++++++++-- data_diff/utils.py | 6 +++ 6 files changed, 104 insertions(+), 38 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index e6d37253..437bc67a 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -9,7 +9,9 @@ import rich import click -from .utils import remove_password_from_url, safezip, match_like +from data_diff.databases.base import parse_table_name + +from .utils import eval_name_template, remove_password_from_url, safezip, match_like from .diff_tables import Algorithm from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR from .joindiff_tables import JoinDiffer @@ -104,6 +106,7 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - help=f"Minimal bisection threshold. Below it, data-diff will download the data and compare it locally. Default={DEFAULT_BISECTION_THRESHOLD}.", metavar="NUM", ) +@click.option("-m", "--materialize", default=None, metavar="TABLE_NAME", help="Materialize the diff results into a new table in the database.") @click.option( "--min-age", default=None, @@ -126,6 +129,11 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - is_flag=True, help="Column names are treated as case-sensitive. Otherwise, data-diff corrects their case according to schema.", ) +@click.option( + "--assume-unique-key", + is_flag=True, + help="Skip validating the uniqueness of the key column during joindiff, which is costly in non-cloud dbs.", +) @click.option( "-j", "--threads", @@ -192,6 +200,8 @@ def _main( case_sensitive, json_output, where, + assume_unique_key, + materialize, threads1=None, threads2=None, __conf__=None, @@ -256,6 +266,8 @@ def _main( differ = JoinDiffer( threaded=threaded, max_threadpool_size=threads and threads * 2, + validate_unique_key = not assume_unique_key, + materialize_to_table = materialize and parse_table_name(eval_name_template(materialize)), ) else: assert algorithm == Algorithm.HASHDIFF diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 288bf6fd..bd33165f 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -107,7 +107,7 @@ class Database(AbstractDatabase): def name(self): return type(self).__name__ - def query(self, sql_ast: Expr, res_type: type): + def query(self, sql_ast: Expr, res_type: type = None): "Query the given SQL code/AST, and attempt to convert the result to type 'res_type'" compiler = Compiler(self) diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 78440a95..5ecd1667 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -68,7 +68,7 @@ def _threaded_call_as_completed(self, func, iterable): @contextmanager def _run_in_background(self, *funcs): with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool: - futures = [task_pool.submit(f) for f in funcs] + futures = [task_pool.submit(f) for f in funcs if f is not None] yield futures for f in futures: f.result() diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index d997730f..1afb9467 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -5,10 +5,12 @@ from decimal import Decimal from functools import partial import logging -from typing import Dict, List +from typing import Dict, List, Optional from runtype import dataclass +from data_diff.databases.database_types import DbPath, Schema + from .utils import safezip from .databases.base import Database @@ -17,15 +19,16 @@ from .diff_tables import TableDiffer, DiffResult from .thread_utils import ThreadedYielder -from .queries import table, sum_, min_, max_, avg +from .queries import table, sum_, min_, max_, avg, SKIP from .queries.api import and_, if_, or_, outerjoin, leftjoin, rightjoin, this, ITable -from .queries.ast_classes import Concat, Count, Expr, Random +from .queries.ast_classes import Concat, Count, Expr, Random, TablePath from .queries.compiler import Compiler from .queries.extras import NormalizeAsString - logger = logging.getLogger("joindiff_tables") +WRITE_LIMIT = 1000 + def merge_dicts(dicts): i = iter(dicts) @@ -60,6 +63,18 @@ def create_temp_table(c: Compiler, name: str, expr: Expr): return f"create temporary table {c.quote(name)} as {c.compile(expr)}" +def drop_table(db, name: DbPath): + t = TablePath(name) + db.query(t.drop(if_exists=True)) + +def append_to_table(name: DbPath, expr: Expr): + t = TablePath(name, expr.schema) + yield t.create(if_not_exists=True) # uses expr.schema + yield 'commit' + yield t.insert_expr(expr) + yield 'commit' + + def bool_to_int(x): return if_(x, 1, 0) @@ -117,6 +132,8 @@ class JoinDiffer(TableDiffer): stats: dict = {} validate_unique_key: bool = True sample_exclusive_rows: bool = True + materialize_to_table: DbPath = None + write_limit: int = WRITE_LIMIT def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: db = table1.database @@ -128,8 +145,12 @@ def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult bg_funcs = [partial(self._test_duplicate_keys, table1, table2)] if self.validate_unique_key else [] + if self.materialize_to_table: + drop_table(db, self.materialize_to_table) + db.query('COMMIT') with self._run_in_background(*bg_funcs): + if isinstance(db, (Snowflake, BigQuery)): # Don't segment the table; let the database handling parallelization yield from self._diff_segments(None, table1, table2, None) @@ -147,12 +168,29 @@ def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: Tabl f"size <= {max_rows}" ) + db = table1.database + diff_rows, a_cols, b_cols, is_diff_cols = self._create_outer_join(table1, table2) + with self._run_in_background( partial(self._collect_stats, 1, table1), partial(self._collect_stats, 2, table2), partial(self._test_null_keys, table1, table2), + partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols), + partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols), + partial(self._materialize_diff, db, diff_rows, segment_index=segment_index) if self.materialize_to_table else None, ): - yield from self._outer_join(table1, table2) + + logger.debug("Querying for different rows") + for is_xa, is_xb, *x in db.query(diff_rows, list): + if is_xa and is_xb: + # Can't both be exclusive, meaning a pk is NULL + # This can happen if the explicit null test didn't finish running yet + raise ValueError(f"NULL values in one or more primary keys") + is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols)) + if not is_xb: + yield "-", tuple(a_row) + if not is_xa: + yield "+", tuple(b_row) def _test_duplicate_keys(self, table1, table2): logger.debug("Testing for duplicate keys") @@ -162,7 +200,7 @@ def _test_duplicate_keys(self, table1, table2): t = ts._make_select() key_columns = [ts.key_column] # XXX - q = t.select(total=Count(), total_distinct=Count(Concat(key_columns), distinct=True)) + q = t.select(total=Count(), total_distinct=Count(Concat(this[key_columns]), distinct=True)) total, total_distinct = ts.database.query(q, tuple) if total != total_distinct: raise ValueError("Duplicate primary keys") @@ -175,7 +213,7 @@ def _test_null_keys(self, table1, table2): t = ts._make_select() key_columns = [ts.key_column] # XXX - q = t.select(*key_columns).where(or_(this[k] == None for k in key_columns)) + q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns)) nulls = ts.database.query(q, list) if nulls: raise ValueError(f"NULL values in one or more primary keys") @@ -188,10 +226,10 @@ def _collect_stats(self, i, table): # Metrics col_exprs = merge_dicts( { - f"sum_{c}": sum_(c), - f"avg_{c}": avg(c), - f"min_{c}": min_(c), - f"max_{c}": max_(c), + f"sum_{c}": sum_(this[c]), + f"avg_{c}": avg(this[c]), + f"min_{c}": min_(this[c]), + f"max_{c}": max_(this[c]), } for c in table._relevant_columns if c == "id" # TODO just if the right type @@ -209,8 +247,7 @@ def _collect_stats(self, i, table): # stats.diff_ratio_by_column = diff_stats # stats.diff_ratio_total = diff_stats['total_diff'] - - def _outer_join(self, table1, table2): + def _create_outer_join(self, table1, table2): db = table1.database if db is not table2.database: raise ValueError("Joindiff only applies to tables within the same database") @@ -239,23 +276,8 @@ def _outer_join(self, table1, table2): _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols}) .where(or_(this[c] == 1 for c in is_diff_cols)) ) + return diff_rows, a_cols, b_cols, is_diff_cols - with self._run_in_background( - partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols), - partial(self._count_diff_per_column, db, diff_rows, cols1, is_diff_cols) - ): - - logger.debug("Querying for different rows") - for is_xa, is_xb, *x in db.query(diff_rows, list): - if is_xa and is_xb: - # Can't both be exclusive, meaning a pk is NULL - # This can happen if the explicit null test didn't finish running yet - raise ValueError(f"NULL values in one or more primary keys") - is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols)) - if not is_xb: - yield "-", tuple(a_row) - if not is_xa: - yield "+", tuple(b_row) def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): logger.info("Counting differences per column") @@ -280,7 +302,7 @@ def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): def exclusive_rows(expr): c = Compiler(db) name = c.new_unique_table_name("temp_table") - yield create_temp_table(c, name, expr) + yield create_temp_table(c, name, expr.limit(self.write_limit)) exclusive_rows = table(name, schema=expr.source_table.schema) count = yield exclusive_rows.count() @@ -293,3 +315,10 @@ def exclusive_rows(expr): # Run as a sequence of thread-local queries (compiled into a ThreadLocalInterpreter) db.query(exclusive_rows(exclusive_rows_query), None) + + def _materialize_diff(self, db, diff_rows, segment_index=None): + assert self.materialize_to_table + + db.query(append_to_table(self.materialize_to_table, diff_rows.limit(self.write_limit))) + logger.info(f"Materialized diff to table '{'.'.join(self.materialize_to_table)}'.") + diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 8a07f4fb..eec3a200 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -140,7 +140,7 @@ class Concat(ExprNode): def compile(self, c: Compiler) -> str: # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL - items = [f"coalesce({c.compile(c.database.to_string(expr))}, '')" for expr in self.exprs] + items = [f"coalesce({c.compile(c.database.to_string(c.compile(expr)))}, '')" for expr in self.exprs] assert items if len(items) == 1: return items[0] @@ -294,6 +294,9 @@ def create(self, if_not_exists=False): raise ValueError("Schema must have a value to create table") return CreateTable(self, if_not_exists=if_not_exists) + def drop(self, if_exists=False): + return DropTable(self, if_exists=if_exists) + def insert_values(self, rows): raise NotImplementedError() @@ -513,13 +516,13 @@ def resolve_names(source_table, exprs): if isinstance(expr, ExprNode): for v in expr._dfs_values(): if isinstance(v, _ResolveColumn): - v.resolve(source_table._get_column(v.name)) + v.resolve(source_table._get_column(v.resolve_name)) i += 1 @dataclass(frozen=False, eq=False, order=False) class _ResolveColumn(ExprNode, LazyOps): - name: str + resolve_name: str resolved: Expr = None def resolve(self, expr): @@ -528,15 +531,22 @@ def resolve(self, expr): def compile(self, c: Compiler) -> str: if self.resolved is None: - raise RuntimeError(f"Column not resolved: {self.name}") + raise RuntimeError(f"Column not resolved: {self.resolve_name}") return self.resolved.compile(c) @property def type(self): if self.resolved is None: - raise RuntimeError(f"Column not resolved: {self.name}") + raise RuntimeError(f"Column not resolved: {self.resolve_name}") return self.resolved.type + @property + def name(self): + if self.resolved is None: + raise RuntimeError(f"Column not resolved: {self.name}") + return self.resolved.name + + class This: def __getattr__(self, name): @@ -606,6 +616,15 @@ def compile(self, c: Compiler) -> str: ne = 'IF NOT EXISTS ' if self.if_not_exists else '' return f'CREATE TABLE {ne}{c.compile(self.path)}({schema})' +@dataclass +class DropTable(Statement): + path: TablePath + if_exists: bool = False + + def compile(self, c: Compiler) -> str: + ie = 'IF EXISTS ' if self.if_exists else '' + return f'DROP TABLE {ie}{c.compile(self.path)}' + @dataclass class InsertToTable(Statement): # TODO Support insert for only some columns diff --git a/data_diff/utils.py b/data_diff/utils.py index 642a4b7b..ca05e051 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -9,6 +9,7 @@ import operator import string import threading +from datetime import datetime alphanums = " -" + string.digits + string.ascii_uppercase + "_" + string.ascii_lowercase @@ -295,3 +296,8 @@ def run_as_daemon(threadfunc, *args): def getLogger(name): return logging.getLogger(name.rsplit('.', 1)[-1]) + +def eval_name_template(name): + def get_timestamp(m): + return datetime.now().isoformat('_', 'seconds').replace(':', '_') + return re.sub('%t', get_timestamp, name) From 733972a7c7428800ff1cbe2d6f5ce387b15aa764 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 4 Oct 2022 17:45:47 +0300 Subject: [PATCH 20/93] Queries: Ran black --- data_diff/queries/api.py | 2 ++ data_diff/queries/ast_classes.py | 34 +++++++++++++++++++------------- data_diff/queries/compiler.py | 4 ++-- tests/test_query.py | 10 +++++----- 4 files changed, 29 insertions(+), 21 deletions(-) diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index 136807eb..a07a9084 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -15,10 +15,12 @@ def leftjoin(*tables: ITable): "Left-joins each table into a 'struct'" return Join(tables, "LEFT") + def rightjoin(*tables: ITable): "Right-joins each table into a 'struct'" return Join(tables, "RIGHT") + def outerjoin(*tables: ITable): "Outer-joins each table into a 'struct'" return Join(tables, "FULL OUTER") diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index eec3a200..b3552620 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -10,7 +10,6 @@ from .base import SKIP, CompileError, DbPath, Schema, args_as_tuple - class ExprNode(Compilable): type: Any = None @@ -129,7 +128,7 @@ def __getitem__(self, column): def count(self): return Select(self, [Count()]) - def union(self, other: 'ITable'): + def union(self, other: "ITable"): return Union(self, other) @@ -172,11 +171,13 @@ def compile(self, c: Compiler) -> str: args = ", ".join(c.compile(e) for e in self.args) return f"{self.name}({args})" + def _expr_type(e: Expr): if isinstance(e, ExprNode): return e.type return type(e) + @dataclass class CaseWhen(ExprNode): cases: Sequence[Tuple[Expr, Expr]] @@ -190,12 +191,12 @@ def compile(self, c: Compiler) -> str: @property def type(self): - when_types = {_expr_type(w) for _c,w in self.cases } + when_types = {_expr_type(w) for _c, w in self.cases} if self.else_: when_types |= _expr_type(self.else_) if len(when_types) > 1: raise RuntimeError(f"Non-matching types in when: {when_types}") - t ,= when_types + (t,) = when_types return t @@ -252,6 +253,7 @@ def compile(self, c: Compiler) -> str: a, b = self.args return f"({c.compile(a)} {self.op} {c.compile(b)})" + class BinBoolOp(BinOp): type = bool @@ -334,8 +336,8 @@ def source_table(self): @property def schema(self): - assert self.columns # TODO Implement SELECT * - s = self.source_tables[0].schema # XXX + assert self.columns # TODO Implement SELECT * + s = self.source_tables[0].schema # XXX return type(s)({c.name: c.type for c in self.columns}) def on(self, *exprs): @@ -390,6 +392,7 @@ class GroupBy(ITable): def having(self): pass + @dataclass class Union(ExprNode, ITable): table1: ITable @@ -441,7 +444,7 @@ def source_table(self): return self def compile(self, parent_c: Compiler) -> str: - c = parent_c.replace(in_select=True) #.add_table_context(self.table) + c = parent_c.replace(in_select=True) # .add_table_context(self.table) columns = ", ".join(map(c.compile, self.columns)) if self.columns else "*" select = f"SELECT {columns}" @@ -547,7 +550,6 @@ def name(self): return self.resolved.name - class This: def __getattr__(self, name): return _ResolveColumn(name) @@ -593,9 +595,11 @@ def compile(self, c: Compiler) -> str: # DDL + class Statement(Compilable): type = None + def to_sql_type(t): if isinstance(t, str): return t @@ -612,9 +616,10 @@ class CreateTable(Statement): if_not_exists: bool = False def compile(self, c: Compiler) -> str: - schema = ', '.join(f'{k} {to_sql_type(v)}' for k, v in self.path.schema.items()) - ne = 'IF NOT EXISTS ' if self.if_not_exists else '' - return f'CREATE TABLE {ne}{c.compile(self.path)}({schema})' + schema = ", ".join(f"{k} {to_sql_type(v)}" for k, v in self.path.schema.items()) + ne = "IF NOT EXISTS " if self.if_not_exists else "" + return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})" + @dataclass class DropTable(Statement): @@ -622,8 +627,9 @@ class DropTable(Statement): if_exists: bool = False def compile(self, c: Compiler) -> str: - ie = 'IF EXISTS ' if self.if_exists else '' - return f'DROP TABLE {ie}{c.compile(self.path)}' + ie = "IF EXISTS " if self.if_exists else "" + return f"DROP TABLE {ie}{c.compile(self.path)}" + @dataclass class InsertToTable(Statement): @@ -632,4 +638,4 @@ class InsertToTable(Statement): expr: Expr def compile(self, c: Compiler) -> str: - return f'INSERT INTO {c.compile(self.path)} {c.compile(self.expr)}' + return f"INSERT INTO {c.compile(self.path)} {c.compile(self.expr)}" diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index 62430880..2c48cb86 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -13,7 +13,7 @@ class Compiler: database: AbstractDialect in_select: bool = False # Compilation runtime flag - in_join: bool = False # Compilation runtime flag + in_join: bool = False # Compilation runtime flag _table_context: List = [] # List[ITable] _subqueries: Dict[str, Any] = {} # XXX not thread-safe @@ -32,7 +32,7 @@ def compile(self, elem) -> str: return f"WITH {subq}\n{res}" return res - def _compile(self, elem) -> Union[str, 'ThreadLocalInterpreter']: + def _compile(self, elem) -> Union[str, "ThreadLocalInterpreter"]: if elem is None: return "NULL" elif isinstance(elem, Compilable): diff --git a/tests/test_query.py b/tests/test_query.py index 3e895bc5..5091843e 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -93,12 +93,12 @@ def test_schema(self): self.assertRaises(KeyError, q.__getitem__, "comment") # test join - s = CaseInsensitiveDict({'x': int, 'y': int}) + s = CaseInsensitiveDict({"x": int, "y": int}) a = table("a", schema=s) b = table("b", schema=s) keys = ["x", "y"] - j = outerjoin(a, b).on(a[k] == b[k] for k in keys).select(a['x'], b['y'], xsum=a['x'] + b['x']) - j['x'], j['y'], j['xsum'] + j = outerjoin(a, b).on(a[k] == b[k] for k in keys).select(a["x"], b["y"], xsum=a["x"] + b["x"]) + j["x"], j["y"], j["xsum"] self.assertRaises(KeyError, j.__getitem__, "ysum") def test_commutable_select(self): @@ -145,8 +145,8 @@ def test_funcs(self): def test_union_all(self): c = Compiler(MockDialect()) - a = table("a").select('x') - b = table("b").select('y') + a = table("a").select("x") + b = table("b").select("y") q = c.compile(a.union(b)) assert q == "SELECT x FROM a UNION SELECT y FROM b" From 00ee4158fc07d2e454145e0c401d607d410e3052 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 4 Oct 2022 17:47:11 +0300 Subject: [PATCH 21/93] Joindiff: Ran black --- data_diff/__main__.py | 30 ++++++++++++----- data_diff/databases/base.py | 3 +- data_diff/databases/presto.py | 2 +- data_diff/diff_tables.py | 23 +++++++++---- data_diff/hashdiff_tables.py | 18 +++++++--- data_diff/joindiff_tables.py | 63 ++++++++++++++++++----------------- data_diff/utils.py | 8 +++-- tests/test_joindiff.py | 18 ++++++++-- 8 files changed, 107 insertions(+), 58 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 437bc67a..c481fdee 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -48,10 +48,10 @@ def _get_schema(pair): def diff_schemas(schema1, schema2, columns): - logging.info('Diffing schemas...') - attrs = 'name', 'type', 'datetime_precision', 'numeric_precision', 'numeric_scale' + logging.info("Diffing schemas...") + attrs = "name", "type", "datetime_precision", "numeric_precision", "numeric_scale" for c in columns: - if c is None: # Skip for convenience + if c is None: # Skip for convenience continue diffs = [] for attr, v1, v2 in safezip(attrs, schema1[c], schema2[c]): @@ -60,6 +60,7 @@ def diff_schemas(schema1, schema2, columns): if diffs: logging.warning(f"Schema mismatch in column '{c}': {', '.join(diffs)}") + class MyHelpFormatter(click.HelpFormatter): def __init__(self, **kwargs): super().__init__(self, **kwargs) @@ -106,7 +107,13 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - help=f"Minimal bisection threshold. Below it, data-diff will download the data and compare it locally. Default={DEFAULT_BISECTION_THRESHOLD}.", metavar="NUM", ) -@click.option("-m", "--materialize", default=None, metavar="TABLE_NAME", help="Materialize the diff results into a new table in the database.") +@click.option( + "-m", + "--materialize", + default=None, + metavar="TABLE_NAME", + help="Materialize the diff results into a new table in the database. (joindiff only)", +) @click.option( "--min-age", default=None, @@ -266,8 +273,8 @@ def _main( differ = JoinDiffer( threaded=threaded, max_threadpool_size=threads and threads * 2, - validate_unique_key = not assume_unique_key, - materialize_to_table = materialize and parse_table_name(eval_name_template(materialize)), + validate_unique_key=not assume_unique_key, + materialize_to_table=materialize and parse_table_name(eval_name_template(materialize)), ) else: assert algorithm == Algorithm.HASHDIFF @@ -326,8 +333,15 @@ def _main( columns = tuple(expanded_columns - {key_column, update_column}) if db1 is db2: - diff_schemas(schema1, schema2, (key_column, update_column,) + columns) - + diff_schemas( + schema1, + schema2, + ( + key_column, + update_column, + ) + + columns, + ) logging.info(f"Diffing using columns: key={key_column} update={update_column} extra={columns}") diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index bd33165f..2956ab63 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -27,7 +27,7 @@ DbPath, ) -from data_diff.queries import Expr, Compiler, table, Select, SKIP, ThreadLocalInterpreter +from data_diff.queries import Expr, Compiler, table, Select, SKIP, ThreadLocalInterpreter logger = logging.getLogger("database") @@ -75,6 +75,7 @@ def _query_cursor(c, sql_code): logger.exception(e) raise + def _query_conn(conn, sql_code: Union[str, ThreadLocalInterpreter]) -> list: c = conn.cursor() diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index c990e06e..85ec4c7c 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -11,6 +11,7 @@ TIMESTAMP_PRECISION_POS, ) + def query_cursor(c, sql_code): c.execute(sql_code) if sql_code.lower().startswith("select"): @@ -87,7 +88,6 @@ def _query(self, sql_code: str) -> list: return query_cursor(c, sql_code) - def close(self): self._conn.close() diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 5ecd1667..7ca0646a 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -20,6 +20,7 @@ logger = getLogger(__name__) + class Algorithm(Enum): AUTO = "auto" JOINDIFF = "joindiff" @@ -28,8 +29,9 @@ class Algorithm(Enum): DiffResult = Iterator[Tuple[str, tuple]] # Iterator[Tuple[Literal["+", "-"], tuple]] + def truncate_error(error: str): - first_line = error.split('\n', 1)[0] + first_line = error.split("\n", 1)[0] return re.sub("'(.*?)'", "'***'", first_line) @@ -137,12 +139,19 @@ def _validate_and_adjust_columns(self, table1: TableSegment, table2: TableSegmen def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: return self._bisect_and_diff_tables(table1, table2) - @abstractmethod - def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): + def _diff_segments( + self, + ti: ThreadedYielder, + table1: TableSegment, + table2: TableSegment, + max_rows: int, + level=0, + segment_index=None, + segment_count=None, + ): ... - def _bisect_and_diff_tables(self, table1, table2): key_type = table1._schema[table1.key_column] key_type2 = table2._schema[table2.key_column] @@ -183,7 +192,6 @@ def _bisect_and_diff_tables(self, table1, table2): return ti - def _parse_key_range_result(self, key_type, key_range): mn, mx = key_range cls = key_type.make_value @@ -193,8 +201,9 @@ def _parse_key_range_result(self, key_type, key_range): except (TypeError, ValueError) as e: raise type(e)(f"Cannot apply {key_type} to '{mn}', '{mx}'.") from e - - def _bisect_and_diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None): + def _bisect_and_diff_segments( + self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None + ): assert table1.is_bounded and table2.is_bounded # Choose evenly spaced checkpoints (according to min_key and max_key) diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index 64b05b67..78b33e17 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -66,8 +66,6 @@ def __post_init__(self): if self.bisection_factor < 2: raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)") - - def _validate_and_adjust_columns(self, table1, table2): for c1, c2 in safezip(table1._relevant_columns, table2._relevant_columns): if c1 not in table1._schema: @@ -115,8 +113,16 @@ def _validate_and_adjust_columns(self, table1, table2): "If encoding/formatting differs between databases, it may result in false positives." ) - - def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): + def _diff_segments( + self, + ti: ThreadedYielder, + table1: TableSegment, + table2: TableSegment, + max_rows: int, + level=0, + segment_index=None, + segment_count=None, + ): logger.info( ". " * level + f"Diffing segment {segment_index}/{segment_count}, " f"key-range: {table1.min_key}..{table2.max_key}, " @@ -148,7 +154,9 @@ def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: Tabl if checksum1 != checksum2: return self._bisect_and_diff_segments(ti, table1, table2, level=level, max_rows=max(count1, count2)) - def _bisect_and_diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None): + def _bisect_and_diff_segments( + self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None + ): assert table1.is_bounded and table2.is_bounded if max_rows is None: diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 1afb9467..f488e945 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -50,6 +50,7 @@ class Stats: def sample(table): return table.order_by(Random()).limit(10) + def create_temp_table(c: Compiler, name: str, expr: Expr): db = c.database if isinstance(db, BigQuery): @@ -67,12 +68,13 @@ def drop_table(db, name: DbPath): t = TablePath(name) db.query(t.drop(if_exists=True)) + def append_to_table(name: DbPath, expr: Expr): t = TablePath(name, expr.schema) yield t.create(if_not_exists=True) # uses expr.schema - yield 'commit' + yield "commit" yield t.insert_expr(expr) - yield 'commit' + yield "commit" def bool_to_int(x): @@ -95,10 +97,7 @@ def _outerjoin(db: Database, a: ITable, b: ITable, keys1: List[str], keys2: List r = rightjoin(a, b).on(*on).select(is_exclusive_a=False, is_exclusive_b=is_exclusive_b, **select_fields) return l.union(r) - return ( - outerjoin(a, b).on(*on) - .select(is_exclusive_a=is_exclusive_a, is_exclusive_b=is_exclusive_b, **select_fields) - ) + return outerjoin(a, b).on(*on).select(is_exclusive_a=is_exclusive_a, is_exclusive_b=is_exclusive_b, **select_fields) def _slice_tuple(t, *sizes): @@ -115,7 +114,6 @@ def json_friendly_value(v): return v - @dataclass class JoinDiffer(TableDiffer): """Finds the diff between two SQL tables in the same database, using JOINs. @@ -143,11 +141,10 @@ def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult table1, table2 = self._threaded_call("with_schema", [table1, table2]) - bg_funcs = [partial(self._test_duplicate_keys, table1, table2)] if self.validate_unique_key else [] if self.materialize_to_table: drop_table(db, self.materialize_to_table) - db.query('COMMIT') + db.query("COMMIT") with self._run_in_background(*bg_funcs): @@ -158,7 +155,16 @@ def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult yield from self._bisect_and_diff_tables(table1, table2) logger.info("Diffing complete") - def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): + def _diff_segments( + self, + ti: ThreadedYielder, + table1: TableSegment, + table2: TableSegment, + max_rows: int, + level=0, + segment_index=None, + segment_count=None, + ): assert table1.database is table2.database if segment_index or table1.min_key or max_rows: @@ -172,13 +178,15 @@ def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: Tabl diff_rows, a_cols, b_cols, is_diff_cols = self._create_outer_join(table1, table2) with self._run_in_background( - partial(self._collect_stats, 1, table1), - partial(self._collect_stats, 2, table2), - partial(self._test_null_keys, table1, table2), - partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols), - partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols), - partial(self._materialize_diff, db, diff_rows, segment_index=segment_index) if self.materialize_to_table else None, - ): + partial(self._collect_stats, 1, table1), + partial(self._collect_stats, 2, table2), + partial(self._test_null_keys, table1, table2), + partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols), + partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols), + partial(self._materialize_diff, db, diff_rows, segment_index=segment_index) + if self.materialize_to_table + else None, + ): logger.debug("Querying for different rows") for is_xa, is_xb, *x in db.query(diff_rows, list): @@ -218,7 +226,6 @@ def _test_null_keys(self, table1, table2): if nulls: raise ValueError(f"NULL values in one or more primary keys") - def _collect_stats(self, i, table): logger.info(f"Collecting stats for table #{i}") db = table.database @@ -265,31 +272,27 @@ def _create_outer_join(self, table1, table2): a = table1._make_select() b = table2._make_select() - is_diff_cols = { - f"is_diff_{c1}": bool_to_int(a[c1].is_distinct_from(b[c2])) for c1, c2 in safezip(cols1, cols2) - } + is_diff_cols = {f"is_diff_{c1}": bool_to_int(a[c1].is_distinct_from(b[c2])) for c1, c2 in safezip(cols1, cols2)} a_cols = {f"table1_{c}": NormalizeAsString(a[c]) for c in cols1} b_cols = {f"table2_{c}": NormalizeAsString(b[c]) for c in cols2} - diff_rows = ( - _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols}) - .where(or_(this[c] == 1 for c in is_diff_cols)) + diff_rows = _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols}).where( + or_(this[c] == 1 for c in is_diff_cols) ) return diff_rows, a_cols, b_cols, is_diff_cols - def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): logger.info("Counting differences per column") is_diff_cols_counts = db.query(diff_rows.select(sum_(this[c]) for c in is_diff_cols), tuple) diff_counts = {} for name, count in safezip(cols, is_diff_cols_counts): diff_counts[name] = diff_counts.get(name, 0) + (count or 0) - self.stats['diff_counts'] = diff_counts + self.stats["diff_counts"] = diff_counts def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): if isinstance(db, Oracle): - exclusive_rows_query = diff_rows.where((this.is_exclusive_a==1) | (this.is_exclusive_b==1)) + exclusive_rows_query = diff_rows.where((this.is_exclusive_a == 1) | (this.is_exclusive_b == 1)) else: exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b) @@ -299,6 +302,7 @@ def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): return logger.info("Counting and sampling exclusive rows") + def exclusive_rows(expr): c = Compiler(db) name = c.new_unique_table_name("temp_table") @@ -306,9 +310,9 @@ def exclusive_rows(expr): exclusive_rows = table(name, schema=expr.source_table.schema) count = yield exclusive_rows.count() - self.stats["exclusive_count"] = self.stats.get('exclusive_count', 0) + count[0][0] + self.stats["exclusive_count"] = self.stats.get("exclusive_count", 0) + count[0][0] sample_rows = yield sample(exclusive_rows.select(*this[list(a_cols)], *this[list(b_cols)])) - self.stats["exclusive_sample"] = self.stats.get('exclusive_sample', []) + sample_rows + self.stats["exclusive_sample"] = self.stats.get("exclusive_sample", []) + sample_rows # Only drops if create table succeeded (meaning, the table didn't already exist) yield f"drop table {c.quote(name)}" @@ -321,4 +325,3 @@ def _materialize_diff(self, db, diff_rows, segment_index=None): db.query(append_to_table(self.materialize_to_table, diff_rows.limit(self.write_limit))) logger.info(f"Materialized diff to table '{'.'.join(self.materialize_to_table)}'.") - diff --git a/data_diff/utils.py b/data_diff/utils.py index ca05e051..8224d270 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -295,9 +295,11 @@ def run_as_daemon(threadfunc, *args): def getLogger(name): - return logging.getLogger(name.rsplit('.', 1)[-1]) + return logging.getLogger(name.rsplit(".", 1)[-1]) + def eval_name_template(name): def get_timestamp(m): - return datetime.now().isoformat('_', 'seconds').replace(':', '_') - return re.sub('%t', get_timestamp, name) + return datetime.now().isoformat("_", "seconds").replace(":", "_") + + return re.sub("%t", get_timestamp, name) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index e8db3167..d9726c85 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -25,7 +25,20 @@ def init_instances(): DATABASE_INSTANCES = {k.__name__: connect(v, N_THREADS) for k, v in CONN_STRINGS.items()} -TEST_DATABASES = {x.__name__ for x in (db.PostgreSQL, db.Snowflake, db.MySQL, db.BigQuery, db.Presto, db.Vertica, db.Trino, db.Oracle, db.Redshift)} +TEST_DATABASES = { + x.__name__ + for x in ( + db.PostgreSQL, + db.Snowflake, + db.MySQL, + db.BigQuery, + db.Presto, + db.Vertica, + db.Trino, + db.Oracle, + db.Redshift, + ) +} _class_per_db_dec = parameterized_class( ("name", "db_name"), [(name, name) for name in DATABASE_URIS if name in TEST_DATABASES] @@ -179,14 +192,13 @@ def test_dup_pks(self): x = self.differ.diff_tables(self.table, self.table2) self.assertRaises(ValueError, list, x) - def test_null_pks(self): time = "2022-01-01 00:00:00" time_str = f"timestamp '{time}'" cols = "id rating timestamp".split() - _insert_row(self.connection, self.table_src, cols, ['null', 9, time_str]) + _insert_row(self.connection, self.table_src, cols, ["null", 9, time_str]) _insert_row(self.connection, self.table_dst, cols, [1, 9, time_str]) x = self.differ.diff_tables(self.table, self.table2) From 78e4c84068da493cec3e99b21ed59120e79db16b Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 5 Oct 2022 11:08:56 +0300 Subject: [PATCH 22/93] Many fixes; Added materialize tests; Now works for : postgresql, mysql, bigquery, presto, trino, snowflake, oracle, redshift --- data_diff/__main__.py | 34 +++++----- data_diff/databases/base.py | 93 ++++++++++++++++++++------- data_diff/databases/bigquery.py | 21 +++++- data_diff/databases/database_types.py | 6 +- data_diff/databases/databricks.py | 4 +- data_diff/databases/mysql.py | 8 +++ data_diff/databases/oracle.py | 12 +++- data_diff/databases/presto.py | 18 ++---- data_diff/databases/snowflake.py | 9 ++- data_diff/joindiff_tables.py | 57 +++++++++++----- data_diff/queries/__init__.py | 4 +- data_diff/queries/api.py | 2 + data_diff/queries/ast_classes.py | 21 +++--- data_diff/queries/base.py | 6 +- data_diff/queries/compiler.py | 31 ++------- data_diff/utils.py | 3 + tests/test_diff_tables.py | 8 +++ tests/test_joindiff.py | 19 +++++- 18 files changed, 233 insertions(+), 123 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index c481fdee..c6c8fefe 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -9,8 +9,6 @@ import rich import click -from data_diff.databases.base import parse_table_name - from .utils import eval_name_template, remove_password_from_url, safezip, match_like from .diff_tables import Algorithm from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR @@ -269,22 +267,6 @@ def _main( logging.error(f"Error while parsing age expression: {e}") return - if algorithm == Algorithm.JOINDIFF: - differ = JoinDiffer( - threaded=threaded, - max_threadpool_size=threads and threads * 2, - validate_unique_key=not assume_unique_key, - materialize_to_table=materialize and parse_table_name(eval_name_template(materialize)), - ) - else: - assert algorithm == Algorithm.HASHDIFF - differ = HashDiffer( - bisection_factor=bisection_factor, - bisection_threshold=bisection_threshold, - threaded=threaded, - max_threadpool_size=threads and threads * 2, - ) - if database1 is None or database2 is None: logging.error( f"Error: Databases not specified. Got {database1} and {database2}. Use --help for more information." @@ -307,6 +289,22 @@ def _main( for db in dbs: db.enable_interactive() + if algorithm == Algorithm.JOINDIFF: + differ = JoinDiffer( + threaded=threaded, + max_threadpool_size=threads and threads * 2, + validate_unique_key=not assume_unique_key, + materialize_to_table=materialize and db1.parse_table_name(eval_name_template(materialize)), + ) + else: + assert algorithm == Algorithm.HASHDIFF + differ = HashDiffer( + bisection_factor=bisection_factor, + bisection_threshold=bisection_threshold, + threaded=threaded, + max_threadpool_size=threads and threads * 2, + ) + table_names = table1, table2 table_paths = [db.parse_table_name(t) for db, t in safezip(dbs, table_names)] diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 2956ab63..79a17578 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -1,8 +1,8 @@ import math import sys import logging -from typing import Dict, Generator, Tuple, Optional, Sequence, Type, List, Union -from functools import wraps +from typing import Any, Callable, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union +from functools import partial, wraps from concurrent.futures import ThreadPoolExecutor import threading from abc import abstractmethod @@ -27,7 +27,7 @@ DbPath, ) -from data_diff.queries import Expr, Compiler, table, Select, SKIP, ThreadLocalInterpreter +from data_diff.queries import Expr, Compiler, table, Select, SKIP logger = logging.getLogger("database") @@ -66,30 +66,39 @@ def _one(seq): return x -def _query_cursor(c, sql_code): - try: - c.execute(sql_code) - if sql_code.lower().startswith("select"): - return c.fetchall() - except Exception as e: - logger.exception(e) - raise +class ThreadLocalInterpreter: + """An interpeter used to execute a sequence of queries within the same thread. + Useful for cursor-sensitive operations, such as creating a temporary table. + """ -def _query_conn(conn, sql_code: Union[str, ThreadLocalInterpreter]) -> list: - c = conn.cursor() + def __init__(self, compiler: Compiler, gen: Generator): + self.gen = gen + self.compiler = compiler - if isinstance(sql_code, ThreadLocalInterpreter): - g = sql_code.interpret() - q = next(g) + def apply_queries(self, callback: Callable[[str], Any]): + q: Expr = next(self.gen) while True: - res = _query_cursor(c, q) + sql = self.compiler.compile(q) try: - q = g.send(res) + try: + res = callback(sql) if sql is not SKIP else SKIP + except Exception as e: + q = self.gen.throw(type(e), e) + else: + q = self.gen.send(res) except StopIteration: break + + +def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocalInterpreter]) -> list: + if isinstance(sql_code, ThreadLocalInterpreter): + return sql_code.apply_queries(callback) else: - return _query_cursor(c, sql_code) + return callback(sql_code) + + + class Database(AbstractDatabase): @@ -108,11 +117,17 @@ class Database(AbstractDatabase): def name(self): return type(self).__name__ - def query(self, sql_ast: Expr, res_type: type = None): + def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): "Query the given SQL code/AST, and attempt to convert the result to type 'res_type'" compiler = Compiler(self) - sql_code = compiler.compile(sql_ast) + if isinstance(sql_ast, Generator): + sql_code = ThreadLocalInterpreter(compiler, sql_ast) + else: + sql_code = compiler.compile(sql_ast) + if sql_code is SKIP: + return SKIP + logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code) if getattr(self, "_interactive", False) and isinstance(sql_ast, Select): explained_sql = compiler.compile(Explain(sql_ast)) @@ -134,7 +149,7 @@ def query(self, sql_ast: Expr, res_type: type = None): elif getattr(res_type, "__origin__", None) is list and len(res_type.__args__) == 1: if res_type.__args__ in ((int,), (str,)): return [_one(row) for row in res] - elif res_type.__args__ == (Tuple,): + elif res_type.__args__ in [(Tuple,), (tuple,)]: return [tuple(row) for row in res] else: raise ValueError(res_type) @@ -311,6 +326,34 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: def random(self) -> str: return "RANDOM()" + def type_repr(self, t) -> str: + if isinstance(t, str): + return t + return { + int: "INT", + str: "VARCHAR", + bool: "BOOLEAN", + float: "FLOAT", + }[t] + + def _query_cursor(self, c, sql_code: str): + assert isinstance(sql_code, str), sql_code + try: + c.execute(sql_code) + if sql_code.lower().startswith("select"): + return c.fetchall() + except Exception as e: + # logger.exception(e) + # logger.error(f'Caused by SQL: {sql_code}') + raise + + def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> list: + c = conn.cursor() + callback = partial(self._query_cursor, c) + return apply_query(callback, sql_code) + + + class ThreadedDatabase(Database): """Access the database through singleton threads. @@ -339,7 +382,7 @@ def _query_in_worker(self, sql_code: Union[str, ThreadLocalInterpreter]): "This method runs in a worker thread" if self._init_error: raise self._init_error - return _query_conn(self.thread_local.conn, sql_code) + return self._query_conn(self.thread_local.conn, sql_code) @abstractmethod def create_connection(self): @@ -348,6 +391,10 @@ def create_connection(self): def close(self): self._queue.shutdown() + @property + def is_autocommit(self) -> bool: + return False + CHECKSUM_HEXDIGITS = 15 # Must be 15 or lower MD5_HEXDIGITS = 32 diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 218c9cb4..7044c084 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -1,6 +1,6 @@ from .database_types import * -from .base import Database, import_helper, parse_table_name, ConnectError -from .base import TIMESTAMP_PRECISION_POS +from .base import Database, import_helper, parse_table_name, ConnectError, apply_query +from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter @import_helper(text="Please install BigQuery and configure your google-cloud access.") @@ -47,7 +47,7 @@ def _normalize_returned_value(self, value): return value.decode() return value - def _query(self, sql_code: str): + def _query_atom(self, sql_code: str): from google.cloud import bigquery try: @@ -60,6 +60,9 @@ def _query(self, sql_code: str): res = [tuple(self._normalize_returned_value(v) for v in row.values()) for row in res] return res + def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): + return apply_query(self._query_atom, sql_code) + def to_string(self, s: str): return f"cast({s} as string)" @@ -98,3 +101,15 @@ def parse_table_name(self, name: str) -> DbPath: def random(self) -> str: return "RAND()" + + @property + def is_autocommit(self) -> bool: + return True + + def type_repr(self, t) -> str: + try: + return { + str: "STRING", + }[t] + except KeyError: + return super().type_repr(t) diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index 1e9c973e..7fe436ae 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -1,6 +1,6 @@ import logging import decimal -from abc import ABC, abstractmethod +from abc import ABC, abstractmethod, abstractproperty from typing import Sequence, Optional, Tuple, Union, Dict, List from datetime import datetime @@ -293,6 +293,10 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: def _normalize_table_path(self, path: DbPath) -> DbPath: ... + @abstractproperty + def is_autocommit(self) -> bool: + ... + Schema = CaseAwareMapping diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index b0ee9fa5..5d381b66 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -1,7 +1,7 @@ import logging from .database_types import * -from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Database, import_helper, _query_conn, parse_table_name +from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Database, import_helper, parse_table_name @import_helper(text="You can install it using 'pip install databricks-sql-connector'") @@ -52,7 +52,7 @@ def __init__( def _query(self, sql_code: str) -> list: "Uses the standard SQL cursor interface" - return _query_conn(self._conn, sql_code) + return self._query_conn(self._conn, sql_code) def quote(self, s: str): return f"`{s}`" diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 07c34aaf..b34afb36 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -76,3 +76,11 @@ def is_distinct_from(self, a: str, b: str) -> str: def random(self) -> str: return "RAND()" + + def type_repr(self, t) -> str: + try: + return { + str: "VARCHAR(1024)", + }[t] + except KeyError: + return super().type_repr(t) diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 79f7bf31..59004412 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -43,9 +43,9 @@ def create_connection(self): except Exception as e: raise ConnectError(*e.args) from e - def _query(self, sql_code: str): + def _query_cursor(self, c, sql_code: str): try: - return super()._query(sql_code) + return super()._query_cursor(c, sql_code) except self._oracle.DatabaseError as e: raise QueryError(e) @@ -130,3 +130,11 @@ def random(self) -> str: def is_distinct_from(self, a: str, b: str) -> str: return f"DECODE({a}, {b}, 1, 0) = 0" + + def type_repr(self, t) -> str: + try: + return { + str: "VARCHAR(1024)", + }[t] + except KeyError: + return super().type_repr(t) diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index 85ec4c7c..2fb041fc 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -1,10 +1,10 @@ +from functools import partial import re from data_diff.utils import match_regexps -from data_diff.queries import ThreadLocalInterpreter from .database_types import * -from .base import Database, import_helper +from .base import Database, import_helper, ThreadLocalInterpreter from .base import ( MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, @@ -75,16 +75,7 @@ def _query(self, sql_code: str) -> list: c = self._conn.cursor() if isinstance(sql_code, ThreadLocalInterpreter): - # TODO reuse code from base.py - g = sql_code.interpret() - q = next(g) - while True: - res = query_cursor(c, q) - try: - q = g.send(res) - except StopIteration: - break - return + return sql_code.apply_queries(partial(query_cursor, c)) return query_cursor(c, sql_code) @@ -142,3 +133,6 @@ def _parse_type( def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Trim doesn't work on CHAR type return f"TRIM(CAST({value} AS VARCHAR))" + + def is_autocommit(self) -> bool: + return False diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 9b03d833..bbd0958c 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,7 +1,7 @@ import logging from .database_types import * -from .base import ConnectError, Database, import_helper, _query_conn, CHECKSUM_MASK +from .base import ConnectError, Database, import_helper, CHECKSUM_MASK, ThreadLocalInterpreter @import_helper("snowflake") @@ -60,9 +60,9 @@ def __init__(self, *, schema: str, **kw): def close(self): self._conn.close() - def _query(self, sql_code: str) -> list: + def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): "Uses the standard SQL cursor interface" - return _query_conn(self._conn, sql_code) + return self._query_conn(self._conn, sql_code) def quote(self, s: str): return f'"{s}"' @@ -87,3 +87,6 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: def normalize_number(self, value: str, coltype: FractionalType) -> str: return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") + + def is_autocommit(self) -> bool: + return True diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index f488e945..c42826c6 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -2,6 +2,7 @@ """ +from contextlib import suppress from decimal import Decimal from functools import partial import logging @@ -10,6 +11,7 @@ from runtype import dataclass from data_diff.databases.database_types import DbPath, Schema +from data_diff.databases.base import QueryError from .utils import safezip @@ -19,7 +21,7 @@ from .diff_tables import TableDiffer, DiffResult from .thread_utils import ThreadedYielder -from .queries import table, sum_, min_, max_, avg, SKIP +from .queries import table, sum_, min_, max_, avg, SKIP, commit from .queries.api import and_, if_, or_, outerjoin, leftjoin, rightjoin, this, ITable from .queries.ast_classes import Concat, Count, Expr, Random, TablePath from .queries.compiler import Compiler @@ -51,30 +53,48 @@ def sample(table): return table.order_by(Random()).limit(10) -def create_temp_table(c: Compiler, name: str, expr: Expr): +def create_temp_table(c: Compiler, table: TablePath, expr: Expr): db = c.database if isinstance(db, BigQuery): - name = f"{db.default_schema}.{name}" - return f"create table {c.quote(name)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}" + return f"create table {c.compile(table)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}" elif isinstance(db, Presto): - return f"create table {c.quote(name)} as {c.compile(expr)}" + return f"create table {c.compile(table)} as {c.compile(expr)}" elif isinstance(db, Oracle): - return f"create global temporary table {c.quote(name)} as {c.compile(expr)}" + return f"create global temporary table {c.compile(table)} as {c.compile(expr)}" else: - return f"create temporary table {c.quote(name)} as {c.compile(expr)}" + return f"create temporary table {c.compile(table)} as {c.compile(expr)}" -def drop_table(db, name: DbPath): +def drop_table_oracle(name: DbPath): t = TablePath(name) - db.query(t.drop(if_exists=True)) + # Experience shows double drop is necessary + with suppress(QueryError): + yield t.drop() + yield t.drop() + yield commit +def drop_table(name: DbPath): + t = TablePath(name) + yield t.drop(if_exists=True) + yield commit + + +def append_to_table_oracle(name: DbPath, expr: Expr): + assert expr.schema, expr + t = TablePath(name, expr.schema) + with suppress(QueryError): + yield t.create() # uses expr.schema + yield commit + yield t.insert_expr(expr) + yield commit def append_to_table(name: DbPath, expr: Expr): + assert expr.schema, expr t = TablePath(name, expr.schema) yield t.create(if_not_exists=True) # uses expr.schema - yield "commit" + yield commit yield t.insert_expr(expr) - yield "commit" + yield commit def bool_to_int(x): @@ -143,8 +163,10 @@ def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult bg_funcs = [partial(self._test_duplicate_keys, table1, table2)] if self.validate_unique_key else [] if self.materialize_to_table: - drop_table(db, self.materialize_to_table) - db.query("COMMIT") + if isinstance(db, Oracle): + db.query(drop_table_oracle(self.materialize_to_table)) + else: + db.query(drop_table(self.materialize_to_table)) with self._run_in_background(*bg_funcs): @@ -306,8 +328,8 @@ def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): def exclusive_rows(expr): c = Compiler(db) name = c.new_unique_table_name("temp_table") - yield create_temp_table(c, name, expr.limit(self.write_limit)) - exclusive_rows = table(name, schema=expr.source_table.schema) + exclusive_rows = TablePath(name, schema=expr.source_table.schema) + yield create_temp_table(c, exclusive_rows, expr.limit(self.write_limit)) count = yield exclusive_rows.count() self.stats["exclusive_count"] = self.stats.get("exclusive_count", 0) + count[0][0] @@ -315,7 +337,7 @@ def exclusive_rows(expr): self.stats["exclusive_sample"] = self.stats.get("exclusive_sample", []) + sample_rows # Only drops if create table succeeded (meaning, the table didn't already exist) - yield f"drop table {c.quote(name)}" + yield exclusive_rows.drop() # Run as a sequence of thread-local queries (compiled into a ThreadLocalInterpreter) db.query(exclusive_rows(exclusive_rows_query), None) @@ -323,5 +345,6 @@ def exclusive_rows(expr): def _materialize_diff(self, db, diff_rows, segment_index=None): assert self.materialize_to_table - db.query(append_to_table(self.materialize_to_table, diff_rows.limit(self.write_limit))) + f = append_to_table_oracle if isinstance(db, Oracle) else append_to_table + db.query(f(self.materialize_to_table, diff_rows.limit(self.write_limit))) logger.info(f"Materialized diff to table '{'.'.join(self.materialize_to_table)}'.") diff --git a/data_diff/queries/__init__.py b/data_diff/queries/__init__.py index 64a6e60f..172e73e4 100644 --- a/data_diff/queries/__init__.py +++ b/data_diff/queries/__init__.py @@ -1,4 +1,4 @@ -from .compiler import Compiler, ThreadLocalInterpreter -from .api import this, join, outerjoin, table, SKIP, sum_, avg, min_, max_, cte +from .compiler import Compiler +from .api import this, join, outerjoin, table, SKIP, sum_, avg, min_, max_, cte, commit from .ast_classes import Expr, ExprNode, Select, Count, BinOp, Explain, In from .extras import Checksum, NormalizeAsString, ApplyFuncAndNormalizeAsString diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index a07a9084..c433f548 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -67,3 +67,5 @@ def max_(expr: Expr): def if_(cond: Expr, then: Expr, else_: Optional[Expr] = None): return CaseWhen([(cond, then)], else_=else_) + +commit = Commit() diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index b3552620..7bd9a520 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -186,7 +186,7 @@ class CaseWhen(ExprNode): def compile(self, c: Compiler) -> str: assert self.cases when_thens = " ".join(f"WHEN {c.compile(when)} THEN {c.compile(then)}" for when, then in self.cases) - else_ = (" " + c.compile(self.else_)) if self.else_ else "" + else_ = (" ELSE " + c.compile(self.else_)) if self.else_ is not None else "" return f"CASE {when_thens}{else_} END" @property @@ -600,23 +600,13 @@ class Statement(Compilable): type = None -def to_sql_type(t): - if isinstance(t, str): - return t - return { - int: "int", - str: "varchar(1024)", - bool: "boolean", - }[t] - - @dataclass class CreateTable(Statement): path: TablePath if_not_exists: bool = False def compile(self, c: Compiler) -> str: - schema = ", ".join(f"{k} {to_sql_type(v)}" for k, v in self.path.schema.items()) + schema = ", ".join(f"{k} {c.database.type_repr(v)}" for k, v in self.path.schema.items()) ne = "IF NOT EXISTS " if self.if_not_exists else "" return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})" @@ -639,3 +629,10 @@ class InsertToTable(Statement): def compile(self, c: Compiler) -> str: return f"INSERT INTO {c.compile(self.path)} {c.compile(self.expr)}" + + +@dataclass +class Commit(Statement): + + def compile(self, c: Compiler) -> str: + return "COMMIT" if not c.database.is_autocommit else SKIP diff --git a/data_diff/queries/base.py b/data_diff/queries/base.py index 50a57e2f..b5d02bb6 100644 --- a/data_diff/queries/base.py +++ b/data_diff/queries/base.py @@ -3,7 +3,11 @@ from data_diff.databases.database_types import DbPath, DbKey, Schema -SKIP = object() +class _SKIP: + def __repr__(self): + return 'SKIP' + +SKIP = _SKIP() class CompileError(Exception): diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index 2c48cb86..e6e3e236 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -1,12 +1,12 @@ import random from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Dict, Generator, Sequence, List, Union +from typing import Any, Dict, Sequence, List, Union from runtype import dataclass from data_diff.utils import ArithString -from data_diff.databases.database_types import AbstractDialect +from data_diff.databases.database_types import AbstractDialect, DbPath @dataclass @@ -32,7 +32,7 @@ def compile(self, elem) -> str: return f"WITH {subq}\n{res}" return res - def _compile(self, elem) -> Union[str, "ThreadLocalInterpreter"]: + def _compile(self, elem) -> str: if elem is None: return "NULL" elif isinstance(elem, Compilable): @@ -47,17 +47,15 @@ def _compile(self, elem) -> Union[str, "ThreadLocalInterpreter"]: return f"b'{elem.decode()}'" elif isinstance(elem, ArithString): return f"'{elem}'" - elif isinstance(elem, Generator): - return ThreadLocalInterpreter(self, elem) assert False, elem def new_unique_name(self, prefix="tmp"): self._counter[0] += 1 return f"{prefix}{self._counter[0]}" - def new_unique_table_name(self, prefix="tmp"): + def new_unique_table_name(self, prefix="tmp") -> DbPath: self._counter[0] += 1 - return f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}" + return self.database.parse_table_name(f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}") def add_table_context(self, *tables: Sequence): return self.replace(_table_context=self._table_context + list(tables)) @@ -68,22 +66,3 @@ class Compilable(ABC): def compile(self, c: Compiler) -> str: ... - -class ThreadLocalInterpreter: - """An interpeter used to execute a sequence of queries within the same thread. - - Useful for cursor-sensitive operations, such as creating a temporary table. - """ - - def __init__(self, compiler: Compiler, gen: Generator): - self.gen = gen - self.compiler = compiler - - def interpret(self): - q = next(self.gen) - while True: - try: - res = yield self.compiler.compile(q) - q = self.gen.send(res) - except StopIteration: - break diff --git a/data_diff/utils.py b/data_diff/utils.py index 8224d270..b572db1b 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -247,6 +247,9 @@ def keys(self) -> Iterable[str]: def items(self) -> Iterable[Tuple[str, V]]: return ((k, v[1]) for k, v in self._dict.items()) + def __len__(self): + return len(self._dict) + class CaseSensitiveDict(dict, CaseAwareMapping): def get_key(self, key): diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 7668ec61..639587de 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -515,6 +515,10 @@ def setUp(self): ] for i in range(0, 10000, 1000): a = ArithAlphanumeric(numberToAlphanum(i), max_len=10) + if not a and isinstance(self.connection, db.Oracle): + # Skip empty string, because Oracle treats it as NULL .. + continue + queries.append(f"INSERT INTO {self.table_src} VALUES ('{a}', '{i}')") queries += [ @@ -563,6 +567,10 @@ def setUp(self): ] for i in range(0, 10000, 1000): a = ArithAlphanumeric(numberToAlphanum(i * i)) + if not a and isinstance(self.connection, db.Oracle): + # Skip empty string, because Oracle treats it as NULL .. + continue + queries.append(f"INSERT INTO {self.table_src} VALUES ('{a}', '{i}')") queries += [ diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index d9726c85..139f62f0 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -1,6 +1,8 @@ +from typing import List from parameterized import parameterized_class from data_diff.databases.connect import connect +from data_diff.queries.ast_classes import TablePath from data_diff.table_segment import TableSegment, split_space from data_diff import databases as db from data_diff.joindiff_tables import JoinDiffer @@ -8,6 +10,7 @@ from .test_diff_tables import TestPerDatabase, _get_float_type, _get_text_type, _commit, _insert_row, _insert_rows from .common import ( + random_table_suffix, str_to_checksum, CONN_STRINGS, N_THREADS, @@ -80,11 +83,25 @@ def test_diff_small_tables(self): _insert_rows(self.connection, self.table_dst, cols, [[1, 1, 1, 9, time_str]]) _commit(self.connection) diff = list(self.differ.diff_tables(self.table, self.table2)) - expected = [("-", ("2", time + ".000000"))] + expected_row = ("2", time + ".000000") + expected = [("-", expected_row)] self.assertEqual(expected, diff) self.assertEqual(2, self.differ.stats["table1_count"]) self.assertEqual(1, self.differ.stats["table2_count"]) + # Test materialize + materialize_path = self.connection.parse_table_name(f'test_mat_{random_table_suffix()}') + mdiffer = self.differ.replace(materialize_to_table=materialize_path) + diff = list(mdiffer.diff_tables(self.table, self.table2)) + self.assertEqual(expected, diff) + + t = TablePath(materialize_path) + rows = self.connection.query( t.select(), List[tuple] ) + self.connection.query( t.drop() ) + # is_xa, is_xb, is_diff1, is_diff2, row1, row2 + assert rows == [(1, 0, 1, 1) + expected_row + (None, None)], rows + + def test_diff_table_above_bisection_threshold(self): time = "2022-01-01 00:00:00" time_str = f"timestamp '{time}'" From b18dbcb3b913500b1ba4c6ac18574ae329565542 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 5 Oct 2022 19:10:44 +0300 Subject: [PATCH 23/93] black --- data_diff/databases/base.py | 5 ----- data_diff/joindiff_tables.py | 2 ++ data_diff/queries/api.py | 1 + data_diff/queries/ast_classes.py | 3 ++- data_diff/queries/base.py | 3 ++- data_diff/queries/compiler.py | 1 - tests/test_joindiff.py | 7 +++---- 7 files changed, 10 insertions(+), 12 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 79a17578..376f5e78 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -98,9 +98,6 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal return callback(sql_code) - - - class Database(AbstractDatabase): """Base abstract class for databases. @@ -353,8 +350,6 @@ def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> lis return apply_query(callback, sql_code) - - class ThreadedDatabase(Database): """Access the database through singleton threads. diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index c42826c6..e0579845 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -73,6 +73,7 @@ def drop_table_oracle(name: DbPath): yield t.drop() yield commit + def drop_table(name: DbPath): t = TablePath(name) yield t.drop(if_exists=True) @@ -88,6 +89,7 @@ def append_to_table_oracle(name: DbPath, expr: Expr): yield t.insert_expr(expr) yield commit + def append_to_table(name: DbPath, expr: Expr): assert expr.schema, expr t = TablePath(name, expr.schema) diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index c433f548..60636346 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -68,4 +68,5 @@ def max_(expr: Expr): def if_(cond: Expr, then: Expr, else_: Optional[Expr] = None): return CaseWhen([(cond, then)], else_=else_) + commit = Commit() diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 7bd9a520..b5081fcb 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -31,11 +31,13 @@ def cast_to(self, to): Expr = Union[ExprNode, str, bool, int, datetime, ArithString, None] + def get_type(e: Expr) -> type: if isinstance(e, ExprNode): return e.type return type(e) + @dataclass class Alias(ExprNode): expr: Expr @@ -633,6 +635,5 @@ def compile(self, c: Compiler) -> str: @dataclass class Commit(Statement): - def compile(self, c: Compiler) -> str: return "COMMIT" if not c.database.is_autocommit else SKIP diff --git a/data_diff/queries/base.py b/data_diff/queries/base.py index b5d02bb6..7b0d96cb 100644 --- a/data_diff/queries/base.py +++ b/data_diff/queries/base.py @@ -5,7 +5,8 @@ class _SKIP: def __repr__(self): - return 'SKIP' + return "SKIP" + SKIP = _SKIP() diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index e6e3e236..02bb48bc 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -65,4 +65,3 @@ class Compilable(ABC): @abstractmethod def compile(self, c: Compiler) -> str: ... - diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 139f62f0..d88c338e 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -90,18 +90,17 @@ def test_diff_small_tables(self): self.assertEqual(1, self.differ.stats["table2_count"]) # Test materialize - materialize_path = self.connection.parse_table_name(f'test_mat_{random_table_suffix()}') + materialize_path = self.connection.parse_table_name(f"test_mat_{random_table_suffix()}") mdiffer = self.differ.replace(materialize_to_table=materialize_path) diff = list(mdiffer.diff_tables(self.table, self.table2)) self.assertEqual(expected, diff) t = TablePath(materialize_path) - rows = self.connection.query( t.select(), List[tuple] ) - self.connection.query( t.drop() ) + rows = self.connection.query(t.select(), List[tuple]) + self.connection.query(t.drop()) # is_xa, is_xb, is_diff1, is_diff2, row1, row2 assert rows == [(1, 0, 1, 1) + expected_row + (None, None)], rows - def test_diff_table_above_bisection_threshold(self): time = "2022-01-01 00:00:00" time_str = f"timestamp '{time}'" From 3a09a779c32e341654719863ea5f6b83a8eb61ad Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 6 Oct 2022 11:15:48 +0300 Subject: [PATCH 24/93] joindiff: docs, refactor --- data_diff/joindiff_tables.py | 27 +++++++++++++++++++-------- data_diff/queries/api.py | 7 +++++-- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index e0579845..2356912b 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -66,7 +66,7 @@ def create_temp_table(c: Compiler, table: TablePath, expr: Expr): def drop_table_oracle(name: DbPath): - t = TablePath(name) + t = table(name) # Experience shows double drop is necessary with suppress(QueryError): yield t.drop() @@ -75,14 +75,15 @@ def drop_table_oracle(name: DbPath): def drop_table(name: DbPath): - t = TablePath(name) + t = table(name) yield t.drop(if_exists=True) yield commit -def append_to_table_oracle(name: DbPath, expr: Expr): +def append_to_table_oracle(path: DbPath, expr: Expr): + """See append_to_table""" assert expr.schema, expr - t = TablePath(name, expr.schema) + t = table(path, schema=expr.schema) with suppress(QueryError): yield t.create() # uses expr.schema yield commit @@ -90,9 +91,11 @@ def append_to_table_oracle(name: DbPath, expr: Expr): yield commit -def append_to_table(name: DbPath, expr: Expr): +def append_to_table(path: DbPath, expr: Expr): + """Append to table + """ assert expr.schema, expr - t = TablePath(name, expr.schema) + t = table(path, schema=expr.schema) yield t.create(if_not_exists=True) # uses expr.schema yield commit yield t.insert_expr(expr) @@ -143,17 +146,25 @@ class JoinDiffer(TableDiffer): The algorithm uses an OUTER JOIN (or equivalent) with extra checks and statistics. The two tables must reside in the same database, and their primary keys must be unique and not null. + All parameters are optional. + Parameters: threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``. There may be many pools, so number of actual threads can be a lot higher. + validate_unique_key (bool): Enable/disable validating that the key columns are unique. + Single query, and can't be threaded, so it's very slow on non-cloud dbs. + Future versions will detect UNIQUE constraints in the schema. + sample_exclusive_rows (bool): Enable/disable sampling of exclusive rows. Creates a temporary table. + materialize_to_table (DbPath, optional): Path of new table to write diff results to. Disabled if not provided. + write_limit (int): Maximum number of rows to write when materializing, per thread. """ - stats: dict = {} validate_unique_key: bool = True sample_exclusive_rows: bool = True materialize_to_table: DbPath = None write_limit: int = WRITE_LIMIT + stats: dict = {} def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: db = table1.database @@ -330,7 +341,7 @@ def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): def exclusive_rows(expr): c = Compiler(db) name = c.new_unique_table_name("temp_table") - exclusive_rows = TablePath(name, schema=expr.source_table.schema) + exclusive_rows = table(name, schema=expr.source_table.schema) yield create_temp_table(c, exclusive_rows, expr.limit(self.write_limit)) count = yield exclusive_rows.count() diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index 60636346..f000ec67 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -30,8 +30,11 @@ def cte(expr: Expr, *, name: Optional[str] = None, params: Sequence[str] = None) return Cte(expr, name, params) -def table(*path: str, schema: Schema = None) -> ITable: - assert all(isinstance(i, str) for i in path), path +def table(*path: str, schema: Schema = None) -> TablePath: + if len(path) == 1 and isinstance(path[0], tuple): + path ,= path + if not all(isinstance(i, str) for i in path): + raise TypeError(f"All elements of table path must be of type 'str'. Got: {path}") return TablePath(path, schema) From 90cbfb6ba4bc012761f8ac25d624de57af96d49a Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 6 Oct 2022 20:15:04 +0300 Subject: [PATCH 25/93] Queries fix --- data_diff/queries/ast_classes.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index b5081fcb..173ce933 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -248,12 +248,17 @@ class BinOp(ExprNode, LazyOps): op: str args: Sequence[Expr] - def __post_init__(self): - assert len(self.args) == 2, self.args - def compile(self, c: Compiler) -> str: - a, b = self.args - return f"({c.compile(a)} {self.op} {c.compile(b)})" + expr = f" {self.op} ".join(c.compile(a) for a in self.args) + return f"({expr})" + + @property + def type(self): + types = {get_type(i) for i in self.args} + if len(types) > 1: + raise TypeError(f"Expected all args to have the same type, got {types}") + t ,= types + return t class BinBoolOp(BinOp): From ad48f5dcccea554a9648fa459ee7814b4fbfae60 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 6 Oct 2022 18:39:41 +0300 Subject: [PATCH 26/93] Composite key - initial (WIP); Refactor: TableSegment.key_column -> key_columns Added test --- data_diff/__init__.py | 25 ++++++---- data_diff/__main__.py | 16 +++---- data_diff/diff_tables.py | 11 ++++- data_diff/joindiff_tables.py | 8 ++-- data_diff/queries/ast_classes.py | 2 +- data_diff/table_segment.py | 27 ++++++----- tests/test_database_types.py | 8 ++-- tests/test_diff_tables.py | 79 +++++++++++++++++--------------- tests/test_joindiff.py | 69 ++++++++++++++++++++++------ tests/test_postgresql.py | 6 +-- 10 files changed, 158 insertions(+), 93 deletions(-) diff --git a/data_diff/__init__.py b/data_diff/__init__.py index 3e8451ba..ae6f021d 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -1,4 +1,4 @@ -from typing import Tuple, Iterator, Optional, Union +from typing import Sequence, Tuple, Iterator, Optional, Union from .tracking import disable_tracking from .databases.connect import connect @@ -12,7 +12,7 @@ def connect_to_table( db_info: Union[str, dict], table_name: Union[DbPath, str], - key_column: str = "id", + key_columns: str = ("id",), thread_count: Optional[int] = 1, **kwargs, ) -> TableSegment: @@ -21,19 +21,21 @@ def connect_to_table( Parameters: db_info: Either a URI string, or a dict of connection options. table_name: Name of the table as a string, or a tuple that signifies the path. - key_column: Name of the key column + key_columns: Names of the key columns thread_count: Number of threads for this connection (only if using a threadpooled db implementation) See Also: :meth:`connect` """ + if isinstance(key_columns, str): + key_columns = (key_columns,) db = connect(db_info, thread_count=thread_count) if isinstance(table_name, str): table_name = db.parse_table_name(table_name) - return TableSegment(db, table_name, key_column, **kwargs) + return TableSegment(db, table_name, key_columns, **kwargs) def diff_tables( @@ -41,7 +43,7 @@ def diff_tables( table2: TableSegment, *, # Name of the key column, which uniquely identifies each row (usually id) - key_column: str = None, + key_columns: Sequence[str] = None, # Name of updated column, which signals that rows changed (usually updated_at or last_update) update_column: str = None, # Extra columns to compare @@ -67,12 +69,12 @@ def diff_tables( """Finds the diff between table1 and table2. Parameters: - key_column (str): Name of the key column, which uniquely identifies each row (usually id) + key_columns (Tuple[str, ...]): Name of the key column, which uniquely identifies each row (usually id) update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update). Used by `min_update` and `max_update`. extra_columns (Tuple[str, ...], optional): Extra columns to compare - min_key (:data:`DbKey`, optional): Lowest key_column value, used to restrict the segment - max_key (:data:`DbKey`, optional): Highest key_column value, used to restrict the segment + min_key (:data:`DbKey`, optional): Lowest key value, used to restrict the segment + max_key (:data:`DbKey`, optional): Highest key value, used to restrict the segment min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment max_update (:data:`DbTime`, optional): Highest update_column value, used to restrict the segment algorithm (:class:`Algorithm`): Which diffing algorithm to use (`HASHDIFF` or `JOINDIFF`) @@ -84,7 +86,7 @@ def diff_tables( Note: The following parameters are used to override the corresponding attributes of the given :class:`TableSegment` instances: - `key_column`, `update_column`, `extra_columns`, `min_key`, `max_key`. If different values are needed per table, it's + `key_columns`, `update_column`, `extra_columns`, `min_key`, `max_key`. If different values are needed per table, it's possible to omit them here, and instead set them directly when creating each :class:`TableSegment`. Example: @@ -98,11 +100,14 @@ def diff_tables( :class:`JoinDiffer` """ + if isinstance(key_columns, str): + key_columns = (key_columns,) + tables = [table1, table2] override_attrs = { k: v for k, v in dict( - key_column=key_column, + key_columns=key_columns, update_column=update_column, extra_columns=extra_columns, min_key=min_key, diff --git a/data_diff/__main__.py b/data_diff/__main__.py index c6c8fefe..0b7fce7c 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -80,7 +80,7 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - @click.argument("table1", required=False) @click.argument("database2", required=False) @click.argument("table2", required=False) -@click.option("-k", "--key-column", default=None, help="Name of primary key column. Default='id'.", metavar="NAME") +@click.option("-k", "--key-columns", default=[], multiple=True, help="Names of primary key columns. Default='id'.", metavar="NAME") @click.option("-t", "--update-column", default=None, help="Name of updated_at/last_updated column", metavar="NAME") @click.option( "-c", @@ -187,7 +187,7 @@ def _main( table1, database2, table2, - key_column, + key_columns, update_column, columns, limit, @@ -233,7 +233,7 @@ def _main( logging.error("Cannot specify a limit when using the -s/--stats switch") return - key_column = key_column or "id" + key_columns = key_columns or ("id",) bisection_factor = DEFAULT_BISECTION_FACTOR if bisection_factor is None else int(bisection_factor) bisection_threshold = DEFAULT_BISECTION_THRESHOLD if bisection_threshold is None else int(bisection_threshold) @@ -328,23 +328,23 @@ def _main( expanded_columns |= match - columns = tuple(expanded_columns - {key_column, update_column}) + columns = tuple(expanded_columns - {*key_columns, update_column}) if db1 is db2: diff_schemas( schema1, schema2, ( - key_column, + *key_columns, update_column, + *columns, ) - + columns, ) - logging.info(f"Diffing using columns: key={key_column} update={update_column} extra={columns}") + logging.info(f"Diffing using columns: key={key_columns} update={update_column} extra={columns}") segments = [ - TableSegment(db, table_path, key_column, update_column, columns, **options)._with_raw_schema(raw_schema) + TableSegment(db, table_path, key_columns, update_column, columns, **options)._with_raw_schema(raw_schema) for db, table_path, raw_schema in safezip(dbs, table_paths, schemas) ] diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index a90d8ef8..6801e0d2 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -153,8 +153,15 @@ def _diff_segments( ... def _bisect_and_diff_tables(self, table1, table2): - key_type = table1._schema[table1.key_column] - key_type2 = table2._schema[table2.key_column] + if len(table1.key_columns) > 1: + raise NotImplementedError("Composite key not supported yet!") + if len(table2.key_columns) > 1: + raise NotImplementedError("Composite key not supported yet!") + key1 ,= table1.key_columns + key2 ,= table2.key_columns + + key_type = table1._schema[key1] + key_type2 = table2._schema[key2] if not isinstance(key_type, IKey): raise NotImplementedError(f"Cannot use column of type {key_type} as a key") if not isinstance(key_type2, IKey): diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 2356912b..3a8175e3 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -241,7 +241,7 @@ def _test_duplicate_keys(self, table1, table2): # Test duplicate keys for ts in [table1, table2]: t = ts._make_select() - key_columns = [ts.key_column] # XXX + key_columns = ts.key_columns q = t.select(total=Count(), total_distinct=Count(Concat(this[key_columns]), distinct=True)) total, total_distinct = ts.database.query(q, tuple) @@ -254,7 +254,7 @@ def _test_null_keys(self, table1, table2): # Test null keys for ts in [table1, table2]: t = ts._make_select() - key_columns = [ts.key_column] # XXX + key_columns = ts.key_columns q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns)) nulls = ts.database.query(q, list) @@ -294,8 +294,8 @@ def _create_outer_join(self, table1, table2): if db is not table2.database: raise ValueError("Joindiff only applies to tables within the same database") - keys1 = [table1.key_column] # XXX - keys2 = [table2.key_column] # XXX + keys1 = table1.key_columns + keys2 = table2.key_columns if len(keys1) != len(keys2): raise ValueError("The provided key columns are of a different count") diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 173ce933..92a17543 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -562,7 +562,7 @@ def __getattr__(self, name): return _ResolveColumn(name) def __getitem__(self, name): - if isinstance(name, list): + if isinstance(name, (list, tuple)): return [_ResolveColumn(n) for n in name] return _ResolveColumn(name) diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 3a4ddbe4..aa1a1498 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -1,5 +1,5 @@ import time -from typing import List, Tuple +from typing import List, Sequence, Tuple import logging from runtype import dataclass @@ -22,12 +22,12 @@ class TableSegment: Parameters: database (Database): Database instance. See :meth:`connect` table_path (:data:`DbPath`): Path to table in form of a tuple. e.g. `('my_dataset', 'table_name')` - key_column (str): Name of the key column, which uniquely identifies each row (usually id) - update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update). + key_columns (Tuple[str]): Name of the key column, which uniquely identifies each row (usually id) + update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update) Used by `min_update` and `max_update`. extra_columns (Tuple[str, ...], optional): Extra columns to compare - min_key (:data:`DbKey`, optional): Lowest key_column value, used to restrict the segment - max_key (:data:`DbKey`, optional): Highest key_column value, used to restrict the segment + min_key (:data:`DbKey`, optional): Lowest key value, used to restrict the segment + max_key (:data:`DbKey`, optional): Highest key value, used to restrict the segment min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment max_update (:data:`DbTime`, optional): Highest update_column value, used to restrict the segment where (str, optional): An additional 'where' expression to restrict the search space. @@ -41,7 +41,7 @@ class TableSegment: table_path: DbPath # Columns - key_column: str + key_columns: Tuple[str, ...] update_column: str = None extra_columns: Tuple[str, ...] = () @@ -80,9 +80,13 @@ def with_schema(self) -> "TableSegment": def _make_key_range(self): if self.min_key is not None: - yield self.min_key <= this[self.key_column] + assert len(self.key_columns) == 1 + k ,= self.key_columns + yield self.min_key <= this[k] if self.max_key is not None: - yield this[self.key_column] < self.max_key + assert len(self.key_columns) == 1 + k ,= self.key_columns + yield this[k] < self.max_key def _make_update_range(self): if self.min_update is not None: @@ -144,7 +148,7 @@ def _relevant_columns(self) -> List[str]: if self.update_column and self.update_column not in extras: extras = [self.update_column] + extras - return [self.key_column] + extras + return list(self.key_columns) + extras @property def _relevant_columns_repr(self) -> List[Expr]: @@ -174,9 +178,10 @@ def query_key_range(self) -> Tuple[int, int]: """Query database for minimum and maximum key. This is used for setting the initial bounds.""" # Normalizes the result (needed for UUIDs) after the min/max computation # TODO better error if there is no schema + k ,= self.key_columns select = self._make_select().select( - ApplyFuncAndNormalizeAsString(this[self.key_column], min_), - ApplyFuncAndNormalizeAsString(this[self.key_column], max_), + ApplyFuncAndNormalizeAsString(this[k], min_), + ApplyFuncAndNormalizeAsString(this[k], max_), ) min_key, max_key = self.database.query(select, tuple) diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 4ac8d5f4..c9e9042c 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -647,11 +647,11 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego insertion_target_duration = time.monotonic() - start if type_category == "uuid": - self.table = TableSegment(self.src_conn, src_table_path, "col", None, ("id",), case_sensitive=False) - self.table2 = TableSegment(self.dst_conn, dst_table_path, "col", None, ("id",), case_sensitive=False) + self.table = TableSegment(self.src_conn, src_table_path, ("col",), None, ("id",), case_sensitive=False) + self.table2 = TableSegment(self.dst_conn, dst_table_path, ("col",), None, ("id",), case_sensitive=False) else: - self.table = TableSegment(self.src_conn, src_table_path, "id", None, ("col",), case_sensitive=False) - self.table2 = TableSegment(self.dst_conn, dst_table_path, "id", None, ("col",), case_sensitive=False) + self.table = TableSegment(self.src_conn, src_table_path, ("id",), None, ("col",), case_sensitive=False) + self.table2 = TableSegment(self.dst_conn, dst_table_path, ("id",), None, ("col",), case_sensitive=False) start = time.monotonic() self.assertEqual(N_SAMPLES, self.table.count()) diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 639587de..9dfc7de7 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -44,6 +44,11 @@ def _class_per_db_dec(filter_name=None): ] return parameterized_class(("name", "db_name"), names) +def _table_segment(database, table_path, key_columns, *args, **kw): + if isinstance(key_columns, str): + key_columns = (key_columns,) + return TableSegment(database, table_path, key_columns, *args, **kw) + def test_per_database(cls): return _class_per_db_dec()(cls) @@ -102,7 +107,7 @@ class TestPerDatabase(unittest.TestCase): preql = None def setUp(self): - assert self.db_name + assert self.db_name, self.db_name init_instances() self.connection = DATABASE_INSTANCES[self.db_name] @@ -170,17 +175,17 @@ def setUp(self): self.preql.commit() def test_init(self): - a = TableSegment( + a = _table_segment( self.connection, self.table_src_path, "id", "datetime", max_update=self.now.datetime, case_sensitive=False ) self.assertRaises( - ValueError, TableSegment, self.connection, self.table_src_path, "id", max_update=self.now.datetime + ValueError, _table_segment, self.connection, self.table_src_path, "id", max_update=self.now.datetime ) def test_basic(self): differ = HashDiffer(bisection_factor=10, bisection_threshold=100) - a = TableSegment(self.connection, self.table_src_path, "id", "datetime", case_sensitive=False) - b = TableSegment(self.connection, self.table_dst_path, "id", "datetime", case_sensitive=False) + a = _table_segment(self.connection, self.table_src_path, "id", "datetime", case_sensitive=False) + b = _table_segment(self.connection, self.table_dst_path, "id", "datetime", case_sensitive=False) assert a.count() == 6 assert b.count() == 5 @@ -190,23 +195,23 @@ def test_basic(self): def test_offset(self): differ = HashDiffer(bisection_factor=2, bisection_threshold=10) sec1 = self.now.shift(seconds=-1).datetime - a = TableSegment(self.connection, self.table_src_path, "id", "datetime", max_update=sec1, case_sensitive=False) - b = TableSegment(self.connection, self.table_dst_path, "id", "datetime", max_update=sec1, case_sensitive=False) + a = _table_segment(self.connection, self.table_src_path, "id", "datetime", max_update=sec1, case_sensitive=False) + b = _table_segment(self.connection, self.table_dst_path, "id", "datetime", max_update=sec1, case_sensitive=False) assert a.count() == 4 assert b.count() == 3 assert not list(differ.diff_tables(a, a)) self.assertEqual(len(list(differ.diff_tables(a, b))), 1) - a = TableSegment(self.connection, self.table_src_path, "id", "datetime", min_update=sec1, case_sensitive=False) - b = TableSegment(self.connection, self.table_dst_path, "id", "datetime", min_update=sec1, case_sensitive=False) + a = _table_segment(self.connection, self.table_src_path, "id", "datetime", min_update=sec1, case_sensitive=False) + b = _table_segment(self.connection, self.table_dst_path, "id", "datetime", min_update=sec1, case_sensitive=False) assert a.count() == 2 assert b.count() == 2 assert not list(differ.diff_tables(a, b)) day1 = self.now.shift(days=-1).datetime - a = TableSegment( + a = _table_segment( self.connection, self.table_src_path, "id", @@ -215,7 +220,7 @@ def test_offset(self): max_update=sec1, case_sensitive=False, ) - b = TableSegment( + b = _table_segment( self.connection, self.table_dst_path, "id", @@ -249,8 +254,8 @@ def setUp(self): ) _commit(self.connection) - self.table = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) - self.table2 = TableSegment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) + self.table = _table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) + self.table2 = _table_segment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) self.differ = HashDiffer(bisection_factor=3, bisection_threshold=4) @@ -443,8 +448,8 @@ def test_diff_column_names(self): ], ) - table1 = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) - table2 = TableSegment(self.connection, self.table_dst_path, "id2", "timestamp2", case_sensitive=False) + table1 = _table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) + table2 = _table_segment(self.connection, self.table_dst_path, "id2", "timestamp2", case_sensitive=False) differ = HashDiffer() diff = list(differ.diff_tables(table1, table2)) @@ -478,8 +483,8 @@ def setUp(self): _commit(self.connection) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_string_keys(self): differ = HashDiffer() @@ -535,8 +540,8 @@ def setUp(self): _commit(self.connection) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_alphanum_keys(self): @@ -549,8 +554,8 @@ def test_alphanum_keys(self): ) _commit(self.connection) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) self.assertRaises(NotImplementedError, list, differ.diff_tables(self.a, self.b)) @@ -587,8 +592,8 @@ def setUp(self): _commit(self.connection) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_varying_alphanum_keys(self): # Test the class itself @@ -609,8 +614,8 @@ def test_varying_alphanum_keys(self): ) _commit(self.connection) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) self.assertRaises(NotImplementedError, list, differ.diff_tables(self.a, self.b)) @@ -619,8 +624,8 @@ def test_varying_alphanum_keys(self): class TestTableSegment(TestPerDatabase): def setUp(self) -> None: super().setUp() - self.table = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) - self.table2 = TableSegment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) + self.table = _table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) + self.table2 = _table_segment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) def test_table_segment(self): early = datetime.datetime(2021, 1, 1, 0, 0) @@ -641,11 +646,11 @@ def test_case_awareness(self): _insert_rows(self.connection, self.table_src, cols, [[1, 9, time_str], [2, 2, time_str]]) _commit(self.connection) - res = tuple(self.table.replace(key_column="Id", case_sensitive=False).with_schema().query_key_range()) + res = tuple(self.table.replace(key_columns=("Id",), case_sensitive=False).with_schema().query_key_range()) assert res == ("1", "2") self.assertRaises( - KeyError, self.table.replace(key_column="Id", case_sensitive=True).with_schema().query_key_range + KeyError, self.table.replace(key_columns=("Id",), case_sensitive=True).with_schema().query_key_range ) @@ -674,8 +679,8 @@ def setUp(self): _commit(self.connection) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_uuid_column_with_nulls(self): differ = HashDiffer() @@ -704,8 +709,8 @@ def setUp(self): _commit(self.connection) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_uuid_columns_with_nulls(self): """ @@ -762,10 +767,10 @@ def setUp(self): _commit(self.connection) - self.a = TableSegment( + self.a = _table_segment( self.connection, self.table_src_path, "id", extra_columns=("c1", "c2"), case_sensitive=False ) - self.b = TableSegment( + self.b = _table_segment( self.connection, self.table_dst_path, "id", extra_columns=("c1", "c2"), case_sensitive=False ) @@ -819,8 +824,8 @@ def setUp(self): _commit(self.connection) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_right_table_empty(self): differ = HashDiffer() diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index d88c338e..62f7cd8e 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -1,3 +1,4 @@ +from functools import wraps from typing import List from parameterized import parameterized_class @@ -28,9 +29,7 @@ def init_instances(): DATABASE_INSTANCES = {k.__name__: connect(v, N_THREADS) for k, v in CONN_STRINGS.items()} -TEST_DATABASES = { - x.__name__ - for x in ( +TEST_DATABASES = ( db.PostgreSQL, db.Snowflake, db.MySQL, @@ -40,17 +39,63 @@ def init_instances(): db.Trino, db.Oracle, db.Redshift, - ) -} - -_class_per_db_dec = parameterized_class( - ("name", "db_name"), [(name, name) for name in DATABASE_URIS if name in TEST_DATABASES] ) -def test_per_database(cls): +def test_per_database(cls, dbs=TEST_DATABASES): + dbs = {db.__name__ for db in dbs} + _class_per_db_dec = parameterized_class( + ("name", "db_name"), [(name, name) for name in DATABASE_URIS if name in dbs] + ) return _class_per_db_dec(cls) +def test_per_database2(*dbs): + @wraps(test_per_database) + def dec(cls): + return test_per_database(cls, dbs) + return dec + + +@test_per_database2(db.Snowflake, db.BigQuery) +class TestCompositeKey(TestPerDatabase): + def setUp(self): + super().setUp() + + float_type = _get_float_type(self.connection) + + self.connection.query( + f"create table {self.table_src}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", + ) + self.connection.query( + f"create table {self.table_dst}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", + ) + _commit(self.connection) + + self.differ = JoinDiffer() + + def test_composite_key(self): + time = "2022-01-01 00:00:00" + time_str = f"timestamp '{time}'" + + cols = "id userid movieid rating timestamp".split() + _insert_rows(self.connection, self.table_src, cols, [[1, 1, 1, 9, time_str], [2, 2, 2, 9, time_str]]) + _insert_rows(self.connection, self.table_dst, cols, [[1, 1, 1, 9, time_str], [2, 3, 2, 9, time_str]]) + _commit(self.connection) + + # Sanity + table1 = TableSegment(self.connection, self.table_src_path, ("id",), "timestamp", ('userid',), case_sensitive=False) + table2 = TableSegment(self.connection, self.table_dst_path, ("id",), "timestamp", ('userid',), case_sensitive=False) + diff = list(self.differ.diff_tables(table1, table2)) + assert len(diff) == 2 + assert self.differ.stats['exclusive_count'] == 0 + + # Test pks diffed, by checking exclusive_count + table1 = TableSegment(self.connection, self.table_src_path, ("id", "userid"), "timestamp", case_sensitive=False) + table2 = TableSegment(self.connection, self.table_dst_path, ("id", "userid"), "timestamp", case_sensitive=False) + diff = list(self.differ.diff_tables(table1, table2)) + assert len(diff) == 2 + assert self.differ.stats['exclusive_count'] == 2 + @test_per_database class TestJoindiff(TestPerDatabase): @@ -61,16 +106,14 @@ def setUp(self): self.connection.query( f"create table {self.table_src}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", - None, ) self.connection.query( f"create table {self.table_dst}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", - None, ) _commit(self.connection) - self.table = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) - self.table2 = TableSegment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) + self.table = TableSegment(self.connection, self.table_src_path, ("id",), "timestamp", case_sensitive=False) + self.table2 = TableSegment(self.connection, self.table_dst_path, ("id",), "timestamp", case_sensitive=False) self.differ = JoinDiffer() diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 3a4f4239..21e64b3e 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -37,8 +37,8 @@ def test_uuid(self): for query in queries: self.connection.query(query, None) - a = TableSegment(self.connection, (self.table_src,), "id", "comment") - b = TableSegment(self.connection, (self.table_dst,), "id", "comment") + a = TableSegment(self.connection, (self.table_src,), ("id",), "comment") + b = TableSegment(self.connection, (self.table_dst,), ("id",), "comment") differ = HashDiffer() diff = list(differ.diff_tables(a, b)) @@ -56,7 +56,7 @@ def test_uuid(self): mysql_conn.query(f"INSERT INTO {self.table_dst}(id, comment) VALUES ('{uuid}', '{comment}')", None) mysql_conn.query(f"COMMIT", None) - c = TableSegment(mysql_conn, (self.table_dst,), "id", "comment") + c = TableSegment(mysql_conn, (self.table_dst,), ("id",), "comment") diff = list(differ.diff_tables(a, c)) assert not diff, diff diff = list(differ.diff_tables(c, a)) From 377b4a77714550a06d144af6f7abd4dffbd11cc5 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 7 Oct 2022 11:18:14 +0300 Subject: [PATCH 27/93] Fixed interactive mode and explain --- data_diff/databases/base.py | 14 +++++++++----- data_diff/databases/database_types.py | 5 +++++ data_diff/databases/mysql.py | 3 +++ data_diff/databases/snowflake.py | 3 +++ data_diff/queries/ast_classes.py | 7 +++++++ tests/test_query.py | 3 +++ tests/test_sql.py | 2 +- 7 files changed, 31 insertions(+), 6 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 376f5e78..8b3a465d 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -27,7 +27,7 @@ DbPath, ) -from data_diff.queries import Expr, Compiler, table, Select, SKIP +from data_diff.queries import Expr, Compiler, table, Select, SKIP, Explain logger = logging.getLogger("database") @@ -114,7 +114,7 @@ class Database(AbstractDatabase): def name(self): return type(self).__name__ - def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): + def query(self, sql_ast: Union[Expr, Generator], res_type: type = list): "Query the given SQL code/AST, and attempt to convert the result to type 'res_type'" compiler = Compiler(self) @@ -128,8 +128,9 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code) if getattr(self, "_interactive", False) and isinstance(sql_ast, Select): explained_sql = compiler.compile(Explain(sql_ast)) - logger.info("EXPLAIN for SQL SELECT") - logger.info(self._query(explained_sql)) + explain = self._query(explained_sql) + for row, in explain: + logger.debug(f'EXPLAIN: {row}') answer = input("Continue? [y/n] ") if not answer.lower() in ["y", "yes"]: sys.exit(1) @@ -337,7 +338,7 @@ def _query_cursor(self, c, sql_code: str): assert isinstance(sql_code, str), sql_code try: c.execute(sql_code) - if sql_code.lower().startswith("select"): + if sql_code.lower().startswith(("select", "explain", "show")): return c.fetchall() except Exception as e: # logger.exception(e) @@ -349,6 +350,9 @@ def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> lis callback = partial(self._query_cursor, c) return apply_query(callback, sql_code) + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN {query}" + class ThreadedDatabase(Database): """Access the database through singleton threads. diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index 7fe436ae..27249f0d 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -172,6 +172,11 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None "Provide SQL fragment for limit and offset inside a select" ... + @abstractmethod + def explain_as_text(self, query: str) -> str: + "Provide SQL for explaining a query, returned in as table(varchar)" + ... + class AbstractDatabase(AbstractDialect): @abstractmethod diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index b34afb36..b666e0c5 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -84,3 +84,6 @@ def type_repr(self, t) -> str: }[t] except KeyError: return super().type_repr(t) + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN FORMAT=TREE {query}" diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index bbd0958c..714fb5f0 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -90,3 +90,6 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: def is_autocommit(self) -> bool: return True + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN USING TEXT {query}" diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 92a17543..ac57bbe9 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -600,6 +600,13 @@ def compile(self, c: Compiler) -> str: return c.database.random() +@dataclass +class Explain(ExprNode): + select: Select + + def compile(self, c: Compiler) -> str: + return c.database.explain_as_text(c.compile(self.select)) + # DDL diff --git a/tests/test_query.py b/tests/test_query.py index 5091843e..fe0de696 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -32,6 +32,9 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None x = offset and f"offset {offset}", limit and f"limit {limit}" return " ".join(filter(None, x)) + def explain_as_text(self, query: str) -> str: + return f"explain {query}" + class TestQuery(unittest.TestCase): def setUp(self): diff --git a/tests/test_sql.py b/tests/test_sql.py index fe17940b..0e1e8d13 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -91,7 +91,7 @@ def test_count_with_column(self): ) def test_explain(self): - expected_sql = "EXPLAIN SELECT count(id) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" + expected_sql = "EXPLAIN FORMAT=TREE SELECT count(id) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" self.assertEqual( expected_sql, self.compiler.compile( From 4c16bac24105729b676123ccda0ae2e8aeec5f35 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 7 Oct 2022 14:55:29 +0300 Subject: [PATCH 28/93] Update README --- README.md | 34 +++++++++++++++++++++++++++------- data_diff/__main__.py | 4 ++++ 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index df0d27c5..a36fe630 100644 --- a/README.md +++ b/README.md @@ -17,10 +17,20 @@ rows across two different databases. * 🔍 Outputs [diff of rows](#example-command-and-output) in detail * 🚨 Simple CLI/API to create monitoring and alerts * 🔁 Bridges column types of different formats and levels of precision (e.g. Double ⇆ Float ⇆ Decimal) -* 🔥 Verify 25M+ rows in <10s, and 1B+ rows in ~5min. +* 🔥 Fast! Verify 25M+ rows in <10s, and 1B+ rows in ~5min. * ♾️ Works for tables with 10s of billions of rows -**data-diff** splits the table into smaller segments, then checksums each +data-diff can diff tables within the same database, or across different databases. + +**Same-DB Diff**: Uses an outer-join to diff the rows as efficiently and accurately as possible. + +Supports materializing the diff results to a database table. + +Can also collect various extra statistics about the tables. + +**Cross-DB Diff**: Employs a divide and conquer algorithm based on hashing, optimized for few changes. + +data-diff splits the table into smaller segments, then checksums each segment in both databases. When the checksums for a segment aren't equal, it will further divide that segment into yet smaller segments, checksumming those until it gets to the differing row(s). See [Technical Explanation][tech-explain] for more @@ -69,8 +79,8 @@ better than MySQL. may span a half-dozen systems, without verifying each intermediate datastore it's extremely difficult to track down where a row got lost. * **Detecting hard deletes for an `updated_at`-based pipeline**. If you're - copying data to your warehouse based on an `updated_at`-style column, then - you'll miss hard-deletes that **data-diff** can find for you. + copying data to your warehouse based on an `updated_at`-style column, data-diff + can find any hard-deletes that you might have missed. * **Make your replication self-healing.** You can use **data-diff** to self-heal by using the diff output to write/update rows in the target database. @@ -217,7 +227,7 @@ may be case-sensitive. This is the case for the Snowflake schema and table names Options: - `--help` - Show help message and exit. - - `-k` or `--key-column` - Name of the primary key column + - `-k` or `--key-columns` - Name of the primary key column. If none provided, default is 'id'. - `-t` or `--update-column` - Name of updated_at/last_updated column - `-c` or `--columns` - Names of extra columns to compare. Can be used more than once in the same command. Accepts a name or a pattern like in SQL. @@ -232,12 +242,22 @@ Options: Example: `--min-age=5min` ignores rows from the last 5 minutes. Valid units: `d, days, h, hours, min, minutes, mon, months, s, seconds, w, weeks, y, years` - `--max-age` - Considers only rows younger than specified. See `--min-age`. - - `--bisection-factor` - Segments per iteration. When set to 2, it performs binary search. - - `--bisection-threshold` - Minimal bisection threshold. i.e. maximum size of pages to diff locally. - `-j` or `--threads` - Number of worker threads to use per database. Default=1. - `-w`, `--where` - An additional 'where' expression to restrict the search space. - `--conf`, `--run` - Specify the run and configuration from a TOML file. (see below) - `--no-tracking` - data-diff sends home anonymous usage data. Use this to disable it. + - `-a`, `--algorithm` `[auto|joindiff|hashdiff]` - Force algorithm choice + +Same-DB diff only: + - `-m`, `--materialize` - Materialize the diff results into a new table in the database. + Use `%t` in the name to place a timestamp. + Example: `-m test_mat_%t` + - `--assume-unique-key` - Skip validating the uniqueness of the key column during joindiff, which is costly in non-cloud dbs. + +Cross-DB diff only: + - `--bisection-threshold` - Minimal size of segment to be split. Smaller segments will be downloaded and compared locally. + - `--bisection-factor` - Segments per iteration. When set to 2, it performs binary search. + ### How to use with a configuration file diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 0b7fce7c..1380e3b4 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -375,6 +375,10 @@ def _main( print(f"Diff-Total: {len(diff)} changed rows out of {max_table_count}") print(f"Diff-Percent: {percent:.14f}%") print(f"Diff-Split: +{plus} -{minus}") + if differ.stats: + print("Extra-Info:") + for k, v in differ.stats.items(): + print(f' {k} = {v}') else: for op, values in diff_iter: color = COLOR_SCHEME[op] From 472f422b4746ee731ef118ab4fd17717d98b6ab9 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 7 Oct 2022 15:01:41 +0300 Subject: [PATCH 29/93] Added --sample-exclusive-rows switch --- README.md | 1 + data_diff/__main__.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/README.md b/README.md index a36fe630..c9e8d0f4 100644 --- a/README.md +++ b/README.md @@ -253,6 +253,7 @@ Same-DB diff only: Use `%t` in the name to place a timestamp. Example: `-m test_mat_%t` - `--assume-unique-key` - Skip validating the uniqueness of the key column during joindiff, which is costly in non-cloud dbs. + - `--sample-exclusive-rows` - Sample several rows that only appear in one of the tables, but not the other. Use with `-s`. Cross-DB diff only: - `--bisection-threshold` - Minimal size of segment to be split. Smaller segments will be downloaded and compared locally. diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 1380e3b4..2201ee58 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -139,6 +139,11 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - is_flag=True, help="Skip validating the uniqueness of the key column during joindiff, which is costly in non-cloud dbs.", ) +@click.option( + "--sample-exclusive-rows", + is_flag=True, + help="Sample several rows that only appear in one of the tables, but not the other.", +) @click.option( "-j", "--threads", @@ -206,6 +211,7 @@ def _main( json_output, where, assume_unique_key, + sample_exclusive_rows, materialize, threads1=None, threads2=None, @@ -294,6 +300,7 @@ def _main( threaded=threaded, max_threadpool_size=threads and threads * 2, validate_unique_key=not assume_unique_key, + sample_exclusive_rows=sample_exclusive_rows, materialize_to_table=materialize and db1.parse_table_name(eval_name_template(materialize)), ) else: From abaabe84f963080e408606157e0501b7a417aeef Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 7 Oct 2022 16:17:35 +0300 Subject: [PATCH 30/93] README: Updated supported database list --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c9e8d0f4..b54c3648 100644 --- a/README.md +++ b/README.md @@ -138,9 +138,9 @@ $ data-diff \ | PostgreSQL >=10 | `postgresql://:@:5432/` | 💚 | | MySQL | `mysql://:@:5432/` | 💚 | | Snowflake | `"snowflake://[:]@//?warehouse=&role=[&authenticator=externalbrowser]"` | 💚 | +| BigQuery | `bigquery:///` | 💚 | +| Redshift | `redshift://:@:5439/` | 💚 | | Oracle | `oracle://:@/database` | 💛 | -| BigQuery | `bigquery:///` | 💛 | -| Redshift | `redshift://:@:5439/` | 💛 | | Presto | `presto://:@:8080/` | 💛 | | Databricks | `databricks://:@//` | 💛 | | Trino | `trino://:@:8080/` | 💛 | @@ -151,6 +151,8 @@ $ data-diff \ | Pinot | | 📝 | | Druid | | 📝 | | Kafka | | 📝 | +| DuckDB | | 📝 | +| SQLite | | 📝 | * 💚: Implemented and thoroughly tested. * 💛: Implemented, but not thoroughly tested yet. From 245aeb6c4b2cca10c5ab9350e77fcaf6ab066f2b Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 7 Oct 2022 16:27:51 +0300 Subject: [PATCH 31/93] Updated docs; Ran black --- README.md | 1 + data_diff/__main__.py | 10 ++++++---- data_diff/config.py | 6 +++--- data_diff/databases/base.py | 4 ++-- data_diff/diff_tables.py | 6 +++--- data_diff/joindiff_tables.py | 3 +-- data_diff/queries/api.py | 2 +- data_diff/queries/ast_classes.py | 3 ++- data_diff/table_segment.py | 6 +++--- tests/test_diff_tables.py | 17 +++++++++++++---- tests/test_joindiff.py | 32 +++++++++++++++++++------------- 11 files changed, 54 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index b54c3648..83f93ae9 100644 --- a/README.md +++ b/README.md @@ -252,6 +252,7 @@ Options: Same-DB diff only: - `-m`, `--materialize` - Materialize the diff results into a new table in the database. + If a table exists by that name, it will be replaced. Use `%t` in the name to place a timestamp. Example: `-m test_mat_%t` - `--assume-unique-key` - Skip validating the uniqueness of the key column during joindiff, which is costly in non-cloud dbs. diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 2201ee58..5ca5e15b 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -80,7 +80,9 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - @click.argument("table1", required=False) @click.argument("database2", required=False) @click.argument("table2", required=False) -@click.option("-k", "--key-columns", default=[], multiple=True, help="Names of primary key columns. Default='id'.", metavar="NAME") +@click.option( + "-k", "--key-columns", default=[], multiple=True, help="Names of primary key columns. Default='id'.", metavar="NAME" +) @click.option("-t", "--update-column", default=None, help="Name of updated_at/last_updated column", metavar="NAME") @click.option( "-c", @@ -110,7 +112,7 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - "--materialize", default=None, metavar="TABLE_NAME", - help="Materialize the diff results into a new table in the database. (joindiff only)", + help="(joindiff only) Materialize the diff results into a new table in the database. If a table exists by that name, it will be replaced.", ) @click.option( "--min-age", @@ -345,7 +347,7 @@ def _main( *key_columns, update_column, *columns, - ) + ), ) logging.info(f"Diffing using columns: key={key_columns} update={update_column} extra={columns}") @@ -385,7 +387,7 @@ def _main( if differ.stats: print("Extra-Info:") for k, v in differ.stats.items(): - print(f' {k} = {v}') + print(f" {k} = {v}") else: for op, values in diff_iter: color = COLOR_SCHEME[op] diff --git a/data_diff/config.py b/data_diff/config.py index ad7c972d..941e2643 100644 --- a/data_diff/config.py +++ b/data_diff/config.py @@ -26,13 +26,13 @@ def _apply_config(config: Dict[str, Any], run_name: str, kw: Dict[str, Any]): else: run_name = "default" - if 'database1' in kw: - for attr in ('table1', 'database2', 'table2'): + if "database1" in kw: + for attr in ("table1", "database2", "table2"): if kw[attr] is None: raise ValueError(f"Specified database1 but not {attr}. Must specify all 4 arguments, or niether.") for index in "12": - run_args[index] = {attr: kw.pop(f"{attr}{index}") for attr in ('database', 'table')} + run_args[index] = {attr: kw.pop(f"{attr}{index}") for attr in ("database", "table")} # Process databases + tables for index in "12": diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 8b3a465d..c96ec6ae 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -129,8 +129,8 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list): if getattr(self, "_interactive", False) and isinstance(sql_ast, Select): explained_sql = compiler.compile(Explain(sql_ast)) explain = self._query(explained_sql) - for row, in explain: - logger.debug(f'EXPLAIN: {row}') + for (row,) in explain: + logger.debug(f"EXPLAIN: {row}") answer = input("Continue? [y/n] ") if not answer.lower() in ["y", "yes"]: sys.exit(1) diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 6801e0d2..24627c45 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -157,8 +157,8 @@ def _bisect_and_diff_tables(self, table1, table2): raise NotImplementedError("Composite key not supported yet!") if len(table2.key_columns) > 1: raise NotImplementedError("Composite key not supported yet!") - key1 ,= table1.key_columns - key2 ,= table2.key_columns + (key1,) = table1.key_columns + (key2,) = table2.key_columns key_type = table1._schema[key1] key_type2 = table2._schema[key2] @@ -214,7 +214,7 @@ def _bisect_and_diff_segments( assert table1.is_bounded and table2.is_bounded # Choose evenly spaced checkpoints (according to min_key and max_key) - biggest_table = max(table1, table2, key=methodcaller('approximate_size')) + biggest_table = max(table1, table2, key=methodcaller("approximate_size")) checkpoints = biggest_table.choose_checkpoints(self.bisection_factor - 1) # Create new instances of TableSegment between each checkpoint diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 3a8175e3..7617495f 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -92,8 +92,7 @@ def append_to_table_oracle(path: DbPath, expr: Expr): def append_to_table(path: DbPath, expr: Expr): - """Append to table - """ + """Append to table""" assert expr.schema, expr t = table(path, schema=expr.schema) yield t.create(if_not_exists=True) # uses expr.schema diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index f000ec67..d9c0945f 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -32,7 +32,7 @@ def cte(expr: Expr, *, name: Optional[str] = None, params: Sequence[str] = None) def table(*path: str, schema: Schema = None) -> TablePath: if len(path) == 1 and isinstance(path[0], tuple): - path ,= path + (path,) = path if not all(isinstance(i, str) for i in path): raise TypeError(f"All elements of table path must be of type 'str'. Got: {path}") return TablePath(path, schema) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index ac57bbe9..a73a69db 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -257,7 +257,7 @@ def type(self): types = {get_type(i) for i in self.args} if len(types) > 1: raise TypeError(f"Expected all args to have the same type, got {types}") - t ,= types + (t,) = types return t @@ -607,6 +607,7 @@ class Explain(ExprNode): def compile(self, c: Compiler) -> str: return c.database.explain_as_text(c.compile(self.select)) + # DDL diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index aa1a1498..c3219dc1 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -81,11 +81,11 @@ def with_schema(self) -> "TableSegment": def _make_key_range(self): if self.min_key is not None: assert len(self.key_columns) == 1 - k ,= self.key_columns + (k,) = self.key_columns yield self.min_key <= this[k] if self.max_key is not None: assert len(self.key_columns) == 1 - k ,= self.key_columns + (k,) = self.key_columns yield this[k] < self.max_key def _make_update_range(self): @@ -178,7 +178,7 @@ def query_key_range(self) -> Tuple[int, int]: """Query database for minimum and maximum key. This is used for setting the initial bounds.""" # Normalizes the result (needed for UUIDs) after the min/max computation # TODO better error if there is no schema - k ,= self.key_columns + (k,) = self.key_columns select = self._make_select().select( ApplyFuncAndNormalizeAsString(this[k], min_), ApplyFuncAndNormalizeAsString(this[k], max_), diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 9dfc7de7..57f9415b 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -44,6 +44,7 @@ def _class_per_db_dec(filter_name=None): ] return parameterized_class(("name", "db_name"), names) + def _table_segment(database, table_path, key_columns, *args, **kw): if isinstance(key_columns, str): key_columns = (key_columns,) @@ -195,16 +196,24 @@ def test_basic(self): def test_offset(self): differ = HashDiffer(bisection_factor=2, bisection_threshold=10) sec1 = self.now.shift(seconds=-1).datetime - a = _table_segment(self.connection, self.table_src_path, "id", "datetime", max_update=sec1, case_sensitive=False) - b = _table_segment(self.connection, self.table_dst_path, "id", "datetime", max_update=sec1, case_sensitive=False) + a = _table_segment( + self.connection, self.table_src_path, "id", "datetime", max_update=sec1, case_sensitive=False + ) + b = _table_segment( + self.connection, self.table_dst_path, "id", "datetime", max_update=sec1, case_sensitive=False + ) assert a.count() == 4 assert b.count() == 3 assert not list(differ.diff_tables(a, a)) self.assertEqual(len(list(differ.diff_tables(a, b))), 1) - a = _table_segment(self.connection, self.table_src_path, "id", "datetime", min_update=sec1, case_sensitive=False) - b = _table_segment(self.connection, self.table_dst_path, "id", "datetime", min_update=sec1, case_sensitive=False) + a = _table_segment( + self.connection, self.table_src_path, "id", "datetime", min_update=sec1, case_sensitive=False + ) + b = _table_segment( + self.connection, self.table_dst_path, "id", "datetime", min_update=sec1, case_sensitive=False + ) assert a.count() == 2 assert b.count() == 2 assert not list(differ.diff_tables(a, b)) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 62f7cd8e..5203c5a2 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -30,15 +30,15 @@ def init_instances(): TEST_DATABASES = ( - db.PostgreSQL, - db.Snowflake, - db.MySQL, - db.BigQuery, - db.Presto, - db.Vertica, - db.Trino, - db.Oracle, - db.Redshift, + db.PostgreSQL, + db.Snowflake, + db.MySQL, + db.BigQuery, + db.Presto, + db.Vertica, + db.Trino, + db.Oracle, + db.Redshift, ) @@ -49,10 +49,12 @@ def test_per_database(cls, dbs=TEST_DATABASES): ) return _class_per_db_dec(cls) + def test_per_database2(*dbs): @wraps(test_per_database) def dec(cls): return test_per_database(cls, dbs) + return dec @@ -83,18 +85,22 @@ def test_composite_key(self): _commit(self.connection) # Sanity - table1 = TableSegment(self.connection, self.table_src_path, ("id",), "timestamp", ('userid',), case_sensitive=False) - table2 = TableSegment(self.connection, self.table_dst_path, ("id",), "timestamp", ('userid',), case_sensitive=False) + table1 = TableSegment( + self.connection, self.table_src_path, ("id",), "timestamp", ("userid",), case_sensitive=False + ) + table2 = TableSegment( + self.connection, self.table_dst_path, ("id",), "timestamp", ("userid",), case_sensitive=False + ) diff = list(self.differ.diff_tables(table1, table2)) assert len(diff) == 2 - assert self.differ.stats['exclusive_count'] == 0 + assert self.differ.stats["exclusive_count"] == 0 # Test pks diffed, by checking exclusive_count table1 = TableSegment(self.connection, self.table_src_path, ("id", "userid"), "timestamp", case_sensitive=False) table2 = TableSegment(self.connection, self.table_dst_path, ("id", "userid"), "timestamp", case_sensitive=False) diff = list(self.differ.diff_tables(table1, table2)) assert len(diff) == 2 - assert self.differ.stats['exclusive_count'] == 2 + assert self.differ.stats["exclusive_count"] == 2 @test_per_database From e8965fd00b1da6d469a8f3fe83fe8ffb0e23610b Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 7 Oct 2022 16:52:27 +0300 Subject: [PATCH 32/93] Joindiff: Fix stats collections --- data_diff/databases/database_types.py | 92 ++++++++++++++------------- data_diff/databases/presto.py | 2 +- data_diff/joindiff_tables.py | 4 +- data_diff/table_segment.py | 1 - tests/test_diff_tables.py | 6 -- tests/test_joindiff.py | 2 + tests/test_query.py | 5 +- 7 files changed, 58 insertions(+), 54 deletions(-) diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index 27249f0d..dc7a806c 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -141,6 +141,8 @@ class UnknownColType(ColType): class AbstractDialect(ABC): + """Dialect-dependent query expressions""" + name: str @abstractmethod @@ -177,56 +179,18 @@ def explain_as_text(self, query: str) -> str: "Provide SQL for explaining a query, returned in as table(varchar)" ... - -class AbstractDatabase(AbstractDialect): @abstractmethod - def timestamp_value(self, t: DbTime) -> str: + def timestamp_value(self, t: datetime) -> str: "Provide SQL for the given timestamp value" ... - @abstractmethod - def md5_to_int(self, s: str) -> str: - "Provide SQL for computing md5 and returning an int" - ... - - @abstractmethod - def _query(self, sql_code: str) -> list: - "Send query to database and return result" - ... - - @abstractmethod - def select_table_schema(self, path: DbPath) -> str: - "Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)" - ... - @abstractmethod - def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: - """Query the table for its schema for table in 'path', and return {column: tuple} - where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?) - """ - ... +class AbstractDatadiffDialect(ABC): + """Dialect-dependent query expressions, that are specific to data-diff""" @abstractmethod - def _process_table_schema( - self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None - ): - """Process the result of query_table_schema(). - - Done in a separate step, to minimize the amount of processed columns. - Needed because processing each column may: - * throw errors and warnings - * query the database to sample values - - """ - - @abstractmethod - def parse_table_name(self, name: str) -> DbPath: - "Parse the given table name into a DbPath" - ... - - @abstractmethod - def close(self): - "Close connection(s) to the database instance. Querying will stop functioning." + def md5_to_int(self, s: str) -> str: + "Provide SQL for computing md5 and returning an int" ... @abstractmethod @@ -294,6 +258,48 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: return self.normalize_uuid(value, coltype) return self.to_string(value) + +class AbstractDatabase(AbstractDialect, AbstractDatadiffDialect): + @abstractmethod + def _query(self, sql_code: str) -> list: + "Send query to database and return result" + ... + + @abstractmethod + def select_table_schema(self, path: DbPath) -> str: + "Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)" + ... + + @abstractmethod + def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: + """Query the table for its schema for table in 'path', and return {column: tuple} + where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?) + """ + ... + + @abstractmethod + def _process_table_schema( + self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None + ): + """Process the result of query_table_schema(). + + Done in a separate step, to minimize the amount of processed columns. + Needed because processing each column may: + * throw errors and warnings + * query the database to sample values + + """ + + @abstractmethod + def parse_table_name(self, name: str) -> DbPath: + "Parse the given table name into a DbPath" + ... + + @abstractmethod + def close(self): + "Close connection(s) to the database instance. Querying will stop functioning." + ... + @abstractmethod def _normalize_table_path(self, path: DbPath) -> DbPath: ... diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index 2fb041fc..d7204775 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -83,7 +83,7 @@ def close(self): self._conn.close() def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - # TODO + # TODO rounds if coltype.rounds: s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" else: diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 7617495f..a1f23b23 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -10,7 +10,7 @@ from runtype import dataclass -from data_diff.databases.database_types import DbPath, Schema +from data_diff.databases.database_types import DbPath, NumericType, Schema from data_diff.databases.base import QueryError @@ -273,7 +273,7 @@ def _collect_stats(self, i, table): f"max_{c}": max_(this[c]), } for c in table._relevant_columns - if c == "id" # TODO just if the right type + if isinstance(table._schema[c], NumericType) ) col_exprs["count"] = Count() diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index c3219dc1..170955cd 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -177,7 +177,6 @@ def count_and_checksum(self) -> Tuple[int, int]: def query_key_range(self) -> Tuple[int, int]: """Query database for minimum and maximum key. This is used for setting the initial bounds.""" # Normalizes the result (needed for UUIDs) after the min/max computation - # TODO better error if there is no schema (k,) = self.key_columns select = self._make_select().select( ApplyFuncAndNormalizeAsString(this[k], min_), diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 57f9415b..84d70434 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -485,8 +485,6 @@ def setUp(self): self.new_uuid = uuid.uuid1(32132131) queries.append(f"INSERT INTO {self.table_src} VALUES ('{self.new_uuid}', 'This one is different')") - # TODO test unexpected values? - for query in queries: self.connection.query(query, None) @@ -542,8 +540,6 @@ def setUp(self): self.new_alphanum = "aBcDeFgHiJ" queries.append(f"INSERT INTO {self.table_src} VALUES ('{self.new_alphanum}', 'This one is different')") - # TODO test unexpected values? - for query in queries: self.connection.query(query, None) @@ -594,8 +590,6 @@ def setUp(self): self.new_alphanum = "aBcDeFgHiJ" queries.append(f"INSERT INTO {self.table_src} VALUES ('{self.new_alphanum}', 'This one is different')") - # TODO test unexpected values? - for query in queries: self.connection.query(query, None) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 5203c5a2..b1babe35 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -137,6 +137,8 @@ def test_diff_small_tables(self): self.assertEqual(expected, diff) self.assertEqual(2, self.differ.stats["table1_count"]) self.assertEqual(1, self.differ.stats["table2_count"]) + self.assertEqual(3, self.differ.stats["table1_sum_id"]) + self.assertEqual(1, self.differ.stats["table2_sum_id"]) # Test materialize materialize_path = self.connection.parse_table_name(f"test_mat_{random_table_suffix()}") diff --git a/tests/test_query.py b/tests/test_query.py index fe0de696..3ab26e43 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,4 +1,4 @@ -from cmath import exp +from datetime import datetime from typing import List, Optional import unittest from data_diff.databases.database_types import AbstractDialect, CaseInsensitiveDict, CaseSensitiveDict @@ -35,6 +35,9 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None def explain_as_text(self, query: str) -> str: return f"explain {query}" + def timestamp_value(self, t: datetime) -> str: + return f"timestamp '{t}'" + class TestQuery(unittest.TestCase): def setUp(self): From 47b9faa3202584d30bf3f2765c1ab57ae1e1636d Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Sat, 8 Oct 2022 10:37:33 +0300 Subject: [PATCH 33/93] Cleanup and minor fixes (pylint pass) --- data_diff/__init__.py | 17 ++++--- data_diff/databases/base.py | 14 ++++-- data_diff/databases/bigquery.py | 3 +- data_diff/databases/database_types.py | 53 ++++++++++---------- data_diff/databases/databricks.py | 14 +++++- data_diff/databases/mysql.py | 12 ++++- data_diff/databases/oracle.py | 16 +++++- data_diff/databases/postgresql.py | 12 ++++- data_diff/databases/presto.py | 23 +++++++-- data_diff/databases/redshift.py | 3 +- data_diff/databases/snowflake.py | 4 +- data_diff/databases/trino.py | 2 +- data_diff/diff_tables.py | 2 +- data_diff/hashdiff_tables.py | 21 ++++---- data_diff/joindiff_tables.py | 70 +++++++++++---------------- data_diff/queries/ast_classes.py | 20 +++----- data_diff/queries/compiler.py | 2 +- data_diff/table_segment.py | 29 +++++------ data_diff/thread_utils.py | 2 +- data_diff/utils.py | 11 ++++- tests/test_database_types.py | 11 ++--- tests/test_query.py | 2 +- 22 files changed, 203 insertions(+), 140 deletions(-) diff --git a/data_diff/__init__.py b/data_diff/__init__.py index ae6f021d..20c6b57d 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -70,24 +70,27 @@ def diff_tables( Parameters: key_columns (Tuple[str, ...]): Name of the key column, which uniquely identifies each row (usually id) - update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update). - Used by `min_update` and `max_update`. + update_column (str, optional): Name of updated column, which signals that rows changed. + Usually updated_at or last_update. Used by `min_update` and `max_update`. extra_columns (Tuple[str, ...], optional): Extra columns to compare min_key (:data:`DbKey`, optional): Lowest key value, used to restrict the segment max_key (:data:`DbKey`, optional): Highest key value, used to restrict the segment min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment max_update (:data:`DbTime`, optional): Highest update_column value, used to restrict the segment algorithm (:class:`Algorithm`): Which diffing algorithm to use (`HASHDIFF` or `JOINDIFF`) - bisection_factor (int): Into how many segments to bisect per iteration. (when algorithm is `HASHDIFF`) - bisection_threshold (Number): When should we stop bisecting and compare locally (when algorithm is `HASHDIFF`; in row count). + bisection_factor (int): Into how many segments to bisect per iteration. (Used when algorithm is `HASHDIFF`) + bisection_threshold (Number): Minimal row count of segment to bisect, otherwise download + and compare locally. (Used when algorithm is `HASHDIFF`). threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. - max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``. + max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. + Only relevant when `threaded` is ``True``. There may be many pools, so number of actual threads can be a lot higher. Note: The following parameters are used to override the corresponding attributes of the given :class:`TableSegment` instances: - `key_columns`, `update_column`, `extra_columns`, `min_key`, `max_key`. If different values are needed per table, it's - possible to omit them here, and instead set them directly when creating each :class:`TableSegment`. + `key_columns`, `update_column`, `extra_columns`, `min_key`, `max_key`. + If different values are needed per table, it's possible to omit them here, and instead set + them directly when creating each :class:`TableSegment`. Example: >>> table1 = connect_to_table('postgresql:///', 'Rating', 'id') diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index c96ec6ae..897d4c3a 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -8,6 +8,7 @@ from abc import abstractmethod from data_diff.utils import is_uuid, safezip +from data_diff.queries import Expr, Compiler, table, Select, SKIP, Explain from .database_types import ( AbstractDatabase, ColType, @@ -27,8 +28,6 @@ DbPath, ) -from data_diff.queries import Expr, Compiler, table, Select, SKIP, Explain - logger = logging.getLogger("database") @@ -110,6 +109,8 @@ class Database(AbstractDatabase): default_schema: str = None SUPPORTS_ALPHANUMS = True + _interactive = False + @property def name(self): return type(self).__name__ @@ -126,11 +127,14 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list): return SKIP logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code) - if getattr(self, "_interactive", False) and isinstance(sql_ast, Select): + if self._interactive and isinstance(sql_ast, Select): explained_sql = compiler.compile(Explain(sql_ast)) explain = self._query(explained_sql) - for (row,) in explain: - logger.debug(f"EXPLAIN: {row}") + for row in explain: + # Most returned a 1-tuple. Presto returns a string + if isinstance(row, tuple): + row ,= row + logger.debug("EXPLAIN: %s", row) answer = input("Continue? [y/n] ") if not answer.lower() in ["y", "yes"]: sys.exit(1) diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 7044c084..603bfecc 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -1,4 +1,5 @@ -from .database_types import * +from typing import Union +from .database_types import Timestamp, Datetime, Integer, Decimal, Float, Text, DbPath, FractionalType, TemporalType from .base import Database, import_helper, parse_table_name, ConnectError, apply_query from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index dc7a806c..1de1d2fc 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -1,6 +1,6 @@ import logging import decimal -from abc import ABC, abstractmethod, abstractproperty +from abc import ABC, abstractmethod from typing import Sequence, Optional, Tuple, Union, Dict, List from datetime import datetime @@ -234,30 +234,6 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: """ ... - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - """Creates an SQL expression, that converts 'value' to a normalized representation. - - The returned expression must accept any SQL value, and return a string. - - The default implementation dispatches to a method according to `coltype`: - - :: - - TemporalType -> normalize_timestamp() - FractionalType -> normalize_number() - *else* -> to_string() - - (`Integer` falls in the *else* category) - - """ - if isinstance(coltype, TemporalType): - return self.normalize_timestamp(value, coltype) - elif isinstance(coltype, FractionalType): - return self.normalize_number(value, coltype) - elif isinstance(coltype, ColType_UUID): - return self.normalize_uuid(value, coltype) - return self.to_string(value) - class AbstractDatabase(AbstractDialect, AbstractDatadiffDialect): @abstractmethod @@ -304,10 +280,35 @@ def close(self): def _normalize_table_path(self, path: DbPath) -> DbPath: ... - @abstractproperty + @property + @abstractmethod def is_autocommit(self) -> bool: ... + def normalize_value_by_type(self, value: str, coltype: ColType) -> str: + """Creates an SQL expression, that converts 'value' to a normalized representation. + + The returned expression must accept any SQL value, and return a string. + + The default implementation dispatches to a method according to `coltype`: + + :: + + TemporalType -> normalize_timestamp() + FractionalType -> normalize_number() + *else* -> to_string() + + (`Integer` falls in the *else* category) + + """ + if isinstance(coltype, TemporalType): + return self.normalize_timestamp(value, coltype) + elif isinstance(coltype, FractionalType): + return self.normalize_number(value, coltype) + elif isinstance(coltype, ColType_UUID): + return self.normalize_uuid(value, coltype) + return self.to_string(value) + Schema = CaseAwareMapping diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 5d381b66..612c1c8d 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -1,6 +1,18 @@ +from typing import Dict, Sequence import logging -from .database_types import * +from .database_types import ( + Integer, + Float, + Decimal, + Timestamp, + Text, + TemporalType, + NumericType, + DbPath, + ColType, + UnknownColType, +) from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Database, import_helper, parse_table_name diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index b666e0c5..3f9eb98c 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -1,4 +1,14 @@ -from .database_types import * +from .database_types import ( + Datetime, + Timestamp, + Float, + Decimal, + Integer, + Text, + TemporalType, + FractionalType, + ColType_UUID, +) from .base import ThreadedDatabase, import_helper, ConnectError from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 59004412..e65fd65a 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -1,6 +1,20 @@ +from typing import Dict, List, Optional + from ..utils import match_regexps -from .database_types import * +from .database_types import ( + Decimal, + Float, + Text, + DbPath, + TemporalType, + ColType, + DbTime, + ColType_UUID, + Timestamp, + TimestampTZ, + FractionalType, +) from .base import ThreadedDatabase, import_helper, ConnectError, QueryError from .base import TIMESTAMP_PRECISION_POS diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index d65ac7de..72d26d07 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -1,4 +1,14 @@ -from .database_types import * +from .database_types import ( + Timestamp, + TimestampTZ, + Float, + Decimal, + Integer, + TemporalType, + Native_UUID, + Text, + FractionalType, +) from .base import ThreadedDatabase, import_helper, ConnectError from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, _CHECKSUM_BITSIZE, TIMESTAMP_PRECISION_POS diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index d7204775..811a9491 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -3,7 +3,19 @@ from data_diff.utils import match_regexps -from .database_types import * +from .database_types import ( + Timestamp, + TimestampTZ, + Integer, + Float, + Text, + FractionalType, + DbPath, + Decimal, + ColType, + ColType_UUID, + TemporalType, +) from .base import Database, import_helper, ThreadLocalInterpreter from .base import ( MD5_HEXDIGITS, @@ -17,7 +29,7 @@ def query_cursor(c, sql_code): if sql_code.lower().startswith("select"): return c.fetchall() # Required for the query to actually run 🤯 - if re.match(r"(insert|create|truncate|drop)", sql_code, re.IGNORECASE): + if re.match(r"(insert|create|truncate|drop|explain)", sql_code, re.IGNORECASE): return c.fetchone() @@ -98,7 +110,7 @@ def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) return ( - "SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision " + "SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision, NULL as numeric_scale " "FROM INFORMATION_SCHEMA.COLUMNS " f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) @@ -110,6 +122,7 @@ def _parse_type( type_repr: str, datetime_precision: int = None, numeric_precision: int = None, + numeric_scale: int = None, ) -> ColType: timestamp_regexps = { r"timestamp\((\d)\)": Timestamp, @@ -134,5 +147,9 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Trim doesn't work on CHAR type return f"TRIM(CAST({value} AS VARCHAR))" + @property def is_autocommit(self) -> bool: return False + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN (FORMAT TEXT) {query}" diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index f11b950c..291d180b 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -1,4 +1,5 @@ -from .database_types import * +from typing import List +from .database_types import Float, TemporalType, FractionalType, DbPath from .postgresql import PostgreSQL, MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 714fb5f0..635ba8f4 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,6 +1,7 @@ +from typing import Union import logging -from .database_types import * +from .database_types import Timestamp, TimestampTZ, Decimal, Float, Text, FractionalType, TemporalType, DbPath from .base import ConnectError, Database, import_helper, CHECKSUM_MASK, ThreadLocalInterpreter @@ -88,6 +89,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: def normalize_number(self, value: str, coltype: FractionalType) -> str: return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") + @property def is_autocommit(self) -> bool: return True diff --git a/data_diff/databases/trino.py b/data_diff/databases/trino.py index c3e3e581..73ef4a97 100644 --- a/data_diff/databases/trino.py +++ b/data_diff/databases/trino.py @@ -1,4 +1,4 @@ -from .database_types import * +from .database_types import TemporalType, ColType_UUID from .presto import Presto from .base import import_helper from .base import TIMESTAMP_PRECISION_POS diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 24627c45..1148041b 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -78,6 +78,7 @@ def _run_in_background(self, *funcs): class TableDiffer(ThreadBase, ABC): bisection_factor = 32 + stats: dict = {} def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: """Diff the given tables. @@ -177,7 +178,6 @@ def _bisect_and_diff_tables(self, table1, table2): table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)] logger.info( - # f"Diffing tables | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}. " f"Diffing segments at key-range: {table1.min_key}..{table2.max_key}. " f"size: table1 <= {table1.approximate_size()}, table2 <= {table2.approximate_size()}" ) diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index ec575bdc..38e6fee5 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -9,7 +9,7 @@ from .utils import safezip from .thread_utils import ThreadedYielder -from .databases.database_types import ColType_UUID, IKey, NumericType, PrecisionType, StringType +from .databases.database_types import ColType_UUID, NumericType, PrecisionType, StringType from .table_segment import TableSegment from .diff_tables import TableDiffer @@ -27,7 +27,7 @@ def diff_sets(a: set, b: set) -> Iterator: s2 = set(b) d = defaultdict(list) - # The first item is always the key (see TableDiffer._relevant_columns) + # The first item is always the key (see TableDiffer.relevant_columns) for i in s1 - s2: d[i[0]].append(("-", i)) for i in s2 - s1: @@ -50,7 +50,8 @@ class HashDiffer(TableDiffer): bisection_factor (int): Into how many segments to bisect per iteration. bisection_threshold (Number): When should we stop bisecting and compare locally (in row count). threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. - max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``. + max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. + Only relevant when `threaded` is ``True``. There may be many pools, so number of actual threads can be a lot higher. """ @@ -67,7 +68,7 @@ def __post_init__(self): raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)") def _validate_and_adjust_columns(self, table1, table2): - for c1, c2 in safezip(table1._relevant_columns, table2._relevant_columns): + for c1, c2 in safezip(table1.relevant_columns, table2.relevant_columns): if c1 not in table1._schema: raise ValueError(f"Column '{c1}' not found in schema for table {table1}") if c2 not in table2._schema: @@ -109,7 +110,7 @@ def _validate_and_adjust_columns(self, table1, table2): raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") for t in [table1, table2]: - for c in t._relevant_columns: + for c in t.relevant_columns: ctype = t._schema[c] if not ctype.supported: logger.warning( @@ -144,10 +145,12 @@ def _diff_segments( (count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2]) if count1 == 0 and count2 == 0: - # logger.warning( - # f"Uneven distribution of keys detected in segment {table1.min_key}..{table2.max_key}. (big gaps in the key column). " - # "For better performance, we recommend to increase the bisection-threshold." - # ) + logger.debug( + "Uneven distribution of keys detected in segment %s..%s (big gaps in the key column). " + "For better performance, we recommend to increase the bisection-threshold.", + table1.min_key, + table1.max_key, + ) assert checksum1 is None and checksum2 is None return diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index a1f23b23..58246def 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -6,22 +6,22 @@ from decimal import Decimal from functools import partial import logging -from typing import Dict, List, Optional +from typing import List from runtype import dataclass -from data_diff.databases.database_types import DbPath, NumericType, Schema +from data_diff.databases.database_types import DbPath, NumericType from data_diff.databases.base import QueryError from .utils import safezip from .databases.base import Database -from .databases import MySQL, BigQuery, Presto, Oracle, PostgreSQL, Snowflake +from .databases import MySQL, BigQuery, Presto, Oracle, Snowflake from .table_segment import TableSegment from .diff_tables import TableDiffer, DiffResult from .thread_utils import ThreadedYielder -from .queries import table, sum_, min_, max_, avg, SKIP, commit +from .queries import table, sum_, min_, max_, avg, commit from .queries.api import and_, if_, or_, outerjoin, leftjoin, rightjoin, this, ITable from .queries.ast_classes import Concat, Count, Expr, Random, TablePath from .queries.compiler import Compiler @@ -40,29 +40,20 @@ def merge_dicts(dicts): return res -@dataclass(frozen=False) -class Stats: - exclusive_count: int - exclusive_sample: List[tuple] - diff_ratio_by_column: Dict[str, float] - diff_ratio_total: float - metrics: Dict[str, float] +def sample(table_expr): + return table_expr.order_by(Random()).limit(10) -def sample(table): - return table.order_by(Random()).limit(10) - - -def create_temp_table(c: Compiler, table: TablePath, expr: Expr): +def create_temp_table(c: Compiler, path: TablePath, expr: Expr): db = c.database if isinstance(db, BigQuery): - return f"create table {c.compile(table)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}" + return f"create table {c.compile(path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}" elif isinstance(db, Presto): - return f"create table {c.compile(table)} as {c.compile(expr)}" + return f"create table {c.compile(path)} as {c.compile(expr)}" elif isinstance(db, Oracle): - return f"create global temporary table {c.compile(table)} as {c.compile(expr)}" + return f"create global temporary table {c.compile(path)} as {c.compile(expr)}" else: - return f"create temporary table {c.compile(table)} as {c.compile(expr)}" + return f"create temporary table {c.compile(path)} as {c.compile(expr)}" def drop_table_oracle(name: DbPath): @@ -149,7 +140,8 @@ class JoinDiffer(TableDiffer): Parameters: threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. - max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``. + max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. + Only relevant when `threaded` is ``True``. There may be many pools, so number of actual threads can be a lot higher. validate_unique_key (bool): Enable/disable validating that the key columns are unique. Single query, and can't be threaded, so it's very slow on non-cloud dbs. @@ -227,8 +219,8 @@ def _diff_segments( if is_xa and is_xb: # Can't both be exclusive, meaning a pk is NULL # This can happen if the explicit null test didn't finish running yet - raise ValueError(f"NULL values in one or more primary keys") - is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols)) + raise ValueError("NULL values in one or more primary keys") + _is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols)) if not is_xb: yield "-", tuple(a_row) if not is_xa: @@ -239,7 +231,7 @@ def _test_duplicate_keys(self, table1, table2): # Test duplicate keys for ts in [table1, table2]: - t = ts._make_select() + t = ts.make_select() key_columns = ts.key_columns q = t.select(total=Count(), total_distinct=Count(Concat(this[key_columns]), distinct=True)) @@ -252,17 +244,17 @@ def _test_null_keys(self, table1, table2): # Test null keys for ts in [table1, table2]: - t = ts._make_select() + t = ts.make_select() key_columns = ts.key_columns q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns)) nulls = ts.database.query(q, list) if nulls: - raise ValueError(f"NULL values in one or more primary keys") + raise ValueError("NULL values in one or more primary keys") - def _collect_stats(self, i, table): + def _collect_stats(self, i, table_seg: TableSegment): logger.info(f"Collecting stats for table #{i}") - db = table.database + db = table_seg.database # Metrics col_exprs = merge_dicts( @@ -272,21 +264,17 @@ def _collect_stats(self, i, table): f"min_{c}": min_(this[c]), f"max_{c}": max_(this[c]), } - for c in table._relevant_columns - if isinstance(table._schema[c], NumericType) + for c in table_seg.relevant_columns + if isinstance(table_seg._schema[c], NumericType) ) col_exprs["count"] = Count() - res = db.query(table._make_select().select(**col_exprs), tuple) + res = db.query(table_seg.make_select().select(**col_exprs), tuple) res = dict(zip([f"table{i}_{n}" for n in col_exprs], map(json_friendly_value, res))) for k, v in res.items(): self.stats[k] = self.stats.get(k, 0) + (v or 0) - # self.stats.update(res) - - logger.debug(f"Done collecting stats for table #{i}") - # stats.diff_ratio_by_column = diff_stats - # stats.diff_ratio_total = diff_stats['total_diff'] + logger.debug("Done collecting stats for table #%s", i) def _create_outer_join(self, table1, table2): db = table1.database @@ -298,13 +286,13 @@ def _create_outer_join(self, table1, table2): if len(keys1) != len(keys2): raise ValueError("The provided key columns are of a different count") - cols1 = table1._relevant_columns - cols2 = table2._relevant_columns + cols1 = table1.relevant_columns + cols2 = table2.relevant_columns if len(cols1) != len(cols2): raise ValueError("The provided columns are of a different count") - a = table1._make_select() - b = table2._make_select() + a = table1.make_select() + b = table2.make_select() is_diff_cols = {f"is_diff_{c1}": bool_to_int(a[c1].is_distinct_from(b[c2])) for c1, c2 in safezip(cols1, cols2)} @@ -359,4 +347,4 @@ def _materialize_diff(self, db, diff_rows, segment_index=None): f = append_to_table_oracle if isinstance(db, Oracle) else append_to_table db.query(f(self.materialize_to_table, diff_rows.limit(self.write_limit))) - logger.info(f"Materialized diff to table '{'.'.join(self.materialize_to_table)}'.") + logger.info("Materialized diff to table '%s'.", ".".join(self.materialize_to_table)) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index a73a69db..226c246b 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -131,7 +131,7 @@ def count(self): return Select(self, [Count()]) def union(self, other: "ITable"): - return Union(self, other) + return SetUnion(self, other) @dataclass @@ -401,7 +401,7 @@ def having(self): @dataclass -class Union(ExprNode, ITable): +class SetUnion(ExprNode, ITable): table1: ITable table2: ITable @@ -422,12 +422,12 @@ def schema(self): def compile(self, parent_c: Compiler) -> str: c = parent_c.replace(in_select=False) - union_all = f"{c.compile(self.table1)} UNION {c.compile(self.table2)}" + union = f"{c.compile(self.table1)} UNION {c.compile(self.table2)}" if parent_c.in_select: - union_all = f"({union_all}) {c.new_unique_name()}" + union = f"({union}) {c.new_unique_name()}" elif parent_c.in_join: - union_all = f"({union_all})" - return union_all + union = f"({union})" + return union @dataclass @@ -567,14 +567,6 @@ def __getitem__(self, name): return _ResolveColumn(name) -@dataclass -class Explain(ExprNode): - sql: Select - - def compile(self, c: Compiler) -> str: - return f"EXPLAIN {c.compile(self.sql)}" - - @dataclass class In(ExprNode): expr: Expr diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index 02bb48bc..eda7d981 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -1,7 +1,7 @@ import random from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Dict, Sequence, List, Union +from typing import Any, Dict, Sequence, List from runtype import dataclass diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 170955cd..cddbe9f5 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -1,5 +1,5 @@ import time -from typing import List, Sequence, Tuple +from typing import List, Tuple import logging from runtype import dataclass @@ -12,7 +12,7 @@ logger = logging.getLogger("table_segment") -RECOMMENDED_CHECKSUM_DURATION = 10 +RECOMMENDED_CHECKSUM_DURATION = 20 @dataclass @@ -23,8 +23,8 @@ class TableSegment: database (Database): Database instance. See :meth:`connect` table_path (:data:`DbPath`): Path to table in form of a tuple. e.g. `('my_dataset', 'table_name')` key_columns (Tuple[str]): Name of the key column, which uniquely identifies each row (usually id) - update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update) - Used by `min_update` and `max_update`. + update_column (str, optional): Name of updated column, which signals that rows changed. + Usually updated_at or last_update. Used by `min_update` and `max_update`. extra_columns (Tuple[str, ...], optional): Extra columns to compare min_key (:data:`DbKey`, optional): Lowest key value, used to restrict the segment max_key (:data:`DbKey`, optional): Highest key value, used to restrict the segment @@ -68,7 +68,7 @@ def __post_init__(self): ) def _with_raw_schema(self, raw_schema: dict) -> "TableSegment": - schema = self.database._process_table_schema(self.table_path, raw_schema, self._relevant_columns, self.where) + schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self.where) return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive)) def with_schema(self) -> "TableSegment": @@ -98,12 +98,12 @@ def _make_update_range(self): def source_table(self): return table(*self.table_path, schema=self._schema) - def _make_select(self): + def make_select(self): return self.source_table.where(*self._make_key_range(), *self._make_update_range(), self.where or SKIP) def get_values(self) -> list: "Download all the relevant values of the segment from the database" - select = self._make_select().select(*self._relevant_columns_repr) + select = self.make_select().select(*self._relevant_columns_repr) return self.database.query(select, List[Tuple]) def choose_checkpoints(self, count: int) -> List[DbKey]: @@ -142,7 +142,7 @@ def new(self, **kwargs) -> "TableSegment": return self.replace(**kwargs) @property - def _relevant_columns(self) -> List[str]: + def relevant_columns(self) -> List[str]: extras = list(self.extra_columns) if self.update_column and self.update_column not in extras: @@ -152,22 +152,23 @@ def _relevant_columns(self) -> List[str]: @property def _relevant_columns_repr(self) -> List[Expr]: - return [NormalizeAsString(this[c]) for c in self._relevant_columns] + return [NormalizeAsString(this[c]) for c in self.relevant_columns] def count(self) -> Tuple[int, int]: """Count how many rows are in the segment, in one pass.""" - return self.database.query(self._make_select().select(Count()), int) + return self.database.query(self.make_select().select(Count()), int) def count_and_checksum(self) -> Tuple[int, int]: """Count and checksum the rows in the segment, in one pass.""" start = time.monotonic() - q = self._make_select().select(Count(), Checksum(self._relevant_columns_repr)) + q = self.make_select().select(Count(), Checksum(self._relevant_columns_repr)) count, checksum = self.database.query(q, tuple) duration = time.monotonic() - start if duration > RECOMMENDED_CHECKSUM_DURATION: logger.warning( - f"Checksum is taking longer than expected ({duration:.2f}s). " - "We recommend increasing --bisection-factor or decreasing --threads." + "Checksum is taking longer than expected (%.2f). " + "We recommend increasing --bisection-factor or decreasing --threads.", + duration, ) if count: @@ -178,7 +179,7 @@ def query_key_range(self) -> Tuple[int, int]: """Query database for minimum and maximum key. This is used for setting the initial bounds.""" # Normalizes the result (needed for UUIDs) after the min/max computation (k,) = self.key_columns - select = self._make_select().select( + select = self.make_select().select( ApplyFuncAndNormalizeAsString(this[k], min_), ApplyFuncAndNormalizeAsString(this[k], max_), ) diff --git a/data_diff/thread_utils.py b/data_diff/thread_utils.py index 1e0d26b8..1be94ad4 100644 --- a/data_diff/thread_utils.py +++ b/data_diff/thread_utils.py @@ -1,9 +1,9 @@ import itertools -from concurrent.futures.thread import _WorkItem from queue import PriorityQueue from collections import deque from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor +from concurrent.futures.thread import _WorkItem from time import sleep from typing import Callable, Iterator, Optional diff --git a/data_diff/utils.py b/data_diff/utils.py index b572db1b..2c8ccfba 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -192,7 +192,10 @@ def remove_password_from_url(url: str, replace_with: str = "***") -> str: def join_iter(joiner: Any, iterable: Iterable) -> Iterable: it = iter(iterable) - yield next(it) + try: + yield next(it) + except StopIteration: + return for i in it: yield joiner yield i @@ -221,6 +224,10 @@ def __contains__(self, key: str) -> bool: def __repr__(self): return repr(dict(self.items())) + @abstractmethod + def items(self) -> Iterable[Tuple[str, V]]: + ... + class CaseInsensitiveDict(CaseAwareMapping): def __init__(self, initial): @@ -302,7 +309,7 @@ def getLogger(name): def eval_name_template(name): - def get_timestamp(m): + def get_timestamp(_match): return datetime.now().isoformat("_", "seconds").replace(":", "_") return re.sub("%t", get_timestamp, name) diff --git a/tests/test_database_types.py b/tests/test_database_types.py index c9e9042c..b63eb2b2 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -418,13 +418,10 @@ def __iter__(self): type_pairs = [] for source_db, source_type_categories in DATABASE_TYPES.items(): for target_db, target_type_categories in DATABASE_TYPES.items(): - for ( - type_category, - source_types, - ) in source_type_categories.items(): # int, datetime, .. - for source_type in source_types: - for target_type in target_type_categories[type_category]: - if CONN_STRINGS.get(source_db, False) and CONN_STRINGS.get(target_db, False): + if CONN_STRINGS.get(source_db, False) and CONN_STRINGS.get(target_db, False): + for type_category, source_types in source_type_categories.items(): # int, datetime, .. + for source_type in source_types: + for target_type in target_type_categories[type_category]: type_pairs.append( ( source_db, diff --git a/tests/test_query.py b/tests/test_query.py index 3ab26e43..d02e9745 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -149,7 +149,7 @@ def test_funcs(self): q = c.compile(t.order_by(Random()).limit(10)) assert q == "SELECT * FROM a ORDER BY random() limit 10" - def test_union_all(self): + def test_union(self): c = Compiler(MockDialect()) a = table("a").select("x") b = table("b").select("y") From 68e6228a76b32f7266bc8fbca7f3f1cddacf106d Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 13 Oct 2022 12:30:59 +0200 Subject: [PATCH 34/93] Deprecate use of FixedAlphanum (Issue #252) --- data_diff/databases/base.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 897d4c3a..ea6d1d0f 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -19,7 +19,6 @@ Native_UUID, String_UUID, String_Alphanum, - String_FixedAlphanum, String_VaryingAlphanum, TemporalType, UnknownColType, @@ -133,7 +132,7 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list): for row in explain: # Most returned a 1-tuple. Presto returns a string if isinstance(row, tuple): - row ,= row + (row,) = row logger.debug("EXPLAIN: %s", row) answer = input("Continue? [y/n] ") if not answer.lower() in ["y", "yes"]: @@ -240,7 +239,7 @@ def _process_table_schema( # Return a dict of form {name: type} after normalization return col_dict - def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], where: str = None): + def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], where: str = None, sample_size=32): """Refine the types in the column dict, by querying the database for a sample of their values 'where' restricts the rows to be sampled. @@ -251,7 +250,7 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe return fields = [self.normalize_uuid(c, String_UUID()) for c in text_columns] - samples_by_row = self.query(table(*table_path).select(*fields).where(where or SKIP).limit(16), list) + samples_by_row = self.query(table(*table_path).select(*fields).where(where or SKIP).limit(sample_size), list) if not samples_by_row: raise ValueError(f"Table {table_path} is empty.") @@ -279,13 +278,7 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe ) else: assert col_name in col_dict - lens = set(map(len, alphanum_samples)) - if len(lens) > 1: - col_dict[col_name] = String_VaryingAlphanum() - else: - (length,) = lens - col_dict[col_name] = String_FixedAlphanum(length=length) - continue + col_dict[col_name] = String_VaryingAlphanum() # @lru_cache() # def get_table_schema(self, path: DbPath) -> Dict[str, ColType]: From 4ea5e2de1b6e43808f49c40156e879ec33d1504a Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 13 Oct 2022 13:21:26 +0200 Subject: [PATCH 35/93] Bugfix: Joindiff crashed when no numeric columns were used. --- data_diff/joindiff_tables.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 58246def..b1134774 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -34,7 +34,11 @@ def merge_dicts(dicts): i = iter(dicts) - res = next(i) + try: + res = next(i) + except StopIteration: + return {} + for d in i: res.update(d) return res From 6ab6f5f26bbf5b8597785b365d005f5e03acb679 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 14 Oct 2022 09:45:10 +0200 Subject: [PATCH 36/93] Small refactor for Oracle --- data_diff/joindiff_tables.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index b1134774..85cc7225 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -103,12 +103,11 @@ def bool_to_int(x): def _outerjoin(db: Database, a: ITable, b: ITable, keys1: List[str], keys2: List[str], select_fields: dict) -> ITable: on = [a[k1] == b[k2] for k1, k2 in safezip(keys1, keys2)] + is_exclusive_a = and_(b[k] == None for k in keys2) + is_exclusive_b = and_(a[k] == None for k in keys1) if isinstance(db, Oracle): - is_exclusive_a = and_(bool_to_int(b[k] == None) for k in keys2) - is_exclusive_b = and_(bool_to_int(a[k] == None) for k in keys1) - else: - is_exclusive_a = and_(b[k] == None for k in keys2) - is_exclusive_b = and_(a[k] == None for k in keys1) + is_exclusive_a = bool_to_int(is_exclusive_a) + is_exclusive_b = bool_to_int(is_exclusive_b) if isinstance(db, MySQL): # No outer join From 3ed6d220a940effc8c9a39439cd14151e6894bea Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 11 Oct 2022 18:03:22 +0200 Subject: [PATCH 37/93] Refactored test_per_class variants into common.test_each_class_in_list --- tests/common.py | 68 +++++++++++++++++++++- tests/test_database_types.py | 5 +- tests/test_diff_tables.py | 108 ++++++----------------------------- tests/test_joindiff.py | 46 +++------------ tests/test_postgresql.py | 9 +-- 5 files changed, 98 insertions(+), 138 deletions(-) diff --git a/tests/common.py b/tests/common.py index 5cce3964..996d29bf 100644 --- a/tests/common.py +++ b/tests/common.py @@ -3,11 +3,16 @@ import os import string import random +from typing import Callable +import unittest +import logging +import subprocess + +from parameterized import parameterized_class from data_diff import databases as db from data_diff import tracking -import logging -import subprocess +from data_diff import connect tracking.disable_tracking() @@ -72,6 +77,14 @@ def get_git_revision_short_hash() -> str: db.Vertica: TEST_VERTICA_CONN_STRING, } +_database_instances = {} + + +def get_conn(cls: type): + if cls not in _database_instances: + _database_instances[cls] = connect(CONN_STRINGS[cls], N_THREADS) + return _database_instances[cls] + def _print_used_dbs(): used = {k.__name__ for k, v in CONN_STRINGS.items() if v is not None} @@ -115,3 +128,54 @@ def _drop_table_if_exists(conn, table): conn.query(f"DROP TABLE IF EXISTS {table}", None) if not isinstance(conn, (db.BigQuery, db.Databricks, db.Clickhouse)): conn.query("COMMIT", None) + + +class TestPerDatabase(unittest.TestCase): + db_cls = None + with_preql = False + + preql = None + + def setUp(self): + assert self.db_cls, self.db_cls + + self.connection = get_conn(self.db_cls) + if self.with_preql: + import preql + + self.preql = preql.Preql(CONN_STRINGS[self.db_cls]) + + table_suffix = random_table_suffix() + self.table_src_name = f"src{table_suffix}" + self.table_dst_name = f"dst{table_suffix}" + + self.table_src_path = self.connection.parse_table_name(self.table_src_name) + self.table_dst_path = self.connection.parse_table_name(self.table_dst_name) + + self.table_src = ".".join(map(self.connection.quote, self.table_src_path)) + self.table_dst = ".".join(map(self.connection.quote, self.table_dst_path)) + + _drop_table_if_exists(self.connection, self.table_src) + _drop_table_if_exists(self.connection, self.table_dst) + + return super().setUp() + + def tearDown(self): + if self.preql: + self.preql._interp.state.db.rollback() + self.preql.close() + + _drop_table_if_exists(self.connection, self.table_src) + _drop_table_if_exists(self.connection, self.table_dst) + + +def _parameterized_class_per_conn(test_databases): + names = [(cls.__name__, cls) for cls in CONN_STRINGS if cls in test_databases] + return parameterized_class(("name", "db_cls"), names) + + +def test_each_database_in_list(databases) -> Callable: + def _test_per_database(cls): + return _parameterized_class_per_conn(databases)(cls) + + return _test_per_database diff --git a/tests/test_database_types.py b/tests/test_database_types.py index b63eb2b2..08bacb96 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -23,6 +23,7 @@ N_THREADS, BENCHMARK, GIT_REVISION, + get_conn, random_table_suffix, _drop_table_if_exists, ) @@ -35,8 +36,8 @@ def init_conns(): if CONNS is not None: return - CONNS = {k: db.connect.connect(v, N_THREADS) for k, v in CONN_STRINGS.items()} - CONNS[db.MySQL].query("SET @@session.time_zone='+00:00'", None) + CONNS = {cls: get_conn(cls) for cls in CONN_STRINGS} + CONNS[db.MySQL].query("SET @@session.time_zone='+00:00'") oracle.SESSION_TIME_ZONE = postgresql.SESSION_TIME_ZONE = "UTC" diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 84d70434..9703b7d4 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -1,48 +1,21 @@ import datetime -import unittest +from typing import Callable import uuid +import unittest -from parameterized import parameterized_class -import preql import arrow # comes with preql -from data_diff.databases.connect import connect from data_diff.hashdiff_tables import HashDiffer from data_diff.table_segment import TableSegment, split_space from data_diff import databases as db from data_diff.utils import ArithAlphanumeric, numberToAlphanum -from .common import ( - TEST_MYSQL_CONN_STRING, - str_to_checksum, - random_table_suffix, - _drop_table_if_exists, - CONN_STRINGS, - N_THREADS, -) - -DATABASE_INSTANCES = None -DATABASE_URIS = {k.__name__: v for k, v in CONN_STRINGS.items()} - - -def init_instances(): - global DATABASE_INSTANCES - if DATABASE_INSTANCES is not None: - return - - DATABASE_INSTANCES = {k.__name__: connect(v, N_THREADS) for k, v in CONN_STRINGS.items()} - +from .common import str_to_checksum, test_each_database_in_list, TestPerDatabase -TEST_DATABASES = {x.__name__ for x in (db.MySQL, db.PostgreSQL, db.Oracle, db.Redshift, db.Snowflake, db.BigQuery)} +TEST_DATABASES = {db.MySQL, db.PostgreSQL, db.Oracle, db.Redshift, db.Snowflake, db.BigQuery} -def _class_per_db_dec(filter_name=None): - names = [ - (name, name) - for name in DATABASE_URIS - if (name in TEST_DATABASES) and (filter_name is None or filter_name(name)) - ] - return parameterized_class(("name", "db_name"), names) +test_each_database: Callable = test_each_database_in_list(TEST_DATABASES) def _table_segment(database, table_path, key_columns, *args, **kw): @@ -51,17 +24,6 @@ def _table_segment(database, table_path, key_columns, *args, **kw): return TableSegment(database, table_path, key_columns, *args, **kw) -def test_per_database(cls): - return _class_per_db_dec()(cls) - - -def test_per_database__filter_name(filter_name): - def _test_per_database(cls): - return _class_per_db_dec(filter_name=filter_name)(cls) - - return _test_per_database - - def _insert_row(conn, table, fields, values): fields = ", ".join(map(str, fields)) values = ", ".join(map(str, values)) @@ -101,45 +63,7 @@ def test_split_space(self): assert len(r) == n, f"split_space({i}, {j+n}, {n}) = {(r)}" -class TestPerDatabase(unittest.TestCase): - db_name = None - with_preql = False - - preql = None - - def setUp(self): - assert self.db_name, self.db_name - init_instances() - - self.connection = DATABASE_INSTANCES[self.db_name] - if self.with_preql: - self.preql = preql.Preql(DATABASE_URIS[self.db_name]) - - table_suffix = random_table_suffix() - self.table_src_name = f"src{table_suffix}" - self.table_dst_name = f"dst{table_suffix}" - - self.table_src_path = self.connection.parse_table_name(self.table_src_name) - self.table_dst_path = self.connection.parse_table_name(self.table_dst_name) - - self.table_src = ".".join(map(self.connection.quote, self.table_src_path)) - self.table_dst = ".".join(map(self.connection.quote, self.table_dst_path)) - - _drop_table_if_exists(self.connection, self.table_src) - _drop_table_if_exists(self.connection, self.table_dst) - - return super().setUp() - - def tearDown(self): - if self.preql: - self.preql._interp.state.db.rollback() - self.preql.close() - - _drop_table_if_exists(self.connection, self.table_src) - _drop_table_if_exists(self.connection, self.table_dst) - - -@test_per_database +@test_each_database class TestDates(TestPerDatabase): with_preql = True @@ -244,7 +168,7 @@ def test_offset(self): self.assertEqual(len(list(differ.diff_tables(a, b))), 1) -@test_per_database +@test_each_database class TestDiffTables(TestPerDatabase): with_preql = True @@ -411,7 +335,7 @@ def test_diff_sorted_by_key(self): self.assertEqual(expected, diff) -@test_per_database +@test_each_database class TestDiffTables2(TestPerDatabase): def test_diff_column_names(self): float_type = _get_float_type(self.connection) @@ -465,7 +389,7 @@ def test_diff_column_names(self): assert diff == [] -@test_per_database +@test_each_database class TestUUIDs(TestPerDatabase): def setUp(self): super().setUp() @@ -515,7 +439,7 @@ def test_where_sampling(self): self.assertRaises(ValueError, list, differ.diff_tables(a_empty, self.b)) -@test_per_database__filter_name(lambda n: n != "MySQL") +@test_each_database_in_list(TEST_DATABASES - {db.MySQL}) class TestAlphanumericKeys(TestPerDatabase): def setUp(self): super().setUp() @@ -565,7 +489,7 @@ def test_alphanum_keys(self): self.assertRaises(NotImplementedError, list, differ.diff_tables(self.a, self.b)) -@test_per_database__filter_name(lambda n: n != "MySQL") +@test_each_database_in_list(TEST_DATABASES - {db.MySQL}) class TestVaryingAlphanumericKeys(TestPerDatabase): def setUp(self): super().setUp() @@ -623,7 +547,7 @@ def test_varying_alphanum_keys(self): self.assertRaises(NotImplementedError, list, differ.diff_tables(self.a, self.b)) -@test_per_database +@test_each_database class TestTableSegment(TestPerDatabase): def setUp(self) -> None: super().setUp() @@ -657,7 +581,7 @@ def test_case_awareness(self): ) -@test_per_database +@test_each_database class TestTableUUID(TestPerDatabase): def setUp(self): super().setUp() @@ -691,7 +615,7 @@ def test_uuid_column_with_nulls(self): self.assertEqual(diff, [("-", (str(self.null_uuid), None))]) -@test_per_database +@test_each_database class TestTableNullRowChecksum(TestPerDatabase): def setUp(self): super().setUp() @@ -741,7 +665,7 @@ def test_uuid_columns_with_nulls(self): self.assertEqual(diff, [("-", (str(self.null_uuid), None))]) -@test_per_database +@test_each_database class TestConcatMultipleColumnWithNulls(TestPerDatabase): def setUp(self): super().setUp() @@ -805,7 +729,7 @@ def test_tables_are_different(self): self.assertEqual(diff, self.diffs) -@test_per_database +@test_each_database class TestTableTableEmpty(TestPerDatabase): def setUp(self): super().setUp() diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index b1babe35..a2c220cb 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -1,35 +1,19 @@ -from functools import wraps from typing import List -from parameterized import parameterized_class -from data_diff.databases.connect import connect from data_diff.queries.ast_classes import TablePath -from data_diff.table_segment import TableSegment, split_space +from data_diff.table_segment import TableSegment from data_diff import databases as db from data_diff.joindiff_tables import JoinDiffer -from .test_diff_tables import TestPerDatabase, _get_float_type, _get_text_type, _commit, _insert_row, _insert_rows +from .test_diff_tables import TestPerDatabase, _get_float_type, _commit, _insert_row, _insert_rows from .common import ( random_table_suffix, - str_to_checksum, - CONN_STRINGS, - N_THREADS, + test_each_database_in_list, ) -DATABASE_INSTANCES = None -DATABASE_URIS = {k.__name__: v for k, v in CONN_STRINGS.items()} - -def init_instances(): - global DATABASE_INSTANCES - if DATABASE_INSTANCES is not None: - return - - DATABASE_INSTANCES = {k.__name__: connect(v, N_THREADS) for k, v in CONN_STRINGS.items()} - - -TEST_DATABASES = ( +TEST_DATABASES = { db.PostgreSQL, db.Snowflake, db.MySQL, @@ -39,26 +23,12 @@ def init_instances(): db.Trino, db.Oracle, db.Redshift, -) - - -def test_per_database(cls, dbs=TEST_DATABASES): - dbs = {db.__name__ for db in dbs} - _class_per_db_dec = parameterized_class( - ("name", "db_name"), [(name, name) for name in DATABASE_URIS if name in dbs] - ) - return _class_per_db_dec(cls) - - -def test_per_database2(*dbs): - @wraps(test_per_database) - def dec(cls): - return test_per_database(cls, dbs) +} - return dec +test_each_database = test_each_database_in_list(TEST_DATABASES) -@test_per_database2(db.Snowflake, db.BigQuery) +@test_each_database_in_list({db.Snowflake, db.BigQuery}) class TestCompositeKey(TestPerDatabase): def setUp(self): super().setUp() @@ -103,7 +73,7 @@ def test_composite_key(self): assert self.differ.stats["exclusive_count"] == 2 -@test_per_database +@test_each_database class TestJoindiff(TestPerDatabase): def setUp(self): super().setUp() diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 21e64b3e..0c57d299 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -1,12 +1,13 @@ import unittest -from data_diff import TableSegment, HashDiffer, connect -from .common import TEST_POSTGRESQL_CONN_STRING, TEST_MYSQL_CONN_STRING, random_table_suffix +from data_diff import TableSegment, HashDiffer +from data_diff import databases as db +from .common import get_conn, random_table_suffix class TestUUID(unittest.TestCase): def setUp(self) -> None: - self.connection = connect(TEST_POSTGRESQL_CONN_STRING) + self.connection = get_conn(db.PostgreSQL) table_suffix = random_table_suffix() @@ -46,7 +47,7 @@ def test_uuid(self): self.assertEqual(diff, [("-", (uuid, "This one is different"))]) # Compare with MySql - mysql_conn = connect(TEST_MYSQL_CONN_STRING) + mysql_conn = get_conn(db.MySQL) rows = self.connection.query(f"SELECT * FROM {self.table_src}", list) From 79e5af68bec9b881c43565d4dd1fcd5da6114239 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 11 Oct 2022 19:27:01 +0200 Subject: [PATCH 38/93] Tests refactor: Use queries-builder to create tables --- data_diff/databases/base.py | 2 + data_diff/databases/bigquery.py | 4 +- data_diff/databases/presto.py | 6 +++ data_diff/queries/api.py | 7 ++- tests/test_diff_tables.py | 91 +++++++++++++-------------------- tests/test_joindiff.py | 34 +++++++----- 6 files changed, 71 insertions(+), 73 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index ea6d1d0f..f2cb370d 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -1,3 +1,4 @@ +from datetime import datetime import math import sys import logging @@ -329,6 +330,7 @@ def type_repr(self, t) -> str: str: "VARCHAR", bool: "BOOLEAN", float: "FLOAT", + datetime: "TIMESTAMP", }[t] def _query_cursor(self, c, sql_code: str): diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 603bfecc..9c500dd5 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -109,8 +109,6 @@ def is_autocommit(self) -> bool: def type_repr(self, t) -> str: try: - return { - str: "STRING", - }[t] + return {str: "STRING", float: "FLOAT64"}[t] except KeyError: return super().type_repr(t) diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index 811a9491..2d69efc8 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -153,3 +153,9 @@ def is_autocommit(self) -> bool: def explain_as_text(self, query: str) -> str: return f"EXPLAIN (FORMAT TEXT) {query}" + + def type_repr(self, t) -> str: + try: + return {float: "REAL"}[t] + except KeyError: + return super().type_repr(t) diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index d9c0945f..2f5d96be 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -1,4 +1,6 @@ from typing import Optional + +from data_diff.utils import CaseAwareMapping, CaseSensitiveDict from .ast_classes import * from .base import args_as_tuple @@ -30,11 +32,14 @@ def cte(expr: Expr, *, name: Optional[str] = None, params: Sequence[str] = None) return Cte(expr, name, params) -def table(*path: str, schema: Schema = None) -> TablePath: +def table(*path: str, schema: Union[dict, CaseAwareMapping] = None) -> TablePath: if len(path) == 1 and isinstance(path[0], tuple): (path,) = path if not all(isinstance(i, str) for i in path): raise TypeError(f"All elements of table path must be of type 'str'. Got: {path}") + if schema and not isinstance(schema, CaseAwareMapping): + assert isinstance(schema, dict) + schema = CaseSensitiveDict(schema) return TablePath(path, schema) diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 9703b7d4..e4465ad8 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -5,6 +5,8 @@ import arrow # comes with preql +from data_diff.queries import table + from data_diff.hashdiff_tables import HashDiffer from data_diff.table_segment import TableSegment, split_space from data_diff import databases as db @@ -40,20 +42,6 @@ def _commit(conn): conn.query("COMMIT", None) -def _get_text_type(conn): - if isinstance(conn, db.BigQuery): - return "STRING" - return "varchar(100)" - - -def _get_float_type(conn): - if isinstance(conn, db.BigQuery): - return "FLOAT64" - elif isinstance(conn, db.Presto): - return "REAL" - return "float" - - class TestUtils(unittest.TestCase): def test_split_space(self): for i in range(0, 10): @@ -175,15 +163,22 @@ class TestDiffTables(TestPerDatabase): def setUp(self): super().setUp() - float_type = _get_float_type(self.connection) + src_table = table( + self.table_src_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime.datetime}, + ) + dst_table = table( + self.table_dst_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime.datetime}, + ) self.connection.query( - f"create table {self.table_src}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", - None, + src_table.create(), + # f"create table {self.table_src}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", ) self.connection.query( - f"create table {self.table_dst}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", - None, + # f"create table {self.table_dst}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", + dst_table.create() ) _commit(self.connection) @@ -338,16 +333,12 @@ def test_diff_sorted_by_key(self): @test_each_database class TestDiffTables2(TestPerDatabase): def test_diff_column_names(self): - float_type = _get_float_type(self.connection) - self.connection.query( - f"create table {self.table_src}(id int, rating {float_type}, timestamp timestamp)", - None, - ) - self.connection.query( - f"create table {self.table_dst}(id2 int, rating2 {float_type}, timestamp2 timestamp)", - None, - ) + src_table = table(self.table_src_path, schema={"id": int, "rating": float, "timestamp": datetime.datetime}) + dst_table = table(self.table_dst_path, schema={"id2": int, "rating2": float, "timestamp2": datetime.datetime}) + + self.connection.query(src_table.create()) + self.connection.query(dst_table.create()) _commit(self.connection) time = "2022-01-01 00:00:00" @@ -394,11 +385,9 @@ class TestUUIDs(TestPerDatabase): def setUp(self): super().setUp() - text_type = _get_text_type(self.connection) + src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) - queries = [ - f"CREATE TABLE {self.table_src}(id {text_type}, text_comment {text_type})", - ] + queries = [src_table.create()] for i in range(100): queries.append(f"INSERT INTO {self.table_src} VALUES ('{uuid.uuid1(i)}', '{i}')") @@ -444,11 +433,9 @@ class TestAlphanumericKeys(TestPerDatabase): def setUp(self): super().setUp() - text_type = _get_text_type(self.connection) + src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) - queries = [ - f"CREATE TABLE {self.table_src}(id {text_type}, text_comment {text_type})", - ] + queries = [src_table.create()] for i in range(0, 10000, 1000): a = ArithAlphanumeric(numberToAlphanum(i), max_len=10) if not a and isinstance(self.connection, db.Oracle): @@ -494,11 +481,9 @@ class TestVaryingAlphanumericKeys(TestPerDatabase): def setUp(self): super().setUp() - text_type = _get_text_type(self.connection) + src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) - queries = [ - f"CREATE TABLE {self.table_src}(id {text_type}, text_comment {text_type})", - ] + queries = [src_table.create()] for i in range(0, 10000, 1000): a = ArithAlphanumeric(numberToAlphanum(i * i)) if not a and isinstance(self.connection, db.Oracle): @@ -586,11 +571,9 @@ class TestTableUUID(TestPerDatabase): def setUp(self): super().setUp() - text_type = _get_text_type(self.connection) + src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) - queries = [ - f"CREATE TABLE {self.table_src}(id {text_type}, text_comment {text_type})", - ] + queries = [src_table.create()] for i in range(10): uuid_value = uuid.uuid1(i) queries.append(f"INSERT INTO {self.table_src} VALUES ('{uuid_value}', '{uuid_value}')") @@ -620,11 +603,11 @@ class TestTableNullRowChecksum(TestPerDatabase): def setUp(self): super().setUp() - text_type = _get_text_type(self.connection) + src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) self.null_uuid = uuid.uuid1(1) queries = [ - f"CREATE TABLE {self.table_src}(id {text_type}, text_comment {text_type})", + src_table.create(), f"INSERT INTO {self.table_src} VALUES ('{uuid.uuid1(1)}', '1')", f"CREATE TABLE {self.table_dst} AS SELECT * FROM {self.table_src}", # Add a row where a column has NULL value @@ -670,12 +653,10 @@ class TestConcatMultipleColumnWithNulls(TestPerDatabase): def setUp(self): super().setUp() - text_type = _get_text_type(self.connection) + src_table = table(self.table_src_path, schema={"id": str, "c1": str, "c2": str}) + dst_table = table(self.table_dst_path, schema={"id": str, "c1": str, "c2": str}) - queries = [ - f"CREATE TABLE {self.table_src}(id {text_type}, c1 {text_type}, c2 {text_type})", - f"CREATE TABLE {self.table_dst}(id {text_type}, c1 {text_type}, c2 {text_type})", - ] + queries = [src_table.create(), dst_table.create()] self.diffs = [] for i in range(0, 8): @@ -734,13 +715,11 @@ class TestTableTableEmpty(TestPerDatabase): def setUp(self): super().setUp() - text_type = _get_text_type(self.connection) + src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) + dst_table = table(self.table_dst_path, schema={"id": str, "text_comment": str}) self.null_uuid = uuid.uuid1(1) - queries = [ - f"CREATE TABLE {self.table_src}(id {text_type}, text_comment {text_type})", - f"CREATE TABLE {self.table_dst}(id {text_type}, text_comment {text_type})", - ] + queries = [src_table.create(), dst_table.create()] self.diffs = [(uuid.uuid1(i), i) for i in range(100)] for pk, value in self.diffs: diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index a2c220cb..5a01b6e1 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -1,11 +1,13 @@ from typing import List +from datetime import datetime from data_diff.queries.ast_classes import TablePath +from data_diff.queries import table from data_diff.table_segment import TableSegment from data_diff import databases as db from data_diff.joindiff_tables import JoinDiffer -from .test_diff_tables import TestPerDatabase, _get_float_type, _commit, _insert_row, _insert_rows +from .test_diff_tables import TestPerDatabase, _commit, _insert_row, _insert_rows from .common import ( random_table_suffix, @@ -33,14 +35,17 @@ class TestCompositeKey(TestPerDatabase): def setUp(self): super().setUp() - float_type = _get_float_type(self.connection) - - self.connection.query( - f"create table {self.table_src}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", + src_table = table( + self.table_src_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}, ) - self.connection.query( - f"create table {self.table_dst}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", + dst_table = table( + self.table_dst_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}, ) + + self.connection.query(src_table.create()) + self.connection.query(dst_table.create()) _commit(self.connection) self.differ = JoinDiffer() @@ -78,14 +83,17 @@ class TestJoindiff(TestPerDatabase): def setUp(self): super().setUp() - float_type = _get_float_type(self.connection) - - self.connection.query( - f"create table {self.table_src}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", + src_table = table( + self.table_src_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}, ) - self.connection.query( - f"create table {self.table_dst}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", + dst_table = table( + self.table_dst_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}, ) + + self.connection.query(src_table.create()) + self.connection.query(dst_table.create()) _commit(self.connection) self.table = TableSegment(self.connection, self.table_src_path, ("id",), "timestamp", case_sensitive=False) From da48ceedb0dde371f64224e079cea03b7bfa3c1a Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 12 Oct 2022 14:31:49 +0200 Subject: [PATCH 39/93] Tests: Removed dependency on Preql; using query-builder instead --- data_diff/queries/ast_classes.py | 36 +++++++++++++-- tests/common.py | 11 ----- tests/test_api.py | 71 +++++++++++++++-------------- tests/test_cli.py | 69 ++++++++++++++-------------- tests/test_diff_tables.py | 78 +++++++++++++------------------- 5 files changed, 134 insertions(+), 131 deletions(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 226c246b..7f88fb76 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -1,6 +1,6 @@ from dataclasses import field from datetime import datetime -from typing import Any, Generator, Optional, Sequence, Tuple, Union +from typing import Any, Generator, List, Optional, Sequence, Tuple, Union from runtype import dataclass @@ -306,8 +306,12 @@ def create(self, if_not_exists=False): def drop(self, if_exists=False): return DropTable(self, if_exists=if_exists) - def insert_values(self, rows): - raise NotImplementedError() + def insert_rows(self, rows): + rows = list(rows) + return InsertToTable(self, ConstantTable(rows)) + + def insert_row(self, *values): + return InsertToTable(self, ConstantTable([values])) def insert_expr(self, expr: Expr): return InsertToTable(self, expr) @@ -592,6 +596,25 @@ def compile(self, c: Compiler) -> str: return c.database.random() +@dataclass +class ConstantTable(ExprNode): + rows: List[Tuple] + + def compile(self, c: Compiler) -> str: + raise NotImplementedError() + + def _value(self, v): + if isinstance(v, str): + return f"'{v}'" + elif isinstance(v, datetime): + return f"timestamp '{v}'" + return str(v) + + def compile_for_insert(self, c: Compiler): + values = ", ".join("(%s)" % ", ".join(self._value(v) for v in row) for row in self.rows) + return f"VALUES {values}" + + @dataclass class Explain(ExprNode): select: Select @@ -635,7 +658,12 @@ class InsertToTable(Statement): expr: Expr def compile(self, c: Compiler) -> str: - return f"INSERT INTO {c.compile(self.path)} {c.compile(self.expr)}" + if isinstance(self.expr, ConstantTable): + expr = self.expr.compile_for_insert(c) + else: + expr = c.compile(self.expr) + + return f"INSERT INTO {c.compile(self.path)} {expr}" @dataclass diff --git a/tests/common.py b/tests/common.py index 996d29bf..5e689131 100644 --- a/tests/common.py +++ b/tests/common.py @@ -132,18 +132,11 @@ def _drop_table_if_exists(conn, table): class TestPerDatabase(unittest.TestCase): db_cls = None - with_preql = False - - preql = None def setUp(self): assert self.db_cls, self.db_cls self.connection = get_conn(self.db_cls) - if self.with_preql: - import preql - - self.preql = preql.Preql(CONN_STRINGS[self.db_cls]) table_suffix = random_table_suffix() self.table_src_name = f"src{table_suffix}" @@ -161,10 +154,6 @@ def setUp(self): return super().setUp() def tearDown(self): - if self.preql: - self.preql._interp.state.db.rollback() - self.preql.close() - _drop_table_if_exists(self.connection, self.table_src) _drop_table_if_exists(self.connection, self.table_dst) diff --git a/tests/test_api.py b/tests/test_api.py index bac88c84..2c67b481 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,49 +1,50 @@ import unittest -import preql import arrow +from datetime import datetime from data_diff import diff_tables, connect_to_table +from data_diff.databases import MySQL +from data_diff.queries.api import table -from .common import TEST_MYSQL_CONN_STRING +from .common import TEST_MYSQL_CONN_STRING, get_conn + + +def _commit(conn): + conn.query("COMMIT", None) class TestApi(unittest.TestCase): def setUp(self) -> None: - self.preql = preql.Preql(TEST_MYSQL_CONN_STRING) - self.preql( - r""" - table test_api { - datetime: datetime - comment: string - } - commit() - - func add(date, comment) { - new test_api(date, comment) - } - """ - ) - self.now = now = arrow.get(self.preql.now()) - self.preql.add(now, "now") - self.preql.add(now, self.now.shift(seconds=-10)) - self.preql.add(now, self.now.shift(seconds=-7)) - self.preql.add(now, self.now.shift(seconds=-6)) - - self.preql( - r""" - const table test_api_2 = test_api - commit() - """ - ) - - self.preql.add(self.now.shift(seconds=-3), "3 seconds ago") - self.preql.commit() + self.conn = get_conn(MySQL) + table_src_name = "test_api" + table_dst_name = "test_api_2" + self.conn.query(f"drop table if exists {table_src_name}") + self.conn.query(f"drop table if exists {table_dst_name}") + + src_table = table(table_src_name, schema={"id": int, "datetime": datetime, "text_comment": str}) + self.conn.query(src_table.create()) + self.now = now = arrow.get() + + rows = [ + (now, "now"), + (self.now.shift(seconds=-10), "a"), + (self.now.shift(seconds=-7), "b"), + (self.now.shift(seconds=-6), "c"), + ] + + self.conn.query(src_table.insert_rows((i, ts.datetime, s) for i, (ts, s) in enumerate(rows))) + _commit(self.conn) + + self.conn.query(f"CREATE TABLE {table_dst_name} AS SELECT * FROM {table_src_name}") + _commit(self.conn) + + self.conn.query(src_table.insert_row(len(rows), self.now.shift(seconds=-3).datetime, "3 seconds ago")) + _commit(self.conn) def tearDown(self) -> None: - self.preql.run_statement("drop table if exists test_api") - self.preql.run_statement("drop table if exists test_api_2") - self.preql.commit() - self.preql.close() + self.conn.query("drop table if exists test_api") + self.conn.query("drop table if exists test_api_2") + _commit(self.conn) return super().tearDown() diff --git a/tests/test_cli.py b/tests/test_cli.py index 4e866680..263dc872 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,13 +1,19 @@ import logging import unittest -import preql import arrow import subprocess import sys +from datetime import datetime from data_diff import diff_tables, connect_to_table +from data_diff.databases import MySQL +from data_diff.queries import table -from .common import TEST_MYSQL_CONN_STRING +from .common import TEST_MYSQL_CONN_STRING, get_conn + + +def _commit(conn): + conn.query("COMMIT", None) def run_datadiff_cli(*args): @@ -23,41 +29,36 @@ def run_datadiff_cli(*args): class TestCLI(unittest.TestCase): def setUp(self) -> None: - self.preql = preql.Preql(TEST_MYSQL_CONN_STRING) - self.preql( - r""" - table test_cli { - datetime: datetime - comment: string - } - commit() - - func add(date, comment) { - new test_cli(date, comment) - } - """ - ) - self.now = now = arrow.get(self.preql.now()) - self.preql.add(now, "now") - self.preql.add(now, self.now.shift(seconds=-10)) - self.preql.add(now, self.now.shift(seconds=-7)) - self.preql.add(now, self.now.shift(seconds=-6)) - - self.preql( - r""" - const table test_cli_2 = test_cli - commit() - """ - ) + self.conn = get_conn(MySQL) + self.conn.query("drop table if exists test_cli") + self.conn.query("drop table if exists test_cli_2") + table_src_name = "test_cli" + table_dst_name = "test_cli_2" + + src_table = table(table_src_name, schema={"id": int, "datetime": datetime, "text_comment": str}) + self.conn.query(src_table.create()) + self.now = now = arrow.get() + + rows = [ + (now, "now"), + (self.now.shift(seconds=-10), "a"), + (self.now.shift(seconds=-7), "b"), + (self.now.shift(seconds=-6), "c"), + ] + + self.conn.query(src_table.insert_rows((i, ts.datetime, s) for i, (ts, s) in enumerate(rows))) + _commit(self.conn) + + self.conn.query(f"CREATE TABLE {table_dst_name} AS SELECT * FROM {table_src_name}") + _commit(self.conn) - self.preql.add(self.now.shift(seconds=-3), "3 seconds ago") - self.preql.commit() + self.conn.query(src_table.insert_row(len(rows), self.now.shift(seconds=-3).datetime, "3 seconds ago")) + _commit(self.conn) def tearDown(self) -> None: - self.preql.run_statement("drop table if exists test_cli") - self.preql.run_statement("drop table if exists test_cli_2") - self.preql.commit() - self.preql.close() + self.conn.query("drop table if exists test_cli") + self.conn.query("drop table if exists test_cli_2") + _commit(self.conn) return super().tearDown() diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index e4465ad8..fbbdac13 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -1,4 +1,4 @@ -import datetime +from datetime import datetime from typing import Callable import uuid import unittest @@ -53,39 +53,29 @@ def test_split_space(self): @test_each_database class TestDates(TestPerDatabase): - with_preql = True - def setUp(self): super().setUp() - self.preql( - f""" - table {self.table_src_name} {{ - datetime: timestamp - text_comment: string - }} - commit() - - func add(date, text_comment) {{ - new {self.table_src_name}(date, text_comment) - }} - """ - ) - self.now = now = arrow.get(self.preql.now()) - self.preql.add(now.shift(days=-50), "50 days ago") - self.preql.add(now.shift(hours=-3), "3 hours ago") - self.preql.add(now.shift(minutes=-10), "10 mins ago") - self.preql.add(now.shift(seconds=-1), "1 second ago") - self.preql.add(now, "now") - - self.preql( - f""" - const table {self.table_dst_name} = {self.table_src_name} - commit() - """ - ) - self.preql.add(self.now.shift(seconds=-3), "2 seconds ago") - self.preql.commit() + src_table = table(self.table_src_path, schema={"id": int, "datetime": datetime, "text_comment": str}) + self.connection.query(src_table.create()) + self.now = now = arrow.get() + + rows = [ + (now.shift(days=-50), "50 days ago"), + (now.shift(hours=-3), "3 hours ago"), + (now.shift(minutes=-10), "10 mins ago"), + (now.shift(seconds=-1), "1 second ago"), + (now, "now"), + ] + + self.connection.query(src_table.insert_rows((i, ts.datetime, s) for i, (ts, s) in enumerate(rows))) + _commit(self.connection) + + self.connection.query(f"CREATE TABLE {self.table_dst_name} AS SELECT * FROM {self.table_src_name}") + _commit(self.connection) + + self.connection.query(src_table.insert_row(len(rows), self.now.shift(seconds=-3).datetime, "3 seconds ago")) + _commit(self.connection) def test_init(self): a = _table_segment( @@ -107,14 +97,14 @@ def test_basic(self): def test_offset(self): differ = HashDiffer(bisection_factor=2, bisection_threshold=10) - sec1 = self.now.shift(seconds=-1).datetime + sec1 = self.now.shift(seconds=-2).datetime a = _table_segment( self.connection, self.table_src_path, "id", "datetime", max_update=sec1, case_sensitive=False ) b = _table_segment( self.connection, self.table_dst_path, "id", "datetime", max_update=sec1, case_sensitive=False ) - assert a.count() == 4 + assert a.count() == 4, a.count() assert b.count() == 3 assert not list(differ.diff_tables(a, a)) @@ -158,28 +148,22 @@ def test_offset(self): @test_each_database class TestDiffTables(TestPerDatabase): - with_preql = True - def setUp(self): super().setUp() src_table = table( self.table_src_path, - schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime.datetime}, + schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}, ) dst_table = table( self.table_dst_path, - schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime.datetime}, + schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}, ) self.connection.query( src_table.create(), - # f"create table {self.table_src}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", - ) - self.connection.query( - # f"create table {self.table_dst}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", - dst_table.create() ) + self.connection.query(dst_table.create()) _commit(self.connection) self.table = _table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) @@ -279,7 +263,7 @@ def test_return_empty_array_when_same(self): _insert_row(self.connection, self.table_src, cols, [1, 1, 1, 9, time_str]) _insert_row(self.connection, self.table_dst, cols, [1, 1, 1, 9, time_str]) - self.preql.commit() + _commit(self.connection) diff = list(self.differ.diff_tables(self.table, self.table2)) self.assertEqual([], diff) @@ -334,8 +318,8 @@ def test_diff_sorted_by_key(self): class TestDiffTables2(TestPerDatabase): def test_diff_column_names(self): - src_table = table(self.table_src_path, schema={"id": int, "rating": float, "timestamp": datetime.datetime}) - dst_table = table(self.table_dst_path, schema={"id2": int, "rating2": float, "timestamp2": datetime.datetime}) + src_table = table(self.table_src_path, schema={"id": int, "rating": float, "timestamp": datetime}) + dst_table = table(self.table_dst_path, schema={"id2": int, "rating2": float, "timestamp2": datetime}) self.connection.query(src_table.create()) self.connection.query(dst_table.create()) @@ -540,8 +524,8 @@ def setUp(self) -> None: self.table2 = _table_segment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) def test_table_segment(self): - early = datetime.datetime(2021, 1, 1, 0, 0) - late = datetime.datetime(2022, 1, 1, 0, 0) + early = datetime(2021, 1, 1, 0, 0) + late = datetime(2022, 1, 1, 0, 0) self.assertRaises(ValueError, self.table.replace, min_update=late, max_update=early) self.assertRaises(ValueError, self.table.replace, min_key=10, max_key=0) From d7e736df09dd7b02c5d9cf9cd5ede8baa8389c5b Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 13 Oct 2022 11:24:35 +0200 Subject: [PATCH 40/93] Tests refactor: Convert more code to use query-builder --- data_diff/databases/base.py | 6 +- data_diff/queries/ast_classes.py | 53 +++- tests/test_diff_tables.py | 425 +++++++++++++++---------------- tests/test_joindiff.py | 175 +++++++------ 4 files changed, 340 insertions(+), 319 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index f2cb370d..e1ac2208 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -121,6 +121,10 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list): compiler = Compiler(self) if isinstance(sql_ast, Generator): sql_code = ThreadLocalInterpreter(compiler, sql_ast) + elif isinstance(sql_ast, list): + for i in sql_ast[:-1]: + self.query(i) + return self.query(sql_ast[-1], res_type) else: sql_code = compiler.compile(sql_ast) if sql_code is SKIP: @@ -250,7 +254,7 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe if not text_columns: return - fields = [self.normalize_uuid(c, String_UUID()) for c in text_columns] + fields = [self.normalize_uuid(self.quote(c), String_UUID()) for c in text_columns] samples_by_row = self.query(table(*table_path).select(*fields).where(where or SKIP).limit(sample_size), list) if not samples_by_row: raise ValueError(f"Table {table_path} is empty.") diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 7f88fb76..0a7642d5 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -1,6 +1,7 @@ from dataclasses import field from datetime import datetime from typing import Any, Generator, List, Optional, Sequence, Tuple, Union +from uuid import UUID from runtype import dataclass @@ -298,22 +299,29 @@ class TablePath(ExprNode, ITable): path: DbPath schema: Optional[Schema] = field(default=None, repr=False) - def create(self, if_not_exists=False): - if not self.schema: - raise ValueError("Schema must have a value to create table") - return CreateTable(self, if_not_exists=if_not_exists) + def create(self, source_table: ITable = None, *, if_not_exists=False): + if source_table is None and not self.schema: + raise ValueError("Either schema or source table needed to create table") + if isinstance(source_table, TablePath): + source_table = source_table.select() + return CreateTable(self, source_table, if_not_exists=if_not_exists) def drop(self, if_exists=False): return DropTable(self, if_exists=if_exists) - def insert_rows(self, rows): + def truncate(self): + return TruncateTable(self) + + def insert_rows(self, rows, *, columns=None): rows = list(rows) - return InsertToTable(self, ConstantTable(rows)) + return InsertToTable(self, ConstantTable(rows), columns=columns) - def insert_row(self, *values): - return InsertToTable(self, ConstantTable([values])) + def insert_row(self, *values, columns=None): + return InsertToTable(self, ConstantTable([values]), columns=columns) def insert_expr(self, expr: Expr): + if isinstance(expr, TablePath): + expr = expr.select() return InsertToTable(self, expr) @property @@ -598,17 +606,21 @@ def compile(self, c: Compiler) -> str: @dataclass class ConstantTable(ExprNode): - rows: List[Tuple] + rows: Sequence[Sequence] def compile(self, c: Compiler) -> str: raise NotImplementedError() def _value(self, v): - if isinstance(v, str): + if v is None: + return "NULL" + elif isinstance(v, str): return f"'{v}'" elif isinstance(v, datetime): return f"timestamp '{v}'" - return str(v) + elif isinstance(v, UUID): + return f"'{v}'" + return repr(v) def compile_for_insert(self, c: Compiler): values = ", ".join("(%s)" % ", ".join(self._value(v) for v in row) for row in self.rows) @@ -633,11 +645,15 @@ class Statement(Compilable): @dataclass class CreateTable(Statement): path: TablePath + source_table: Expr = None if_not_exists: bool = False def compile(self, c: Compiler) -> str: - schema = ", ".join(f"{k} {c.database.type_repr(v)}" for k, v in self.path.schema.items()) ne = "IF NOT EXISTS " if self.if_not_exists else "" + if self.source_table: + return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}" + + schema = ", ".join(f"{c.database.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items()) return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})" @@ -651,11 +667,20 @@ def compile(self, c: Compiler) -> str: return f"DROP TABLE {ie}{c.compile(self.path)}" +@dataclass +class TruncateTable(Statement): + path: TablePath + + def compile(self, c: Compiler) -> str: + return f"TRUNCATE TABLE {c.compile(self.path)}" + + @dataclass class InsertToTable(Statement): # TODO Support insert for only some columns path: TablePath expr: Expr + columns: List[str] = None def compile(self, c: Compiler) -> str: if isinstance(self.expr, ConstantTable): @@ -663,7 +688,9 @@ def compile(self, c: Compiler) -> str: else: expr = c.compile(self.expr) - return f"INSERT INTO {c.compile(self.path)} {expr}" + columns = f"(%s)" % ", ".join(map(c.quote, self.columns)) if self.columns is not None else "" + + return f"INSERT INTO {c.compile(self.path)}{columns} {expr}" @dataclass diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index fbbdac13..e615acc9 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -5,7 +5,7 @@ import arrow # comes with preql -from data_diff.queries import table +from data_diff.queries import table, this, commit from data_diff.hashdiff_tables import HashDiffer from data_diff.table_segment import TableSegment, split_space @@ -26,22 +26,6 @@ def _table_segment(database, table_path, key_columns, *args, **kw): return TableSegment(database, table_path, key_columns, *args, **kw) -def _insert_row(conn, table, fields, values): - fields = ", ".join(map(str, fields)) - values = ", ".join(map(str, values)) - conn.query(f"INSERT INTO {table}({fields}) VALUES ({values})", None) - - -def _insert_rows(conn, table, fields, tuple_list): - for t in tuple_list: - _insert_row(conn, table, fields, t) - - -def _commit(conn): - if not isinstance(conn, db.BigQuery): - conn.query("COMMIT", None) - - class TestUtils(unittest.TestCase): def test_split_space(self): for i in range(0, 10): @@ -68,14 +52,15 @@ def setUp(self): (now, "now"), ] - self.connection.query(src_table.insert_rows((i, ts.datetime, s) for i, (ts, s) in enumerate(rows))) - _commit(self.connection) - - self.connection.query(f"CREATE TABLE {self.table_dst_name} AS SELECT * FROM {self.table_src_name}") - _commit(self.connection) - - self.connection.query(src_table.insert_row(len(rows), self.now.shift(seconds=-3).datetime, "3 seconds ago")) - _commit(self.connection) + self.connection.query( + [ + src_table.insert_rows((i, ts.datetime, s) for i, (ts, s) in enumerate(rows)), + table(self.table_dst_path).create(src_table), + commit, + src_table.insert_row(len(rows), self.now.shift(seconds=-3).datetime, "3 seconds ago"), + commit, + ] + ) def test_init(self): a = _table_segment( @@ -151,20 +136,16 @@ class TestDiffTables(TestPerDatabase): def setUp(self): super().setUp() - src_table = table( + self.src_table = table( self.table_src_path, schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}, ) - dst_table = table( + self.dst_table = table( self.table_dst_path, schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}, ) - self.connection.query( - src_table.create(), - ) - self.connection.query(dst_table.create()) - _commit(self.connection) + self.connection.query([self.src_table.create(), self.dst_table.create(), commit]) self.table = _table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) self.table2 = _table_segment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) @@ -178,12 +159,12 @@ def test_properties_on_empty_table(self): def test_get_values(self): time = "2022-01-01 00:00:00.000000" - time_str = f"timestamp '{time}'" + time_obj = datetime.fromisoformat(time) cols = "id userid movieid rating timestamp".split() - _insert_row(self.connection, self.table_src, cols, [1, 1, 1, 9, time_str]) - _commit(self.connection) - id_ = self.connection.query(f"select id from {self.table_src}", int) + id_ = self.connection.query( + [self.src_table.insert_row(1, 1, 1, 9, time_obj, columns=cols), commit, self.src_table.select(this.id)], int + ) table = self.table.with_schema() @@ -193,12 +174,17 @@ def test_get_values(self): def test_diff_small_tables(self): time = "2022-01-01 00:00:00" - time_str = f"timestamp '{time}'" + time_obj = datetime.fromisoformat(time) cols = "id userid movieid rating timestamp".split() - _insert_rows(self.connection, self.table_src, cols, [[1, 1, 1, 9, time_str], [2, 2, 2, 9, time_str]]) - _insert_rows(self.connection, self.table_dst, cols, [[1, 1, 1, 9, time_str]]) - _commit(self.connection) + self.connection.query( + [ + self.src_table.insert_rows([[1, 1, 1, 9, time_obj], [2, 2, 2, 9, time_obj]], columns=cols), + self.dst_table.insert_rows([[1, 1, 1, 9, time_obj]], columns=cols), + commit, + ] + ) + diff = list(self.differ.diff_tables(self.table, self.table2)) expected = [("-", ("2", time + ".000000"))] self.assertEqual(expected, diff) @@ -209,44 +195,49 @@ def test_non_threaded(self): differ = HashDiffer(bisection_factor=3, bisection_threshold=4, threaded=False) time = "2022-01-01 00:00:00" - time_str = f"timestamp '{time}'" + time_obj = datetime.fromisoformat(time) cols = "id userid movieid rating timestamp".split() - _insert_row(self.connection, self.table_src, cols, [1, 1, 1, 9, time_str]) - _insert_rows(self.connection, self.table_dst, cols, [[1, 1, 1, 9, time_str]]) - _commit(self.connection) + self.connection.query( + [ + self.src_table.insert_row(1, 1, 1, 9, time_obj, columns=cols), + self.dst_table.insert_row(1, 1, 1, 9, time_obj, columns=cols), + commit, + ] + ) + diff = list(differ.diff_tables(self.table, self.table2)) self.assertEqual(diff, []) def test_diff_table_above_bisection_threshold(self): time = "2022-01-01 00:00:00" - time_str = f"timestamp '{time}'" + time_obj = datetime.fromisoformat(time) cols = "id userid movieid rating timestamp".split() - _insert_rows( - self.connection, - self.table_src, - cols, - [ - [1, 1, 1, 9, time_str], - [2, 2, 2, 9, time_str], - [3, 3, 3, 9, time_str], - [4, 4, 4, 9, time_str], - [5, 5, 5, 9, time_str], - ], - ) - _insert_rows( - self.connection, - self.table_dst, - cols, + self.connection.query( [ - [1, 1, 1, 9, time_str], - [2, 2, 2, 9, time_str], - [3, 3, 3, 9, time_str], - [4, 4, 4, 9, time_str], - ], + self.src_table.insert_rows( + [ + [1, 1, 1, 9, time_obj], + [2, 2, 2, 9, time_obj], + [3, 3, 3, 9, time_obj], + [4, 4, 4, 9, time_obj], + [5, 5, 5, 9, time_obj], + ], + columns=cols, + ), + self.dst_table.insert_rows( + [ + [1, 1, 1, 9, time_obj], + [2, 2, 2, 9, time_obj], + [3, 3, 3, 9, time_obj], + [4, 4, 4, 9, time_obj], + ], + columns=cols, + ), + commit, + ] ) - _commit(self.connection) diff = list(self.differ.diff_tables(self.table, self.table2)) expected = [("-", ("5", time + ".000000"))] @@ -256,14 +247,18 @@ def test_diff_table_above_bisection_threshold(self): def test_return_empty_array_when_same(self): time = "2022-01-01 00:00:00" - time_str = f"timestamp '{time}'" + time_obj = datetime.fromisoformat(time) cols = "id userid movieid rating timestamp".split() - _insert_row(self.connection, self.table_src, cols, [1, 1, 1, 9, time_str]) - _insert_row(self.connection, self.table_dst, cols, [1, 1, 1, 9, time_str]) + self.connection.query( + [ + self.src_table.insert_row(1, 1, 1, 9, time_obj, columns=cols), + self.dst_table.insert_row(1, 1, 1, 9, time_obj, columns=cols), + commit, + ] + ) - _commit(self.connection) diff = list(self.differ.diff_tables(self.table, self.table2)) self.assertEqual([], diff) @@ -271,37 +266,36 @@ def test_diff_sorted_by_key(self): time = "2022-01-01 00:00:00" time2 = "2021-01-01 00:00:00" - time_str = f"timestamp '{time}'" - time_str2 = f"timestamp '{time2}'" + time_obj = datetime.fromisoformat(time) + time_obj2 = datetime.fromisoformat(time2) cols = "id userid movieid rating timestamp".split() - _insert_rows( - self.connection, - self.table_src, - cols, - [ - [1, 1, 1, 9, time_str], - [2, 2, 2, 9, time_str2], - [3, 3, 3, 9, time_str], - [4, 4, 4, 9, time_str2], - [5, 5, 5, 9, time_str], - ], - ) - - _insert_rows( - self.connection, - self.table_dst, - cols, + self.connection.query( [ - [1, 1, 1, 9, time_str], - [2, 2, 2, 9, time_str], - [3, 3, 3, 9, time_str], - [4, 4, 4, 9, time_str], - [5, 5, 5, 9, time_str], - ], + self.src_table.insert_rows( + [ + [1, 1, 1, 9, time_obj], + [2, 2, 2, 9, time_obj2], + [3, 3, 3, 9, time_obj], + [4, 4, 4, 9, time_obj2], + [5, 5, 5, 9, time_obj], + ], + columns=cols, + ), + self.dst_table.insert_rows( + [ + [1, 1, 1, 9, time_obj], + [2, 2, 2, 9, time_obj], + [3, 3, 3, 9, time_obj], + [4, 4, 4, 9, time_obj], + [5, 5, 5, 9, time_obj], + ], + columns=cols, + ), + commit, + ] ) - _commit(self.connection) differ = HashDiffer() diff = list(differ.diff_tables(self.table, self.table2)) @@ -318,42 +312,40 @@ def test_diff_sorted_by_key(self): class TestDiffTables2(TestPerDatabase): def test_diff_column_names(self): - src_table = table(self.table_src_path, schema={"id": int, "rating": float, "timestamp": datetime}) - dst_table = table(self.table_dst_path, schema={"id2": int, "rating2": float, "timestamp2": datetime}) + self.src_table = table(self.table_src_path, schema={"id": int, "rating": float, "timestamp": datetime}) + self.dst_table = table(self.table_dst_path, schema={"id2": int, "rating2": float, "timestamp2": datetime}) - self.connection.query(src_table.create()) - self.connection.query(dst_table.create()) - _commit(self.connection) + self.connection.query([self.src_table.create(), self.dst_table.create(), commit]) time = "2022-01-01 00:00:00" time2 = "2021-01-01 00:00:00" - time_str = f"timestamp '{time}'" - time_str2 = f"timestamp '{time2}'" - _insert_rows( - self.connection, - self.table_src, - ["id", "rating", "timestamp"], - [ - [1, 9, time_str], - [2, 9, time_str2], - [3, 9, time_str], - [4, 9, time_str2], - [5, 9, time_str], - ], - ) + time_obj = datetime.fromisoformat(time) + time_obj2 = datetime.fromisoformat(time2) - _insert_rows( - self.connection, - self.table_dst, - ["id2", "rating2", "timestamp2"], + self.connection.query( [ - [1, 9, time_str], - [2, 9, time_str2], - [3, 9, time_str], - [4, 9, time_str2], - [5, 9, time_str], - ], + self.src_table.insert_rows( + [ + [1, 9, time_obj], + [2, 9, time_obj2], + [3, 9, time_obj], + [4, 9, time_obj2], + [5, 9, time_obj], + ], + columns=["id", "rating", "timestamp"], + ), + self.dst_table.insert_rows( + [ + [1, 9, time_obj], + [2, 9, time_obj2], + [3, 9, time_obj], + [4, 9, time_obj2], + [5, 9, time_obj], + ], + columns=["id2", "rating2", "timestamp2"], + ), + ] ) table1 = _table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) @@ -369,23 +361,19 @@ class TestUUIDs(TestPerDatabase): def setUp(self): super().setUp() - src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) - - queries = [src_table.create()] - for i in range(100): - queries.append(f"INSERT INTO {self.table_src} VALUES ('{uuid.uuid1(i)}', '{i}')") - - queries += [ - f"CREATE TABLE {self.table_dst} AS SELECT * FROM {self.table_src}", - ] + self.src_table = src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) self.new_uuid = uuid.uuid1(32132131) - queries.append(f"INSERT INTO {self.table_src} VALUES ('{self.new_uuid}', 'This one is different')") - - for query in queries: - self.connection.query(query, None) - _commit(self.connection) + self.connection.query( + [ + src_table.create(), + src_table.insert_rows((uuid.uuid1(i), str(i)) for i in range(100)), + table(self.table_dst_path).create(src_table), + src_table.insert_row(self.new_uuid, "This one is different"), + commit, + ] + ) self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) @@ -396,7 +384,7 @@ def test_string_keys(self): self.assertEqual(diff, [("-", (str(self.new_uuid), "This one is different"))]) self.connection.query( - f"INSERT INTO {self.table_src} VALUES ('unexpected', '<-- this bad value should not break us')", None + self.src_table.insert_row('unexpected', '<-- this bad value should not break us') ) self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) @@ -417,29 +405,29 @@ class TestAlphanumericKeys(TestPerDatabase): def setUp(self): super().setUp() - src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) + self.src_table = src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) + self.new_alphanum = "aBcDeFgHiJ" - queries = [src_table.create()] + values = [] for i in range(0, 10000, 1000): a = ArithAlphanumeric(numberToAlphanum(i), max_len=10) if not a and isinstance(self.connection, db.Oracle): # Skip empty string, because Oracle treats it as NULL .. continue - queries.append(f"INSERT INTO {self.table_src} VALUES ('{a}', '{i}')") + values.append((str(a), str(i))) - queries += [ - f"CREATE TABLE {self.table_dst} AS SELECT * FROM {self.table_src}", + queries = [ + src_table.create(), + src_table.insert_rows(values), + table(self.table_dst_path).create(src_table), + src_table.insert_row(self.new_alphanum, 'This one is different'), + commit, ] - self.new_alphanum = "aBcDeFgHiJ" - queries.append(f"INSERT INTO {self.table_src} VALUES ('{self.new_alphanum}', 'This one is different')") - for query in queries: self.connection.query(query, None) - _commit(self.connection) - self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) @@ -449,10 +437,7 @@ def test_alphanum_keys(self): diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, [("-", (str(self.new_alphanum), "This one is different"))]) - self.connection.query( - f"INSERT INTO {self.table_src} VALUES ('@@@', '<-- this bad value should not break us')", None - ) - _commit(self.connection) + self.connection.query([self.src_table.insert_row("@@@", "<-- this bad value should not break us"), commit]) self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) @@ -465,28 +450,28 @@ class TestVaryingAlphanumericKeys(TestPerDatabase): def setUp(self): super().setUp() - src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) + self.src_table = src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) - queries = [src_table.create()] + values = [] for i in range(0, 10000, 1000): a = ArithAlphanumeric(numberToAlphanum(i * i)) if not a and isinstance(self.connection, db.Oracle): # Skip empty string, because Oracle treats it as NULL .. continue - queries.append(f"INSERT INTO {self.table_src} VALUES ('{a}', '{i}')") - - queries += [ - f"CREATE TABLE {self.table_dst} AS SELECT * FROM {self.table_src}", - ] + values.append((str(a), str(i))) self.new_alphanum = "aBcDeFgHiJ" - queries.append(f"INSERT INTO {self.table_src} VALUES ('{self.new_alphanum}', 'This one is different')") - for query in queries: - self.connection.query(query, None) + queries = [ + src_table.create(), + src_table.insert_rows(values), + table(self.table_dst_path).create(src_table), + src_table.insert_row(self.new_alphanum, "This one is different"), + commit, + ] - _commit(self.connection) + self.connection.query(queries) self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) @@ -506,9 +491,9 @@ def test_varying_alphanum_keys(self): self.assertEqual(diff, [("-", (str(self.new_alphanum), "This one is different"))]) self.connection.query( - f"INSERT INTO {self.table_src} VALUES ('@@@', '<-- this bad value should not break us')", None + self.src_table.insert_row('@@@', '<-- this bad value should not break us'), + commit, ) - _commit(self.connection) self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) @@ -531,16 +516,15 @@ def test_table_segment(self): self.assertRaises(ValueError, self.table.replace, min_key=10, max_key=0) def test_case_awareness(self): - # create table - self.connection.query(f"create table {self.table_src}(id int, userid int, timestamp timestamp)", None) - _commit(self.connection) + src_table = table(self.table_src_path, schema={"id": int, "userid": int, "timestamp": datetime}) - # insert rows cols = "id userid timestamp".split() time = "2022-01-01 00:00:00.000000" - time_str = f"timestamp '{time}'" - _insert_rows(self.connection, self.table_src, cols, [[1, 9, time_str], [2, 2, time_str]]) - _commit(self.connection) + time_obj = datetime.fromisoformat(time) + + self.connection.query( + [src_table.create(), src_table.insert_rows([[1, 9, time_obj], [2, 2, time_obj]], columns=cols), commit] + ) res = tuple(self.table.replace(key_columns=("Id",), case_sensitive=False).with_schema().query_key_range()) assert res == ("1", "2") @@ -557,21 +541,20 @@ def setUp(self): src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) - queries = [src_table.create()] + values = [] for i in range(10): uuid_value = uuid.uuid1(i) - queries.append(f"INSERT INTO {self.table_src} VALUES ('{uuid_value}', '{uuid_value}')") + values.append((uuid_value, uuid_value)) self.null_uuid = uuid.uuid1(32132131) - queries += [ - f"CREATE TABLE {self.table_dst} AS SELECT * FROM {self.table_src}", - f"INSERT INTO {self.table_src} VALUES ('{self.null_uuid}', NULL)", - ] - - for query in queries: - self.connection.query(query, None) - _commit(self.connection) + self.connection.query([ + src_table.create(), + src_table.insert_rows(values), + table(self.table_dst_path).create(src_table), + src_table.insert_row(self.null_uuid, None), + commit, + ]) self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) @@ -590,18 +573,15 @@ def setUp(self): src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) self.null_uuid = uuid.uuid1(1) - queries = [ - src_table.create(), - f"INSERT INTO {self.table_src} VALUES ('{uuid.uuid1(1)}', '1')", - f"CREATE TABLE {self.table_dst} AS SELECT * FROM {self.table_src}", - # Add a row where a column has NULL value - f"INSERT INTO {self.table_src} VALUES ('{self.null_uuid}', NULL)", - ] - - for query in queries: - self.connection.query(query, None) - - _commit(self.connection) + self.connection.query( + [ + src_table.create(), + src_table.insert_row(uuid.uuid1(1), '1'), + table(self.table_dst_path).create(src_table), + src_table.insert_row(self.null_uuid, None), # Add a row where a column has NULL value + commit, + ] + ) self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) @@ -640,24 +620,30 @@ def setUp(self): src_table = table(self.table_src_path, schema={"id": str, "c1": str, "c2": str}) dst_table = table(self.table_dst_path, schema={"id": str, "c1": str, "c2": str}) - queries = [src_table.create(), dst_table.create()] + src_values = [] + dst_values = [] self.diffs = [] for i in range(0, 8): pk = uuid.uuid1(i) - table_src_c1_val = str(i) - table_dst_c1_val = str(i) + "-different" + src_row = (str(pk), str(i), None) + dst_row = (str(pk), str(i) + "-different", None) - queries.append(f"INSERT INTO {self.table_src} VALUES ('{pk}', '{table_src_c1_val}', NULL)") - queries.append(f"INSERT INTO {self.table_dst} VALUES ('{pk}', '{table_dst_c1_val}', NULL)") + src_values.append(src_row) + dst_values.append(dst_row) - self.diffs.append(("-", (str(pk), table_src_c1_val, None))) - self.diffs.append(("+", (str(pk), table_dst_c1_val, None))) - - for query in queries: - self.connection.query(query, None) + self.diffs.append(("-", src_row)) + self.diffs.append(("+", dst_row)) - _commit(self.connection) + self.connection.query( + [ + src_table.create(), + dst_table.create(), + src_table.insert_rows(src_values), + dst_table.insert_rows(dst_values), + commit, + ] + ) self.a = _table_segment( self.connection, self.table_src_path, "id", extra_columns=("c1", "c2"), case_sensitive=False @@ -699,20 +685,14 @@ class TestTableTableEmpty(TestPerDatabase): def setUp(self): super().setUp() - src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) - dst_table = table(self.table_dst_path, schema={"id": str, "text_comment": str}) + self.src_table = src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) + self.dst_table = dst_table = table(self.table_dst_path, schema={"id": str, "text_comment": str}) self.null_uuid = uuid.uuid1(1) - queries = [src_table.create(), dst_table.create()] - - self.diffs = [(uuid.uuid1(i), i) for i in range(100)] - for pk, value in self.diffs: - queries.append(f"INSERT INTO {self.table_src} VALUES ('{pk}', '{value}')") - for query in queries: - self.connection.query(query, None) + self.diffs = [(uuid.uuid1(i), str(i)) for i in range(100)] - _commit(self.connection) + self.connection.query([src_table.create(), dst_table.create(), src_table.insert_rows(self.diffs), commit]) self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) @@ -722,14 +702,7 @@ def test_right_table_empty(self): self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) def test_left_table_empty(self): - queries = [ - f"INSERT INTO {self.table_dst} SELECT id, text_comment FROM {self.table_src}", - f"TRUNCATE TABLE {self.table_src}", - ] - for query in queries: - self.connection.query(query, None) - - _commit(self.connection) + self.connection.query([self.dst_table.insert_expr(self.src_table), self.src_table.truncate(), commit]) differ = HashDiffer() self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 5a01b6e1..d3db82e0 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -2,12 +2,12 @@ from datetime import datetime from data_diff.queries.ast_classes import TablePath -from data_diff.queries import table +from data_diff.queries import table, commit from data_diff.table_segment import TableSegment from data_diff import databases as db from data_diff.joindiff_tables import JoinDiffer -from .test_diff_tables import TestPerDatabase, _commit, _insert_row, _insert_rows +from .test_diff_tables import TestPerDatabase from .common import ( random_table_suffix, @@ -35,29 +35,32 @@ class TestCompositeKey(TestPerDatabase): def setUp(self): super().setUp() - src_table = table( + self.src_table = table( self.table_src_path, schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}, ) - dst_table = table( + self.dst_table = table( self.table_dst_path, schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}, ) - self.connection.query(src_table.create()) - self.connection.query(dst_table.create()) - _commit(self.connection) + self.connection.query([self.src_table.create(), self.dst_table.create(), commit]) self.differ = JoinDiffer() def test_composite_key(self): time = "2022-01-01 00:00:00" - time_str = f"timestamp '{time}'" + time_obj = datetime.fromisoformat(time) cols = "id userid movieid rating timestamp".split() - _insert_rows(self.connection, self.table_src, cols, [[1, 1, 1, 9, time_str], [2, 2, 2, 9, time_str]]) - _insert_rows(self.connection, self.table_dst, cols, [[1, 1, 1, 9, time_str], [2, 3, 2, 9, time_str]]) - _commit(self.connection) + + self.connection.query( + [ + self.src_table.insert_rows([[1, 1, 1, 9, time_obj], [2, 2, 2, 9, time_obj]], columns=cols), + self.dst_table.insert_rows([[1, 1, 1, 9, time_obj], [2, 3, 2, 9, time_obj]], columns=cols), + commit, + ] + ) # Sanity table1 = TableSegment( @@ -83,18 +86,16 @@ class TestJoindiff(TestPerDatabase): def setUp(self): super().setUp() - src_table = table( + self.src_table = table( self.table_src_path, schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}, ) - dst_table = table( + self.dst_table = table( self.table_dst_path, schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}, ) - self.connection.query(src_table.create()) - self.connection.query(dst_table.create()) - _commit(self.connection) + self.connection.query([self.src_table.create(), self.dst_table.create(), commit]) self.table = TableSegment(self.connection, self.table_src_path, ("id",), "timestamp", case_sensitive=False) self.table2 = TableSegment(self.connection, self.table_dst_path, ("id",), "timestamp", case_sensitive=False) @@ -103,12 +104,18 @@ def setUp(self): def test_diff_small_tables(self): time = "2022-01-01 00:00:00" - time_str = f"timestamp '{time}'" + time_obj = datetime.fromisoformat(time) cols = "id userid movieid rating timestamp".split() - _insert_rows(self.connection, self.table_src, cols, [[1, 1, 1, 9, time_str], [2, 2, 2, 9, time_str]]) - _insert_rows(self.connection, self.table_dst, cols, [[1, 1, 1, 9, time_str]]) - _commit(self.connection) + + self.connection.query( + [ + self.src_table.insert_rows([[1, 1, 1, 9, time_obj], [2, 2, 2, 9, time_obj]], columns=cols), + self.dst_table.insert_rows([[1, 1, 1, 9, time_obj]], columns=cols), + commit, + ] + ) + diff = list(self.differ.diff_tables(self.table, self.table2)) expected_row = ("2", time + ".000000") expected = [("-", expected_row)] @@ -132,34 +139,34 @@ def test_diff_small_tables(self): def test_diff_table_above_bisection_threshold(self): time = "2022-01-01 00:00:00" - time_str = f"timestamp '{time}'" + time_obj = datetime.fromisoformat(time) cols = "id userid movieid rating timestamp".split() - _insert_rows( - self.connection, - self.table_src, - cols, - [ - [1, 1, 1, 9, time_str], - [2, 2, 2, 9, time_str], - [3, 3, 3, 9, time_str], - [4, 4, 4, 9, time_str], - [5, 5, 5, 9, time_str], - ], - ) - _insert_rows( - self.connection, - self.table_dst, - cols, + self.connection.query( [ - [1, 1, 1, 9, time_str], - [2, 2, 2, 9, time_str], - [3, 3, 3, 9, time_str], - [4, 4, 4, 9, time_str], - ], + self.src_table.insert_rows( + [ + [1, 1, 1, 9, time_obj], + [2, 2, 2, 9, time_obj], + [3, 3, 3, 9, time_obj], + [4, 4, 4, 9, time_obj], + [5, 5, 5, 9, time_obj], + ], + columns=cols, + ), + self.dst_table.insert_rows( + [ + [1, 1, 1, 9, time_obj], + [2, 2, 2, 9, time_obj], + [3, 3, 3, 9, time_obj], + [4, 4, 4, 9, time_obj], + ], + columns=cols, + ), + commit, + ] ) - _commit(self.connection) diff = list(self.differ.diff_tables(self.table, self.table2)) expected = [("-", ("5", time + ".000000"))] @@ -169,12 +176,16 @@ def test_diff_table_above_bisection_threshold(self): def test_return_empty_array_when_same(self): time = "2022-01-01 00:00:00" - time_str = f"timestamp '{time}'" + time_obj = datetime.fromisoformat(time) cols = "id userid movieid rating timestamp".split() - _insert_row(self.connection, self.table_src, cols, [1, 1, 1, 9, time_str]) - _insert_row(self.connection, self.table_dst, cols, [1, 1, 1, 9, time_str]) + self.connection.query( + [ + self.src_table.insert_row(1, 1, 1, 9, time_obj, columns=cols), + self.dst_table.insert_row(1, 1, 1, 9, time_obj, columns=cols), + ] + ) diff = list(self.differ.diff_tables(self.table, self.table2)) self.assertEqual([], diff) @@ -183,37 +194,36 @@ def test_diff_sorted_by_key(self): time = "2022-01-01 00:00:00" time2 = "2021-01-01 00:00:00" - time_str = f"timestamp '{time}'" - time_str2 = f"timestamp '{time2}'" + time_obj = datetime.fromisoformat(time) + time_obj2 = datetime.fromisoformat(time2) cols = "id userid movieid rating timestamp".split() - _insert_rows( - self.connection, - self.table_src, - cols, - [ - [1, 1, 1, 9, time_str], - [2, 2, 2, 9, time_str2], - [3, 3, 3, 9, time_str], - [4, 4, 4, 9, time_str2], - [5, 5, 5, 9, time_str], - ], - ) - - _insert_rows( - self.connection, - self.table_dst, - cols, + self.connection.query( [ - [1, 1, 1, 9, time_str], - [2, 2, 2, 9, time_str], - [3, 3, 3, 9, time_str], - [4, 4, 4, 9, time_str], - [5, 5, 5, 9, time_str], - ], + self.src_table.insert_rows( + [ + [1, 1, 1, 9, time_obj], + [2, 2, 2, 9, time_obj2], + [3, 3, 3, 9, time_obj], + [4, 4, 4, 9, time_obj2], + [5, 5, 5, 9, time_obj], + ], + columns=cols, + ), + self.dst_table.insert_rows( + [ + [1, 1, 1, 9, time_obj], + [2, 2, 2, 9, time_obj], + [3, 3, 3, 9, time_obj], + [4, 4, 4, 9, time_obj], + [5, 5, 5, 9, time_obj], + ], + columns=cols, + ), + commit, + ] ) - _commit(self.connection) diff = list(self.differ.diff_tables(self.table, self.table2)) expected = [ @@ -226,25 +236,32 @@ def test_diff_sorted_by_key(self): def test_dup_pks(self): time = "2022-01-01 00:00:00" - time_str = f"timestamp '{time}'" + time_obj = datetime.fromisoformat(time) cols = "id rating timestamp".split() - _insert_row(self.connection, self.table_src, cols, [1, 9, time_str]) - _insert_row(self.connection, self.table_src, cols, [1, 10, time_str]) - _insert_row(self.connection, self.table_dst, cols, [1, 9, time_str]) + self.connection.query( + [ + self.src_table.insert_rows([[1, 9, time_obj], [1, 10, time_obj]], columns=cols), + self.dst_table.insert_row(1, 9, time_obj, columns=cols), + ] + ) x = self.differ.diff_tables(self.table, self.table2) self.assertRaises(ValueError, list, x) def test_null_pks(self): time = "2022-01-01 00:00:00" - time_str = f"timestamp '{time}'" + time_obj = datetime.fromisoformat(time) cols = "id rating timestamp".split() - _insert_row(self.connection, self.table_src, cols, ["null", 9, time_str]) - _insert_row(self.connection, self.table_dst, cols, [1, 9, time_str]) + self.connection.query( + [ + self.src_table.insert_row(None, 9, time_obj, columns=cols), + self.dst_table.insert_row(1, 9, time_obj, columns=cols), + ] + ) x = self.differ.diff_tables(self.table, self.table2) self.assertRaises(ValueError, list, x) From fcb37f36f7005d6f469ed57eeb6f5d6beb9326fe Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 13 Oct 2022 19:00:02 +0200 Subject: [PATCH 41/93] Refactor repeating functions into query_utils --- data_diff/joindiff_tables.py | 50 ++++---------------------------- data_diff/query_utils.py | 55 ++++++++++++++++++++++++++++++++++++ tests/common.py | 21 ++++---------- tests/test_database_types.py | 26 ++++++++++++----- 4 files changed, 85 insertions(+), 67 deletions(-) create mode 100644 data_diff/query_utils.py diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 85cc7225..641f11d0 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -10,8 +10,8 @@ from runtype import dataclass -from data_diff.databases.database_types import DbPath, NumericType -from data_diff.databases.base import QueryError +from .databases.database_types import DbPath, NumericType +from .query_utils import append_to_table, drop_table from .utils import safezip @@ -48,7 +48,7 @@ def sample(table_expr): return table_expr.order_by(Random()).limit(10) -def create_temp_table(c: Compiler, path: TablePath, expr: Expr): +def create_temp_table(c: Compiler, path: TablePath, expr: Expr) -> str: db = c.database if isinstance(db, BigQuery): return f"create table {c.compile(path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}" @@ -60,42 +60,6 @@ def create_temp_table(c: Compiler, path: TablePath, expr: Expr): return f"create temporary table {c.compile(path)} as {c.compile(expr)}" -def drop_table_oracle(name: DbPath): - t = table(name) - # Experience shows double drop is necessary - with suppress(QueryError): - yield t.drop() - yield t.drop() - yield commit - - -def drop_table(name: DbPath): - t = table(name) - yield t.drop(if_exists=True) - yield commit - - -def append_to_table_oracle(path: DbPath, expr: Expr): - """See append_to_table""" - assert expr.schema, expr - t = table(path, schema=expr.schema) - with suppress(QueryError): - yield t.create() # uses expr.schema - yield commit - yield t.insert_expr(expr) - yield commit - - -def append_to_table(path: DbPath, expr: Expr): - """Append to table""" - assert expr.schema, expr - t = table(path, schema=expr.schema) - yield t.create(if_not_exists=True) # uses expr.schema - yield commit - yield t.insert_expr(expr) - yield commit - - def bool_to_int(x): return if_(x, 1, 0) @@ -170,10 +134,7 @@ def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult bg_funcs = [partial(self._test_duplicate_keys, table1, table2)] if self.validate_unique_key else [] if self.materialize_to_table: - if isinstance(db, Oracle): - db.query(drop_table_oracle(self.materialize_to_table)) - else: - db.query(drop_table(self.materialize_to_table)) + drop_table(db, self.materialize_to_table) with self._run_in_background(*bg_funcs): @@ -348,6 +309,5 @@ def exclusive_rows(expr): def _materialize_diff(self, db, diff_rows, segment_index=None): assert self.materialize_to_table - f = append_to_table_oracle if isinstance(db, Oracle) else append_to_table - db.query(f(self.materialize_to_table, diff_rows.limit(self.write_limit))) + append_to_table(db, self.materialize_to_table, diff_rows.limit(self.write_limit)) logger.info("Materialized diff to table '%s'.", ".".join(self.materialize_to_table)) diff --git a/data_diff/query_utils.py b/data_diff/query_utils.py new file mode 100644 index 00000000..dac22fe1 --- /dev/null +++ b/data_diff/query_utils.py @@ -0,0 +1,55 @@ +"Module for query utilities that didn't make it into the query-builder (yet)" + +from contextlib import suppress + +from data_diff.databases.database_types import DbPath +from data_diff.databases.base import QueryError + +from .databases import Oracle +from .queries import table, commit, Expr + +def _drop_table_oracle(name: DbPath): + t = table(name) + # Experience shows double drop is necessary + with suppress(QueryError): + yield t.drop() + yield t.drop() + yield commit + + +def _drop_table(name: DbPath): + t = table(name) + yield t.drop(if_exists=True) + yield commit + + +def drop_table(db, tbl): + if isinstance(db, Oracle): + db.query(_drop_table_oracle(tbl)) + else: + db.query(_drop_table(tbl)) + + +def _append_to_table_oracle(path: DbPath, expr: Expr): + """See append_to_table""" + assert expr.schema, expr + t = table(path, schema=expr.schema) + with suppress(QueryError): + yield t.create() # uses expr.schema + yield commit + yield t.insert_expr(expr) + yield commit + + +def _append_to_table(path: DbPath, expr: Expr): + """Append to table""" + assert expr.schema, expr + t = table(path, schema=expr.schema) + yield t.create(if_not_exists=True) # uses expr.schema + yield commit + yield t.insert_expr(expr) + yield commit + +def append_to_table(db, path, expr): + f = _append_to_table_oracle if isinstance(db, Oracle) else _append_to_table + db.query(f(path, expr)) diff --git a/tests/common.py b/tests/common.py index 5e689131..a652e1c4 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,4 +1,3 @@ -from contextlib import suppress import hashlib import os import string @@ -13,6 +12,8 @@ from data_diff import databases as db from data_diff import tracking from data_diff import connect +from data_diff.queries.api import table +from data_diff.query_utils import drop_table tracking.disable_tracking() @@ -119,16 +120,6 @@ def str_to_checksum(str: str): return int(md5[half_pos:], 16) -def _drop_table_if_exists(conn, table): - with suppress(db.QueryError): - if isinstance(conn, db.Oracle): - conn.query(f"DROP TABLE {table}", None) - conn.query(f"DROP TABLE {table}", None) - else: - conn.query(f"DROP TABLE IF EXISTS {table}", None) - if not isinstance(conn, (db.BigQuery, db.Databricks, db.Clickhouse)): - conn.query("COMMIT", None) - class TestPerDatabase(unittest.TestCase): db_cls = None @@ -148,14 +139,14 @@ def setUp(self): self.table_src = ".".join(map(self.connection.quote, self.table_src_path)) self.table_dst = ".".join(map(self.connection.quote, self.table_dst_path)) - _drop_table_if_exists(self.connection, self.table_src) - _drop_table_if_exists(self.connection, self.table_dst) + drop_table(self.connection, self.table_src_path) + drop_table(self.connection, self.table_dst_path) return super().setUp() def tearDown(self): - _drop_table_if_exists(self.connection, self.table_src) - _drop_table_if_exists(self.connection, self.table_dst) + drop_table(self.connection, self.table_src_path) + drop_table(self.connection, self.table_dst_path) def _parameterized_class_per_conn(test_databases): diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 08bacb96..bb792826 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -1,3 +1,4 @@ +from contextlib import suppress import unittest import time import json @@ -14,6 +15,7 @@ from data_diff import databases as db from data_diff.databases import postgresql, oracle +from data_diff.query_utils import drop_table from data_diff.utils import number_to_human, accumulate from data_diff.hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD from data_diff.table_segment import TableSegment @@ -25,7 +27,6 @@ GIT_REVISION, get_conn, random_table_suffix, - _drop_table_if_exists, ) CONNS = None @@ -465,6 +466,17 @@ def expand_params(testcase_func, param_num, param): return name +def _drop_table_if_exists(conn, tbl): + if isinstance(conn, db.Oracle): + with suppress(db.QueryError): + conn.query(f"DROP TABLE {tbl}", None) + conn.query(f"DROP TABLE {tbl}", None) + else: + conn.query(f"DROP TABLE IF EXISTS {tbl}", None) + if not isinstance(conn, (db.BigQuery, db.Databricks, db.Clickhouse)): + conn.query("COMMIT", None) + + def _insert_to_table(conn, table, values, type): current_n_rows = conn.query(f"SELECT COUNT(*) FROM {table}", int) if current_n_rows == N_SAMPLES: @@ -593,8 +605,8 @@ def setUp(self) -> None: def tearDown(self) -> None: if not BENCHMARK: - _drop_table_if_exists(self.src_conn, self.src_table) - _drop_table_if_exists(self.dst_conn, self.dst_table) + drop_table(self.src_conn, self.src_table_path) + drop_table(self.dst_conn, self.dst_table_path) return super().tearDown() @@ -618,14 +630,14 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego src_table_name = f"src_{self._testMethodName[11:]}{table_suffix}" dst_table_name = f"dst_{self._testMethodName[11:]}{table_suffix}" - src_table_path = src_conn.parse_table_name(src_table_name) - dst_table_path = dst_conn.parse_table_name(dst_table_name) + self.src_table_path = src_table_path = src_conn.parse_table_name(src_table_name) + self.dst_table_path = dst_table_path = dst_conn.parse_table_name(dst_table_name) self.src_table = src_table = ".".join(map(src_conn.quote, src_table_path)) self.dst_table = dst_table = ".".join(map(dst_conn.quote, dst_table_path)) start = time.monotonic() if not BENCHMARK: - _drop_table_if_exists(src_conn, src_table) + drop_table(src_conn, src_table_path) _create_table_with_indexes(src_conn, src_table, source_type) _insert_to_table(src_conn, src_table, enumerate(sample_values, 1), source_type) insertion_source_duration = time.monotonic() - start @@ -639,7 +651,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego start = time.monotonic() if not BENCHMARK: - _drop_table_if_exists(dst_conn, dst_table) + drop_table(dst_conn, dst_table_path) _create_table_with_indexes(dst_conn, dst_table, target_type) _insert_to_table(dst_conn, dst_table, values_in_source, target_type) insertion_target_duration = time.monotonic() - start From b09eca38ff3278fb88666384132d5823068266c6 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 14 Oct 2022 10:09:27 +0200 Subject: [PATCH 42/93] Remove trino from CI tests, to reduce load on github actions --- .github/workflows/ci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 21115c45..b5f8d548 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -46,7 +46,6 @@ jobs: env: DATADIFF_SNOWFLAKE_URI: '${{ secrets.DATADIFF_SNOWFLAKE_URI }}' DATADIFF_PRESTO_URI: '${{ secrets.DATADIFF_PRESTO_URI }}' - DATADIFF_TRINO_URI: '${{ secrets.DATADIFF_TRINO_URI }}' DATADIFF_CLICKHOUSE_URI: 'clickhouse://clickhouse:Password1@localhost:9000/clickhouse' DATADIFF_VERTICA_URI: 'vertica://vertica:Password1@localhost:5433/vertica' run: | From 1fb79da944f36c80545c7da9c53b1cc497960f1e Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 14 Oct 2022 11:45:44 +0200 Subject: [PATCH 43/93] Fix for databricks; clickhouse --- data_diff/databases/clickhouse.py | 4 ++++ data_diff/databases/databricks.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index f0f7a5ad..9b657d89 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -150,3 +150,7 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: ) """ return value + + @property + def is_autocommit(self) -> bool: + return True diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 612c1c8d..e1d76349 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -151,3 +151,7 @@ def parse_table_name(self, name: str) -> DbPath: def close(self): self._conn.close() + + @property + def is_autocommit(self) -> bool: + return True From 3bffe5b7201d21f09c57da32c858848fe216d3c0 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 14 Oct 2022 13:25:46 +0200 Subject: [PATCH 44/93] General tests now include Presto, Trino & Vertica; Includes small fixes --- data_diff/databases/base.py | 4 ++- data_diff/databases/presto.py | 4 +++ data_diff/query_utils.py | 2 ++ tests/common.py | 1 - tests/test_diff_tables.py | 54 ++++++++++++++++++++++------------- tests/test_joindiff.py | 8 +++--- 6 files changed, 47 insertions(+), 26 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index e1ac2208..c33f7ead 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -79,6 +79,7 @@ def apply_queries(self, callback: Callable[[str], Any]): q: Expr = next(self.gen) while True: sql = self.compiler.compile(q) + logger.debug("Running SQL (%s-TL): %s", self.compiler.database.name, sql) try: try: res = callback(sql) if sql is not SKIP else SKIP @@ -130,7 +131,8 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list): if sql_code is SKIP: return SKIP - logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code) + logger.debug("Running SQL (%s): %s", self.name, sql_code) + if self._interactive and isinstance(sql_ast, Select): explained_sql = compiler.compile(Explain(sql_ast)) explain = self._query(explained_sql) diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index 2d69efc8..56a32d48 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -11,6 +11,7 @@ Text, FractionalType, DbPath, + DbTime, Decimal, ColType, ColType_UUID, @@ -159,3 +160,6 @@ def type_repr(self, t) -> str: return {float: "REAL"}[t] except KeyError: return super().type_repr(t) + + def timestamp_value(self, t: DbTime) -> str: + return f"timestamp '{t.isoformat(' ')}'" diff --git a/data_diff/query_utils.py b/data_diff/query_utils.py index dac22fe1..825dbdc3 100644 --- a/data_diff/query_utils.py +++ b/data_diff/query_utils.py @@ -8,6 +8,7 @@ from .databases import Oracle from .queries import table, commit, Expr + def _drop_table_oracle(name: DbPath): t = table(name) # Experience shows double drop is necessary @@ -50,6 +51,7 @@ def _append_to_table(path: DbPath, expr: Expr): yield t.insert_expr(expr) yield commit + def append_to_table(db, path, expr): f = _append_to_table_oracle if isinstance(db, Oracle) else _append_to_table db.query(f(path, expr)) diff --git a/tests/common.py b/tests/common.py index a652e1c4..aad75074 100644 --- a/tests/common.py +++ b/tests/common.py @@ -120,7 +120,6 @@ def str_to_checksum(str: str): return int(md5[half_pos:], 16) - class TestPerDatabase(unittest.TestCase): db_cls = None diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index e615acc9..c87b23bf 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -15,7 +15,17 @@ from .common import str_to_checksum, test_each_database_in_list, TestPerDatabase -TEST_DATABASES = {db.MySQL, db.PostgreSQL, db.Oracle, db.Redshift, db.Snowflake, db.BigQuery} +TEST_DATABASES = { + db.MySQL, + db.PostgreSQL, + db.Oracle, + db.Redshift, + db.Snowflake, + db.BigQuery, + db.Presto, + db.Trino, + db.Vertica, +} test_each_database: Callable = test_each_database_in_list(TEST_DATABASES) @@ -383,9 +393,7 @@ def test_string_keys(self): diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, [("-", (str(self.new_uuid), "This one is different"))]) - self.connection.query( - self.src_table.insert_row('unexpected', '<-- this bad value should not break us') - ) + self.connection.query(self.src_table.insert_row("unexpected", "<-- this bad value should not break us")) self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) @@ -421,7 +429,7 @@ def setUp(self): src_table.create(), src_table.insert_rows(values), table(self.table_dst_path).create(src_table), - src_table.insert_row(self.new_alphanum, 'This one is different'), + src_table.insert_row(self.new_alphanum, "This one is different"), commit, ] @@ -491,7 +499,7 @@ def test_varying_alphanum_keys(self): self.assertEqual(diff, [("-", (str(self.new_alphanum), "This one is different"))]) self.connection.query( - self.src_table.insert_row('@@@', '<-- this bad value should not break us'), + self.src_table.insert_row("@@@", "<-- this bad value should not break us"), commit, ) @@ -548,13 +556,15 @@ def setUp(self): self.null_uuid = uuid.uuid1(32132131) - self.connection.query([ - src_table.create(), - src_table.insert_rows(values), - table(self.table_dst_path).create(src_table), - src_table.insert_row(self.null_uuid, None), - commit, - ]) + self.connection.query( + [ + src_table.create(), + src_table.insert_rows(values), + table(self.table_dst_path).create(src_table), + src_table.insert_row(self.null_uuid, None), + commit, + ] + ) self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) @@ -576,9 +586,9 @@ def setUp(self): self.connection.query( [ src_table.create(), - src_table.insert_row(uuid.uuid1(1), '1'), + src_table.insert_row(uuid.uuid1(1), "1"), table(self.table_dst_path).create(src_table), - src_table.insert_row(self.null_uuid, None), # Add a row where a column has NULL value + src_table.insert_row(self.null_uuid, None), # Add a row where a column has NULL value commit, ] ) @@ -685,24 +695,28 @@ class TestTableTableEmpty(TestPerDatabase): def setUp(self): super().setUp() - self.src_table = src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) - self.dst_table = dst_table = table(self.table_dst_path, schema={"id": str, "text_comment": str}) + self.src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) + self.dst_table = table(self.table_dst_path, schema={"id": str, "text_comment": str}) self.null_uuid = uuid.uuid1(1) self.diffs = [(uuid.uuid1(i), str(i)) for i in range(100)] - self.connection.query([src_table.create(), dst_table.create(), src_table.insert_rows(self.diffs), commit]) - self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_right_table_empty(self): + self.connection.query( + [self.src_table.create(), self.dst_table.create(), self.src_table.insert_rows(self.diffs), commit] + ) + differ = HashDiffer() self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) def test_left_table_empty(self): - self.connection.query([self.dst_table.insert_expr(self.src_table), self.src_table.truncate(), commit]) + self.connection.query( + [self.src_table.create(), self.dst_table.create(), self.dst_table.insert_rows(self.diffs), commit] + ) differ = HashDiffer() self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index d3db82e0..3bf2246d 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -17,14 +17,14 @@ TEST_DATABASES = { db.PostgreSQL, - db.Snowflake, db.MySQL, + db.Snowflake, db.BigQuery, - db.Presto, - db.Vertica, - db.Trino, db.Oracle, db.Redshift, + db.Presto, + db.Trino, + db.Vertica, } test_each_database = test_each_database_in_list(TEST_DATABASES) From d86950ef78e79e80299d6e327ddf972d46c7f22c Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 14 Oct 2022 14:48:14 +0200 Subject: [PATCH 45/93] Added differ name to tracking info --- data_diff/diff_tables.py | 1 + 1 file changed, 1 insertion(+) diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 1148041b..bf30cd9a 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -96,6 +96,7 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: if is_tracking_enabled(): options = dict(self) + options["differ_name"] = type(self).__name__ event_json = create_start_event_json(options) run_as_daemon(send_event_json, event_json) From a010c6a28c3622df81d4778c3ca1e5bb2d861554 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Sun, 16 Oct 2022 11:56:51 +0200 Subject: [PATCH 46/93] Added --materialize-all-rows switch --- data_diff/__main__.py | 9 ++++++++- data_diff/joindiff_tables.py | 17 +++++++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 5ca5e15b..4d0fc6ce 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -144,7 +144,12 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - @click.option( "--sample-exclusive-rows", is_flag=True, - help="Sample several rows that only appear in one of the tables, but not the other.", + help="Sample several rows that only appear in one of the tables, but not the other. (joindiff only)", +) +@click.option( + "--materialize-all-rows", + is_flag=True, + help="Materialize every row, even if they are the same, instead of just the differing rows. (joindiff only)", ) @click.option( "-j", @@ -214,6 +219,7 @@ def _main( where, assume_unique_key, sample_exclusive_rows, + materialize_all_rows, materialize, threads1=None, threads2=None, @@ -303,6 +309,7 @@ def _main( max_threadpool_size=threads and threads * 2, validate_unique_key=not assume_unique_key, sample_exclusive_rows=sample_exclusive_rows, + materialize_all_rows=materialize_all_rows, materialize_to_table=materialize and db1.parse_table_name(eval_name_template(materialize)), ) else: diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 641f11d0..af66e6f8 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -121,6 +121,7 @@ class JoinDiffer(TableDiffer): validate_unique_key: bool = True sample_exclusive_rows: bool = True materialize_to_table: DbPath = None + materialize_all_rows: bool = False write_limit: int = WRITE_LIMIT stats: dict = {} @@ -165,7 +166,7 @@ def _diff_segments( ) db = table1.database - diff_rows, a_cols, b_cols, is_diff_cols = self._create_outer_join(table1, table2) + diff_rows, a_cols, b_cols, is_diff_cols, all_rows = self._create_outer_join(table1, table2) with self._run_in_background( partial(self._collect_stats, 1, table1), @@ -173,7 +174,12 @@ def _diff_segments( partial(self._test_null_keys, table1, table2), partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols), partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols), - partial(self._materialize_diff, db, diff_rows, segment_index=segment_index) + partial( + self._materialize_diff, + db, + all_rows if self.materialize_all_rows else diff_rows, + segment_index=segment_index, + ) if self.materialize_to_table else None, ): @@ -263,10 +269,9 @@ def _create_outer_join(self, table1, table2): a_cols = {f"table1_{c}": NormalizeAsString(a[c]) for c in cols1} b_cols = {f"table2_{c}": NormalizeAsString(b[c]) for c in cols2} - diff_rows = _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols}).where( - or_(this[c] == 1 for c in is_diff_cols) - ) - return diff_rows, a_cols, b_cols, is_diff_cols + all_rows = _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols}) + diff_rows = all_rows.where(or_(this[c] == 1 for c in is_diff_cols)) + return diff_rows, a_cols, b_cols, is_diff_cols, all_rows def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): logger.info("Counting differences per column") From 08209e8064731bbbf3d36ec311626a696f8d946c Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Sun, 16 Oct 2022 12:09:12 +0200 Subject: [PATCH 47/93] Added tests for materialize_all_rows switch --- tests/test_joindiff.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index d3db82e0..60279f6f 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -133,9 +133,18 @@ def test_diff_small_tables(self): t = TablePath(materialize_path) rows = self.connection.query(t.select(), List[tuple]) - self.connection.query(t.drop()) # is_xa, is_xb, is_diff1, is_diff2, row1, row2 assert rows == [(1, 0, 1, 1) + expected_row + (None, None)], rows + self.connection.query(t.drop()) + + # Test materialize all rows + mdiffer = mdiffer.replace(materialize_all_rows=True) + diff = list(mdiffer.diff_tables(self.table, self.table2)) + self.assertEqual(expected, diff) + rows = self.connection.query(t.select(), List[tuple]) + assert len(rows) == 2, len(rows) + self.connection.query(t.drop()) + def test_diff_table_above_bisection_threshold(self): time = "2022-01-01 00:00:00" From 3bf7e1cc084b9365b6b4cc5a5cde6a71d39fae7c Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Sun, 16 Oct 2022 18:43:42 +0200 Subject: [PATCH 48/93] Added switch --table-write-limit --- data_diff/__main__.py | 10 +++++++++- data_diff/joindiff_tables.py | 10 +++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 4d0fc6ce..06b1cd60 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -12,7 +12,7 @@ from .utils import eval_name_template, remove_password_from_url, safezip, match_like from .diff_tables import Algorithm from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR -from .joindiff_tables import JoinDiffer +from .joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer from .table_segment import TableSegment from .databases.database_types import create_schema from .databases.connect import connect @@ -151,6 +151,12 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - is_flag=True, help="Materialize every row, even if they are the same, instead of just the differing rows. (joindiff only)", ) +@click.option( + "--table-write-limit", + default=TABLE_WRITE_LIMIT, + help=f"Maximum number of rows to write when creating materialized or sample tables, per thread. Default={TABLE_WRITE_LIMIT}", + metavar="COUNT", +) @click.option( "-j", "--threads", @@ -220,6 +226,7 @@ def _main( assume_unique_key, sample_exclusive_rows, materialize_all_rows, + table_write_limit, materialize, threads1=None, threads2=None, @@ -310,6 +317,7 @@ def _main( validate_unique_key=not assume_unique_key, sample_exclusive_rows=sample_exclusive_rows, materialize_all_rows=materialize_all_rows, + table_write_limit=table_write_limit, materialize_to_table=materialize and db1.parse_table_name(eval_name_template(materialize)), ) else: diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index af66e6f8..b630d66e 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -29,7 +29,7 @@ logger = logging.getLogger("joindiff_tables") -WRITE_LIMIT = 1000 +TABLE_WRITE_LIMIT = 1000 def merge_dicts(dicts): @@ -115,14 +115,14 @@ class JoinDiffer(TableDiffer): Future versions will detect UNIQUE constraints in the schema. sample_exclusive_rows (bool): Enable/disable sampling of exclusive rows. Creates a temporary table. materialize_to_table (DbPath, optional): Path of new table to write diff results to. Disabled if not provided. - write_limit (int): Maximum number of rows to write when materializing, per thread. + table_write_limit (int): Maximum number of rows to write when materializing, per thread. """ validate_unique_key: bool = True sample_exclusive_rows: bool = True materialize_to_table: DbPath = None materialize_all_rows: bool = False - write_limit: int = WRITE_LIMIT + table_write_limit: int = TABLE_WRITE_LIMIT stats: dict = {} def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: @@ -298,7 +298,7 @@ def exclusive_rows(expr): c = Compiler(db) name = c.new_unique_table_name("temp_table") exclusive_rows = table(name, schema=expr.source_table.schema) - yield create_temp_table(c, exclusive_rows, expr.limit(self.write_limit)) + yield create_temp_table(c, exclusive_rows, expr.limit(self.table_write_limit)) count = yield exclusive_rows.count() self.stats["exclusive_count"] = self.stats.get("exclusive_count", 0) + count[0][0] @@ -314,5 +314,5 @@ def exclusive_rows(expr): def _materialize_diff(self, db, diff_rows, segment_index=None): assert self.materialize_to_table - append_to_table(db, self.materialize_to_table, diff_rows.limit(self.write_limit)) + append_to_table(db, self.materialize_to_table, diff_rows.limit(self.table_write_limit)) logger.info("Materialized diff to table '%s'.", ".".join(self.materialize_to_table)) From c01e75faffcb54820262cc870e8d2f2fc0e15f54 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 19 Oct 2022 10:26:29 -0300 Subject: [PATCH 49/93] Small fix --- data_diff/__main__.py | 2 +- tests/test_joindiff.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 06b1cd60..50847b7c 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -411,7 +411,7 @@ def _main( jsonl = json.dumps([op, list(values)]) rich.print(f"[{color}]{jsonl}[/{color}]") else: - text = f"{op} {', '.join(values)}" + text = f"{op} {', '.join(map(str, values))}" rich.print(f"[{color}]{text}[/{color}]") sys.stdout.flush() diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index afdeeca9..03ca3d69 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -145,7 +145,6 @@ def test_diff_small_tables(self): assert len(rows) == 2, len(rows) self.connection.query(t.drop()) - def test_diff_table_above_bisection_threshold(self): time = "2022-01-01 00:00:00" time_obj = datetime.fromisoformat(time) From d308a43c3cb021340ec237bf1424de03fe164535 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 19 Oct 2022 12:04:01 -0300 Subject: [PATCH 50/93] Tests: Fix for Oracle --- data_diff/databases/base.py | 16 ++++++++++++++++ data_diff/databases/oracle.py | 4 +++- data_diff/queries/ast_classes.py | 15 +-------------- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index c33f7ead..f0a96f20 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -7,6 +7,7 @@ from concurrent.futures import ThreadPoolExecutor import threading from abc import abstractmethod +from uuid import UUID from data_diff.utils import is_uuid, safezip from data_diff.queries import Expr, Compiler, table, Select, SKIP, Explain @@ -328,6 +329,21 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: def random(self) -> str: return "RANDOM()" + def _constant_value(self, v): + if v is None: + return "NULL" + elif isinstance(v, str): + return f"'{v}'" + elif isinstance(v, datetime): + return f"timestamp '{v}'" + elif isinstance(v, UUID): + return f"'{v}'" + return repr(v) + + def constant_values(self, rows) -> str: + values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows) + return f"VALUES {values}" + def type_repr(self, t) -> str: if isinstance(t, str): return t diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index e65fd65a..80647ba3 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -1,7 +1,6 @@ from typing import Dict, List, Optional from ..utils import match_regexps - from .database_types import ( Decimal, Float, @@ -152,3 +151,6 @@ def type_repr(self, t) -> str: }[t] except KeyError: return super().type_repr(t) + + def constant_values(self, rows) -> str: + return " UNION ALL ".join("SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 0a7642d5..b5456b59 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -1,7 +1,6 @@ from dataclasses import field from datetime import datetime from typing import Any, Generator, List, Optional, Sequence, Tuple, Union -from uuid import UUID from runtype import dataclass @@ -611,20 +610,8 @@ class ConstantTable(ExprNode): def compile(self, c: Compiler) -> str: raise NotImplementedError() - def _value(self, v): - if v is None: - return "NULL" - elif isinstance(v, str): - return f"'{v}'" - elif isinstance(v, datetime): - return f"timestamp '{v}'" - elif isinstance(v, UUID): - return f"'{v}'" - return repr(v) - def compile_for_insert(self, c: Compiler): - values = ", ".join("(%s)" % ", ".join(self._value(v) for v in row) for row in self.rows) - return f"VALUES {values}" + return c.database.constant_values(self.rows) @dataclass From 83d9c51ebcd57628acbc0bfbc69d49054ea0fa10 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 19 Oct 2022 11:07:14 -0300 Subject: [PATCH 51/93] Version bump - Release candidate --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 386dba0d..46f579cd 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "data-diff" -version = "0.2.8" +version = "0.3.0rc1" description = "Command-line tool and Python library to efficiently diff rows across two different databases." authors = ["Datafold "] license = "MIT" From 9b6129a2c61d32e67bb84dd14f4da39e3488b868 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 14 Oct 2022 17:26:04 +0200 Subject: [PATCH 52/93] Now tests for unique key constraints (if possible) instead of always actively validating (+ tests) --- data_diff/databases/base.py | 15 ++++++++ data_diff/databases/bigquery.py | 5 ++- data_diff/databases/database_types.py | 11 ++++++ data_diff/databases/mysql.py | 1 + data_diff/databases/oracle.py | 1 + data_diff/databases/postgresql.py | 1 + data_diff/databases/snowflake.py | 5 ++- data_diff/joindiff_tables.py | 16 ++++++--- data_diff/queries/ast_classes.py | 8 +++-- tests/common.py | 1 + tests/test_joindiff.py | 49 +++++++++++++++++++++++++++ 11 files changed, 103 insertions(+), 10 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index f0a96f20..48a551b2 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -110,6 +110,7 @@ class Database(AbstractDatabase): TYPE_CLASSES: Dict[str, type] = {} default_schema: str = None SUPPORTS_ALPHANUMS = True + SUPPORTS_PRIMARY_KEY = False _interactive = False @@ -235,6 +236,20 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: assert len(d) == len(rows) return d + + def select_table_unique_columns(self, path: DbPath) -> str: + schema, table = self._normalize_table_path(path) + + return ( + "SELECT column_name " + "FROM information_schema.key_column_usage " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + ) + + def query_table_unique_columns(self, path: DbPath) -> List[str]: + res = self.query(self.select_table_unique_columns(path), List[str]) + return list(res) + def _process_table_schema( self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None ): diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 9c500dd5..3d3720b6 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import List, Union from .database_types import Timestamp, Datetime, Integer, Decimal, Float, Text, DbPath, FractionalType, TemporalType from .base import Database, import_helper, parse_table_name, ConnectError, apply_query from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter @@ -78,6 +78,9 @@ def select_table_schema(self, path: DbPath) -> str: f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) + def query_table_unique_columns(self, path: DbPath) -> List[str]: + return [] + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index 1de1d2fc..2a76ae05 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -253,6 +253,17 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: """ ... + @abstractmethod + def select_table_unique_columns(self, path: DbPath) -> str: + "Provide SQL for selecting the names of unique columns in the table" + ... + + @abstractmethod + def query_table_unique_columns(self, path: DbPath) -> List[str]: + """Query the table for its unique columns for table in 'path', and return {column} + """ + ... + @abstractmethod def _process_table_schema( self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 3f9eb98c..f7023946 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -39,6 +39,7 @@ class MySQL(ThreadedDatabase): } ROUNDS_ON_PREC_LOSS = True SUPPORTS_ALPHANUMS = False + SUPPORTS_PRIMARY_KEY = True def __init__(self, *, thread_count, **kw): self._args = kw diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 80647ba3..d59e5b55 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -38,6 +38,7 @@ class Oracle(ThreadedDatabase): "VARCHAR2": Text, } ROUNDS_ON_PREC_LOSS = True + SUPPORTS_PRIMARY_KEY = True def __init__(self, *, host, database, thread_count, **kw): self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 72d26d07..02920c2b 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -46,6 +46,7 @@ class PostgreSQL(ThreadedDatabase): "uuid": Native_UUID, } ROUNDS_ON_PREC_LOSS = True + SUPPORTS_PRIMARY_KEY = True default_schema = "public" diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 635ba8f4..afd52ba8 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, List import logging from .database_types import Timestamp, TimestampTZ, Decimal, Float, Text, FractionalType, TemporalType, DbPath @@ -95,3 +95,6 @@ def is_autocommit(self) -> bool: def explain_as_text(self, query: str) -> str: return f"EXPLAIN USING TEXT {query}" + + def query_table_unique_columns(self, path: DbPath) -> List[str]: + return [] diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index b630d66e..26789bc2 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -196,18 +196,24 @@ def _diff_segments( if not is_xa: yield "+", tuple(b_row) - def _test_duplicate_keys(self, table1, table2): + def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment): logger.debug("Testing for duplicate keys") # Test duplicate keys for ts in [table1, table2]: + unique = ts.database.query_table_unique_columns(ts.table_path) + t = ts.make_select() key_columns = ts.key_columns - q = t.select(total=Count(), total_distinct=Count(Concat(this[key_columns]), distinct=True)) - total, total_distinct = ts.database.query(q, tuple) - if total != total_distinct: - raise ValueError("Duplicate primary keys") + unvalidated = list(set(key_columns) - set(unique)) + if unvalidated: + # Validate that there are no duplicate keys + self.stats['validated_unique_keys'] = self.stats.get('validated_unique_keys', []) + [unvalidated] + q = t.select(total=Count(), total_distinct=Count(Concat(this[unvalidated]), distinct=True)) + total, total_distinct = ts.database.query(q, tuple) + if total != total_distinct: + raise ValueError("Duplicate primary keys") def _test_null_keys(self, table1, table2): logger.debug("Testing for null keys") diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index b5456b59..9b7fe63a 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -298,12 +298,12 @@ class TablePath(ExprNode, ITable): path: DbPath schema: Optional[Schema] = field(default=None, repr=False) - def create(self, source_table: ITable = None, *, if_not_exists=False): + def create(self, source_table: ITable = None, *, if_not_exists=False, primary_keys=None): if source_table is None and not self.schema: raise ValueError("Either schema or source table needed to create table") if isinstance(source_table, TablePath): source_table = source_table.select() - return CreateTable(self, source_table, if_not_exists=if_not_exists) + return CreateTable(self, source_table, if_not_exists=if_not_exists, primary_keys=primary_keys) def drop(self, if_exists=False): return DropTable(self, if_exists=if_exists) @@ -634,6 +634,7 @@ class CreateTable(Statement): path: TablePath source_table: Expr = None if_not_exists: bool = False + primary_keys: List[str] = None def compile(self, c: Compiler) -> str: ne = "IF NOT EXISTS " if self.if_not_exists else "" @@ -641,7 +642,8 @@ def compile(self, c: Compiler) -> str: return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}" schema = ", ".join(f"{c.database.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items()) - return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})" + pks = ", PRIMARY KEY (%s)" % ', '.join(self.primary_keys) if self.primary_keys and c.database.SUPPORTS_PRIMARY_KEY else "" + return f"CREATE TABLE {ne}{c.compile(self.path)}({schema}{pks})" @dataclass diff --git a/tests/common.py b/tests/common.py index aad75074..cd974e34 100644 --- a/tests/common.py +++ b/tests/common.py @@ -149,6 +149,7 @@ def tearDown(self): def _parameterized_class_per_conn(test_databases): + test_databases = set(test_databases) names = [(cls.__name__, cls) for cls in CONN_STRINGS if cls in test_databases] return parameterized_class(("name", "db_cls"), names) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 03ca3d69..557b207f 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -273,3 +273,52 @@ def test_null_pks(self): x = self.differ.diff_tables(self.table, self.table2) self.assertRaises(ValueError, list, x) + + +@test_each_database_in_list(d for d in TEST_DATABASES if d.SUPPORTS_PRIMARY_KEY) +class TestPrimaryKeys(TestPerDatabase): + def setUp(self): + super().setUp() + + self.src_table = table( + self.table_src_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float}, + ) + self.dst_table = table( + self.table_dst_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float}, + ) + + self.connection.query([ + self.src_table.create(primary_keys=['id']), + self.dst_table.create(primary_keys=['id', 'userid']), + commit + ] + ) + + self.differ = JoinDiffer() + + def test_unique_constraint(self): + self.connection.query( + [ + self.src_table.insert_rows([[1, 1, 1, 9], [2, 2, 2, 9]]), + self.dst_table.insert_rows([[1, 1, 1, 9], [2, 2, 2, 9]]), + commit, + ] + ) + + # Test no active validation + table = TableSegment(self.connection, self.table_src_path, ("id",), case_sensitive=False) + table2 = TableSegment(self.connection, self.table_dst_path, ("id",), case_sensitive=False) + + res = list(self.differ.diff_tables(table, table2)) + assert not res + assert 'validated_unique_keys' not in self.differ.stats + + # Test active validation + table = TableSegment(self.connection, self.table_src_path, ("userid",), case_sensitive=False) + table2 = TableSegment(self.connection, self.table_dst_path, ("userid",), case_sensitive=False) + + res = list(self.differ.diff_tables(table, table2)) + assert not res + self.assertEqual( self.differ.stats['validated_unique_keys'], [['userid']] ) From 9ad272df0c738a74da94246003332ad65a44ad40 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 14 Oct 2022 17:26:16 +0200 Subject: [PATCH 53/93] black --- data_diff/databases/base.py | 1 - data_diff/databases/database_types.py | 3 +-- data_diff/joindiff_tables.py | 2 +- data_diff/queries/ast_classes.py | 6 +++++- tests/test_joindiff.py | 11 ++++------- 5 files changed, 11 insertions(+), 12 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 48a551b2..18192152 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -236,7 +236,6 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: assert len(d) == len(rows) return d - def select_table_unique_columns(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index 2a76ae05..c02cc3e2 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -260,8 +260,7 @@ def select_table_unique_columns(self, path: DbPath) -> str: @abstractmethod def query_table_unique_columns(self, path: DbPath) -> List[str]: - """Query the table for its unique columns for table in 'path', and return {column} - """ + """Query the table for its unique columns for table in 'path', and return {column}""" ... @abstractmethod diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 26789bc2..863a8450 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -209,7 +209,7 @@ def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment): unvalidated = list(set(key_columns) - set(unique)) if unvalidated: # Validate that there are no duplicate keys - self.stats['validated_unique_keys'] = self.stats.get('validated_unique_keys', []) + [unvalidated] + self.stats["validated_unique_keys"] = self.stats.get("validated_unique_keys", []) + [unvalidated] q = t.select(total=Count(), total_distinct=Count(Concat(this[unvalidated]), distinct=True)) total, total_distinct = ts.database.query(q, tuple) if total != total_distinct: diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 9b7fe63a..4b93efdf 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -642,7 +642,11 @@ def compile(self, c: Compiler) -> str: return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}" schema = ", ".join(f"{c.database.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items()) - pks = ", PRIMARY KEY (%s)" % ', '.join(self.primary_keys) if self.primary_keys and c.database.SUPPORTS_PRIMARY_KEY else "" + pks = ( + ", PRIMARY KEY (%s)" % ", ".join(self.primary_keys) + if self.primary_keys and c.database.SUPPORTS_PRIMARY_KEY + else "" + ) return f"CREATE TABLE {ne}{c.compile(self.path)}({schema}{pks})" diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 557b207f..35e2c79f 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -289,11 +289,8 @@ def setUp(self): schema={"id": int, "userid": int, "movieid": int, "rating": float}, ) - self.connection.query([ - self.src_table.create(primary_keys=['id']), - self.dst_table.create(primary_keys=['id', 'userid']), - commit - ] + self.connection.query( + [self.src_table.create(primary_keys=["id"]), self.dst_table.create(primary_keys=["id", "userid"]), commit] ) self.differ = JoinDiffer() @@ -313,7 +310,7 @@ def test_unique_constraint(self): res = list(self.differ.diff_tables(table, table2)) assert not res - assert 'validated_unique_keys' not in self.differ.stats + assert "validated_unique_keys" not in self.differ.stats # Test active validation table = TableSegment(self.connection, self.table_src_path, ("userid",), case_sensitive=False) @@ -321,4 +318,4 @@ def test_unique_constraint(self): res = list(self.differ.diff_tables(table, table2)) assert not res - self.assertEqual( self.differ.stats['validated_unique_keys'], [['userid']] ) + self.assertEqual(self.differ.stats["validated_unique_keys"], [["userid"]]) From 56b45da4aafdbb4c8cea033d469bf51abc55ec4a Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 17 Oct 2022 13:31:25 +0200 Subject: [PATCH 54/93] Added Database.SUPPORTS_UNIQUE_CONSTAINT --- data_diff/databases/base.py | 3 +++ data_diff/databases/mysql.py | 1 + data_diff/databases/oracle.py | 1 + data_diff/databases/postgresql.py | 1 + data_diff/joindiff_tables.py | 2 +- 5 files changed, 7 insertions(+), 1 deletion(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 18192152..7d0ac864 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -111,6 +111,7 @@ class Database(AbstractDatabase): default_schema: str = None SUPPORTS_ALPHANUMS = True SUPPORTS_PRIMARY_KEY = False + SUPPORTS_UNIQUE_CONSTAINT = False _interactive = False @@ -246,6 +247,8 @@ def select_table_unique_columns(self, path: DbPath) -> str: ) def query_table_unique_columns(self, path: DbPath) -> List[str]: + if not self.SUPPORTS_UNIQUE_CONSTAINT: + raise NotImplementedError("This database doesn't support 'unique' constraints") res = self.query(self.select_table_unique_columns(path), List[str]) return list(res) diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index f7023946..e8e47b1b 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -40,6 +40,7 @@ class MySQL(ThreadedDatabase): ROUNDS_ON_PREC_LOSS = True SUPPORTS_ALPHANUMS = False SUPPORTS_PRIMARY_KEY = True + SUPPORTS_UNIQUE_CONSTAINT = True def __init__(self, *, thread_count, **kw): self._args = kw diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index d59e5b55..c135f71f 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -39,6 +39,7 @@ class Oracle(ThreadedDatabase): } ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True + SUPPORTS_UNIQUE_CONSTAINT = True def __init__(self, *, host, database, thread_count, **kw): self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 02920c2b..3181dab1 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -47,6 +47,7 @@ class PostgreSQL(ThreadedDatabase): } ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True + SUPPORTS_UNIQUE_CONSTAINT = True default_schema = "public" diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 863a8450..1107ba87 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -201,7 +201,7 @@ def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment): # Test duplicate keys for ts in [table1, table2]: - unique = ts.database.query_table_unique_columns(ts.table_path) + unique = ts.database.query_table_unique_columns(ts.table_path) if ts.database.SUPPORTS_UNIQUE_CONSTAINT else [] t = ts.make_select() key_columns = ts.key_columns From 3e82588d4f10815d612100f683e1f11249766ac7 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 19 Oct 2022 15:23:56 -0300 Subject: [PATCH 55/93] Fix for Oracle --- data_diff/databases/oracle.py | 1 - tests/test_joindiff.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index c135f71f..d59e5b55 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -39,7 +39,6 @@ class Oracle(ThreadedDatabase): } ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True - SUPPORTS_UNIQUE_CONSTAINT = True def __init__(self, *, host, database, thread_count, **kw): self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 35e2c79f..22ed217d 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -275,8 +275,8 @@ def test_null_pks(self): self.assertRaises(ValueError, list, x) -@test_each_database_in_list(d for d in TEST_DATABASES if d.SUPPORTS_PRIMARY_KEY) -class TestPrimaryKeys(TestPerDatabase): +@test_each_database_in_list(d for d in TEST_DATABASES if d.SUPPORTS_PRIMARY_KEY and d.SUPPORTS_UNIQUE_CONSTAINT) +class TestUniqueConstraint(TestPerDatabase): def setUp(self): super().setUp() From d7205cbb428e4e6e346fffc17db8ca75ecc55575 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 20 Oct 2022 10:03:50 -0300 Subject: [PATCH 56/93] Various small fixes for 'queries' module --- data_diff/databases/base.py | 1 + data_diff/databases/database_types.py | 9 ++-- data_diff/queries/api.py | 4 +- data_diff/queries/ast_classes.py | 67 +++++++++++++++------------ data_diff/queries/compiler.py | 10 ++-- 5 files changed, 51 insertions(+), 40 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index f0a96f20..189dced4 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -335,6 +335,7 @@ def _constant_value(self, v): elif isinstance(v, str): return f"'{v}'" elif isinstance(v, datetime): + # TODO use self.timestamp_value return f"timestamp '{v}'" elif isinstance(v, UUID): return f"'{v}'" diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index 1de1d2fc..86df4489 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -147,12 +147,12 @@ class AbstractDialect(ABC): @abstractmethod def quote(self, s: str): - "Quote SQL name (implementation specific)" + "Quote SQL name" ... @abstractmethod def concat(self, l: List[str]) -> str: - "Provide SQL for concatenating a bunch of column into a string" + "Provide SQL for concatenating a bunch of columns into a string" ... @abstractmethod @@ -162,12 +162,13 @@ def is_distinct_from(self, a: str, b: str) -> str: @abstractmethod def to_string(self, s: str) -> str: + # TODO rewrite using cast_to(x, str) "Provide SQL for casting a column to string" ... @abstractmethod def random(self) -> str: - "Provide SQL for generating a random number" + "Provide SQL for generating a random number betweein 0..1" @abstractmethod def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): @@ -176,7 +177,7 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None @abstractmethod def explain_as_text(self, query: str) -> str: - "Provide SQL for explaining a query, returned in as table(varchar)" + "Provide SQL for explaining a query, returned as table(varchar)" ... @abstractmethod diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index 2f5d96be..797fafa5 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -47,14 +47,14 @@ def or_(*exprs: Expr): exprs = args_as_tuple(exprs) if len(exprs) == 1: return exprs[0] - return BinOp("OR", exprs) + return BinBoolOp("OR", exprs) def and_(*exprs: Expr): exprs = args_as_tuple(exprs) if len(exprs) == 1: return exprs[0] - return BinOp("AND", exprs) + return BinBoolOp("AND", exprs) def sum_(expr: Expr): diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index b5456b59..59395ca2 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -32,12 +32,13 @@ def cast_to(self, to): Expr = Union[ExprNode, str, bool, int, datetime, ArithString, None] -def get_type(e: Expr) -> type: +def _expr_type(e: Expr) -> type: if isinstance(e, ExprNode): return e.type return type(e) + @dataclass class Alias(ExprNode): expr: Expr @@ -48,7 +49,7 @@ def compile(self, c: Compiler) -> str: @property def type(self): - return get_type(self.expr) + return _expr_type(self.expr) def _drop_skips(exprs): @@ -156,6 +157,8 @@ class Count(ExprNode): expr: Expr = "*" distinct: bool = False + type = int + def compile(self, c: Compiler) -> str: expr = c.compile(self.expr) if self.distinct: @@ -174,11 +177,6 @@ def compile(self, c: Compiler) -> str: return f"{self.name}({args})" -def _expr_type(e: Expr): - if isinstance(e, ExprNode): - return e.type - return type(e) - @dataclass class CaseWhen(ExprNode): @@ -226,6 +224,9 @@ def __le__(self, other): def __or__(self, other): return BinBoolOp("OR", [self, other]) + def __and__(self, other): + return BinBoolOp("AND", [self, other]) + def is_distinct_from(self, other): return IsDistinctFrom(self, other) @@ -254,7 +255,7 @@ def compile(self, c: Compiler) -> str: @property def type(self): - types = {get_type(i) for i in self.args} + types = {_expr_type(i) for i in self.args} if len(types) > 1: raise TypeError(f"Expected all args to have the same type, got {types}") (t,) = types @@ -298,6 +299,16 @@ class TablePath(ExprNode, ITable): path: DbPath schema: Optional[Schema] = field(default=None, repr=False) + @property + def source_table(self): + return self + + def compile(self, c: Compiler) -> str: + path = self.path # c.database._normalize_table_path(self.name) + return ".".join(map(c.quote, path)) + + # Statement shorthands + def create(self, source_table: ITable = None, *, if_not_exists=False): if source_table is None and not self.schema: raise ValueError("Either schema or source table needed to create table") @@ -323,14 +334,6 @@ def insert_expr(self, expr: Expr): expr = expr.select() return InsertToTable(self, expr) - @property - def source_table(self): - return self - - def compile(self, c: Compiler) -> str: - path = self.path # c.database._normalize_table_path(self.name) - return ".".join(map(c.quote, path)) - @dataclass class TableAlias(ExprNode, ITable): @@ -386,7 +389,7 @@ def compile(self, parent_c: Compiler) -> str: tables = [ t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in self.source_tables ] - c = parent_c.add_table_context(*tables).replace(in_join=True, in_select=False) + c = parent_c.add_table_context(*tables, in_join=True, in_select=False) op = " JOIN " if self.op is None else f" {self.op} JOIN " joined = op.join(c.compile(t) for t in tables) @@ -408,7 +411,7 @@ def compile(self, parent_c: Compiler) -> str: class GroupBy(ITable): def having(self): - pass + raise NotImplementedError() @dataclass @@ -546,26 +549,26 @@ class _ResolveColumn(ExprNode, LazyOps): resolve_name: str resolved: Expr = None - def resolve(self, expr): - assert self.resolved is None + def resolve(self, expr: Expr): + if self.resolved is not None: + raise RuntimeError("Already resolved!") self.resolved = expr - def compile(self, c: Compiler) -> str: + def _get_resolved(self) -> Expr: if self.resolved is None: raise RuntimeError(f"Column not resolved: {self.resolve_name}") - return self.resolved.compile(c) + return self.resolved + + def compile(self, c: Compiler) -> str: + return self._get_resolved().compile(c) @property def type(self): - if self.resolved is None: - raise RuntimeError(f"Column not resolved: {self.resolve_name}") - return self.resolved.type + return self._get_resolved().type @property def name(self): - if self.resolved is None: - raise RuntimeError(f"Column not resolved: {self.name}") - return self.resolved.name + return self._get_resolved().name class This: @@ -583,6 +586,8 @@ class In(ExprNode): expr: Expr list: Sequence[Expr] + type = bool + def compile(self, c: Compiler): elems = ", ".join(map(c.compile, self.list)) return f"({c.compile(self.expr)} IN ({elems}))" @@ -599,6 +604,8 @@ def compile(self, c: Compiler) -> str: @dataclass class Random(ExprNode): + type = float + def compile(self, c: Compiler) -> str: return c.database.random() @@ -618,6 +625,8 @@ def compile_for_insert(self, c: Compiler): class Explain(ExprNode): select: Select + type = str + def compile(self, c: Compiler) -> str: return c.database.explain_as_text(c.compile(self.select)) @@ -640,7 +649,7 @@ def compile(self, c: Compiler) -> str: if self.source_table: return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}" - schema = ", ".join(f"{c.database.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items()) + schema = ", ".join(f"{c.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items()) return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})" diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index eda7d981..31242131 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -21,9 +21,6 @@ class Compiler: _counter: List = [0] - def quote(self, s: str): - return self.database.quote(s) - def compile(self, elem) -> str: res = self._compile(elem) if self.root and self._subqueries: @@ -57,8 +54,11 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath: self._counter[0] += 1 return self.database.parse_table_name(f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}") - def add_table_context(self, *tables: Sequence): - return self.replace(_table_context=self._table_context + list(tables)) + def add_table_context(self, *tables: Sequence, **kw): + return self.replace(_table_context=self._table_context + list(tables), **kw) + + def quote(self, s: str): + return self.database.quote(s) class Compilable(ABC): From fc16e921c05439b52307f0b01db7d150992be7b3 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 20 Oct 2022 10:35:33 -0300 Subject: [PATCH 57/93] Rewrite CaseAwareMapping --- data_diff/utils.py | 54 ++++++++++++++-------------------------------- 1 file changed, 16 insertions(+), 38 deletions(-) diff --git a/data_diff/utils.py b/data_diff/utils.py index 2c8ccfba..76576327 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -1,8 +1,8 @@ import logging import re import math -from typing import Iterable, Tuple, Union, Any, Sequence, Dict -from typing import TypeVar, Generic +from typing import Iterable, Iterator, MutableMapping, Union, Any, Sequence, Dict +from typing import TypeVar from abc import ABC, abstractmethod from urllib.parse import urlparse from uuid import UUID @@ -204,58 +204,36 @@ def join_iter(joiner: Any, iterable: Iterable) -> Iterable: V = TypeVar("V") -class CaseAwareMapping(ABC, Generic[V]): +class CaseAwareMapping(MutableMapping[str, V]): @abstractmethod def get_key(self, key: str) -> str: ... - @abstractmethod - def __getitem__(self, key: str) -> V: - ... - - @abstractmethod - def __setitem__(self, key: str, value: V): - ... - - @abstractmethod - def __contains__(self, key: str) -> bool: - ... - - def __repr__(self): - return repr(dict(self.items())) - - @abstractmethod - def items(self) -> Iterable[Tuple[str, V]]: - ... - class CaseInsensitiveDict(CaseAwareMapping): def __init__(self, initial): self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()} - def get_key(self, key: str) -> str: - return self._dict[key.lower()][0] - def __getitem__(self, key: str) -> V: return self._dict[key.lower()][1] - def __setitem__(self, key: str, value): - k = key.lower() - if k in self._dict: - key = self._dict[k][0] - self._dict[k] = key, value + def __iter__(self) -> Iterator[V]: + return iter(self._dict) + + def __len__(self) -> int: + return len(self._dict) - def __contains__(self, key): - return key.lower() in self._dict + def __setitem__(self, key: str, value): + self._dict[key.lower()] = key, value - def keys(self) -> Iterable[str]: - return self._dict.keys() + def __delitem__(self, key: str): + del self._dict[key.lower()] - def items(self) -> Iterable[Tuple[str, V]]: - return ((k, v[1]) for k, v in self._dict.items()) + def get_key(self, key: str) -> str: + return self._dict[key.lower()][0] - def __len__(self): - return len(self._dict) + def __repr__(self) -> str: + return repr(dict(self.items())) class CaseSensitiveDict(dict, CaseAwareMapping): From 5e3a220c1b2ee398caceeb8d0aaf41098bb0b11b Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 20 Oct 2022 10:38:35 -0300 Subject: [PATCH 58/93] Ran black --- data_diff/databases/oracle.py | 4 +++- data_diff/queries/ast_classes.py | 2 -- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 80647ba3..bd849e59 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -153,4 +153,6 @@ def type_repr(self, t) -> str: return super().type_repr(t) def constant_values(self, rows) -> str: - return " UNION ALL ".join("SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows) + return " UNION ALL ".join( + "SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows + ) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 59395ca2..88d7ab11 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -38,7 +38,6 @@ def _expr_type(e: Expr) -> type: return type(e) - @dataclass class Alias(ExprNode): expr: Expr @@ -177,7 +176,6 @@ def compile(self, c: Compiler) -> str: return f"{self.name}({args})" - @dataclass class CaseWhen(ExprNode): cases: Sequence[Tuple[Expr, Expr]] From 601d2bbcd3d83818f26b7ac7a0c51e388a2e7f11 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 20 Oct 2022 16:25:02 -0300 Subject: [PATCH 59/93] Fix --- data_diff/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/data_diff/utils.py b/data_diff/utils.py index 76576327..a2b7e801 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -224,7 +224,10 @@ def __len__(self) -> int: return len(self._dict) def __setitem__(self, key: str, value): - self._dict[key.lower()] = key, value + k = key.lower() + if k in self._dict: + key = self._dict[k][0] + self._dict[k] = key, value def __delitem__(self, key: str): del self._dict[key.lower()] From 1fc52c25bab2a86248a63a40b131cf80bdf4875c Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 20 Oct 2022 16:44:49 -0300 Subject: [PATCH 60/93] README rewritten by upper management. I'm forced to push as-is, may contain incorrect details! --- CONTRIBUTING.md | 83 ++++++++ README.md | 537 +++++++++++++++++++----------------------------- 2 files changed, 289 insertions(+), 331 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6a6418d9..56c717ab 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -79,3 +79,86 @@ New databases should be added as a new module in the `data-diff/databases/` fold If possible, please also add the database setup to `docker-compose.yml`, so that we can run and test it for ourselves. If you do, also update the CI (`ci.yml`). Guide to implementing a new database driver: https://data-diff.readthedocs.io/en/latest/new-database-driver-guide.html + +## Development Setup + +The development setup centers around using `docker-compose` to boot up various +databases, and then inserting data into them. + +For Mac for performance of Docker, we suggest enabling in the UI: + +* Use new Virtualization Framework +* Enable VirtioFS accelerated directory sharing + +**1. Install Data Diff** + +When developing/debugging, it's recommended to install dependencies and run it +directly with `poetry` rather than go through the package. + +``` +$ brew install mysql postgresql # MacOS dependencies for C bindings +$ apt-get install libpq-dev libmysqlclient-dev # Debian dependencies +$ pip install poetry # Python dependency isolation tool +$ poetry install # Install dependencies +``` +**2. Start Databases** + +[Install **docker-compose**][docker-compose] if you haven't already. + +```shell-session +$ docker-compose up -d mysql postgres # run mysql and postgres dbs in background +``` + +[docker-compose]: https://docs.docker.com/compose/install/ + +**3. Run Unit Tests** + +There are more than 1000 tests for all the different type and database +combinations, so we recommend using `unittest-parallel` that's installed as a +development dependency. + +```shell-session +$ poetry run unittest-parallel -j 16 # run all tests +$ poetry run python -m unittest -k # run individual test +``` + +**4. Seed the Database(s) (optional)** + +First, download the CSVs of seeding data: + +```shell-session +$ curl https://datafold-public.s3.us-west-2.amazonaws.com/1m.csv -o dev/ratings.csv +# For a larger data-set (but takes 25x longer to import): +# - curl https://datafold-public.s3.us-west-2.amazonaws.com/25m.csv -o dev/ratings.csv +``` + +Now you can insert it into the testing database(s): + +```shell-session +# It's optional to seed more than one to run data-diff(1) against. +$ poetry run preql -f dev/prepare_db.pql mysql://mysql:Password1@127.0.0.1:3306/mysql +$ poetry run preql -f dev/prepare_db.pql postgresql://postgres:Password1@127.0.0.1:5432/postgres +# Cloud databases +$ poetry run preql -f dev/prepare_db.pql snowflake:// +$ poetry run preql -f dev/prepare_db.pql mssql:// +$ poetry run preql -f dev/prepare_db.pql bigquery:/// +``` + +**5. Run **data-diff** against seeded database (optional)** + +```bash +poetry run python3 -m data_diff postgresql://postgres:Password1@localhost/postgres rating postgresql://postgres:Password1@localhost/postgres rating_del1 --verbose +``` + +**6. Run benchmarks (optional)** + +```shell-session +$ dev/benchmark.sh # runs benchmarks and puts results in benchmark_.csv +$ poetry run python3 dev/graph.py # create graphs from benchmark_*.csv files +``` + +You can adjust how many rows we benchmark with by passing `N_SAMPLES` to `dev/benchmark.sh`: + +```shell-session +$ N_SAMPLES=100000000 dev/benchmark.sh # 100m which is our canonical target +``` diff --git a/README.md b/README.md index 83f93ae9..815da8fa 100644 --- a/README.md +++ b/README.md @@ -1,232 +1,176 @@ -# **data-diff** - -**data-diff is in shape to be run in production, but also under development. If -you run into issues or bugs, please [open an issue](https://github.com/datafold/data-diff/issues/new/choose) and we'll help you out ASAP! You can -also find us in `#tools-data-diff` in the [Locally Optimistic Slack][slack].** - -**We'd love to hear about your experience using data-diff, and learn more your use cases. [Reach out to product team share any product feedback or feature requests!](https://calendly.com/jp-toor/customer-interview-oss)** - -💸💸 **Looking for paid contributors!** 💸💸 If you're up for making money working on awesome open-source tools, we're looking for developers with a deep understanding of databases and solid Python knowledge. [**Apply here!**](https://docs.google.com/forms/d/e/1FAIpQLScEa5tc9CM0uNsb3WigqRFq92OZENkThM04nIs7ZVl_bwsGMw/viewform) +

+ Datafold +

----- +# **data-diff** -**data-diff** is a command-line tool and Python library to efficiently diff -rows across two different databases. +## What is `data-diff`? +data-diff is a **free, open-source tool** that enables data professionals to detect differences in values between any two tables. It's fast, easy to use, and reliable. Even at massive scale. -* ⇄ Verifies across [many different databases][dbs] (e.g. PostgreSQL -> Snowflake) -* 🔍 Outputs [diff of rows](#example-command-and-output) in detail -* 🚨 Simple CLI/API to create monitoring and alerts -* 🔁 Bridges column types of different formats and levels of precision (e.g. Double ⇆ Float ⇆ Decimal) -* 🔥 Fast! Verify 25M+ rows in <10s, and 1B+ rows in ~5min. -* ♾️ Works for tables with 10s of billions of rows +_Are you a developer with a deep understanding of databases and solid Python knowledge? [We're hiring!](https://www.datafold.com/careers)_ -data-diff can diff tables within the same database, or across different databases. +## Use cases -**Same-DB Diff**: Uses an outer-join to diff the rows as efficiently and accurately as possible. +### Diff Tables Between Databases +#### Quickly identify issues when moving data between databases -Supports materializing the diff results to a database table. +

+ diff2 +

-Can also collect various extra statistics about the tables. +### Diff Tables Within a Database (available in pre release) +#### Improve code reviews by identifying data problems you don't have tests for +

+ + Intro to Diff + +

-**Cross-DB Diff**: Employs a divide and conquer algorithm based on hashing, optimized for few changes. +  +  -data-diff splits the table into smaller segments, then checksums each -segment in both databases. When the checksums for a segment aren't equal, it -will further divide that segment into yet smaller segments, checksumming those -until it gets to the differing row(s). See [Technical Explanation][tech-explain] for more -details. +## Get started -This approach has performance within an order of magnitude of `count(*)` when -there are few/no changes, but is able to output each differing row! By pushing -the compute into the databases, it's _much_ faster than querying for and -comparing every row. +### Installation -![Performance for 100M rows](https://user-images.githubusercontent.com/97400/175182987-a3900d4e-c097-4732-a4e9-19a40fac8cdc.png) +#### First, install `data-diff` using `pip`. -**†:** The implementation for downloading all rows that `data-diff` and -`count(*)` is compared to is not optimal. It is a single Python multi-threaded -process. The performance is fairly driver-specific, e.g. PostgreSQL's performs 10x -better than MySQL. +``` +pip install data-diff +``` -## Table of Contents - -- [**data-diff**](#data-diff) - - [Table of Contents](#table-of-contents) - - [Common use-cases](#common-use-cases) - - [Example Command and Output](#example-command-and-output) - - [Supported Databases](#supported-databases) -- [How to install](#how-to-install) - - [Install drivers](#install-drivers) -- [How to use](#how-to-use) - - [How to use from the command-line](#how-to-use-from-the-command-line) - - [How to use from Python](#how-to-use-from-python) -- [Technical Explanation](#technical-explanation) - - [Performance Considerations](#performance-considerations) -- [Anonymous Tracking](#anonymous-tracking) -- [Development Setup](#development-setup) -- [License](#license) - -## Common use-cases - -* **Verify data migrations.** Verify that all data was copied when doing a - critical data migration. For example, migrating from Heroku PostgreSQL to Amazon RDS. -* **Verifying data pipelines.** Moving data from a relational database to a - warehouse/data lake with Fivetran, Airbyte, Debezium, or some other pipeline. -* **Alerting and maintaining data integrity SLOs.** You can create and monitor - your SLO of e.g. 99.999% data integrity, and alert your team when data is - missing. -* **Debugging complex data pipelines.** When data gets lost in pipelines that - may span a half-dozen systems, without verifying each intermediate datastore - it's extremely difficult to track down where a row got lost. -* **Detecting hard deletes for an `updated_at`-based pipeline**. If you're - copying data to your warehouse based on an `updated_at`-style column, data-diff - can find any hard-deletes that you might have missed. -* **Make your replication self-healing.** You can use **data-diff** to - self-heal by using the diff output to write/update rows in the target - database. - -## Example Command and Output - -Below we run a comparison with the CLI for 25M rows in PostgreSQL where the -right-hand table is missing single row with `id=12500048`: +To try out bleeding-edge features, including materialization of results in your data warehouse: ``` -$ data-diff \ - postgresql://user:password@localhost/database rating \ - postgresql://user:password@localhost/database rating_del1 \ - --bisection-threshold 100000 \ # for readability, try default first - --bisection-factor 6 \ # for readability, try default first - --update-column timestamp \ - --verbose - - # Consider running with --interactive the first time. - # Runs `EXPLAIN` for you to verify the queries are using indexes. - # --interactive -[10:15:00] INFO - Diffing tables | segments: 6, bisection threshold: 100000. -[10:15:00] INFO - . Diffing segment 1/6, key-range: 1..4166683, size: 4166682 -[10:15:03] INFO - . Diffing segment 2/6, key-range: 4166683..8333365, size: 4166682 -[10:15:06] INFO - . Diffing segment 3/6, key-range: 8333365..12500047, size: 4166682 -[10:15:09] INFO - . Diffing segment 4/6, key-range: 12500047..16666729, size: 4166682 -[10:15:12] INFO - . . Diffing segment 1/6, key-range: 12500047..13194494, size: 694447 -[10:15:13] INFO - . . . Diffing segment 1/6, key-range: 12500047..12615788, size: 115741 -[10:15:13] INFO - . . . . Diffing segment 1/6, key-range: 12500047..12519337, size: 19290 -[10:15:13] INFO - . . . . Diff found 1 different rows. -[10:15:13] INFO - . . . . Diffing segment 2/6, key-range: 12519337..12538627, size: 19290 -[10:15:13] INFO - . . . . Diffing segment 3/6, key-range: 12538627..12557917, size: 19290 -[10:15:13] INFO - . . . . Diffing segment 4/6, key-range: 12557917..12577207, size: 19290 -[10:15:13] INFO - . . . . Diffing segment 5/6, key-range: 12577207..12596497, size: 19290 -[10:15:13] INFO - . . . . Diffing segment 6/6, key-range: 12596497..12615788, size: 19291 -[10:15:13] INFO - . . . Diffing segment 2/6, key-range: 12615788..12731529, size: 115741 -[10:15:13] INFO - . . . Diffing segment 3/6, key-range: 12731529..12847270, size: 115741 -[10:15:13] INFO - . . . Diffing segment 4/6, key-range: 12847270..12963011, size: 115741 -[10:15:14] INFO - . . . Diffing segment 5/6, key-range: 12963011..13078752, size: 115741 -[10:15:14] INFO - . . . Diffing segment 6/6, key-range: 13078752..13194494, size: 115742 -[10:15:14] INFO - . . Diffing segment 2/6, key-range: 13194494..13888941, size: 694447 -[10:15:14] INFO - . . Diffing segment 3/6, key-range: 13888941..14583388, size: 694447 -[10:15:15] INFO - . . Diffing segment 4/6, key-range: 14583388..15277835, size: 694447 -[10:15:15] INFO - . . Diffing segment 5/6, key-range: 15277835..15972282, size: 694447 -[10:15:15] INFO - . . Diffing segment 6/6, key-range: 15972282..16666729, size: 694447 -+ (12500048, 1268104625) -[10:15:16] INFO - . Diffing segment 5/6, key-range: 16666729..20833411, size: 4166682 -[10:15:19] INFO - . Diffing segment 6/6, key-range: 20833411..25000096, size: 4166685 +pip install data-diff --pre ``` -## Supported Databases +#### Then, install one or more driver(s) specific to the database(s) you want to connect to. -| Database | Connection string | Status | -|---------------|-------------------------------------------------------------------------------------------------------------------------------------|--------| -| PostgreSQL >=10 | `postgresql://:@:5432/` | 💚 | -| MySQL | `mysql://:@:5432/` | 💚 | -| Snowflake | `"snowflake://[:]@//?warehouse=&role=[&authenticator=externalbrowser]"` | 💚 | -| BigQuery | `bigquery:///` | 💚 | -| Redshift | `redshift://:@:5439/` | 💚 | -| Oracle | `oracle://:@/database` | 💛 | -| Presto | `presto://:@:8080/` | 💛 | -| Databricks | `databricks://:@//` | 💛 | -| Trino | `trino://:@:8080/` | 💛 | -| Clickhouse | `clickhouse://:@:9000/` | 💛 | -| Vertica | `vertica://:@:5433/` | 💛 | -| ElasticSearch | | 📝 | -| Planetscale | | 📝 | -| Pinot | | 📝 | -| Druid | | 📝 | -| Kafka | | 📝 | -| DuckDB | | 📝 | -| SQLite | | 📝 | +- `pip install 'data-diff[mysql]'` -* 💚: Implemented and thoroughly tested. -* 💛: Implemented, but not thoroughly tested yet. -* ⏳: Implementation in progress. -* 📝: Implementation planned. Contributions welcome. +- `pip install 'data-diff[postgresql]'` -If a database is not on the list, we'd still love to support it. Open an issue -to discuss it. +- `pip install 'data-diff[snowflake]'` -Note: Because URLs allow many special characters, and may collide with the syntax of your command-line, -it's recommended to surround them with quotes. Alternatively, you may provide them in a TOML file via the `--config` option. +- `pip install 'data-diff[presto]'` +- `pip install 'data-diff[oracle]'` -# How to install +- `pip install 'data-diff[trino]'` -Requires Python 3.7+ with pip. +- `pip install 'data-diff[clickhouse]'` -```pip install data-diff``` +- `pip install 'data-diff[vertica]'` -## Install drivers +- For BigQuery, see: https://pypi.org/project/google-cloud-bigquery/ -To connect to a database, we need to have its driver installed, in the form of a Python library. +_Some drivers have dependencies that cannot be installed using `pip` and still need to be installed manually._ -While you may install them manually, we offer an easy way to install them along with data-diff*: +### Run your first diff -- `pip install 'data-diff[mysql]'` +Once you've installed `data-diff`, you can run it from the command line. -- `pip install 'data-diff[postgresql]'` +``` +data-diff DB1_URI TABLE1_NAME DB2_URI TABLE2_NAME [OPTIONS] +``` -- `pip install 'data-diff[snowflake]'` +Be sure to read [the How to Use section below](#how-to-use) which gets into specific details about how to build one of these commands depending on your database setup. -- `pip install 'data-diff[presto]'` +#### Code Example: Diff Tables Between Databases +Here's an example command for your copy/pasting, taken from the screenshot above when we diffed data between Snowflake and Postgres. -- `pip install 'data-diff[oracle]'` +``` +data-diff \ + postgresql://:''@localhost:5432/ \ + \ + "snowflake://:@//?warehouse=&role=" \ +
\ + -k activity_id \ + -c activity \ + -w "event_timestamp < '2022-10-10'" +``` -- `pip install 'data-diff[trino]'` +#### Code Example: Diff Tables Within a Database (available in pre release) -- `pip install 'data-diff[clickhouse]'` +Here's a code example from [the video](https://www.loom.com/share/682e4b7d74e84eb4824b983311f0a3b2), where we compare data between two Snowflake tables within one database. -- `pip install 'data-diff[vertica]'` +``` +data-diff \ + "snowflake://:@//?warehouse=&role=" \ + . \ + -k org_id \ + -c created_at -c is_internal \ + -w "org_id != 1 and org_id < 2000" \ + -m test_results_%t \ + --materialize-all-rows \ + --table-write-limit 10000 +``` -- For BigQuery, see: https://pypi.org/project/google-cloud-bigquery/ +In both code examples, I've used `<>` carrots to represent values that **should be replaced with your values** in the database connection strings. For the flags (`-k`, `-c`, etc.), I opted for "real" values (`org_id`, `is_internal`) to give you a more realistic view of what your command will look like. +### We're here to help! -Users can also install several drivers at once: +We know, that `data-diff DB1_URI TABLE1_NAME DB2_URI TABLE2_NAME [OPTIONS]` command can become long and dense. And maybe you're new to the command line. -```pip install 'data-diff[mysql,postgresql,snowflake]'``` +We're here to help [on slack](https://locallyoptimistic.slack.com/archives/C03HUNGQV0S) if you have ANY questions as you use `data-diff` in your workflow. -_* Some drivers have dependencies that cannot be installed using `pip` and still need to be installed manually._ +## How to Use +This section gets into more details, including: +- [database-specific syntax](#how-to-use-from-the-command-line) +- [the many options (flags) you can use beyond the examples presented above](#options) +- [how to run `data-diff` using a TOML configuration file](#how-to-use-with-a-configuration-file) +- [how to run`data-diff` from Python](#how-to-use-from-python) +### How to use from the command line -### Install Psycopg2 +To run `data-diff` from the command line, run this command: -In order to run Postgresql, you'll need `psycopg2`. This Python package requires some additional dependencies described in their [documentation](https://www.psycopg.org/docs/install.html#build-prerequisites). -An easy solution is to install [psycopg2-binary](https://www.psycopg.org/docs/install.html#quick-install) by running: +`data-diff DB1_URI TABLE1_NAME DB2_URI TABLE2_NAME [OPTIONS]` -```pip install psycopg2-binary``` +Let's break this down. Assume there are two tables stored in two databases, and you want to know the differences between those tables. -Which comes with a pre-compiled binary and does not require additonal prerequisites. However, note that for production use it is adviced to use `psycopg2`. +- `DB1_URI` will be a string that `data-diff` uses to connect to the database where the first table is stored. +- `TABLE1_NAME` is the name of the table in the `DB1_URI` database. +- `DB2_URI` will be a string that `data-diff` uses to connect to the database where the second table is stored. +- `TABLE2_NAME` is the name of the second table in the `DB2_URI` database. +- `[OPTIONS]` can be replaced with a variety of additional commands, [detailed here](#options). -# How to use -## How to use from the command-line +| Database | Connection string | Status | +|---------------|-------------------------------------------------------------------------------------------------------------------------------------|--------| +| PostgreSQL >=10 | `postgresql://:''@:5432/` | 💚 | +| MySQL | `mysql://:@:5432/` | 💚 | +| Snowflake | **With password:**`"snowflake://:@//?warehouse=&role="`
**With SSO:** `"snowflake://@//?warehouse=&role=&authenticator=externalbrowser"`
_Note: Unless something is explicitly case sensitive (like your password) use all caps._ | 💚 | +| BigQuery | `bigquery:///` | 💚 | +| Redshift | `redshift://:@:5439/` | 💚 | +| Oracle | `oracle://:@/database` | 💛 | +| Presto | `presto://:@:8080/` | 💛 | +| Databricks | `databricks://:@//` | 💛 | +| Trino | `trino://:@:8080/` | 💛 | +| Clickhouse | `clickhouse://:@:9000/` | 💛 | +| Vertica | `vertica://:@:5433/` | 💛 | +| ElasticSearch | | 📝 | +| Planetscale | | 📝 | +| Pinot | | 📝 | +| Druid | | 📝 | +| Kafka | | 📝 | +| DuckDB | | 📝 | +| SQLite | | 📝 | -Usage: `data-diff DB1_URI TABLE1_NAME DB2_URI TABLE2_NAME [OPTIONS]` +* 💚: Implemented and thoroughly tested. +* 💛: Implemented, but not thoroughly tested yet. +* ⏳: Implementation in progress. +* 📝: Implementation planned. Contributions welcome. -See the [example command](#example-command-and-output) and the [sample -connection strings](#supported-databases). +If a database is not on the list, we'd still love to support it. Open an issue +to discuss it. -Note that for some databases, the arguments that you enter in the command line -may be case-sensitive. This is the case for the Snowflake schema and table names. +Note: Because URLs allow many special characters, and may collide with the syntax of your command-line, +it's recommended to surround them with quotes. Alternatively, you may [provide them in a TOML file](#how-to-use-with-a-configuration-file) via the `--config` option. -Options: +#### Options - `--help` - Show help message and exit. - `-k` or `--key-columns` - Name of the primary key column. If none provided, default is 'id'. @@ -248,19 +192,24 @@ Options: - `-w`, `--where` - An additional 'where' expression to restrict the search space. - `--conf`, `--run` - Specify the run and configuration from a TOML file. (see below) - `--no-tracking` - data-diff sends home anonymous usage data. Use this to disable it. - - `-a`, `--algorithm` `[auto|joindiff|hashdiff]` - Force algorithm choice -Same-DB diff only: + **The following two options are not available when using the pre release In-DB feature:** + + - `--bisection-threshold` - Minimal size of segment to be split. Smaller segments will be downloaded and compared locally. + - `--bisection-factor` - Segments per iteration. When set to 2, it performs binary search. + +**In-DB commands, available in pre release only:** - `-m`, `--materialize` - Materialize the diff results into a new table in the database. If a table exists by that name, it will be replaced. Use `%t` in the name to place a timestamp. Example: `-m test_mat_%t` - `--assume-unique-key` - Skip validating the uniqueness of the key column during joindiff, which is costly in non-cloud dbs. - `--sample-exclusive-rows` - Sample several rows that only appear in one of the tables, but not the other. Use with `-s`. + - `--materialize-all-rows` - Materialize every row, even if they are the same, instead of just the differing rows. + - `--table-write-limit` - Maximum number of rows to write when creating materialized or sample tables, per thread. Default=1000. + - `-a`, `--algorithm` `[auto|joindiff|hashdiff]` - Force algorithm choice + -Cross-DB diff only: - - `--bisection-threshold` - Minimal size of segment to be split. Smaller segments will be downloaded and compared locally. - - `--bisection-factor` - Segments per iteration. When set to 2, it performs binary search. @@ -268,11 +217,11 @@ Cross-DB diff only: Data-diff lets you load the configuration for a run from a TOML file. -Reasons to use a configuration file: +**Reasons to use a configuration file:** -- Convenience - Set-up the parameters for diffs that need to run often +- Convenience: Set-up the parameters for diffs that need to run often -- Easier and more readable - you can define the database connection settings as config values, instead of in a URI. +- Easier and more readable: You can define the database connection settings as config values, instead of in a URI. - Gives you fine-grained control over the settings switches, without requiring any Python code. @@ -313,8 +262,7 @@ flag is overwritten to `false`. Running it with `data-diff --conf myconfig.toml --run test_diff -v` will set verbose back to `true`. - -## How to use from Python +### How to use from Python API reference: [https://data-diff.readthedocs.io/en/latest/](https://data-diff.readthedocs.io/en/latest/) @@ -337,7 +285,62 @@ for different_row in diff_tables(table1, table2): Run `help(diff_tables)` or [read the docs](https://data-diff.readthedocs.io/en/latest/) to learn about the different options. -# Technical Explanation +## Reporting bugs and contributing + +- [Open an issue](https://github.com/datafold/data-diff/issues/new/choose) or chat with us [on slack](https://locallyoptimistic.slack.com/archives/C03HUNGQV0S). +- Interested in contributing to this open source project? Please see our [Contributing Guideline](https://github.com/datafold/data-diff/blob/master/CONTRIBUTING.md)! +- Did we mention [we're hiring](https://www.datafold.com/careers)? + +## Usage Analytics & Data Privacy + +data-diff collects anonymous usage data to help our team improve the tool and to apply development efforts to where our users need them most. + +We capture two events: one when the data-diff run starts, and one when it is finished. No user data or potentially sensitive information is or ever will be collected. The captured data is limited to: + +- Operating System and Python version +- Types of databases used (postgresql, mysql, etc.) +- Sizes of tables diffed, run time, and diff row count (numbers only) +- Error message, if any, truncated to the first 20 characters. +- A persistent UUID to indentify the session, stored in `~/.datadiff.toml` + +If you do not wish to participate, the tracking can be easily disabled with one of the following methods: + +* In the CLI, use the `--no-tracking` flag. +* In the config file, set `no_tracking = true` (for example, under `[run.default]`) +* If you're using the Python API: +```python +import data_diff +data_diff.disable_tracking() # Call this first, before making any API calls +# Connect and diff your tables without any tracking +``` + +## Technical Explanation + +### Overview + +data-diff splits the table into smaller segments, then checksums each segment in both databases. When the checksums for a segment aren't equal, it will further divide that segment into yet smaller segments, checksumming those until it gets to the differing row(s). + +This approach has performance within an order of magnitude of count(*) when there are few/no changes, but is able to output each differing row! By pushing the compute into the databases, it's much faster than querying for and comparing every row. + +![Performance for 100M rows](https://user-images.githubusercontent.com/97400/175182987-a3900d4e-c097-4732-a4e9-19a40fac8cdc.png) + +**†:** The implementation for downloading all rows that `data-diff` and +`count(*)` is compared to is not optimal. It is a single Python multi-threaded +process. The performance is fairly driver-specific, e.g. PostgreSQL's performs 10x +better than MySQL. + +### A note on same-db diff vs cross-db diff + +data-diff can diff tables within the same database, or across different databases. + +**Same-DB Diff:** +- Uses an outer-join to diff the rows as efficiently and accurately as possible. +- Supports materializing the diff results to a database table. +- Can also collect various extra statistics about the tables. + +**Cross-DB Diff:** Employs a divide and conquer algorithm based on hashing, optimized for few changes. + +### Deep Dive In this section we'll be doing a walk-through of exactly how **data-diff** works, and how to tune `--bisection-factor` and `--bisection-threshold`. @@ -411,11 +414,11 @@ WHERE (id >= 1) AND (id < 100000) This keeps the amount of data that has to be transferred between the databases to a minimum, making it very performant! Additionally, if you have an index on -`updated_at` (highly recommended) then the query will be fast as the database +`updated_at` (highly recommended), then the query will be fast, as the database only has to do a partial index scan between `id=1..100k`. If you are not sure whether the queries are using an index, you can run it with -`--interactive`. This puts **data-diff** in interactive mode where it shows an +`--interactive`. This puts **data-diff** in interactive mode, where it shows an `EXPLAIN` before executing each query, requiring confirmation to proceed. After running the checksum queries on both sides, we see that all segments @@ -443,7 +446,7 @@ Now **data-diff** will do exactly as it just did for the _whole table_ for only this segment: Split it into `--bisection-factor` segments. However, this time, because each segment has `100k/10=10k` entries, which is -less than the `--bisection-threshold` it will pull down every row in the segment +less than the `--bisection-threshold`, it will pull down every row in the segment and compare them in memory in **data-diff**. ``` @@ -470,35 +473,31 @@ Finally **data-diff** will output the `(id, updated_at)` for each row that was d (122001, 1653672821) ``` -If you pass `--stats` you'll see e.g. what % of rows were different. +If you pass `--stats` you'll see stats such as the % of rows were different. -## Performance Considerations +### Performance Considerations * Ensure that you have indexes on the columns you are comparing. Preferably a compound index. You can run with `--interactive` to see an `EXPLAIN` for the queries. * Consider increasing the number of simultaneous threads executing queries per database with `--threads`. For databases that limit concurrency - per query, e.g. PostgreSQL/MySQL, this can improve performance dramatically. + per query, such as PostgreSQL/MySQL, this can improve performance dramatically. * If you are only interested in _whether_ something changed, pass `--limit 1`. This can be useful if changes are very rare. This is often faster than doing a `count(*)`, for the reason mentioned above. -* If the table is _very_ large, consider a larger `--bisection-factor`. Explained in - the [technical explanation][tech-explain]. Otherwise you may run into timeouts. +* If the table is _very_ large, consider a larger `--bisection-factor`. Otherwise, you may run into timeouts. * If there are a lot of changes, consider a larger `--bisection-threshold`. - Explained in the [technical explanation][tech-explain]. -* If there are very large gaps in your key column, e.g. 10s of millions of - continuous rows missing, then **data-diff** may perform poorly doing lots of - queries for ranges of rows that do not exist (see [technical - explanation][tech-explain]). We have ideas on how to tackle this issue, which we have - yet to implement. If you're experiencing this effect, please open an issue and we +* If there are very large gaps in your key column (e.g., 10s of millions of + continuous rows missing), then **data-diff** may perform poorly, doing lots of + queries for ranges of rows that do not exist. We have ideas on how to tackle this issue, which we have yet to implement. If you're experiencing this effect, please open an issue, and we will prioritize it. * The fewer columns you verify (passed with `--columns`), the faster - **data-diff** will be. On one extreme you can verify every column, on the - other you can verify _only_ `updated_at`, if you trust it enough. You can also - _only_ verify `id` if you're interested in only presence, e.g. to detect + **data-diff** will be. On one extreme, you can verify every column; on the + other, you can verify _only_ `updated_at`, if you trust it enough. You can also + _only_ verify `id` if you're interested in only presence, such as to detect missing hard deletes. You can do also do a hybrid where you verify - `updated_at` and the most critical value, e.g a money value in `amount` but + `updated_at` and the most critical value, such as a money value in `amount`, but not verify a large serialized column like `json_settings`. * We have ideas for making **data-diff** even faster that we haven't implemented yet: faster checksums by reducing type-casts @@ -507,130 +506,6 @@ If you pass `--stats` you'll see e.g. what % of rows were different. gaps), and improvements to bypass Python/driver performance limitations when comparing huge amounts of rows locally (i.e. for very high `bisection_threshold` values). -# Usage Analytics - -data-diff collects anonymous usage data to help our team improve the tool and to apply development efforts to where our users need them most. - -We capture two events, one when the data-diff run starts and one when it is finished. No user data or potentially sensitive information is or ever will be collected. The captured data is limited to: - -- Operating System and Python version - -- Types of databases used (postgresql, mysql, etc.) - -- Sizes of tables diffed, run time, and diff row count (numbers only) - -- Error message, if any, truncated to the first 20 characters. - -- A persistent UUID to indentify the session, stored in `~/.datadiff.toml` - -If you do not wish to participate, the tracking can be easily disabled with one of the following methods: - -* In the CLI, use the `--no-tracking` flag. - -* In the config file, set `no_tracking = true` (for example, under `[run.default]`) - -* If you're using the Python API: - -```python -import data_diff -data_diff.disable_tracking() # Call this first, before making any API calls - -# Connect and diff your tables without any tracking -``` - - -# Development Setup - -The development setup centers around using `docker-compose` to boot up various -databases, and then inserting data into them. - -For Mac for performance of Docker, we suggest enabling in the UI: - -* Use new Virtualization Framework -* Enable VirtioFS accelerated directory sharing - -**1. Install Data Diff** - -When developing/debugging, it's recommended to install dependencies and run it -directly with `poetry` rather than go through the package. - -``` -$ brew install mysql postgresql # MacOS dependencies for C bindings -$ apt-get install libpq-dev libmysqlclient-dev # Debian dependencies - -$ pip install poetry # Python dependency isolation tool -$ poetry install # Install dependencies -``` -**2. Start Databases** - -[Install **docker-compose**][docker-compose] if you haven't already. - -```shell-session -$ docker-compose up -d mysql postgres # run mysql and postgres dbs in background -``` - -[docker-compose]: https://docs.docker.com/compose/install/ - -**3. Run Unit Tests** - -There are more than 1000 tests for all the different type and database -combinations, so we recommend using `unittest-parallel` that's installed as a -development dependency. - -```shell-session -$ poetry run unittest-parallel -j 16 # run all tests -$ poetry run python -m unittest -k # run individual test -``` - -**4. Seed the Database(s) (optional)** - -First, download the CSVs of seeding data: - -```shell-session -$ curl https://datafold-public.s3.us-west-2.amazonaws.com/1m.csv -o dev/ratings.csv - -# For a larger data-set (but takes 25x longer to import): -# - curl https://datafold-public.s3.us-west-2.amazonaws.com/25m.csv -o dev/ratings.csv -``` - -Now you can insert it into the testing database(s): - -```shell-session -# It's optional to seed more than one to run data-diff(1) against. -$ poetry run preql -f dev/prepare_db.pql mysql://mysql:Password1@127.0.0.1:3306/mysql -$ poetry run preql -f dev/prepare_db.pql postgresql://postgres:Password1@127.0.0.1:5432/postgres - -# Cloud databases -$ poetry run preql -f dev/prepare_db.pql snowflake:// -$ poetry run preql -f dev/prepare_db.pql mssql:// -$ poetry run preql -f dev/prepare_db.pql bigquery:/// -``` - -**5. Run **data-diff** against seeded database (optional)** - -```bash -poetry run python3 -m data_diff postgresql://postgres:Password1@localhost/postgres rating postgresql://postgres:Password1@localhost/postgres rating_del1 --verbose -``` - -**6. Run benchmarks (optional)** - -```shell-session -$ dev/benchmark.sh # runs benchmarks and puts results in benchmark_.csv -$ poetry run python3 dev/graph.py # create graphs from benchmark_*.csv files -``` - -You can adjust how many rows we benchmark with by passing `N_SAMPLES` to `dev/benchmark.sh`: - -```shell-session -$ N_SAMPLES=100000000 dev/benchmark.sh # 100m which is our canonical target -``` - - -# License - -[MIT License](https://github.com/datafold/data-diff/blob/master/LICENSE) +## License -[dbs]: #supported-databases -[tech-explain]: #technical-explanation -[perf]: #performance-considerations -[slack]: https://locallyoptimistic.com/community/ +This project is licensed under the terms of the [MIT License](https://github.com/datafold/data-diff/blob/master/LICENSE). From 12b63dfc04bd2d2007cc2f9209beb2671452d5d3 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 20 Oct 2022 17:32:14 -0300 Subject: [PATCH 61/93] Improve docs --- docs/index.rst | 44 +++++++++++++++++++++++++++++-------- docs/supported-databases.md | 29 ++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 9 deletions(-) create mode 100644 docs/supported-databases.md diff --git a/docs/index.rst b/docs/index.rst index af5f5c5d..b3fad229 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,6 +3,7 @@ :caption: Reference :hidden: + supported-databases python-api new-database-driver-guide @@ -33,12 +34,40 @@ Requires Python 3.7+ with pip. pip install data-diff -or when you need extras like mysql and postgresql: +For installing with 3rd-party database connectors, use the following syntax: :: + pip install "data-diff[db1,db2]" + + e.g. pip install "data-diff[mysql,postgresql]" +Supported connectors: + +- mysql +- postgresql +- snowflake +- presto +- oracle +- trino +- clickhouse +- vertica + + +How to use from the shell +------------------------- + +.. code-block:: bash + + # Same-DB diff, using outer join + $ data-diff DB TABLE1 TABLE2 [options] + + # Cross-DB diff, using hashes + $ data-diff DB1 TABLE1 DB2 TABLE2 [options] + +We recommend using a configuration file, with the ``--conf`` switch, to keep the command simple and managable. + How to use from Python ---------------------- @@ -66,12 +95,9 @@ How to use from Python Resources --------- -- Source code (git): ``_ -- API Reference - - :doc:`python-api` -- Guides +- Users + - Source code (git): ``_ + - :doc:`supported-databases` + - :doc:`python-api` +- Contributors - :doc:`new-database-driver-guide` -- Tutorials - - TODO - - diff --git a/docs/supported-databases.md b/docs/supported-databases.md new file mode 100644 index 00000000..7cfef6ad --- /dev/null +++ b/docs/supported-databases.md @@ -0,0 +1,29 @@ +# List of supported databases + +| Database | Status | Connection string | +|---------------|-------------------------------------------------------------------------------------------------------------------------------------|--------| +| PostgreSQL >=10 | 💚 | `postgresql://:@:5432/` | +| MySQL | 💚 | `mysql://:@:5432/` | +| Snowflake | 💚 | `"snowflake://[:]@//?warehouse=&role=[&authenticator=externalbrowser]"` | +| BigQuery | 💚 | `bigquery:///` | +| Redshift | 💚 | `redshift://:@:5439/` | +| Oracle | 💛 | `oracle://:@/database` | +| Presto | 💛 | `presto://:@:8080/` | +| Databricks | 💛 | `databricks://:@//` | +| Trino | 💛 | `trino://:@:8080/` | +| Clickhouse | 💛 | `clickhouse://:@:9000/` | +| Vertica | 💛 | `vertica://:@:5433/` | +| ElasticSearch | 📝 | | +| Planetscale | 📝 | | +| Pinot | 📝 | | +| Druid | 📝 | | +| Kafka | 📝 | | +| DuckDB | 📝 | | +| SQLite | 📝 | | + +* 💚: Implemented and thoroughly tested. +* 💛: Implemented, but not thoroughly tested yet. +* ⏳: Implementation in progress. +* 📝: Implementation planned. Contributions welcome. + +Is your database not listed here? We accept pull-requests! From 6e54c479f51f611b29930e316acf71a9384bd00a Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 20 Oct 2022 17:51:17 -0300 Subject: [PATCH 62/93] Fix docs & README, pass 1 --- README.md | 190 +-------------------------------- docs/how-to-use.md | 159 ++++++++++++++++++++++++++++ docs/index.rst | 40 +------ docs/technical-explanation.md | 191 ++++++++++++++++++++++++++++++++++ 4 files changed, 355 insertions(+), 225 deletions(-) create mode 100644 docs/how-to-use.md create mode 100644 docs/technical-explanation.md diff --git a/README.md b/README.md index 815da8fa..d9ea28f0 100644 --- a/README.md +++ b/README.md @@ -316,195 +316,7 @@ data_diff.disable_tracking() # Call this first, before making any API calls ## Technical Explanation -### Overview - -data-diff splits the table into smaller segments, then checksums each segment in both databases. When the checksums for a segment aren't equal, it will further divide that segment into yet smaller segments, checksumming those until it gets to the differing row(s). - -This approach has performance within an order of magnitude of count(*) when there are few/no changes, but is able to output each differing row! By pushing the compute into the databases, it's much faster than querying for and comparing every row. - -![Performance for 100M rows](https://user-images.githubusercontent.com/97400/175182987-a3900d4e-c097-4732-a4e9-19a40fac8cdc.png) - -**†:** The implementation for downloading all rows that `data-diff` and -`count(*)` is compared to is not optimal. It is a single Python multi-threaded -process. The performance is fairly driver-specific, e.g. PostgreSQL's performs 10x -better than MySQL. - -### A note on same-db diff vs cross-db diff - -data-diff can diff tables within the same database, or across different databases. - -**Same-DB Diff:** -- Uses an outer-join to diff the rows as efficiently and accurately as possible. -- Supports materializing the diff results to a database table. -- Can also collect various extra statistics about the tables. - -**Cross-DB Diff:** Employs a divide and conquer algorithm based on hashing, optimized for few changes. - -### Deep Dive - -In this section we'll be doing a walk-through of exactly how **data-diff** -works, and how to tune `--bisection-factor` and `--bisection-threshold`. - -Let's consider a scenario with an `orders` table with 1M rows. Fivetran is -replicating it contionously from PostgreSQL to Snowflake: - -``` -┌─────────────┐ ┌─────────────┐ -│ PostgreSQL │ │ Snowflake │ -├─────────────┤ ├─────────────┤ -│ │ │ │ -│ │ │ │ -│ │ ┌─────────────┐ │ table with │ -│ table with ├──┤ replication ├──────▶│ ?maybe? all │ -│lots of rows!│ └─────────────┘ │ the same │ -│ │ │ rows. │ -│ │ │ │ -│ │ │ │ -│ │ │ │ -└─────────────┘ └─────────────┘ -``` - -In order to check whether the two tables are the same, **data-diff** splits -the table into `--bisection-factor=10` segments. - -We also have to choose which columns we want to checksum. In our case, we care -about the primary key, `--key-column=id` and the update column -`--update-column=updated_at`. `updated_at` is updated every time the row is, and -we have an index on it. - -**data-diff** starts by querying both databases for the `min(id)` and `max(id)` -of the table. Then it splits the table into `--bisection-factor=10` segments of -`1M/10 = 100K` keys each: - -``` -┌──────────────────────┐ ┌──────────────────────┐ -│ PostgreSQL │ │ Snowflake │ -├──────────────────────┤ ├──────────────────────┤ -│ id=1..100k │ │ id=1..100k │ -├──────────────────────┤ ├──────────────────────┤ -│ id=100k..200k │ │ id=100k..200k │ -├──────────────────────┤ ├──────────────────────┤ -│ id=200k..300k ├─────────────▶│ id=200k..300k │ -├──────────────────────┤ ├──────────────────────┤ -│ id=300k..400k │ │ id=300k..400k │ -├──────────────────────┤ ├──────────────────────┤ -│ ... │ │ ... │ -├──────────────────────┤ ├──────────────────────┤ -│ 900k..100k │ │ 900k..100k │ -└───────────────────▲──┘ └▲─────────────────────┘ - ┃ ┃ - ┃ ┃ - ┃ checksum queries ┃ - ┃ ┃ - ┌─┻──────────────────┻────┐ - │ data-diff │ - └─────────────────────────┘ -``` - -Now **data-diff** will start running `--threads=1` queries in parallel that -checksum each segment. The queries for checksumming each segment will look -something like this, depending on the database: - -```sql -SELECT count(*), - sum(cast(conv(substring(md5(concat(cast(id as char), cast(timestamp as char))), 18), 16, 10) as unsigned)) -FROM `rating_del1` -WHERE (id >= 1) AND (id < 100000) -``` - -This keeps the amount of data that has to be transferred between the databases -to a minimum, making it very performant! Additionally, if you have an index on -`updated_at` (highly recommended), then the query will be fast, as the database -only has to do a partial index scan between `id=1..100k`. - -If you are not sure whether the queries are using an index, you can run it with -`--interactive`. This puts **data-diff** in interactive mode, where it shows an -`EXPLAIN` before executing each query, requiring confirmation to proceed. - -After running the checksum queries on both sides, we see that all segments -are the same except `id=100k..200k`: - -``` -┌──────────────────────┐ ┌──────────────────────┐ -│ PostgreSQL │ │ Snowflake │ -├──────────────────────┤ ├──────────────────────┤ -│ checksum=0102 │ │ checksum=0102 │ -├──────────────────────┤ mismatch! ├──────────────────────┤ -│ checksum=ffff ◀──────────────▶ checksum=aaab │ -├──────────────────────┤ ├──────────────────────┤ -│ checksum=abab │ │ checksum=abab │ -├──────────────────────┤ ├──────────────────────┤ -│ checksum=f0f0 │ │ checksum=f0f0 │ -├──────────────────────┤ ├──────────────────────┤ -│ ... │ │ ... │ -├──────────────────────┤ ├──────────────────────┤ -│ checksum=9494 │ │ checksum=9494 │ -└──────────────────────┘ └──────────────────────┘ -``` - -Now **data-diff** will do exactly as it just did for the _whole table_ for only -this segment: Split it into `--bisection-factor` segments. - -However, this time, because each segment has `100k/10=10k` entries, which is -less than the `--bisection-threshold`, it will pull down every row in the segment -and compare them in memory in **data-diff**. - -``` -┌──────────────────────┐ ┌──────────────────────┐ -│ PostgreSQL │ │ Snowflake │ -├──────────────────────┤ ├──────────────────────┤ -│ id=100k..110k │ │ id=100k..110k │ -├──────────────────────┤ ├──────────────────────┤ -│ id=110k..120k │ │ id=110k..120k │ -├──────────────────────┤ ├──────────────────────┤ -│ id=120k..130k │ │ id=120k..130k │ -├──────────────────────┤ ├──────────────────────┤ -│ id=130k..140k │ │ id=130k..140k │ -├──────────────────────┤ ├──────────────────────┤ -│ ... │ │ ... │ -├──────────────────────┤ ├──────────────────────┤ -│ 190k..200k │ │ 190k..200k │ -└──────────────────────┘ └──────────────────────┘ -``` - -Finally **data-diff** will output the `(id, updated_at)` for each row that was different: - -``` -(122001, 1653672821) -``` - -If you pass `--stats` you'll see stats such as the % of rows were different. - -### Performance Considerations - -* Ensure that you have indexes on the columns you are comparing. Preferably a - compound index. You can run with `--interactive` to see an `EXPLAIN` for the - queries. -* Consider increasing the number of simultaneous threads executing - queries per database with `--threads`. For databases that limit concurrency - per query, such as PostgreSQL/MySQL, this can improve performance dramatically. -* If you are only interested in _whether_ something changed, pass `--limit 1`. - This can be useful if changes are very rare. This is often faster than doing a - `count(*)`, for the reason mentioned above. -* If the table is _very_ large, consider a larger `--bisection-factor`. Otherwise, you may run into timeouts. -* If there are a lot of changes, consider a larger `--bisection-threshold`. -* If there are very large gaps in your key column (e.g., 10s of millions of - continuous rows missing), then **data-diff** may perform poorly, doing lots of - queries for ranges of rows that do not exist. We have ideas on how to tackle this issue, which we have yet to implement. If you're experiencing this effect, please open an issue, and we - will prioritize it. -* The fewer columns you verify (passed with `--columns`), the faster - **data-diff** will be. On one extreme, you can verify every column; on the - other, you can verify _only_ `updated_at`, if you trust it enough. You can also - _only_ verify `id` if you're interested in only presence, such as to detect - missing hard deletes. You can do also do a hybrid where you verify - `updated_at` and the most critical value, such as a money value in `amount`, but - not verify a large serialized column like `json_settings`. -* We have ideas for making **data-diff** even faster that - we haven't implemented yet: faster checksums by reducing type-casts - and using a faster hash than MD5, dynamic adaptation of - `bisection_factor`/`threads`/`bisection_threshold` (especially with large key - gaps), and improvements to bypass Python/driver performance limitations when - comparing huge amounts of rows locally (i.e. for very high `bisection_threshold` values). +See here: https://data-diff.readthedocs.io/en/latest/technical-explanation.html ## License diff --git a/docs/how-to-use.md b/docs/how-to-use.md new file mode 100644 index 00000000..6f78a634 --- /dev/null +++ b/docs/how-to-use.md @@ -0,0 +1,159 @@ +# How to use + +## How to use from the shell (or: command-line) + +Run the following command: + +```bash + # Same-DB diff, using outer join + $ data-diff DB TABLE1 TABLE2 [options] + + # Cross-DB diff, using hashes + $ data-diff DB1 TABLE1 DB2 TABLE2 [options] +``` + +Where DB is either a database URL that's compatible with SQLAlchemy, or the name of a database specified in a configuration file. + +We recommend using a configuration file, with the ``--conf`` switch, to keep the command simple and managable. + +### Options + + - `--help` - Show help message and exit. + - `-k` or `--key-columns` - Name of the primary key column. If none provided, default is 'id'. + - `-t` or `--update-column` - Name of updated_at/last_updated column + - `-c` or `--columns` - Names of extra columns to compare. Can be used more than once in the same command. + Accepts a name or a pattern like in SQL. + Example: `-c col% -c another_col -c %foorb.r%` + - `-l` or `--limit` - Maximum number of differences to find (limits maximum bandwidth and runtime) + - `-s` or `--stats` - Print stats instead of a detailed diff + - `-d` or `--debug` - Print debug info + - `-v` or `--verbose` - Print extra info + - `-i` or `--interactive` - Confirm queries, implies `--debug` + - `--json` - Print JSONL output for machine readability + - `--min-age` - Considers only rows older than specified. Useful for specifying replication lag. + Example: `--min-age=5min` ignores rows from the last 5 minutes. + Valid units: `d, days, h, hours, min, minutes, mon, months, s, seconds, w, weeks, y, years` + - `--max-age` - Considers only rows younger than specified. See `--min-age`. + - `-j` or `--threads` - Number of worker threads to use per database. Default=1. + - `-w`, `--where` - An additional 'where' expression to restrict the search space. + - `--conf`, `--run` - Specify the run and configuration from a TOML file. (see below) + - `--no-tracking` - data-diff sends home anonymous usage data. Use this to disable it. + + **The following two options are not available when using the pre release In-DB feature:** + + - `--bisection-threshold` - Minimal size of segment to be split. Smaller segments will be downloaded and compared locally. + - `--bisection-factor` - Segments per iteration. When set to 2, it performs binary search. + +**In-DB commands, available in pre release only:** + - `-m`, `--materialize` - Materialize the diff results into a new table in the database. + If a table exists by that name, it will be replaced. + Use `%t` in the name to place a timestamp. + Example: `-m test_mat_%t` + - `--assume-unique-key` - Skip validating the uniqueness of the key column during joindiff, which is costly in non-cloud dbs. + - `--sample-exclusive-rows` - Sample several rows that only appear in one of the tables, but not the other. Use with `-s`. + - `--materialize-all-rows` - Materialize every row, even if they are the same, instead of just the differing rows. + - `--table-write-limit` - Maximum number of rows to write when creating materialized or sample tables, per thread. Default=1000. + - `-a`, `--algorithm` `[auto|joindiff|hashdiff]` - Force algorithm choice + + + +### How to use with a configuration file + +Data-diff lets you load the configuration for a run from a TOML file. + +**Reasons to use a configuration file:** + +- Convenience: Set-up the parameters for diffs that need to run often + +- Easier and more readable: You can define the database connection settings as config values, instead of in a URI. + +- Gives you fine-grained control over the settings switches, without requiring any Python code. + +Use `--conf` to specify that path to the configuration file. data-diff will load the settings from `run.default`, if it's defined. + +Then you can, optionally, use `--run` to choose to load the settings of a specific run, and override the settings `run.default`. (all runs extend `run.default`, like inheritance). + +Finally, CLI switches have the final say, and will override the settings defined by the configuration file, and the current run. + +Example TOML file: + +```toml +# Specify the connection params to the test database. +[database.test_postgresql] +driver = "postgresql" +user = "postgres" +password = "Password1" + +# Specify the default run params +[run.default] +update_column = "timestamp" +verbose = true + +# Specify params for a run 'test_diff'. +[run.test_diff] +verbose = false +# Source 1 ("left") +1.database = "test_postgresql" # Use options from database.test_postgresql +1.table = "rating" +# Source 2 ("right") +2.database = "postgresql://postgres:Password1@/" # Use URI like in the CLI +2.table = "rating_del1" +``` + +In this example, running `data-diff --conf myconfig.toml --run test_diff` will compare between `rating` and `rating_del1`. +It will use the `timestamp` column as the update column, as specified in `run.default`. However, it won't be verbose, since that +flag is overwritten to `false`. + +Running it with `data-diff --conf myconfig.toml --run test_diff -v` will set verbose back to `true`. + + +## How to use from Python + +Import the `data_diff` module, and use the following functions: + +- `connect_to_table()` to connect to a specific table in the database + +- `diff_tables()` to diff those tables + + +Example: + +```python +# Optional: Set logging to display the progress of the diff +import logging +logging.basicConfig(level=logging.INFO) + +from data_diff import connect_to_table, diff_tables + +table1 = connect_to_table("postgresql:///", "table_name", "id") +table2 = connect_to_table("mysql:///", "table_name", "id") + +for different_row in diff_tables(table1, table2): + plus_or_minus, columns = different_row + print(plus_or_minus, columns) +``` + +Run `help(diff_tables)` or [read the docs](https://data-diff.readthedocs.io/en/latest/) to learn about the different options. + +## Usage Analytics & Data Privacy + +data-diff collects anonymous usage data to help our team improve the tool and to apply development efforts to where our users need them most. + +We capture two events: one when the data-diff run starts, and one when it is finished. No user data or potentially sensitive information is or ever will be collected. The captured data is limited to: + +- Operating System and Python version +- Types of databases used (postgresql, mysql, etc.) +- Sizes of tables diffed, run time, and diff row count (numbers only) +- Error message, if any, truncated to the first 20 characters. +- A persistent UUID to indentify the session, stored in `~/.datadiff.toml` + +If you do not wish to participate, the tracking can be easily disabled with one of the following methods: + +* In the CLI, use the `--no-tracking` flag. +* In the config file, set `no_tracking = true` (for example, under `[run.default]`) +* If you're using the Python API: +```python +import data_diff +data_diff.disable_tracking() # Call this first, before making any API calls +# Connect and diff your tables without any tracking +``` diff --git a/docs/index.rst b/docs/index.rst index b3fad229..7b78b66f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -4,7 +4,9 @@ :hidden: supported-databases + how-to-use python-api + technical-explanation new-database-driver-guide Introduction @@ -55,42 +57,6 @@ Supported connectors: - vertica -How to use from the shell -------------------------- - -.. code-block:: bash - - # Same-DB diff, using outer join - $ data-diff DB TABLE1 TABLE2 [options] - - # Cross-DB diff, using hashes - $ data-diff DB1 TABLE1 DB2 TABLE2 [options] - -We recommend using a configuration file, with the ``--conf`` switch, to keep the command simple and managable. - - -How to use from Python ----------------------- - -.. code-block:: python - - # Optional: Set logging to display the progress of the diff - import logging - logging.basicConfig(level=logging.INFO) - - from data_diff import connect_to_table, diff_tables - - table1 = connect_to_table("postgresql:///", "table_name", "id") - table2 = connect_to_table("mysql:///", "table_name", "id") - - for sign, columns in diff_tables(table1, table2): - print(sign, columns) - - # Example output: - + ('4775622148347', '2022-06-05 16:57:32.000000') - - ('4775622312187', '2022-06-05 16:57:32.000000') - - ('4777375432955', '2022-06-07 16:57:36.000000') - Resources --------- @@ -98,6 +64,8 @@ Resources - Users - Source code (git): ``_ - :doc:`supported-databases` + - :doc:`how-to-use` - :doc:`python-api` + - :doc:`technical-explanation` - Contributors - :doc:`new-database-driver-guide` diff --git a/docs/technical-explanation.md b/docs/technical-explanation.md new file mode 100644 index 00000000..572bd5eb --- /dev/null +++ b/docs/technical-explanation.md @@ -0,0 +1,191 @@ +# Technical explanation + +data-diff can diff tables within the same database, or across different databases. + +**Same-DB Diff:** +- Uses an outer-join to diff the rows as efficiently and accurately as possible. +- Supports materializing the diff results to a database table. +- Can also collect various extra statistics about the tables. + +**Cross-DB Diff:** Employs a divide and conquer algorithm based on hashing, optimized for few changes. + +The following is a technical explanation of the cross-db diff. + +### Overview + +data-diff splits the table into smaller segments, then checksums each segment in both databases. When the checksums for a segment aren't equal, it will further divide that segment into yet smaller segments, checksumming those until it gets to the differing row(s). + +This approach has performance within an order of magnitude of count(*) when there are few/no changes, but is able to output each differing row! By pushing the compute into the databases, it's much faster than querying for and comparing every row. + +![Performance for 100M rows](https://user-images.githubusercontent.com/97400/175182987-a3900d4e-c097-4732-a4e9-19a40fac8cdc.png) + +**†:** The implementation for downloading all rows that `data-diff` and +`count(*)` is compared to is not optimal. It is a single Python multi-threaded +process. The performance is fairly driver-specific, e.g. PostgreSQL's performs 10x +better than MySQL. + +### Deep Dive + +In this section we'll be doing a walk-through of exactly how **data-diff** +works, and how to tune `--bisection-factor` and `--bisection-threshold`. + +Let's consider a scenario with an `orders` table with 1M rows. Fivetran is +replicating it contionously from PostgreSQL to Snowflake: + +``` +┌─────────────┐ ┌─────────────┐ +│ PostgreSQL │ │ Snowflake │ +├─────────────┤ ├─────────────┤ +│ │ │ │ +│ │ │ │ +│ │ ┌─────────────┐ │ table with │ +│ table with ├──┤ replication ├──────▶│ ?maybe? all │ +│lots of rows!│ └─────────────┘ │ the same │ +│ │ │ rows. │ +│ │ │ │ +│ │ │ │ +│ │ │ │ +└─────────────┘ └─────────────┘ +``` + +In order to check whether the two tables are the same, **data-diff** splits +the table into `--bisection-factor=10` segments. + +We also have to choose which columns we want to checksum. In our case, we care +about the primary key, `--key-column=id` and the update column +`--update-column=updated_at`. `updated_at` is updated every time the row is, and +we have an index on it. + +**data-diff** starts by querying both databases for the `min(id)` and `max(id)` +of the table. Then it splits the table into `--bisection-factor=10` segments of +`1M/10 = 100K` keys each: + +``` +┌──────────────────────┐ ┌──────────────────────┐ +│ PostgreSQL │ │ Snowflake │ +├──────────────────────┤ ├──────────────────────┤ +│ id=1..100k │ │ id=1..100k │ +├──────────────────────┤ ├──────────────────────┤ +│ id=100k..200k │ │ id=100k..200k │ +├──────────────────────┤ ├──────────────────────┤ +│ id=200k..300k ├─────────────▶│ id=200k..300k │ +├──────────────────────┤ ├──────────────────────┤ +│ id=300k..400k │ │ id=300k..400k │ +├──────────────────────┤ ├──────────────────────┤ +│ ... │ │ ... │ +├──────────────────────┤ ├──────────────────────┤ +│ 900k..100k │ │ 900k..100k │ +└───────────────────▲──┘ └▲─────────────────────┘ + ┃ ┃ + ┃ ┃ + ┃ checksum queries ┃ + ┃ ┃ + ┌─┻──────────────────┻────┐ + │ data-diff │ + └─────────────────────────┘ +``` + +Now **data-diff** will start running `--threads=1` queries in parallel that +checksum each segment. The queries for checksumming each segment will look +something like this, depending on the database: + +```sql +SELECT count(*), + sum(cast(conv(substring(md5(concat(cast(id as char), cast(timestamp as char))), 18), 16, 10) as unsigned)) +FROM `rating_del1` +WHERE (id >= 1) AND (id < 100000) +``` + +This keeps the amount of data that has to be transferred between the databases +to a minimum, making it very performant! Additionally, if you have an index on +`updated_at` (highly recommended), then the query will be fast, as the database +only has to do a partial index scan between `id=1..100k`. + +If you are not sure whether the queries are using an index, you can run it with +`--interactive`. This puts **data-diff** in interactive mode, where it shows an +`EXPLAIN` before executing each query, requiring confirmation to proceed. + +After running the checksum queries on both sides, we see that all segments +are the same except `id=100k..200k`: + +``` +┌──────────────────────┐ ┌──────────────────────┐ +│ PostgreSQL │ │ Snowflake │ +├──────────────────────┤ ├──────────────────────┤ +│ checksum=0102 │ │ checksum=0102 │ +├──────────────────────┤ mismatch! ├──────────────────────┤ +│ checksum=ffff ◀──────────────▶ checksum=aaab │ +├──────────────────────┤ ├──────────────────────┤ +│ checksum=abab │ │ checksum=abab │ +├──────────────────────┤ ├──────────────────────┤ +│ checksum=f0f0 │ │ checksum=f0f0 │ +├──────────────────────┤ ├──────────────────────┤ +│ ... │ │ ... │ +├──────────────────────┤ ├──────────────────────┤ +│ checksum=9494 │ │ checksum=9494 │ +└──────────────────────┘ └──────────────────────┘ +``` + +Now **data-diff** will do exactly as it just did for the _whole table_ for only +this segment: Split it into `--bisection-factor` segments. + +However, this time, because each segment has `100k/10=10k` entries, which is +less than the `--bisection-threshold`, it will pull down every row in the segment +and compare them in memory in **data-diff**. + +``` +┌──────────────────────┐ ┌──────────────────────┐ +│ PostgreSQL │ │ Snowflake │ +├──────────────────────┤ ├──────────────────────┤ +│ id=100k..110k │ │ id=100k..110k │ +├──────────────────────┤ ├──────────────────────┤ +│ id=110k..120k │ │ id=110k..120k │ +├──────────────────────┤ ├──────────────────────┤ +│ id=120k..130k │ │ id=120k..130k │ +├──────────────────────┤ ├──────────────────────┤ +│ id=130k..140k │ │ id=130k..140k │ +├──────────────────────┤ ├──────────────────────┤ +│ ... │ │ ... │ +├──────────────────────┤ ├──────────────────────┤ +│ 190k..200k │ │ 190k..200k │ +└──────────────────────┘ └──────────────────────┘ +``` + +Finally **data-diff** will output the `(id, updated_at)` for each row that was different: + +``` +(122001, 1653672821) +``` + +If you pass `--stats` you'll see stats such as the % of rows were different. + +### Performance Considerations + +* Ensure that you have indexes on the columns you are comparing. Preferably a + compound index. You can run with `--interactive` to see an `EXPLAIN` for the + queries. +* Consider increasing the number of simultaneous threads executing + queries per database with `--threads`. For databases that limit concurrency + per query, such as PostgreSQL/MySQL, this can improve performance dramatically. +* If you are only interested in _whether_ something changed, pass `--limit 1`. + This can be useful if changes are very rare. This is often faster than doing a + `count(*)`, for the reason mentioned above. +* If the table is _very_ large, consider a larger `--bisection-factor`. Otherwise, you may run into timeouts. +* If there are a lot of changes, consider a larger `--bisection-threshold`. +* If there are very large gaps in your key column (e.g., 10s of millions of + continuous rows missing), then **data-diff** may perform poorly, doing lots of + queries for ranges of rows that do not exist. We have ideas on how to tackle this issue, which we have yet to implement. If you're experiencing this effect, please open an issue, and we + will prioritize it. +* The fewer columns you verify (passed with `--columns`), the faster + **data-diff** will be. On one extreme, you can verify every column; on the + other, you can verify _only_ `updated_at`, if you trust it enough. You can also + _only_ verify `id` if you're interested in only presence, such as to detect + missing hard deletes. You can do also do a hybrid where you verify + `updated_at` and the most critical value, such as a money value in `amount`, but + not verify a large serialized column like `json_settings`. +* We have ideas for making **data-diff** even faster that + we haven't implemented yet: faster checksums by reducing type-casts + and using a faster hash than MD5, dynamic adaptation of + `bisection_factor`/`threads`/`bisection_threshold` (especially with large key + gaps), and improvements to bypass Python/driver performance limitations when + comparing huge amounts of rows locally (i.e. for very high `bisection_threshold` values). From a431cefc212d56cb7b076b765e141c3d04d620e4 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 20 Oct 2022 18:05:17 -0300 Subject: [PATCH 63/93] Fix docs: pass 2 --- README.md | 202 ++------------------------------------------- docs/how-to-use.md | 5 ++ 2 files changed, 12 insertions(+), 195 deletions(-) diff --git a/README.md b/README.md index d9ea28f0..8eba7010 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,8 @@ data-diff is a **free, open-source tool** that enables data professionals to det _Are you a developer with a deep understanding of databases and solid Python knowledge? [We're hiring!](https://www.datafold.com/careers)_ +[**Documentation**](https://data-diff.readthedocs.io/en/latest/) + ## Use cases ### Diff Tables Between Databases @@ -91,7 +93,7 @@ data-diff \ -w "event_timestamp < '2022-10-10'" ``` -#### Code Example: Diff Tables Within a Database (available in pre release) +#### Code Example: Diff Tables Within a Database (available in pre-release) Here's a code example from [the video](https://www.loom.com/share/682e4b7d74e84eb4824b983311f0a3b2), where we compare data between two Snowflake tables within one database. @@ -111,208 +113,18 @@ In both code examples, I've used `<>` carrots to represent values that **should ### We're here to help! -We know, that `data-diff DB1_URI TABLE1_NAME DB2_URI TABLE2_NAME [OPTIONS]` command can become long and dense. And maybe you're new to the command line. +We know, that in some cases, the data-diff command can become long and dense. And maybe you're new to the command line. We're here to help [on slack](https://locallyoptimistic.slack.com/archives/C03HUNGQV0S) if you have ANY questions as you use `data-diff` in your workflow. ## How to Use -This section gets into more details, including: -- [database-specific syntax](#how-to-use-from-the-command-line) -- [the many options (flags) you can use beyond the examples presented above](#options) -- [how to run `data-diff` using a TOML configuration file](#how-to-use-with-a-configuration-file) -- [how to run`data-diff` from Python](#how-to-use-from-python) - -### How to use from the command line - -To run `data-diff` from the command line, run this command: - -`data-diff DB1_URI TABLE1_NAME DB2_URI TABLE2_NAME [OPTIONS]` - -Let's break this down. Assume there are two tables stored in two databases, and you want to know the differences between those tables. - -- `DB1_URI` will be a string that `data-diff` uses to connect to the database where the first table is stored. -- `TABLE1_NAME` is the name of the table in the `DB1_URI` database. -- `DB2_URI` will be a string that `data-diff` uses to connect to the database where the second table is stored. -- `TABLE2_NAME` is the name of the second table in the `DB2_URI` database. -- `[OPTIONS]` can be replaced with a variety of additional commands, [detailed here](#options). - - - -| Database | Connection string | Status | -|---------------|-------------------------------------------------------------------------------------------------------------------------------------|--------| -| PostgreSQL >=10 | `postgresql://:''@:5432/` | 💚 | -| MySQL | `mysql://:@:5432/` | 💚 | -| Snowflake | **With password:**`"snowflake://:@//?warehouse=&role="`
**With SSO:** `"snowflake://@//?warehouse=&role=&authenticator=externalbrowser"`
_Note: Unless something is explicitly case sensitive (like your password) use all caps._ | 💚 | -| BigQuery | `bigquery:///` | 💚 | -| Redshift | `redshift://:@:5439/` | 💚 | -| Oracle | `oracle://:@/database` | 💛 | -| Presto | `presto://:@:8080/` | 💛 | -| Databricks | `databricks://:@//` | 💛 | -| Trino | `trino://:@:8080/` | 💛 | -| Clickhouse | `clickhouse://:@:9000/` | 💛 | -| Vertica | `vertica://:@:5433/` | 💛 | -| ElasticSearch | | 📝 | -| Planetscale | | 📝 | -| Pinot | | 📝 | -| Druid | | 📝 | -| Kafka | | 📝 | -| DuckDB | | 📝 | -| SQLite | | 📝 | - -* 💚: Implemented and thoroughly tested. -* 💛: Implemented, but not thoroughly tested yet. -* ⏳: Implementation in progress. -* 📝: Implementation planned. Contributions welcome. - -If a database is not on the list, we'd still love to support it. Open an issue -to discuss it. - -Note: Because URLs allow many special characters, and may collide with the syntax of your command-line, -it's recommended to surround them with quotes. Alternatively, you may [provide them in a TOML file](#how-to-use-with-a-configuration-file) via the `--config` option. - -#### Options - - - `--help` - Show help message and exit. - - `-k` or `--key-columns` - Name of the primary key column. If none provided, default is 'id'. - - `-t` or `--update-column` - Name of updated_at/last_updated column - - `-c` or `--columns` - Names of extra columns to compare. Can be used more than once in the same command. - Accepts a name or a pattern like in SQL. - Example: `-c col% -c another_col -c %foorb.r%` - - `-l` or `--limit` - Maximum number of differences to find (limits maximum bandwidth and runtime) - - `-s` or `--stats` - Print stats instead of a detailed diff - - `-d` or `--debug` - Print debug info - - `-v` or `--verbose` - Print extra info - - `-i` or `--interactive` - Confirm queries, implies `--debug` - - `--json` - Print JSONL output for machine readability - - `--min-age` - Considers only rows older than specified. Useful for specifying replication lag. - Example: `--min-age=5min` ignores rows from the last 5 minutes. - Valid units: `d, days, h, hours, min, minutes, mon, months, s, seconds, w, weeks, y, years` - - `--max-age` - Considers only rows younger than specified. See `--min-age`. - - `-j` or `--threads` - Number of worker threads to use per database. Default=1. - - `-w`, `--where` - An additional 'where' expression to restrict the search space. - - `--conf`, `--run` - Specify the run and configuration from a TOML file. (see below) - - `--no-tracking` - data-diff sends home anonymous usage data. Use this to disable it. - - **The following two options are not available when using the pre release In-DB feature:** - - - `--bisection-threshold` - Minimal size of segment to be split. Smaller segments will be downloaded and compared locally. - - `--bisection-factor` - Segments per iteration. When set to 2, it performs binary search. - -**In-DB commands, available in pre release only:** - - `-m`, `--materialize` - Materialize the diff results into a new table in the database. - If a table exists by that name, it will be replaced. - Use `%t` in the name to place a timestamp. - Example: `-m test_mat_%t` - - `--assume-unique-key` - Skip validating the uniqueness of the key column during joindiff, which is costly in non-cloud dbs. - - `--sample-exclusive-rows` - Sample several rows that only appear in one of the tables, but not the other. Use with `-s`. - - `--materialize-all-rows` - Materialize every row, even if they are the same, instead of just the differing rows. - - `--table-write-limit` - Maximum number of rows to write when creating materialized or sample tables, per thread. Default=1000. - - `-a`, `--algorithm` `[auto|joindiff|hashdiff]` - Force algorithm choice - - - - - -### How to use with a configuration file - -Data-diff lets you load the configuration for a run from a TOML file. - -**Reasons to use a configuration file:** - -- Convenience: Set-up the parameters for diffs that need to run often - -- Easier and more readable: You can define the database connection settings as config values, instead of in a URI. - -- Gives you fine-grained control over the settings switches, without requiring any Python code. - -Use `--conf` to specify that path to the configuration file. data-diff will load the settings from `run.default`, if it's defined. - -Then you can, optionally, use `--run` to choose to load the settings of a specific run, and override the settings `run.default`. (all runs extend `run.default`, like inheritance). - -Finally, CLI switches have the final say, and will override the settings defined by the configuration file, and the current run. - -Example TOML file: - -```toml -# Specify the connection params to the test database. -[database.test_postgresql] -driver = "postgresql" -user = "postgres" -password = "Password1" - -# Specify the default run params -[run.default] -update_column = "timestamp" -verbose = true - -# Specify params for a run 'test_diff'. -[run.test_diff] -verbose = false -# Source 1 ("left") -1.database = "test_postgresql" # Use options from database.test_postgresql -1.table = "rating" -# Source 2 ("right") -2.database = "postgresql://postgres:Password1@/" # Use URI like in the CLI -2.table = "rating_del1" -``` - -In this example, running `data-diff --conf myconfig.toml --run test_diff` will compare between `rating` and `rating_del1`. -It will use the `timestamp` column as the update column, as specified in `run.default`. However, it won't be verbose, since that -flag is overwritten to `false`. - -Running it with `data-diff --conf myconfig.toml --run test_diff -v` will set verbose back to `true`. - -### How to use from Python - -API reference: [https://data-diff.readthedocs.io/en/latest/](https://data-diff.readthedocs.io/en/latest/) -Example: +[How to use from the shell (or: command-line)](https://data-diff.readthedocs.io/en/latest/how-to-use.html#how-to-use-from-the-shell-or-command-line) -```python -# Optional: Set logging to display the progress of the diff -import logging -logging.basicConfig(level=logging.INFO) +[How to use from Python](https://data-diff.readthedocs.io/en/latest/how-to-use.html#how-to-use-from-python) -from data_diff import connect_to_table, diff_tables +[Usage Analytics & Data Privacy](https://data-diff.readthedocs.io/en/latest/how-to-use.html#usage-analytics-data-privacy) -table1 = connect_to_table("postgresql:///", "table_name", "id") -table2 = connect_to_table("mysql:///", "table_name", "id") - -for different_row in diff_tables(table1, table2): - plus_or_minus, columns = different_row - print(plus_or_minus, columns) -``` - -Run `help(diff_tables)` or [read the docs](https://data-diff.readthedocs.io/en/latest/) to learn about the different options. - -## Reporting bugs and contributing - -- [Open an issue](https://github.com/datafold/data-diff/issues/new/choose) or chat with us [on slack](https://locallyoptimistic.slack.com/archives/C03HUNGQV0S). -- Interested in contributing to this open source project? Please see our [Contributing Guideline](https://github.com/datafold/data-diff/blob/master/CONTRIBUTING.md)! -- Did we mention [we're hiring](https://www.datafold.com/careers)? - -## Usage Analytics & Data Privacy - -data-diff collects anonymous usage data to help our team improve the tool and to apply development efforts to where our users need them most. - -We capture two events: one when the data-diff run starts, and one when it is finished. No user data or potentially sensitive information is or ever will be collected. The captured data is limited to: - -- Operating System and Python version -- Types of databases used (postgresql, mysql, etc.) -- Sizes of tables diffed, run time, and diff row count (numbers only) -- Error message, if any, truncated to the first 20 characters. -- A persistent UUID to indentify the session, stored in `~/.datadiff.toml` - -If you do not wish to participate, the tracking can be easily disabled with one of the following methods: - -* In the CLI, use the `--no-tracking` flag. -* In the config file, set `no_tracking = true` (for example, under `[run.default]`) -* If you're using the Python API: -```python -import data_diff -data_diff.disable_tracking() # Call this first, before making any API calls -# Connect and diff your tables without any tracking -``` ## Technical Explanation diff --git a/docs/how-to-use.md b/docs/how-to-use.md index 6f78a634..b5a1f5bb 100644 --- a/docs/how-to-use.md +++ b/docs/how-to-use.md @@ -16,6 +16,11 @@ Where DB is either a database URL that's compatible with SQLAlchemy, or the name We recommend using a configuration file, with the ``--conf`` switch, to keep the command simple and managable. +For a list of example URLs, see [list of supported databases](supported-databases.md). + +Note: Because URLs allow many special characters, and may collide with the syntax of your command-line, +it's recommended to surround them with quotes. + ### Options - `--help` - Show help message and exit. From 4c39bc11fcda71447491b5960a1ce7e5a270ea4c Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 20 Oct 2022 18:17:57 -0300 Subject: [PATCH 64/93] Update link text --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8eba7010..ada3141e 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ data-diff is a **free, open-source tool** that enables data professionals to det _Are you a developer with a deep understanding of databases and solid Python knowledge? [We're hiring!](https://www.datafold.com/careers)_ -[**Documentation**](https://data-diff.readthedocs.io/en/latest/) +[**Documentation on readthedocs.io**](https://data-diff.readthedocs.io/en/latest/) ## Use cases From 73f2c8ab1fa0ba3baa0ef51287c95d33ffc25317 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 21 Oct 2022 09:44:24 -0300 Subject: [PATCH 65/93] Downgrade mysql-connector-python to 8.0.29 --- poetry.lock | 121 ++++++++++++++++++++++--------------------------- pyproject.toml | 2 +- 2 files changed, 55 insertions(+), 68 deletions(-) diff --git a/poetry.lock b/poetry.lock index 51ad249e..9e473982 100644 --- a/poetry.lock +++ b/poetry.lock @@ -197,19 +197,19 @@ regex = ["regex"] [[package]] name = "mysql-connector-python" -version = "8.0.30" +version = "8.0.29" description = "MySQL driver written in Python" category = "main" optional = false python-versions = "*" [package.dependencies] -protobuf = ">=3.11.0,<=3.20.1" +protobuf = ">=3.0.0" [package.extras] -compression = ["lz4 (>=2.1.6,<=3.1.3)", "zstandard (>=0.12.0,<=0.15.2)"] -dns-srv = ["dnspython (>=1.16.0,<=2.1.0)"] -gssapi = ["gssapi (>=1.6.9,<=1.7.3)"] +compression = ["lz4 (>=2.1.6)", "zstandard (>=0.12.0)"] +dns-srv = ["dnspython (>=1.16.0)"] +gssapi = ["gssapi (>=1.6.9)"] [[package]] name = "oscrypto" @@ -287,8 +287,8 @@ wcwidth = "*" [[package]] name = "protobuf" -version = "3.20.1" -description = "Protocol Buffers" +version = "4.21.8" +description = "" category = "main" optional = false python-versions = ">=3.7" @@ -330,15 +330,15 @@ plugins = ["importlib-metadata"] [[package]] name = "PyJWT" -version = "2.5.0" +version = "2.6.0" description = "JSON Web Token implementation in Python" category = "main" optional = false python-versions = ">=3.7" [package.extras] -crypto = ["cryptography (>=3.3.1)", "types-cryptography (>=3.3.21)"] -dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.3.1)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "types-cryptography (>=3.3.21)", "zope.interface"] +crypto = ["cryptography (>=3.4.0)"] +dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] @@ -370,7 +370,7 @@ six = ">=1.5" [[package]] name = "pytz" -version = "2022.4" +version = "2022.5" description = "World timezone definitions, modern and historical" category = "main" optional = false @@ -432,7 +432,7 @@ python-versions = ">=3.6,<4.0" [[package]] name = "setuptools" -version = "65.4.1" +version = "65.5.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" category = "main" optional = false @@ -519,7 +519,7 @@ python-versions = ">=3.7" [[package]] name = "tzdata" -version = "2022.4" +version = "2022.5" description = "Provider of IANA time zone data" category = "main" optional = false @@ -612,7 +612,7 @@ vertica = [] [metadata] lock-version = "1.1" python-versions = "^3.7" -content-hash = "6f68ef35366f62a4a47721baed5b2734e0a9015d7c4a49ff9a8410284acb6e71" +content-hash = "074d933430a02a6ba56dadc359f89b875dc74e42070c1dfc1b79e8e8cd610404" [metadata.files] arrow = [ @@ -893,27 +893,24 @@ lark-parser = [ {file = "lark-parser-0.11.3.tar.gz", hash = "sha256:e29ca814a98bb0f81674617d878e5f611cb993c19ea47f22c80da3569425f9bd"}, ] mysql-connector-python = [ - {file = "mysql-connector-python-8.0.30.tar.gz", hash = "sha256:59a8592e154c874c299763bb8aa12c518384c364bcfd0d193e85c869ea81a895"}, - {file = "mysql_connector_python-8.0.30-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f1eb74eb30bb04ff314f5e19af5421d23b504e41d16ddcee2603b4100d18fd68"}, - {file = "mysql_connector_python-8.0.30-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:712cdfa97f35fec715e8d7aaa15ed9ce04f3cf71b3c177fcca273047040de9f2"}, - {file = "mysql_connector_python-8.0.30-cp310-cp310-manylinux1_i686.whl", hash = "sha256:ce23ca9c27e1f7b4707b3299ce515125f312736d86a7e5b2aa778484fa3ffa10"}, - {file = "mysql_connector_python-8.0.30-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:8876b1d51cae33cdfe7021d68206661e94dcd2666e5e14a743f8321e2b068e84"}, - {file = "mysql_connector_python-8.0.30-cp310-cp310-win_amd64.whl", hash = "sha256:41a04d1900e366bf6c2a645ead89ab9a567806d5ada7d417a3a31f170321dd14"}, - {file = "mysql_connector_python-8.0.30-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:7f771bd5cba3ade6d9f7a649e65d7c030f69f0e69980632b5cbbd3d19c39cee5"}, - {file = "mysql_connector_python-8.0.30-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:611c6945805216104575f7143ff6497c87396ce82d3257e6da7257b65406f13e"}, - {file = "mysql_connector_python-8.0.30-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:47deb8c3324db7eb2bfb720ec8084d547b1bce457672ea261bc21836024249db"}, - {file = "mysql_connector_python-8.0.30-cp37-cp37m-win_amd64.whl", hash = "sha256:234c6b156a1989bebca6eb564dc8f2e9d352f90a51bd228ccd68eb66fcd5fd7a"}, - {file = "mysql_connector_python-8.0.30-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8b7d50c221320b0e609dce9ca8801ab2f2a748dfee65cd76b1e4c6940757734a"}, - {file = "mysql_connector_python-8.0.30-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:d8f74c9388176635f75c01d47d0abc783a47e58d7f36d04fb6ee40ab6fb35c9b"}, - {file = "mysql_connector_python-8.0.30-cp38-cp38-manylinux1_i686.whl", hash = "sha256:1d9d3af14594aceda2c3096564b4c87ffac21e375806a802daeaf7adcd18d36b"}, - {file = "mysql_connector_python-8.0.30-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:f5d812245754d4759ebc8c075662fef65397e1e2a438a3c391eac9d545077b8b"}, - {file = "mysql_connector_python-8.0.30-cp38-cp38-win_amd64.whl", hash = "sha256:a130c5489861c7ff2990e5b503c37beb2fb7b32211b92f9107ad864ee90654c0"}, - {file = "mysql_connector_python-8.0.30-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:954a1fc2e9a811662c5b17cea24819c020ff9d56b2ff8e583dd0a233fb2399f6"}, - {file = "mysql_connector_python-8.0.30-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:62266d1b18cb4e286a05df0e1c99163a4955c82d41045305bcf0ab2aac107843"}, - {file = "mysql_connector_python-8.0.30-cp39-cp39-manylinux1_i686.whl", hash = "sha256:36e763f21e62b3c9623a264f2513ee11924ea1c9cc8640c115a279d3087064be"}, - {file = "mysql_connector_python-8.0.30-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:b5dc0f3295e404f93b674bfaff7589a9fbb8b5ae6c1c134112a1d1beb2f664b2"}, - {file = "mysql_connector_python-8.0.30-cp39-cp39-win_amd64.whl", hash = "sha256:33c4e567547a9a1868462fda8f2b19ea186a7b1afe498171dca39c0f3aa43a75"}, - {file = "mysql_connector_python-8.0.30-py2.py3-none-any.whl", hash = "sha256:f1d40cac9c786e292433716c1ade7a8968cbc3ea177026697b86a63188ddba34"}, + {file = "mysql-connector-python-8.0.29.tar.gz", hash = "sha256:29ec05ded856b4da4e47239f38489c03b31673ae0f46a090d0e4e29c670e6181"}, + {file = "mysql_connector_python-8.0.29-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:bed43ea3a11f8d4e7c2e3f20c891214e68b45451314f91fddf9ca701de7a53ac"}, + {file = "mysql_connector_python-8.0.29-cp310-cp310-manylinux1_i686.whl", hash = "sha256:6e2267ad75b37b5e1c480cde77cdc4f795427a54266ead30aabcdbf75ac70064"}, + {file = "mysql_connector_python-8.0.29-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:d5afb766b379111942d4260f29499f93355823c7241926471d843c9281fe477c"}, + {file = "mysql_connector_python-8.0.29-cp310-cp310-win_amd64.whl", hash = "sha256:4de5959e27038cbd11dfccb1afaa2fd258c013e59d3e15709dd1992086103050"}, + {file = "mysql_connector_python-8.0.29-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:895135cde57622edf48e1fce3beb4ed85f18332430d48f5c1d9630d49f7712b0"}, + {file = "mysql_connector_python-8.0.29-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:fdd262d8538aa504475f8860cfda939a297d3b213c8d15f7ceed52508aeb2aa3"}, + {file = "mysql_connector_python-8.0.29-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:89597c091c4f25b6e023cbbcd32be73affbb0b44256761fe3b8e1d4b14d14d02"}, + {file = "mysql_connector_python-8.0.29-cp37-cp37m-win_amd64.whl", hash = "sha256:ab0e9d9b5fc114b78dfa9c74e8bfa30b48fcfa17dbb9241ad6faada08a589900"}, + {file = "mysql_connector_python-8.0.29-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:245087999f081b389d66621f2abfe2463e3927f63c7c4c0f70ce0f82786ccb93"}, + {file = "mysql_connector_python-8.0.29-cp38-cp38-manylinux1_i686.whl", hash = "sha256:5eef51e48b22aadd633563bbdaf02112d98d954a4ead53f72fde283ea3f88152"}, + {file = "mysql_connector_python-8.0.29-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:b7dccd7f72f19c97b58428ebf8e709e24eb7e9b67a408af7e77b60efde44bea4"}, + {file = "mysql_connector_python-8.0.29-cp38-cp38-win_amd64.whl", hash = "sha256:7be3aeff73b85eab3af2a1e80c053a98cbcb99e142192e551ebd4c1e41ce2596"}, + {file = "mysql_connector_python-8.0.29-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:a7fd6a71df824f5a7d9a94060598d67b3a32eeccdc9837ee2cd98a44e2536cae"}, + {file = "mysql_connector_python-8.0.29-cp39-cp39-manylinux1_i686.whl", hash = "sha256:fd608c288f596c4c8767d9a8e90f129385bd19ee6e3adaf6974ad8012c6138b8"}, + {file = "mysql_connector_python-8.0.29-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:f353893481476a537cca7afd4e81e0ed84dd2173932b7f1721ab3e3351cbf324"}, + {file = "mysql_connector_python-8.0.29-cp39-cp39-win_amd64.whl", hash = "sha256:1bef2a4a2b529c6e9c46414100ab7032c252244e8a9e017d2b6a41bb9cea9312"}, + {file = "mysql_connector_python-8.0.29-py2.py3-none-any.whl", hash = "sha256:047420715bbb51d3cba78de446c8a6db4666459cd23e168568009c620a3f5b90"}, ] oscrypto = [ {file = "oscrypto-1.3.0-py2.py3-none-any.whl", hash = "sha256:2b2f1d2d42ec152ca90ccb5682f3e051fb55986e1b170ebde472b133713e7085"}, @@ -936,30 +933,20 @@ prompt-toolkit = [ {file = "prompt_toolkit-3.0.31.tar.gz", hash = "sha256:9ada952c9d1787f52ff6d5f3484d0b4df8952787c087edf6a1f7c2cb1ea88148"}, ] protobuf = [ - {file = "protobuf-3.20.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3cc797c9d15d7689ed507b165cd05913acb992d78b379f6014e013f9ecb20996"}, - {file = "protobuf-3.20.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ff8d8fa42675249bb456f5db06c00de6c2f4c27a065955917b28c4f15978b9c3"}, - {file = "protobuf-3.20.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:cd68be2559e2a3b84f517fb029ee611546f7812b1fdd0aa2ecc9bc6ec0e4fdde"}, - {file = "protobuf-3.20.1-cp310-cp310-win32.whl", hash = "sha256:9016d01c91e8e625141d24ec1b20fed584703e527d28512aa8c8707f105a683c"}, - {file = "protobuf-3.20.1-cp310-cp310-win_amd64.whl", hash = "sha256:32ca378605b41fd180dfe4e14d3226386d8d1b002ab31c969c366549e66a2bb7"}, - {file = "protobuf-3.20.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:9be73ad47579abc26c12024239d3540e6b765182a91dbc88e23658ab71767153"}, - {file = "protobuf-3.20.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:097c5d8a9808302fb0da7e20edf0b8d4703274d140fd25c5edabddcde43e081f"}, - {file = "protobuf-3.20.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e250a42f15bf9d5b09fe1b293bdba2801cd520a9f5ea2d7fb7536d4441811d20"}, - {file = "protobuf-3.20.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:cdee09140e1cd184ba9324ec1df410e7147242b94b5f8b0c64fc89e38a8ba531"}, - {file = "protobuf-3.20.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:af0ebadc74e281a517141daad9d0f2c5d93ab78e9d455113719a45a49da9db4e"}, - {file = "protobuf-3.20.1-cp37-cp37m-win32.whl", hash = "sha256:755f3aee41354ae395e104d62119cb223339a8f3276a0cd009ffabfcdd46bb0c"}, - {file = "protobuf-3.20.1-cp37-cp37m-win_amd64.whl", hash = "sha256:62f1b5c4cd6c5402b4e2d63804ba49a327e0c386c99b1675c8a0fefda23b2067"}, - {file = "protobuf-3.20.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:06059eb6953ff01e56a25cd02cca1a9649a75a7e65397b5b9b4e929ed71d10cf"}, - {file = "protobuf-3.20.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:cb29edb9eab15742d791e1025dd7b6a8f6fcb53802ad2f6e3adcb102051063ab"}, - {file = "protobuf-3.20.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:69ccfdf3657ba59569c64295b7d51325f91af586f8d5793b734260dfe2e94e2c"}, - {file = "protobuf-3.20.1-cp38-cp38-win32.whl", hash = "sha256:dd5789b2948ca702c17027c84c2accb552fc30f4622a98ab5c51fcfe8c50d3e7"}, - {file = "protobuf-3.20.1-cp38-cp38-win_amd64.whl", hash = "sha256:77053d28427a29987ca9caf7b72ccafee011257561259faba8dd308fda9a8739"}, - {file = "protobuf-3.20.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f50601512a3d23625d8a85b1638d914a0970f17920ff39cec63aaef80a93fb7"}, - {file = "protobuf-3.20.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:284f86a6207c897542d7e956eb243a36bb8f9564c1742b253462386e96c6b78f"}, - {file = "protobuf-3.20.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7403941f6d0992d40161aa8bb23e12575637008a5a02283a930addc0508982f9"}, - {file = "protobuf-3.20.1-cp39-cp39-win32.whl", hash = "sha256:db977c4ca738dd9ce508557d4fce0f5aebd105e158c725beec86feb1f6bc20d8"}, - {file = "protobuf-3.20.1-cp39-cp39-win_amd64.whl", hash = "sha256:7e371f10abe57cee5021797126c93479f59fccc9693dafd6bd5633ab67808a91"}, - {file = "protobuf-3.20.1-py2.py3-none-any.whl", hash = "sha256:adfc6cf69c7f8c50fd24c793964eef18f0ac321315439d94945820612849c388"}, - {file = "protobuf-3.20.1.tar.gz", hash = "sha256:adc31566d027f45efe3f44eeb5b1f329da43891634d61c75a5944e9be6dd42c9"}, + {file = "protobuf-4.21.8-cp310-abi3-win32.whl", hash = "sha256:c252c55ee15175aa1b21b7b9896e6add5162d066d5202e75c39f96136f08cce3"}, + {file = "protobuf-4.21.8-cp310-abi3-win_amd64.whl", hash = "sha256:809ca0b225d3df42655a12f311dd0f4148a943c51f1ad63c38343e457492b689"}, + {file = "protobuf-4.21.8-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:bbececaf3cfea9ea65ebb7974e6242d310d2a7772a6f015477e0d79993af4511"}, + {file = "protobuf-4.21.8-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:b02eabb9ebb1a089ed20626a90ad7a69cee6bcd62c227692466054b19c38dd1f"}, + {file = "protobuf-4.21.8-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:4761201b93e024bb70ee3a6a6425d61f3152ca851f403ba946fb0cde88872661"}, + {file = "protobuf-4.21.8-cp37-cp37m-win32.whl", hash = "sha256:f2d55ff22ec300c4d954d3b0d1eeb185681ec8ad4fbecff8a5aee6a1cdd345ba"}, + {file = "protobuf-4.21.8-cp37-cp37m-win_amd64.whl", hash = "sha256:c5f94911dd8feb3cd3786fc90f7565c9aba7ce45d0f254afd625b9628f578c3f"}, + {file = "protobuf-4.21.8-cp38-cp38-win32.whl", hash = "sha256:b37b76efe84d539f16cba55ee0036a11ad91300333abd213849cbbbb284b878e"}, + {file = "protobuf-4.21.8-cp38-cp38-win_amd64.whl", hash = "sha256:2c92a7bfcf4ae76a8ac72e545e99a7407e96ffe52934d690eb29a8809ee44d7b"}, + {file = "protobuf-4.21.8-cp39-cp39-win32.whl", hash = "sha256:89d641be4b5061823fa0e463c50a2607a97833e9f8cfb36c2f91ef5ccfcc3861"}, + {file = "protobuf-4.21.8-cp39-cp39-win_amd64.whl", hash = "sha256:bc471cf70a0f53892fdd62f8cd4215f0af8b3f132eeee002c34302dff9edd9b6"}, + {file = "protobuf-4.21.8-py2.py3-none-any.whl", hash = "sha256:a55545ce9eec4030cf100fcb93e861c622d927ef94070c1a3c01922902464278"}, + {file = "protobuf-4.21.8-py3-none-any.whl", hash = "sha256:0f236ce5016becd989bf39bd20761593e6d8298eccd2d878eda33012645dc369"}, + {file = "protobuf-4.21.8.tar.gz", hash = "sha256:427426593b55ff106c84e4a88cac855175330cb6eb7e889e85aaa7b5652b686d"}, ] psycopg2 = [ {file = "psycopg2-2.9.4-cp310-cp310-win32.whl", hash = "sha256:8de6a9fc5f42fa52f559e65120dcd7502394692490c98fed1221acf0819d7797"}, @@ -1015,8 +1002,8 @@ Pygments = [ {file = "Pygments-2.13.0.tar.gz", hash = "sha256:56a8508ae95f98e2b9bdf93a6be5ae3f7d8af858b43e02c5a2ff083726be40c1"}, ] PyJWT = [ - {file = "PyJWT-2.5.0-py3-none-any.whl", hash = "sha256:8d82e7087868e94dd8d7d418e5088ce64f7daab4b36db654cbaedb46f9d1ca80"}, - {file = "PyJWT-2.5.0.tar.gz", hash = "sha256:e77ab89480905d86998442ac5788f35333fa85f65047a534adc38edf3c88fc3b"}, + {file = "PyJWT-2.6.0-py3-none-any.whl", hash = "sha256:d83c3d892a77bbb74d3e1a2cfa90afaadb60945205d1095d9221f04466f64c14"}, + {file = "PyJWT-2.6.0.tar.gz", hash = "sha256:69285c7e31fc44f68a1feb309e948e0df53259d579295e6cfe2b1792329f05fd"}, ] pyOpenSSL = [ {file = "pyOpenSSL-22.0.0-py2.py3-none-any.whl", hash = "sha256:ea252b38c87425b64116f808355e8da644ef9b07e429398bfece610f893ee2e0"}, @@ -1027,8 +1014,8 @@ python-dateutil = [ {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, ] pytz = [ - {file = "pytz-2022.4-py2.py3-none-any.whl", hash = "sha256:2c0784747071402c6e99f0bafdb7da0fa22645f06554c7ae06bf6358897e9c91"}, - {file = "pytz-2022.4.tar.gz", hash = "sha256:48ce799d83b6f8aab2020e369b627446696619e79645419610b9facd909b3174"}, + {file = "pytz-2022.5-py2.py3-none-any.whl", hash = "sha256:335ab46900b1465e714b4fda4963d87363264eb662aab5e65da039c25f1f5b22"}, + {file = "pytz-2022.5.tar.gz", hash = "sha256:c4d88f472f54d615e9cd582a5004d1e5f624854a6a27a6211591c251f22a6914"}, ] pytz-deprecation-shim = [ {file = "pytz_deprecation_shim-0.1.0.post0-py2.py3-none-any.whl", hash = "sha256:8314c9692a636c8eb3bda879b9f119e350e93223ae83e70e80c31675a0fdc1a6"}, @@ -1047,8 +1034,8 @@ runtype = [ {file = "runtype-0.2.7.tar.gz", hash = "sha256:5a9e1212846b3e54d4ba29fd7db602af5544a2a4253d1f8d829087214a8766ad"}, ] setuptools = [ - {file = "setuptools-65.4.1-py3-none-any.whl", hash = "sha256:1b6bdc6161661409c5f21508763dc63ab20a9ac2f8ba20029aaaa7fdb9118012"}, - {file = "setuptools-65.4.1.tar.gz", hash = "sha256:3050e338e5871e70c72983072fe34f6032ae1cdeeeb67338199c2f74e083a80e"}, + {file = "setuptools-65.5.0-py3-none-any.whl", hash = "sha256:f62ea9da9ed6289bfe868cd6845968a2c854d1427f8548d52cae02a42b4f0356"}, + {file = "setuptools-65.5.0.tar.gz", hash = "sha256:512e5536220e38146176efb833d4a62aa726b7bbff82cfbc8ba9eaa3996e0b17"}, ] six = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, @@ -1089,8 +1076,8 @@ typing-extensions = [ {file = "typing_extensions-4.4.0.tar.gz", hash = "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa"}, ] tzdata = [ - {file = "tzdata-2022.4-py2.py3-none-any.whl", hash = "sha256:74da81ecf2b3887c94e53fc1d466d4362aaf8b26fc87cda18f22004544694583"}, - {file = "tzdata-2022.4.tar.gz", hash = "sha256:ada9133fbd561e6ec3d1674d3fba50251636e918aa97bd59d63735bef5a513bb"}, + {file = "tzdata-2022.5-py2.py3-none-any.whl", hash = "sha256:323161b22b7802fdc78f20ca5f6073639c64f1a7227c40cd3e19fd1d0ce6650a"}, + {file = "tzdata-2022.5.tar.gz", hash = "sha256:e15b2b3005e2546108af42a0eb4ccab4d9e225e2dfbf4f77aad50c70a4b1f3ab"}, ] tzlocal = [ {file = "tzlocal-4.2-py3-none-any.whl", hash = "sha256:89885494684c929d9191c57aa27502afc87a579be5cdd3225c77c463ea043745"}, diff --git a/pyproject.toml b/pyproject.toml index 46f579cd..b0ae7ff6 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dsnparse = "*" click = "^8.1" rich = "*" toml = "^0.10.2" -mysql-connector-python = {version="*", optional=true} +mysql-connector-python = {version="8.0.29", optional=true} psycopg2 = {version="*", optional=true} snowflake-connector-python = {version="^2.7.2", optional=true} cryptography = {version="*", optional=true} From 99cdd0c4d992ef8a3ed47eab700eb3667b539027 Mon Sep 17 00:00:00 2001 From: Will Sweet Date: Fri, 21 Oct 2022 10:08:07 -0500 Subject: [PATCH 66/93] Update documentation link Switch from the current readthedocs documentation to the Datafold hosted documentation, which is also [open-source](https://github.com/datafold/datafold-docs/tree/main/docs/os_diff). --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ada3141e..11f4a66c 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ data-diff is a **free, open-source tool** that enables data professionals to det _Are you a developer with a deep understanding of databases and solid Python knowledge? [We're hiring!](https://www.datafold.com/careers)_ -[**Documentation on readthedocs.io**](https://data-diff.readthedocs.io/en/latest/) +[**Check out our documentation!**](https://docs.datafold.com/os_diff/about) ## Use cases From 307111384a90d0d25c637c6ad62370c981518ed7 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 25 Oct 2022 10:21:08 -0300 Subject: [PATCH 67/93] Add example to docstring (Issue #261) --- data_diff/databases/connect.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/data_diff/databases/connect.py b/data_diff/databases/connect.py index 94cb52d2..8468a734 100644 --- a/data_diff/databases/connect.py +++ b/data_diff/databases/connect.py @@ -184,6 +184,8 @@ def connect(db_conf: Union[str, dict], thread_count: Optional[int] = 1) -> Datab Configuration can be given either as a URI string, or as a dict of {option: value}. + The dictionary configuration uses the same keys as the TOML 'database' definition given with --conf. + thread_count determines the max number of worker threads per database, if relevant. None means no limit. @@ -205,6 +207,12 @@ def connect(db_conf: Union[str, dict], thread_count: Optional[int] = 1) -> Datab - trino - clickhouse - vertica + + Example: + >>> connect("mysql://localhost/db") + + >>> connect({"driver": "mysql", "host": "localhost", "database": "db"}) + """ if isinstance(db_conf, str): return connect_to_uri(db_conf, thread_count) From f8d24ea09b819774000858d2d35a8baf6fa16405 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 27 Oct 2022 16:36:57 -0300 Subject: [PATCH 68/93] Tests: one more data-type for oracle --- tests/test_database_types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_database_types.py b/tests/test_database_types.py index bb792826..1c36ff7f 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -178,6 +178,7 @@ def init_conns(): "numeric", "real", "double precision", + "Number(5, 2)", ], "uuid": [ "CHAR(100)", From 5e879faba934be56dcbaae99d3538bcab34b2390 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 28 Oct 2022 10:40:21 -0300 Subject: [PATCH 69/93] Queries: Added Param mechanism, to help speed up query construction. --- data_diff/queries/ast_classes.py | 17 ++++++++++++++++- data_diff/queries/compiler.py | 10 +++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 88d7ab11..f363df14 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -6,7 +6,7 @@ from data_diff.utils import ArithString, join_iter -from .compiler import Compilable, Compiler +from .compiler import Compilable, Compiler, cv_params from .base import SKIP, CompileError, DbPath, Schema, args_as_tuple @@ -691,3 +691,18 @@ def compile(self, c: Compiler) -> str: class Commit(Statement): def compile(self, c: Compiler) -> str: return "COMMIT" if not c.database.is_autocommit else SKIP + +@dataclass +class Param(ExprNode, ITable): + """A value placeholder, to be specified at compilation time using the `cv_params` context variable.""" + + name: str + + @property + def source_table(self): + return self + + def compile(self, c: Compiler) -> str: + params = cv_params.get() + return c._compile(params[self.name]) + diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index 31242131..e9a66bed 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -8,10 +8,15 @@ from data_diff.utils import ArithString from data_diff.databases.database_types import AbstractDialect, DbPath +import contextvars + +cv_params = contextvars.ContextVar("params") + @dataclass class Compiler: database: AbstractDialect + params: dict = {} in_select: bool = False # Compilation runtime flag in_join: bool = False # Compilation runtime flag @@ -21,7 +26,10 @@ class Compiler: _counter: List = [0] - def compile(self, elem) -> str: + def compile(self, elem, params=None) -> str: + if params: + cv_params.set(params) + res = self._compile(elem) if self.root and self._subqueries: subq = ", ".join(f"\n {k} AS ({v})" for k, v in self._subqueries.items()) From f152baf306038167aad45f6000ba364f744dba5b Mon Sep 17 00:00:00 2001 From: Jardayn Date: Mon, 31 Oct 2022 19:14:20 +0200 Subject: [PATCH 70/93] Added link on how to get a slack invite --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 11f4a66c..e305f45c 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,8 @@ We know, that in some cases, the data-diff command can become long and dense. An We're here to help [on slack](https://locallyoptimistic.slack.com/archives/C03HUNGQV0S) if you have ANY questions as you use `data-diff` in your workflow. +To get a Slack invite - [click here](https://locallyoptimistic.com/community/) + ## How to Use [How to use from the shell (or: command-line)](https://data-diff.readthedocs.io/en/latest/how-to-use.html#how-to-use-from-the-shell-or-command-line) From 078920895ec4eedba389948e5ff36fc17c0b6d0e Mon Sep 17 00:00:00 2001 From: Leo Folsom Date: Mon, 31 Oct 2022 11:42:26 -0700 Subject: [PATCH 71/93] link to docs and incorporate roman/gerard feedback --- README.md | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 11f4a66c..f97f433d 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,9 @@ data-diff is a **free, open-source tool** that enables data professionals to det _Are you a developer with a deep understanding of databases and solid Python knowledge? [We're hiring!](https://www.datafold.com/careers)_ -[**Check out our documentation!**](https://docs.datafold.com/os_diff/about) +## Documentation + +[**Our detailed documentation**](https://docs.datafold.com/os_diff/about) has everything you need to start diffing. ## Use cases @@ -20,7 +22,7 @@ _Are you a developer with a deep understanding of databases and solid Python kno diff2

-### Diff Tables Within a Database (available in pre release) +### Diff Tables Within a Database (available in pre-release) #### Improve code reviews by identifying data problems you don't have tests for

@@ -77,7 +79,7 @@ Once you've installed `data-diff`, you can run it from the command line. data-diff DB1_URI TABLE1_NAME DB2_URI TABLE2_NAME [OPTIONS] ``` -Be sure to read [the How to Use section below](#how-to-use) which gets into specific details about how to build one of these commands depending on your database setup. +Be sure to read [the docs](https://docs.datafold.com/os_diff/how_to_use) for detailed instructions how to build one of these commands depending on your database setup. #### Code Example: Diff Tables Between Databases Here's an example command for your copy/pasting, taken from the screenshot above when we diffed data between Snowflake and Postgres. @@ -113,22 +115,26 @@ In both code examples, I've used `<>` carrots to represent values that **should ### We're here to help! -We know, that in some cases, the data-diff command can become long and dense. And maybe you're new to the command line. +We know that in some cases, the data-diff command can become long and dense. And maybe you're new to the command line. -We're here to help [on slack](https://locallyoptimistic.slack.com/archives/C03HUNGQV0S) if you have ANY questions as you use `data-diff` in your workflow. +* We're here to help [on slack](https://locallyoptimistic.slack.com/archives/C03HUNGQV0S) if you have ANY questions as you use `data-diff` in your workflow. +* You can also post a question in [GitHub Discussions](https://github.com/datafold/data-diff/discussions). -## How to Use -[How to use from the shell (or: command-line)](https://data-diff.readthedocs.io/en/latest/how-to-use.html#how-to-use-from-the-shell-or-command-line) +## How to Use -[How to use from Python](https://data-diff.readthedocs.io/en/latest/how-to-use.html#how-to-use-from-python) +* [How to use from the shell (or: command-line)](https://docs.datafold.com/os_diff/how_to_use#how-to-use-from-the-command-line) +* [How to use from Python](https://docs.datafold.com/os_diff/how_to_use#how-to-use-from-python) +* [Usage Analytics & Data Privacy](https://docs.datafold.com/os_diff/usage_analytics_data_privacy) -[Usage Analytics & Data Privacy](https://data-diff.readthedocs.io/en/latest/how-to-use.html#usage-analytics-data-privacy) +## How to Contribute +* Feel free to open an issue or contribute to the project by working on an existing issue. +* Please read the [contributing guidelines](https://github.com/leoebfolsom/data-diff/blob/master/CONTRIBUTING.md) to get started. ## Technical Explanation -See here: https://data-diff.readthedocs.io/en/latest/technical-explanation.html +Check out this [technical explanation](https://docs.datafold.com/os_diff/technical_explanation) of how data-diff works. ## License From 5716aee0998c8614caae13e6589f9eddc0da7224 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 31 Oct 2022 11:21:06 -0300 Subject: [PATCH 72/93] Tiny Cleanup --- data_diff/__main__.py | 4 ++-- data_diff/databases/base.py | 8 ++++---- data_diff/databases/database_types.py | 2 +- data_diff/databases/oracle.py | 4 ++-- data_diff/databases/redshift.py | 4 ++-- data_diff/databases/vertica.py | 4 ++-- data_diff/joindiff_tables.py | 3 +-- data_diff/utils.py | 2 +- 8 files changed, 15 insertions(+), 16 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 50847b7c..0ad6de11 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -65,8 +65,8 @@ def __init__(self, **kwargs): self.indent_increment = 6 def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -> None: - self.write(f"data-diff - efficiently diff rows across database tables.\n\n") - self.write(f"Usage:\n") + self.write("data-diff - efficiently diff rows across database tables.\n\n") + self.write("Usage:\n") self.write(f" * In-db diff: {prog} [OPTIONS]\n") self.write(f" * Cross-db diff: {prog} [OPTIONS]\n") self.write(f" * Using config: {prog} --conf PATH [--run NAME] [OPTIONS]\n") diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 189dced4..b68739d8 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -143,7 +143,7 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list): (row,) = row logger.debug("EXPLAIN: %s", row) answer = input("Continue? [y/n] ") - if not answer.lower() in ["y", "yes"]: + if answer.lower() not in ["y", "yes"]: sys.exit(1) res = self._query(sql_code) @@ -310,9 +310,9 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None return f"LIMIT {limit}" - def concat(self, l: List[str]) -> str: - assert len(l) > 1 - joined_exprs = ", ".join(l) + def concat(self, items: List[str]) -> str: + assert len(items) > 1 + joined_exprs = ", ".join(items) return f"concat({joined_exprs})" def is_distinct_from(self, a: str, b: str) -> str: diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index 86df4489..8adc9fbb 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -151,7 +151,7 @@ def quote(self, s: str): ... @abstractmethod - def concat(self, l: List[str]) -> str: + def concat(self, items: List[str]) -> str: "Provide SQL for concatenating a bunch of columns into a string" ... diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index bd849e59..6b4ebe2c 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -127,8 +127,8 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None return f"FETCH NEXT {limit} ROWS ONLY" - def concat(self, l: List[str]) -> str: - joined_exprs = " || ".join(l) + def concat(self, items: List[str]) -> str: + joined_exprs = " || ".join(items) return f"({joined_exprs})" def timestamp_value(self, t: DbTime) -> str: diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index 291d180b..afaa28a4 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -36,8 +36,8 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: def normalize_number(self, value: str, coltype: FractionalType) -> str: return self.to_string(f"{value}::decimal(38,{coltype.precision})") - def concat(self, l: List[str]) -> str: - joined_exprs = " || ".join(l) + def concat(self, items: List[str]) -> str: + joined_exprs = " || ".join(items) return f"({joined_exprs})" def select_table_schema(self, path: DbPath) -> str: diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index cc606511..6b486555 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -99,8 +99,8 @@ def select_table_schema(self, path: DbPath) -> str: def quote(self, s: str): return f'"{s}"' - def concat(self, l: List[str]) -> str: - return " || ".join(l) + def concat(self, items: List[str]) -> str: + return " || ".join(items) def md5_to_int(self, s: str) -> str: return f"CAST(HEX_TO_INTEGER(SUBSTRING(MD5({s}), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS})) AS NUMERIC(38, 0))" diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index b630d66e..d2dbca61 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -2,7 +2,6 @@ """ -from contextlib import suppress from decimal import Decimal from functools import partial import logging @@ -21,7 +20,7 @@ from .diff_tables import TableDiffer, DiffResult from .thread_utils import ThreadedYielder -from .queries import table, sum_, min_, max_, avg, commit +from .queries import table, sum_, min_, max_, avg from .queries.api import and_, if_, or_, outerjoin, leftjoin, rightjoin, this, ITable from .queries.ast_classes import Concat, Count, Expr, Random, TablePath from .queries.compiler import Compiler diff --git a/data_diff/utils.py b/data_diff/utils.py index a2b7e801..de011d02 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -3,7 +3,7 @@ import math from typing import Iterable, Iterator, MutableMapping, Union, Any, Sequence, Dict from typing import TypeVar -from abc import ABC, abstractmethod +from abc import abstractmethod from urllib.parse import urlparse from uuid import UUID import operator From a8fddb77574f3effd6fd4eec938848af94cca5df Mon Sep 17 00:00:00 2001 From: Jardayn Date: Tue, 1 Nov 2022 00:56:35 +0200 Subject: [PATCH 73/93] Modified Docker-compose to give containers unique names --- docker-compose.yml | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index b23c3b1e..60bab061 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,7 +2,7 @@ version: "3.8" services: postgres: - container_name: postgresql + container_name: dd-postgresql image: postgres:14.1-alpine # work_mem: less tmp files # maintenance_work_mem: improve table-level op perf @@ -25,7 +25,7 @@ services: - local mysql: - container_name: mysql + container_name: dd-mysql image: mysql:oracle # fsync less aggressively for insertion perf for test setup command: > @@ -52,7 +52,7 @@ services: - local clickhouse: - container_name: clickhouse + container_name: dd-clickhouse image: clickhouse/clickhouse-server:21.12.3.32 restart: always volumes: @@ -76,6 +76,7 @@ services: # prestodb.dbapi.connect(host="127.0.0.1", user="presto").cursor().execute('SELECT * FROM system.runtime.nodes') presto: + container_name: dd-presto build: context: ./dev dockerfile: ./Dockerfile.prestosql.340 @@ -88,6 +89,7 @@ services: - local trino: + container_name: dd-trino image: 'trinodb/trino:389' hostname: trino ports: @@ -98,7 +100,7 @@ services: - local vertica: - container_name: vertica + container_name: dd-vertica image: vertica/vertica-ce:12.0.0-0 restart: always volumes: From d56e7c1496a8bad614552f3ce28ae1edcd7157d9 Mon Sep 17 00:00:00 2001 From: Jardayn Date: Tue, 1 Nov 2022 00:59:30 +0200 Subject: [PATCH 74/93] Added arrow to poetry --- poetry.lock | 118 +++++++++++++++++++++++++++---------------------- pyproject.toml | 1 + 2 files changed, 67 insertions(+), 52 deletions(-) diff --git a/poetry.lock b/poetry.lock index 9e473982..e8adc739 100644 --- a/poetry.lock +++ b/poetry.lock @@ -19,7 +19,7 @@ optional = false python-versions = "*" [[package]] -name = "backports.zoneinfo" +name = "backports-zoneinfo" version = "0.2.1" description = "Backport of the standard library zoneinfo module" category = "main" @@ -57,7 +57,7 @@ optional = false python-versions = ">=3.6.0" [package.extras] -unicode_backport = ["unicodedata2"] +unicode-backport = ["unicodedata2"] [[package]] name = "click" @@ -90,11 +90,11 @@ zstd = ["clickhouse-cityhash (>=1.0.2.1)", "zstd"] [[package]] name = "colorama" -version = "0.4.5" +version = "0.4.6" description = "Cross-platform colored terminal text." category = "main" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" [[package]] name = "commonmark" @@ -191,7 +191,7 @@ optional = false python-versions = "*" [package.extras] -atomic_cache = ["atomicwrites"] +atomic-cache = ["atomicwrites"] nearley = ["js2py"] regex = ["regex"] @@ -270,7 +270,7 @@ six = "*" [package.extras] all = ["google-auth", "requests-kerberos"] -google_auth = ["google-auth"] +google-auth = ["google-auth"] kerberos = ["requests-kerberos"] tests = ["google-auth", "httpretty", "pytest", "pytest-runner", "requests-kerberos"] @@ -287,7 +287,7 @@ wcwidth = "*" [[package]] name = "protobuf" -version = "4.21.8" +version = "4.21.9" description = "" category = "main" optional = false @@ -295,7 +295,7 @@ python-versions = ">=3.7" [[package]] name = "psycopg2" -version = "2.9.4" +version = "2.9.5" description = "psycopg2 - Python-PostgreSQL Database Adapter" category = "main" optional = false @@ -318,7 +318,7 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [[package]] -name = "Pygments" +name = "pygments" version = "2.13.0" description = "Pygments is a syntax highlighting package written in Python." category = "main" @@ -329,7 +329,7 @@ python-versions = ">=3.6" plugins = ["importlib-metadata"] [[package]] -name = "PyJWT" +name = "pyjwt" version = "2.6.0" description = "JSON Web Token implementation in Python" category = "main" @@ -343,7 +343,7 @@ docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] [[package]] -name = "pyOpenSSL" +name = "pyopenssl" version = "22.0.0" description = "Python wrapper module around the OpenSSL library" category = "main" @@ -404,7 +404,7 @@ urllib3 = ">=1.21.1,<1.27" [package.extras] socks = ["PySocks (>=1.5.6,!=1.5.7)"] -use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"] +use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "rich" @@ -519,7 +519,7 @@ python-versions = ">=3.7" [[package]] name = "tzdata" -version = "2022.5" +version = "2022.6" description = "Provider of IANA time zone data" category = "main" optional = false @@ -544,7 +544,7 @@ test = ["pytest (>=4.3)", "pytest-mock (>=3.3)"] [[package]] name = "unittest-parallel" -version = "1.5.2" +version = "1.5.3" description = "Parallel unit test runner with coverage support" category = "dev" optional = false @@ -588,7 +588,7 @@ python-versions = "*" [[package]] name = "zipp" -version = "3.9.0" +version = "3.10.0" description = "Backport of pathlib-compatible object wrapper for zip files" category = "main" optional = false @@ -612,7 +612,7 @@ vertica = [] [metadata] lock-version = "1.1" python-versions = "^3.7" -content-hash = "074d933430a02a6ba56dadc359f89b875dc74e42070c1dfc1b79e8e8cd610404" +content-hash = "2ee6a778364480c8f72eda863926460037a8a7f580dc9d920388ec8c178ddb35" [metadata.files] arrow = [ @@ -623,7 +623,7 @@ asn1crypto = [ {file = "asn1crypto-1.5.1-py2.py3-none-any.whl", hash = "sha256:db4e40728b728508912cbb3d44f19ce188f218e9eba635821bb4b68564f8fd67"}, {file = "asn1crypto-1.5.1.tar.gz", hash = "sha256:13ae38502be632115abf8a24cbe5f4da52e3b5231990aff31123c805306ccb9c"}, ] -"backports.zoneinfo" = [ +backports-zoneinfo = [ {file = "backports.zoneinfo-0.2.1-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:da6013fd84a690242c310d77ddb8441a559e9cb3d3d59ebac9aca1a57b2e18bc"}, {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:89a48c0d158a3cc3f654da4c2de1ceba85263fafb861b98b59040a5086259722"}, {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:1c5742112073a563c81f786e77514969acb58649bcdf6cdf0b4ed31a348d4546"}, @@ -734,6 +734,19 @@ clickhouse-driver = [ {file = "clickhouse_driver-0.2.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8c776ab9592d456351ba2c4b05a8149761f36991033257d9c31ef6d26952dbe7"}, {file = "clickhouse_driver-0.2.4-cp310-cp310-win32.whl", hash = "sha256:bb1423d6daa8736aade0f7d31870c28ab2e7553a21cf923af6f1ff4a219c0ac9"}, {file = "clickhouse_driver-0.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:3955616073d030dc8cc7b0ef68ffe045510334137c1b5d11447347352d0dec28"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cd217a28e3821cbba49fc0ea87e0e6dde799e62af327d1f9c5f9480abd71e17c"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f4863a8eb369a36266f372e9eacc3fe222e4d31e1b2a7a2da759b521fffad1c"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c87ce674ebb2f5e38b68c36464538fffdea4f5432bb136cb6980489ae3c6dbe9"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:383690650fccaffa7f0e56d6fb0b00b9227b408fb3d92291a1f1ed66ce83df7c"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a68ac4633fd4cf265e619adeec1c0ee67ff1d9b5373c140b8400adcc4831c19"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5551056e5ab9e1dac67abbdf202b6e67590aa79a013a9da8ecbaec828e0790fe"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:93b601573779c5e8c2344cd983ef87de0fc8a31392bdc8571e79ed318f30dbbb"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:f228edc258f9ef6ee29b5618832b38479a0cfaa5bb837150ba62bbc1357a58cd"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:4b39c0f962a3664a72e0bfaa30929d0317b5e3427ff8b36d7889f8d31f4ff89e"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:47465598241cdf0b3a7810667c6104ada7b992a797768883ce30635a213568c3"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e9f31e7ccfd5cf526fd9db50ade94504007992922c8e556ba54e8ba637e9cca0"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-win32.whl", hash = "sha256:da0bcd41aeb50ec4316808c11d591aef60fe1b9de997df10ffcad9ab3cb0efa2"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:bb64ad0dfcc5ee158b01411e7a828bb3b20e4a2bc2f99da219acff0a9d18808c"}, {file = "clickhouse_driver-0.2.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b746e83652fbb89cb907adfdd69d6a7c7019bb5bbbdf68454bdcd16b09959a00"}, {file = "clickhouse_driver-0.2.4-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b9cf37f0b7165619d2e0188a71018300ed1a937ca518b01a5d168aec0a09add"}, {file = "clickhouse_driver-0.2.4-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:be4da882979c5a6d5631b5db6acb6407d76c73be6635ab0a76378b98aac8e5ab"}, @@ -793,8 +806,8 @@ clickhouse-driver = [ {file = "clickhouse_driver-0.2.4-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:821c7efff84fda8b68680140e07241991ab27296dc62213eb788efef4470fdd5"}, ] colorama = [ - {file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"}, - {file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"}, + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] commonmark = [ {file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"}, @@ -933,33 +946,33 @@ prompt-toolkit = [ {file = "prompt_toolkit-3.0.31.tar.gz", hash = "sha256:9ada952c9d1787f52ff6d5f3484d0b4df8952787c087edf6a1f7c2cb1ea88148"}, ] protobuf = [ - {file = "protobuf-4.21.8-cp310-abi3-win32.whl", hash = "sha256:c252c55ee15175aa1b21b7b9896e6add5162d066d5202e75c39f96136f08cce3"}, - {file = "protobuf-4.21.8-cp310-abi3-win_amd64.whl", hash = "sha256:809ca0b225d3df42655a12f311dd0f4148a943c51f1ad63c38343e457492b689"}, - {file = "protobuf-4.21.8-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:bbececaf3cfea9ea65ebb7974e6242d310d2a7772a6f015477e0d79993af4511"}, - {file = "protobuf-4.21.8-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:b02eabb9ebb1a089ed20626a90ad7a69cee6bcd62c227692466054b19c38dd1f"}, - {file = "protobuf-4.21.8-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:4761201b93e024bb70ee3a6a6425d61f3152ca851f403ba946fb0cde88872661"}, - {file = "protobuf-4.21.8-cp37-cp37m-win32.whl", hash = "sha256:f2d55ff22ec300c4d954d3b0d1eeb185681ec8ad4fbecff8a5aee6a1cdd345ba"}, - {file = "protobuf-4.21.8-cp37-cp37m-win_amd64.whl", hash = "sha256:c5f94911dd8feb3cd3786fc90f7565c9aba7ce45d0f254afd625b9628f578c3f"}, - {file = "protobuf-4.21.8-cp38-cp38-win32.whl", hash = "sha256:b37b76efe84d539f16cba55ee0036a11ad91300333abd213849cbbbb284b878e"}, - {file = "protobuf-4.21.8-cp38-cp38-win_amd64.whl", hash = "sha256:2c92a7bfcf4ae76a8ac72e545e99a7407e96ffe52934d690eb29a8809ee44d7b"}, - {file = "protobuf-4.21.8-cp39-cp39-win32.whl", hash = "sha256:89d641be4b5061823fa0e463c50a2607a97833e9f8cfb36c2f91ef5ccfcc3861"}, - {file = "protobuf-4.21.8-cp39-cp39-win_amd64.whl", hash = "sha256:bc471cf70a0f53892fdd62f8cd4215f0af8b3f132eeee002c34302dff9edd9b6"}, - {file = "protobuf-4.21.8-py2.py3-none-any.whl", hash = "sha256:a55545ce9eec4030cf100fcb93e861c622d927ef94070c1a3c01922902464278"}, - {file = "protobuf-4.21.8-py3-none-any.whl", hash = "sha256:0f236ce5016becd989bf39bd20761593e6d8298eccd2d878eda33012645dc369"}, - {file = "protobuf-4.21.8.tar.gz", hash = "sha256:427426593b55ff106c84e4a88cac855175330cb6eb7e889e85aaa7b5652b686d"}, + {file = "protobuf-4.21.9-cp310-abi3-win32.whl", hash = "sha256:6e0be9f09bf9b6cf497b27425487706fa48c6d1632ddd94dab1a5fe11a422392"}, + {file = "protobuf-4.21.9-cp310-abi3-win_amd64.whl", hash = "sha256:a7d0ea43949d45b836234f4ebb5ba0b22e7432d065394b532cdca8f98415e3cf"}, + {file = "protobuf-4.21.9-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:b5ab0b8918c136345ff045d4b3d5f719b505b7c8af45092d7f45e304f55e50a1"}, + {file = "protobuf-4.21.9-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:2c9c2ed7466ad565f18668aa4731c535511c5d9a40c6da39524bccf43e441719"}, + {file = "protobuf-4.21.9-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:e575c57dc8b5b2b2caa436c16d44ef6981f2235eb7179bfc847557886376d740"}, + {file = "protobuf-4.21.9-cp37-cp37m-win32.whl", hash = "sha256:9227c14010acd9ae7702d6467b4625b6fe853175a6b150e539b21d2b2f2b409c"}, + {file = "protobuf-4.21.9-cp37-cp37m-win_amd64.whl", hash = "sha256:a419cc95fca8694804709b8c4f2326266d29659b126a93befe210f5bbc772536"}, + {file = "protobuf-4.21.9-cp38-cp38-win32.whl", hash = "sha256:5b0834e61fb38f34ba8840d7dcb2e5a2f03de0c714e0293b3963b79db26de8ce"}, + {file = "protobuf-4.21.9-cp38-cp38-win_amd64.whl", hash = "sha256:84ea107016244dfc1eecae7684f7ce13c788b9a644cd3fca5b77871366556444"}, + {file = "protobuf-4.21.9-cp39-cp39-win32.whl", hash = "sha256:f9eae277dd240ae19bb06ff4e2346e771252b0e619421965504bd1b1bba7c5fa"}, + {file = "protobuf-4.21.9-cp39-cp39-win_amd64.whl", hash = "sha256:6e312e280fbe3c74ea9e080d9e6080b636798b5e3939242298b591064470b06b"}, + {file = "protobuf-4.21.9-py2.py3-none-any.whl", hash = "sha256:7eb8f2cc41a34e9c956c256e3ac766cf4e1a4c9c925dc757a41a01be3e852965"}, + {file = "protobuf-4.21.9-py3-none-any.whl", hash = "sha256:48e2cd6b88c6ed3d5877a3ea40df79d08374088e89bedc32557348848dff250b"}, + {file = "protobuf-4.21.9.tar.gz", hash = "sha256:61f21493d96d2a77f9ca84fefa105872550ab5ef71d21c458eb80edcf4885a99"}, ] psycopg2 = [ - {file = "psycopg2-2.9.4-cp310-cp310-win32.whl", hash = "sha256:8de6a9fc5f42fa52f559e65120dcd7502394692490c98fed1221acf0819d7797"}, - {file = "psycopg2-2.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:1da77c061bdaab450581458932ae5e469cc6e36e0d62f988376e9f513f11cb5c"}, - {file = "psycopg2-2.9.4-cp36-cp36m-win32.whl", hash = "sha256:a11946bad3557ca254f17357d5a4ed63bdca45163e7a7d2bfb8e695df069cc3a"}, - {file = "psycopg2-2.9.4-cp36-cp36m-win_amd64.whl", hash = "sha256:46361c054df612c3cc813fdb343733d56543fb93565cff0f8ace422e4da06acb"}, - {file = "psycopg2-2.9.4-cp37-cp37m-win32.whl", hash = "sha256:aafa96f2da0071d6dd0cbb7633406d99f414b40ab0f918c9d9af7df928a1accb"}, - {file = "psycopg2-2.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:aa184d551a767ad25df3b8d22a0a62ef2962e0e374c04f6cbd1204947f540d61"}, - {file = "psycopg2-2.9.4-cp38-cp38-win32.whl", hash = "sha256:839f9ea8f6098e39966d97fcb8d08548fbc57c523a1e27a1f0609addf40f777c"}, - {file = "psycopg2-2.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:c7fa041b4acb913f6968fce10169105af5200f296028251d817ab37847c30184"}, - {file = "psycopg2-2.9.4-cp39-cp39-win32.whl", hash = "sha256:07b90a24d5056687781ddaef0ea172fd951f2f7293f6ffdd03d4f5077801f426"}, - {file = "psycopg2-2.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:849bd868ae3369932127f0771c08d1109b254f08d48dc42493c3d1b87cb2d308"}, - {file = "psycopg2-2.9.4.tar.gz", hash = "sha256:d529926254e093a1b669f692a3aa50069bc71faf5b0ecd91686a78f62767d52f"}, + {file = "psycopg2-2.9.5-cp310-cp310-win32.whl", hash = "sha256:d3ef67e630b0de0779c42912fe2cbae3805ebaba30cda27fea2a3de650a9414f"}, + {file = "psycopg2-2.9.5-cp310-cp310-win_amd64.whl", hash = "sha256:4cb9936316d88bfab614666eb9e32995e794ed0f8f6b3b718666c22819c1d7ee"}, + {file = "psycopg2-2.9.5-cp36-cp36m-win32.whl", hash = "sha256:b9ac1b0d8ecc49e05e4e182694f418d27f3aedcfca854ebd6c05bb1cffa10d6d"}, + {file = "psycopg2-2.9.5-cp36-cp36m-win_amd64.whl", hash = "sha256:fc04dd5189b90d825509caa510f20d1d504761e78b8dfb95a0ede180f71d50e5"}, + {file = "psycopg2-2.9.5-cp37-cp37m-win32.whl", hash = "sha256:922cc5f0b98a5f2b1ff481f5551b95cd04580fd6f0c72d9b22e6c0145a4840e0"}, + {file = "psycopg2-2.9.5-cp37-cp37m-win_amd64.whl", hash = "sha256:1e5a38aa85bd660c53947bd28aeaafb6a97d70423606f1ccb044a03a1203fe4a"}, + {file = "psycopg2-2.9.5-cp38-cp38-win32.whl", hash = "sha256:f5b6320dbc3cf6cfb9f25308286f9f7ab464e65cfb105b64cc9c52831748ced2"}, + {file = "psycopg2-2.9.5-cp38-cp38-win_amd64.whl", hash = "sha256:1a5c7d7d577e0eabfcf15eb87d1e19314c8c4f0e722a301f98e0e3a65e238b4e"}, + {file = "psycopg2-2.9.5-cp39-cp39-win32.whl", hash = "sha256:322fd5fca0b1113677089d4ebd5222c964b1760e361f151cbb2706c4912112c5"}, + {file = "psycopg2-2.9.5-cp39-cp39-win_amd64.whl", hash = "sha256:190d51e8c1b25a47484e52a79638a8182451d6f6dff99f26ad9bd81e5359a0fa"}, + {file = "psycopg2-2.9.5.tar.gz", hash = "sha256:a5246d2e683a972e2187a8714b5c2cf8156c064629f9a9b1a873c1730d9e245a"}, ] pycparser = [ {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, @@ -997,15 +1010,15 @@ pycryptodomex = [ {file = "pycryptodomex-3.15.0-pp36-pypy36_pp73-win32.whl", hash = "sha256:35a8f7afe1867118330e2e0e0bf759c409e28557fb1fc2fbb1c6c937297dbe9a"}, {file = "pycryptodomex-3.15.0.tar.gz", hash = "sha256:7341f1bb2dadb0d1a0047f34c3a58208a92423cdbd3244d998e4b28df5eac0ed"}, ] -Pygments = [ +pygments = [ {file = "Pygments-2.13.0-py3-none-any.whl", hash = "sha256:f643f331ab57ba3c9d89212ee4a2dabc6e94f117cf4eefde99a0574720d14c42"}, {file = "Pygments-2.13.0.tar.gz", hash = "sha256:56a8508ae95f98e2b9bdf93a6be5ae3f7d8af858b43e02c5a2ff083726be40c1"}, ] -PyJWT = [ +pyjwt = [ {file = "PyJWT-2.6.0-py3-none-any.whl", hash = "sha256:d83c3d892a77bbb74d3e1a2cfa90afaadb60945205d1095d9221f04466f64c14"}, {file = "PyJWT-2.6.0.tar.gz", hash = "sha256:69285c7e31fc44f68a1feb309e948e0df53259d579295e6cfe2b1792329f05fd"}, ] -pyOpenSSL = [ +pyopenssl = [ {file = "pyOpenSSL-22.0.0-py2.py3-none-any.whl", hash = "sha256:ea252b38c87425b64116f808355e8da644ef9b07e429398bfece610f893ee2e0"}, {file = "pyOpenSSL-22.0.0.tar.gz", hash = "sha256:660b1b1425aac4a1bea1d94168a85d99f0b3144c869dd4390d27629d0087f1bf"}, ] @@ -1076,15 +1089,16 @@ typing-extensions = [ {file = "typing_extensions-4.4.0.tar.gz", hash = "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa"}, ] tzdata = [ - {file = "tzdata-2022.5-py2.py3-none-any.whl", hash = "sha256:323161b22b7802fdc78f20ca5f6073639c64f1a7227c40cd3e19fd1d0ce6650a"}, - {file = "tzdata-2022.5.tar.gz", hash = "sha256:e15b2b3005e2546108af42a0eb4ccab4d9e225e2dfbf4f77aad50c70a4b1f3ab"}, + {file = "tzdata-2022.6-py2.py3-none-any.whl", hash = "sha256:04a680bdc5b15750c39c12a448885a51134a27ec9af83667663f0b3a1bf3f342"}, + {file = "tzdata-2022.6.tar.gz", hash = "sha256:91f11db4503385928c15598c98573e3af07e7229181bee5375bd30f1695ddcae"}, ] tzlocal = [ {file = "tzlocal-4.2-py3-none-any.whl", hash = "sha256:89885494684c929d9191c57aa27502afc87a579be5cdd3225c77c463ea043745"}, {file = "tzlocal-4.2.tar.gz", hash = "sha256:ee5842fa3a795f023514ac2d801c4a81d1743bbe642e3940143326b3a00addd7"}, ] unittest-parallel = [ - {file = "unittest-parallel-1.5.2.tar.gz", hash = "sha256:42e82215862619ba7ce269db30eb63b878671ebb2ab9bfcead1fede43800b7ef"}, + {file = "unittest-parallel-1.5.3.tar.gz", hash = "sha256:32182bb2230371d651e6fc9795ddf52c134eb36f5064dc339fdbb5984a639517"}, + {file = "unittest_parallel-1.5.3-py3-none-any.whl", hash = "sha256:5670c9eca19450dedb493e9dad2ca4dcbbe12e04477d934ff6c92071d36bace7"}, ] urllib3 = [ {file = "urllib3-1.26.12-py2.py3-none-any.whl", hash = "sha256:b930dd878d5a8afb066a637fbb35144fe7901e3b209d1cd4f524bd0e9deee997"}, @@ -1099,6 +1113,6 @@ wcwidth = [ {file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"}, ] zipp = [ - {file = "zipp-3.9.0-py3-none-any.whl", hash = "sha256:972cfa31bc2fedd3fa838a51e9bc7e64b7fb725a8c00e7431554311f180e9980"}, - {file = "zipp-3.9.0.tar.gz", hash = "sha256:3a7af91c3db40ec72dd9d154ae18e008c69efe8ca88dde4f9a731bb82fe2f9eb"}, + {file = "zipp-3.10.0-py3-none-any.whl", hash = "sha256:4fcb6f278987a6605757302a6e40e896257570d11c51628968ccb2a47e80c6c1"}, + {file = "zipp-3.10.0.tar.gz", hash = "sha256:7a7262fd930bd3e36c50b9a64897aec3fafff3dfdeec9623ae22b40e93f99bb8"}, ] diff --git a/pyproject.toml b/pyproject.toml index b0ae7ff6..212d177d 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ presto-python-client = {version="*", optional=true} clickhouse-driver = {version="*", optional=true} [tool.poetry.dev-dependencies] +arrow = "^1.2.3" parameterized = "*" unittest-parallel = "*" preql = "^0.2.19" From 7dc4ca078a365dba2d2e0d321fad0e30e0afb075 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 1 Nov 2022 14:12:32 -0300 Subject: [PATCH 75/93] Bugfix for commit 7bab3acdfbc082004f02f5ac2343c1d247993d37 --- data_diff/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_diff/config.py b/data_diff/config.py index 941e2643..ae289852 100644 --- a/data_diff/config.py +++ b/data_diff/config.py @@ -26,7 +26,7 @@ def _apply_config(config: Dict[str, Any], run_name: str, kw: Dict[str, Any]): else: run_name = "default" - if "database1" in kw: + if kw.get("database1") is not None: for attr in ("table1", "database2", "table2"): if kw[attr] is None: raise ValueError(f"Specified database1 but not {attr}. Must specify all 4 arguments, or niether.") From 57d4352f67ccca662a3a4b75c8369d6d855aca3a Mon Sep 17 00:00:00 2001 From: Jardayn Date: Tue, 1 Nov 2022 21:08:09 +0200 Subject: [PATCH 76/93] Fixed test + Contrib readme improvements --- CONTRIBUTING.md | 2 +- tests/test_cli.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 56c717ab..9f8c0b10 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -34,7 +34,7 @@ The same goes for other technical requests, like missing features, or gaps in th See [issues](/datafold/data-diff/issues/). -For questions, and non-technical discussions, see [discussions](/datafold/data-diff/discussions). +For questions, and non-technical discussions, see [discussions](https://github.com/datafold/data-diff/discussions). ### Contributing code diff --git a/tests/test_cli.py b/tests/test_cli.py index 263dc872..1c131f54 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -37,7 +37,7 @@ def setUp(self) -> None: src_table = table(table_src_name, schema={"id": int, "datetime": datetime, "text_comment": str}) self.conn.query(src_table.create()) - self.now = now = arrow.get() + self.now = now = arrow.get(datetime.now()) rows = [ (now, "now"), From d8c94f727f304fbfdc795d33d53a8483a57f6848 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 1 Nov 2022 16:49:46 -0300 Subject: [PATCH 77/93] Better error in config. (Issue #270) --- data_diff/config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/data_diff/config.py b/data_diff/config.py index ae289852..18c31e6e 100644 --- a/data_diff/config.py +++ b/data_diff/config.py @@ -36,7 +36,10 @@ def _apply_config(config: Dict[str, Any], run_name: str, kw: Dict[str, Any]): # Process databases + tables for index in "12": - args = run_args.pop(index, {}) + try: + args = run_args.pop(index) + except KeyError: + raise ConfigParseError(f"Could not find source #{index}: Expecting a key of '{index}' containing '.database' and '.table'.") for attr in ("database", "table"): if attr not in args: raise ConfigParseError(f"Running 'run.{run_name}': Connection #{index} is missing attribute '{attr}'.") From 9845328eb0b34574a2888116fa5af51e0d753f7a Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 1 Nov 2022 15:31:39 -0300 Subject: [PATCH 78/93] Refactor Dialect out of Database (partial refactor) --- data_diff/databases/base.py | 122 ++++++++++---------- data_diff/databases/bigquery.py | 79 +++++++------ data_diff/databases/clickhouse.py | 156 ++++++++++++++------------ data_diff/databases/database_types.py | 68 +++++------ data_diff/databases/databricks.py | 51 +++++---- data_diff/databases/mssql.py | 2 +- data_diff/databases/mysql.py | 87 +++++++------- data_diff/databases/oracle.py | 140 ++++++++++++----------- data_diff/databases/postgresql.py | 57 +++++----- data_diff/databases/presto.py | 81 ++++++------- data_diff/databases/redshift.py | 27 +++-- data_diff/databases/snowflake.py | 53 +++++---- data_diff/databases/trino.py | 18 ++- data_diff/databases/vertica.py | 79 +++++++------ data_diff/queries/ast_classes.py | 20 ++-- data_diff/queries/compiler.py | 12 +- data_diff/queries/extras.py | 4 +- tests/common.py | 4 +- tests/test_database.py | 4 +- tests/test_database_types.py | 4 +- tests/test_joindiff.py | 2 +- tests/test_query.py | 35 ++++-- 22 files changed, 607 insertions(+), 498 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 06b6ea08..0a907c6b 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -13,6 +13,9 @@ from data_diff.queries import Expr, Compiler, table, Select, SKIP, Explain from .database_types import ( AbstractDatabase, + AbstractDialect, + AbstractMixin_MD5, + AbstractMixin_NormalizeValue, ColType, Integer, Decimal, @@ -99,6 +102,65 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal return callback(sql_code) +class BaseDialect(AbstractDialect, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + SUPPORTS_PRIMARY_KEY = False + + def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): + if offset: + raise NotImplementedError("No support for OFFSET in query") + + return f"LIMIT {limit}" + + def concat(self, items: List[str]) -> str: + assert len(items) > 1 + joined_exprs = ", ".join(items) + return f"concat({joined_exprs})" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"{a} is distinct from {b}" + + def timestamp_value(self, t: DbTime) -> str: + return f"'{t.isoformat()}'" + + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + if isinstance(coltype, String_UUID): + return f"TRIM({value})" + return self.to_string(value) + + def random(self) -> str: + return "RANDOM()" + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN {query}" + + def _constant_value(self, v): + if v is None: + return "NULL" + elif isinstance(v, str): + return f"'{v}'" + elif isinstance(v, datetime): + # TODO use self.timestamp_value + return f"timestamp '{v}'" + elif isinstance(v, UUID): + return f"'{v}'" + return repr(v) + + def constant_values(self, rows) -> str: + values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows) + return f"VALUES {values}" + + def type_repr(self, t) -> str: + if isinstance(t, str): + return t + return { + int: "INT", + str: "VARCHAR", + bool: "BOOLEAN", + float: "FLOAT", + datetime: "TIMESTAMP", + }[t] + + class Database(AbstractDatabase): """Base abstract class for databases. @@ -109,8 +171,9 @@ class Database(AbstractDatabase): TYPE_CLASSES: Dict[str, type] = {} default_schema: str = None + dialect: AbstractDialect = None + SUPPORTS_ALPHANUMS = True - SUPPORTS_PRIMARY_KEY = False SUPPORTS_UNIQUE_CONSTAINT = False _interactive = False @@ -274,7 +337,7 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe if not text_columns: return - fields = [self.normalize_uuid(self.quote(c), String_UUID()) for c in text_columns] + fields = [self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID()) for c in text_columns] samples_by_row = self.query(table(*table_path).select(*fields).where(where or SKIP).limit(sample_size), list) if not samples_by_row: raise ValueError(f"Table {table_path} is empty.") @@ -321,58 +384,6 @@ def _normalize_table_path(self, path: DbPath) -> DbPath: def parse_table_name(self, name: str) -> DbPath: return parse_table_name(name) - def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): - if offset: - raise NotImplementedError("No support for OFFSET in query") - - return f"LIMIT {limit}" - - def concat(self, items: List[str]) -> str: - assert len(items) > 1 - joined_exprs = ", ".join(items) - return f"concat({joined_exprs})" - - def is_distinct_from(self, a: str, b: str) -> str: - return f"{a} is distinct from {b}" - - def timestamp_value(self, t: DbTime) -> str: - return f"'{t.isoformat()}'" - - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - if isinstance(coltype, String_UUID): - return f"TRIM({value})" - return self.to_string(value) - - def random(self) -> str: - return "RANDOM()" - - def _constant_value(self, v): - if v is None: - return "NULL" - elif isinstance(v, str): - return f"'{v}'" - elif isinstance(v, datetime): - # TODO use self.timestamp_value - return f"timestamp '{v}'" - elif isinstance(v, UUID): - return f"'{v}'" - return repr(v) - - def constant_values(self, rows) -> str: - values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows) - return f"VALUES {values}" - - def type_repr(self, t) -> str: - if isinstance(t, str): - return t - return { - int: "INT", - str: "VARCHAR", - bool: "BOOLEAN", - float: "FLOAT", - datetime: "TIMESTAMP", - }[t] - def _query_cursor(self, c, sql_code: str): assert isinstance(sql_code, str), sql_code try: @@ -389,9 +400,6 @@ def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> lis callback = partial(self._query_cursor, c) return apply_query(callback, sql_code) - def explain_as_text(self, query: str) -> str: - return f"EXPLAIN {query}" - class ThreadedDatabase(Database): """Access the database through singleton threads. diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 3d3720b6..58577585 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -1,6 +1,6 @@ from typing import List, Union from .database_types import Timestamp, Datetime, Integer, Decimal, Float, Text, DbPath, FractionalType, TemporalType -from .base import Database, import_helper, parse_table_name, ConnectError, apply_query +from .base import BaseDialect, Database, import_helper, parse_table_name, ConnectError, apply_query from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter @@ -11,7 +11,48 @@ def import_bigquery(): return bigquery +class Dialect(BaseDialect): + name = "BigQuery" + + def random(self) -> str: + return "RAND()" + + def quote(self, s: str): + return f"`{s}`" + + def md5_as_int(self, s: str) -> str: + return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)" + + def to_string(self, s: str): + return f"cast({s} as string)" + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" + return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" + + if coltype.precision == 0: + return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000, {value})" + elif coltype.precision == 6: + return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" + + timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return f"format('%.{coltype.precision}f', {value})" + + def type_repr(self, t) -> str: + try: + return {str: "STRING", float: "FLOAT64"}[t] + except KeyError: + return super().type_repr(t) + + class BigQuery(Database): + dialect = Dialect() TYPE_CLASSES = { # Dates "TIMESTAMP": Timestamp, @@ -37,12 +78,6 @@ def __init__(self, project, *, dataset, **kw): self.default_schema = dataset - def quote(self, s: str): - return f"`{s}`" - - def md5_to_int(self, s: str) -> str: - return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)" - def _normalize_returned_value(self, value): if isinstance(value, bytes): return value.decode() @@ -64,9 +99,6 @@ def _query_atom(self, sql_code: str): def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): return apply_query(self._query_atom, sql_code) - def to_string(self, s: str): - return f"cast({s} as string)" - def close(self): self._client.close() @@ -81,37 +113,10 @@ def select_table_schema(self, path: DbPath) -> str: def query_table_unique_columns(self, path: DbPath) -> List[str]: return [] - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" - return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" - - if coltype.precision == 0: - return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000, {value})" - elif coltype.precision == 6: - return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" - - timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" - return ( - f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return f"format('%.{coltype.precision}f', {value})" - def parse_table_name(self, name: str) -> DbPath: path = parse_table_name(name) return self._normalize_table_path(path) - def random(self) -> str: - return "RAND()" - @property def is_autocommit(self) -> bool: return True - - def type_repr(self, t) -> str: - try: - return {str: "STRING", float: "FLOAT64"}[t] - except KeyError: - return super().type_repr(t) diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 9b657d89..40657585 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -4,11 +4,22 @@ MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS, + BaseDialect, ThreadedDatabase, import_helper, ConnectError, ) -from .database_types import ColType, Decimal, Float, Integer, FractionalType, Native_UUID, TemporalType, Text, Timestamp +from .database_types import ( + ColType, + Decimal, + Float, + Integer, + FractionalType, + Native_UUID, + TemporalType, + Text, + Timestamp, +) @import_helper("clickhouse") @@ -18,7 +29,78 @@ def import_clickhouse(): return clickhouse_driver +class Dialect(BaseDialect): + name = "Clickhouse" + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. + # For example: + # select toString(toDecimal128(1.10, 2)); -- the result is 1.1 + # select toString(toDecimal128(1.00, 2)); -- the result is 1 + # So, we should use some custom approach to save these trailing zeros. + # To avoid it, we can add a small value like 0.000001 to prevent dropping of zeros from the end when casting. + # For examples above it looks like: + # select toString(toDecimal128(1.10, 2 + 1) + toDecimal128(0.001, 3)); -- the result is 1.101 + # After that, cut an extra symbol from the string, i.e. 1.101 -> 1.10 + # So, the algorithm is: + # 1. Cast to decimal with precision + 1 + # 2. Add a small value 10^(-precision-1) + # 3. Cast the result to string + # 4. Drop the extra digit from the string. To do that, we need to slice the string + # with length = digits in an integer part + 1 (symbol of ".") + precision + + if coltype.precision == 0: + return self.to_string(f"round({value})") + + precision = coltype.precision + # TODO: too complex, is there better performance way? + value = f""" + if({value} >= 0, '', '-') || left( + toString( + toDecimal128( + round(abs({value}), {precision}), + {precision} + 1 + ) + + + toDecimal128( + exp10(-{precision + 1}), + {precision} + 1 + ) + ), + toUInt8( + greatest( + floor(log10(abs({value}))) + 1, + 1 + ) + ) + 1 + {precision} + ) + """ + return value + + def quote(self, s: str) -> str: + return f'"{s}"' + + def md5_as_int(self, s: str) -> str: + substr_idx = 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS + return f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx})))))" + + def to_string(self, s: str) -> str: + return f"toString({s})" + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + prec = coltype.precision + if coltype.rounds: + timestamp = f"toDateTime64(round(toUnixTimestamp64Micro(toDateTime64({value}, 6)) / 1000000, {prec}), 6)" + return self.to_string(timestamp) + + fractional = f"toUnixTimestamp64Micro(toDateTime64({value}, {prec})) % 1000000" + fractional = f"lpad({self.to_string(fractional)}, 6, '0')" + value = f"formatDateTime({value}, '%Y-%m-%d %H:%M:%S') || '.' || {self.to_string(fractional)}" + return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')" + + class Clickhouse(ThreadedDatabase): + dialect = Dialect() TYPE_CLASSES = { "Int8": Integer, "Int16": Integer, @@ -80,77 +162,11 @@ def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: return self.TYPE_CLASSES.get(type_repr) - def quote(self, s: str) -> str: - return f'"{s}"' - - def md5_to_int(self, s: str) -> str: - substr_idx = 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS - return f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx})))))" - - def to_string(self, s: str) -> str: - return f"toString({s})" - - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - prec = coltype.precision - if coltype.rounds: - timestamp = f"toDateTime64(round(toUnixTimestamp64Micro(toDateTime64({value}, 6)) / 1000000, {prec}), 6)" - return self.to_string(timestamp) - - fractional = f"toUnixTimestamp64Micro(toDateTime64({value}, {prec})) % 1000000" - fractional = f"lpad({self.to_string(fractional)}, 6, '0')" - value = f"formatDateTime({value}, '%Y-%m-%d %H:%M:%S') || '.' || {self.to_string(fractional)}" - return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')" + @property + def is_autocommit(self) -> bool: + return True def _convert_db_precision_to_digits(self, p: int) -> int: # Done the same as for PostgreSQL but need to rewrite in another way # because it does not help for float with a big integer part. return super()._convert_db_precision_to_digits(p) - 2 - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. - # For example: - # select toString(toDecimal128(1.10, 2)); -- the result is 1.1 - # select toString(toDecimal128(1.00, 2)); -- the result is 1 - # So, we should use some custom approach to save these trailing zeros. - # To avoid it, we can add a small value like 0.000001 to prevent dropping of zeros from the end when casting. - # For examples above it looks like: - # select toString(toDecimal128(1.10, 2 + 1) + toDecimal128(0.001, 3)); -- the result is 1.101 - # After that, cut an extra symbol from the string, i.e. 1.101 -> 1.10 - # So, the algorithm is: - # 1. Cast to decimal with precision + 1 - # 2. Add a small value 10^(-precision-1) - # 3. Cast the result to string - # 4. Drop the extra digit from the string. To do that, we need to slice the string - # with length = digits in an integer part + 1 (symbol of ".") + precision - - if coltype.precision == 0: - return self.to_string(f"round({value})") - - precision = coltype.precision - # TODO: too complex, is there better performance way? - value = f""" - if({value} >= 0, '', '-') || left( - toString( - toDecimal128( - round(abs({value}), {precision}), - {precision} + 1 - ) - + - toDecimal128( - exp10(-{precision + 1}), - {precision} + 1 - ) - ), - toUInt8( - greatest( - floor(log10(abs({value}))) + 1, - 1 - ) - ) + 1 + {precision} - ) - """ - return value - - @property - def is_autocommit(self) -> bool: - return True diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index 6c9af301..f03861d3 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -186,14 +186,7 @@ def timestamp_value(self, t: datetime) -> str: ... -class AbstractDatadiffDialect(ABC): - """Dialect-dependent query expressions, that are specific to data-diff""" - - @abstractmethod - def md5_to_int(self, s: str) -> str: - "Provide SQL for computing md5 and returning an int" - ... - +class AbstractMixin_NormalizeValue(ABC): @abstractmethod def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: """Creates an SQL expression, that converts 'value' to a normalized timestamp. @@ -235,8 +228,41 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: """ ... + def normalize_value_by_type(self, value: str, coltype: ColType) -> str: + """Creates an SQL expression, that converts 'value' to a normalized representation. + + The returned expression must accept any SQL value, and return a string. + + The default implementation dispatches to a method according to `coltype`: + + :: -class AbstractDatabase(AbstractDialect, AbstractDatadiffDialect): + TemporalType -> normalize_timestamp() + FractionalType -> normalize_number() + *else* -> to_string() + + (`Integer` falls in the *else* category) + + """ + if isinstance(coltype, TemporalType): + return self.normalize_timestamp(value, coltype) + elif isinstance(coltype, FractionalType): + return self.normalize_number(value, coltype) + elif isinstance(coltype, ColType_UUID): + return self.normalize_uuid(value, coltype) + return self.to_string(value) + + +class AbstractMixin_MD5(ABC): + """Dialect-dependent query expressions, that are specific to data-diff""" + + @abstractmethod + def md5_as_int(self, s: str) -> str: + "Provide SQL for computing md5 and returning an int" + ... + + +class AbstractDatabase: @abstractmethod def _query(self, sql_code: str) -> list: "Send query to database and return result" @@ -296,30 +322,6 @@ def _normalize_table_path(self, path: DbPath) -> DbPath: def is_autocommit(self) -> bool: ... - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - """Creates an SQL expression, that converts 'value' to a normalized representation. - - The returned expression must accept any SQL value, and return a string. - - The default implementation dispatches to a method according to `coltype`: - - :: - - TemporalType -> normalize_timestamp() - FractionalType -> normalize_number() - *else* -> to_string() - - (`Integer` falls in the *else* category) - - """ - if isinstance(coltype, TemporalType): - return self.normalize_timestamp(value, coltype) - elif isinstance(coltype, FractionalType): - return self.normalize_number(value, coltype) - elif isinstance(coltype, ColType_UUID): - return self.normalize_uuid(value, coltype) - return self.to_string(value) - Schema = CaseAwareMapping diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index e1d76349..496ead10 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -13,7 +13,7 @@ ColType, UnknownColType, ) -from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Database, import_helper, parse_table_name +from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, BaseDialect, Database, import_helper, parse_table_name @import_helper(text="You can install it using 'pip install databricks-sql-connector'") @@ -23,7 +23,34 @@ def import_databricks(): return databricks +class Dialect(BaseDialect): + name = "Databricks" + + def quote(self, s: str): + return f"`{s}`" + + def md5_as_int(self, s: str) -> str: + return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0))" + + def to_string(self, s: str) -> str: + return f"cast({s} as string)" + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + """Databricks timestamp contains no more than 6 digits in precision""" + + if coltype.rounds: + timestamp = f"cast(round(unix_micros({value}) / 1000000, {coltype.precision}) * 1000000 as bigint)" + return f"date_format(timestamp_micros({timestamp}), 'yyyy-MM-dd HH:mm:ss.SSSSSS')" + + precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision) + return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')" + + def normalize_number(self, value: str, coltype: NumericType) -> str: + return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") + + class Databricks(Database): + dialect = Dialect() TYPE_CLASSES = { # Numbers "INT": Integer, @@ -66,15 +93,6 @@ def _query(self, sql_code: str) -> list: "Uses the standard SQL cursor interface" return self._query_conn(self._conn, sql_code) - def quote(self, s: str): - return f"`{s}`" - - def md5_to_int(self, s: str) -> str: - return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0))" - - def to_string(self, s: str) -> str: - return f"cast({s} as string)" - def _convert_db_precision_to_digits(self, p: int) -> int: # Subtracting 1 due to wierd precision issues return max(super()._convert_db_precision_to_digits(p) - 1, 0) @@ -132,19 +150,6 @@ def _process_table_schema( self._refine_coltypes(path, col_dict, where) return col_dict - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - """Databricks timestamp contains no more than 6 digits in precision""" - - if coltype.rounds: - timestamp = f"cast(round(unix_micros({value}) / 1000000, {coltype.precision}) * 1000000 as bigint)" - return f"date_format(timestamp_micros({timestamp}), 'yyyy-MM-dd HH:mm:ss.SSSSSS')" - - precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision) - return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')" - - def normalize_number(self, value: str, coltype: NumericType) -> str: - return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") - def parse_table_name(self, name: str) -> DbPath: path = parse_table_name(name) return self._normalize_table_path(path) diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index 9029ff14..8d394e3c 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -17,7 +17,7 @@ # def quote(self, s: str): # return f"[{s}]" -# def md5_to_int(self, s: str) -> str: +# def md5_as_int(self, s: str) -> str: # return f"CONVERT(decimal(38,0), CONVERT(bigint, HashBytes('MD5', {s}), 2))" # # return f"CONVERT(bigint, (CHECKSUM({s})))" diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index e8e47b1b..2f0484c0 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -9,7 +9,7 @@ FractionalType, ColType_UUID, ) -from .base import ThreadedDatabase, import_helper, ConnectError +from .base import ThreadedDatabase, import_helper, ConnectError, BaseDialect from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS @@ -20,51 +20,14 @@ def import_mysql(): return mysql.connector -class MySQL(ThreadedDatabase): - TYPE_CLASSES = { - # Dates - "datetime": Datetime, - "timestamp": Timestamp, - # Numbers - "double": Float, - "float": Float, - "decimal": Decimal, - "int": Integer, - "bigint": Integer, - # Text - "varchar": Text, - "char": Text, - "varbinary": Text, - "binary": Text, - } - ROUNDS_ON_PREC_LOSS = True - SUPPORTS_ALPHANUMS = False +class Dialect(BaseDialect): + name = "MySQL" SUPPORTS_PRIMARY_KEY = True - SUPPORTS_UNIQUE_CONSTAINT = True - - def __init__(self, *, thread_count, **kw): - self._args = kw - - super().__init__(thread_count=thread_count) - - # In MySQL schema and database are synonymous - self.default_schema = kw["database"] - - def create_connection(self): - mysql = import_mysql() - try: - return mysql.connect(charset="utf8", use_unicode=True, **self._args) - except mysql.Error as e: - if e.errno == mysql.errorcode.ER_ACCESS_DENIED_ERROR: - raise ConnectError("Bad user name or password") from e - elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR: - raise ConnectError("Database does not exist") from e - raise ConnectError(*e.args) from e def quote(self, s: str): return f"`{s}`" - def md5_to_int(self, s: str) -> str: + def md5_as_int(self, s: str) -> str: return f"cast(conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as unsigned)" def to_string(self, s: str): @@ -99,3 +62,45 @@ def type_repr(self, t) -> str: def explain_as_text(self, query: str) -> str: return f"EXPLAIN FORMAT=TREE {query}" + + +class MySQL(ThreadedDatabase): + dialect = Dialect() + TYPE_CLASSES = { + # Dates + "datetime": Datetime, + "timestamp": Timestamp, + # Numbers + "double": Float, + "float": Float, + "decimal": Decimal, + "int": Integer, + "bigint": Integer, + # Text + "varchar": Text, + "char": Text, + "varbinary": Text, + "binary": Text, + } + ROUNDS_ON_PREC_LOSS = True + SUPPORTS_ALPHANUMS = False + SUPPORTS_UNIQUE_CONSTAINT = True + + def __init__(self, *, thread_count, **kw): + self._args = kw + + super().__init__(thread_count=thread_count) + + # In MySQL schema and database are synonymous + self.default_schema = kw["database"] + + def create_connection(self): + mysql = import_mysql() + try: + return mysql.connect(charset="utf8", use_unicode=True, **self._args) + except mysql.Error as e: + if e.errno == mysql.errorcode.ER_ACCESS_DENIED_ERROR: + raise ConnectError("Bad user name or password") from e + elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR: + raise ConnectError("Database does not exist") from e + raise ConnectError(*e.args) from e diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 73b53492..7449c0dc 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -14,7 +14,7 @@ TimestampTZ, FractionalType, ) -from .base import ThreadedDatabase, import_helper, ConnectError, QueryError +from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, QueryError from .base import TIMESTAMP_PRECISION_POS SESSION_TIME_ZONE = None # Changed by the tests @@ -27,7 +27,80 @@ def import_oracle(): return cx_Oracle +class Dialect(BaseDialect): + name = "Oracle" + SUPPORTS_PRIMARY_KEY = True + + def md5_as_int(self, s: str) -> str: + # standard_hash is faster than DBMS_CRYPTO.Hash + # TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ? + return f"to_number(substr(standard_hash({s}, 'MD5'), 18), 'xxxxxxxxxxxxxxx')" + + def quote(self, s: str): + return f"{s}" + + def to_string(self, s: str): + return f"cast({s} as varchar(1024))" + + def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): + if offset: + raise NotImplementedError("No support for OFFSET in query") + + return f"FETCH NEXT {limit} ROWS ONLY" + + def concat(self, items: List[str]) -> str: + joined_exprs = " || ".join(items) + return f"({joined_exprs})" + + def timestamp_value(self, t: DbTime) -> str: + return "timestamp '%s'" % t.isoformat(" ") + + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + # Cast is necessary for correct MD5 (trimming not enough) + return f"CAST(TRIM({value}) AS VARCHAR(36))" + + def random(self) -> str: + return "dbms_random.value" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"DECODE({a}, {b}, 1, 0) = 0" + + def type_repr(self, t) -> str: + try: + return { + str: "VARCHAR(1024)", + }[t] + except KeyError: + return super().type_repr(t) + + def constant_values(self, rows) -> str: + return " UNION ALL ".join( + "SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows + ) + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" + + if coltype.precision > 0: + truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision}')" + else: + truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.')" + return f"RPAD({truncated}, {TIMESTAMP_PRECISION_POS+6}, '0')" + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + # FM999.9990 + format_str = "FM" + "9" * (38 - coltype.precision) + if coltype.precision: + format_str += "0." + "9" * (coltype.precision - 1) + "0" + return f"to_char({value}, '{format_str}')" + + def explain_as_text(self, query: str) -> str: + raise NotImplementedError("Explain not yet implemented in Oracle") + + class Oracle(ThreadedDatabase): + dialect = Dialect() TYPE_CLASSES: Dict[str, type] = { "NUMBER": Decimal, "FLOAT": Float, @@ -38,7 +111,6 @@ class Oracle(ThreadedDatabase): "VARCHAR2": Text, } ROUNDS_ON_PREC_LOSS = True - SUPPORTS_PRIMARY_KEY = True def __init__(self, *, host, database, thread_count, **kw): self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) @@ -63,17 +135,6 @@ def _query_cursor(self, c, sql_code: str): except self._oracle.DatabaseError as e: raise QueryError(e) - def md5_to_int(self, s: str) -> str: - # standard_hash is faster than DBMS_CRYPTO.Hash - # TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ? - return f"to_number(substr(standard_hash({s}, 'MD5'), 18), 'xxxxxxxxxxxxxxx')" - - def quote(self, s: str): - return f"{s}" - - def to_string(self, s: str): - return f"cast({s} as varchar(1024))" - def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) @@ -82,23 +143,6 @@ def select_table_schema(self, path: DbPath) -> str: f" FROM ALL_TAB_COLUMNS WHERE table_name = '{table.upper()}' AND owner = '{schema.upper()}'" ) - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" - - if coltype.precision > 0: - truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision}')" - else: - truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.')" - return f"RPAD({truncated}, {TIMESTAMP_PRECISION_POS+6}, '0')" - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - # FM999.9990 - format_str = "FM" + "9" * (38 - coltype.precision) - if coltype.precision: - format_str += "0." + "9" * (coltype.precision - 1) + "0" - return f"to_char({value}, '{format_str}')" - def _parse_type( self, table_path: DbPath, @@ -121,39 +165,3 @@ def _parse_type( return super()._parse_type( table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale ) - - def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): - if offset: - raise NotImplementedError("No support for OFFSET in query") - - return f"FETCH NEXT {limit} ROWS ONLY" - - def concat(self, items: List[str]) -> str: - joined_exprs = " || ".join(items) - return f"({joined_exprs})" - - def timestamp_value(self, t: DbTime) -> str: - return "timestamp '%s'" % t.isoformat(" ") - - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - # Cast is necessary for correct MD5 (trimming not enough) - return f"CAST(TRIM({value}) AS VARCHAR(36))" - - def random(self) -> str: - return "dbms_random.value" - - def is_distinct_from(self, a: str, b: str) -> str: - return f"DECODE({a}, {b}, 1, 0) = 0" - - def type_repr(self, t) -> str: - try: - return { - str: "VARCHAR(1024)", - }[t] - except KeyError: - return super().type_repr(t) - - def constant_values(self, rows) -> str: - return " UNION ALL ".join( - "SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows - ) diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 3181dab1..a044ba3a 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -9,7 +9,7 @@ Text, FractionalType, ) -from .base import ThreadedDatabase, import_helper, ConnectError +from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, _CHECKSUM_BITSIZE, TIMESTAMP_PRECISION_POS SESSION_TIME_ZONE = None # Changed by the tests @@ -24,7 +24,34 @@ def import_postgresql(): return psycopg2 +class Dialect(BaseDialect): + name = "PostgreSQL" + SUPPORTS_PRIMARY_KEY = True + + def quote(self, s: str): + return f'"{s}"' + + def md5_as_int(self, s: str) -> str: + return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint" + + def to_string(self, s: str): + return f"{s}::varchar" + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" + + timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"{value}::decimal(38, {coltype.precision})") + + class PostgreSQL(ThreadedDatabase): + dialect = Dialect() TYPE_CLASSES = { # Timestamps "timestamp with time zone": TimestampTZ, @@ -46,7 +73,6 @@ class PostgreSQL(ThreadedDatabase): "uuid": Native_UUID, } ROUNDS_ON_PREC_LOSS = True - SUPPORTS_PRIMARY_KEY = True SUPPORTS_UNIQUE_CONSTAINT = True default_schema = "public" @@ -56,10 +82,6 @@ def __init__(self, *, thread_count, **kw): super().__init__(thread_count=thread_count) - def _convert_db_precision_to_digits(self, p: int) -> int: - # Subtracting 2 due to wierd precision issues in PostgreSQL - return super()._convert_db_precision_to_digits(p) - 2 - def create_connection(self): if not self._args: self._args["host"] = None # psycopg2 requires 1+ arguments @@ -73,23 +95,6 @@ def create_connection(self): except pg.OperationalError as e: raise ConnectError(*e.args) from e - def quote(self, s: str): - return f'"{s}"' - - def md5_to_int(self, s: str) -> str: - return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint" - - def to_string(self, s: str): - return f"{s}::varchar" - - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" - - timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" - return ( - f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"{value}::decimal(38, {coltype.precision})") + def _convert_db_precision_to_digits(self, p: int) -> int: + # Subtracting 2 due to wierd precision issues in PostgreSQL + return super()._convert_db_precision_to_digits(p) - 2 diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index 56a32d48..b1749893 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -17,7 +17,7 @@ ColType_UUID, TemporalType, ) -from .base import Database, import_helper, ThreadLocalInterpreter +from .base import BaseDialect, Database, import_helper, ThreadLocalInterpreter from .base import ( MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, @@ -41,7 +41,49 @@ def import_presto(): return prestodb +class Dialect(BaseDialect): + name = "Presto" + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN (FORMAT TEXT) {query}" + + def type_repr(self, t) -> str: + try: + return {float: "REAL"}[t] + except KeyError: + return super().type_repr(t) + + def timestamp_value(self, t: DbTime) -> str: + return f"timestamp '{t.isoformat(' ')}'" + + def quote(self, s: str): + return f'"{s}"' + + def md5_as_int(self, s: str) -> str: + return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))" + + def to_string(self, s: str): + return f"cast({s} as varchar)" + + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + # Trim doesn't work on CHAR type + return f"TRIM(CAST({value} AS VARCHAR))" + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + # TODO rounds + if coltype.rounds: + s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" + else: + s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" + + return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))") + + class Presto(Database): + dialect = Dialect() default_schema = "public" TYPE_CLASSES = { # Timestamps @@ -74,15 +116,6 @@ def __init__(self, **kw): else: self._conn = prestodb.dbapi.connect(**kw) - def quote(self, s: str): - return f'"{s}"' - - def md5_to_int(self, s: str) -> str: - return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))" - - def to_string(self, s: str): - return f"cast({s} as varchar)" - def _query(self, sql_code: str) -> list: "Uses the standard SQL cursor interface" c = self._conn.cursor() @@ -95,18 +128,6 @@ def _query(self, sql_code: str) -> list: def close(self): self._conn.close() - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - # TODO rounds - if coltype.rounds: - s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - else: - s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - - return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))") - def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) @@ -144,22 +165,6 @@ def _parse_type( return super()._parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - # Trim doesn't work on CHAR type - return f"TRIM(CAST({value} AS VARCHAR))" - @property def is_autocommit(self) -> bool: return False - - def explain_as_text(self, query: str) -> str: - return f"EXPLAIN (FORMAT TEXT) {query}" - - def type_repr(self, t) -> str: - try: - return {float: "REAL"}[t] - except KeyError: - return super().type_repr(t) - - def timestamp_value(self, t: DbTime) -> str: - return f"timestamp '{t.isoformat(' ')}'" diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index afaa28a4..7d274b34 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -1,16 +1,12 @@ from typing import List from .database_types import Float, TemporalType, FractionalType, DbPath -from .postgresql import PostgreSQL, MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS +from .postgresql import PostgreSQL, MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS, Dialect -class Redshift(PostgreSQL): - TYPE_CLASSES = { - **PostgreSQL.TYPE_CLASSES, - "double": Float, - "real": Float, - } +class Dialect(Dialect): + name = "Redshift" - def md5_to_int(self, s: str) -> str: + def md5_as_int(self, s: str) -> str: return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: @@ -40,6 +36,18 @@ def concat(self, items: List[str]) -> str: joined_exprs = " || ".join(items) return f"({joined_exprs})" + def is_distinct_from(self, a: str, b: str) -> str: + return f"{a} IS NULL AND NOT {b} IS NULL OR {b} IS NULL OR {a}!={b}" + + +class Redshift(PostgreSQL): + dialect = Dialect() + TYPE_CLASSES = { + **PostgreSQL.TYPE_CLASSES, + "double": Float, + "real": Float, + } + def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) @@ -47,6 +55,3 @@ def select_table_schema(self, path: DbPath) -> str: "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM information_schema.columns " f"WHERE table_name = '{table.lower()}' AND table_schema = '{schema.lower()}'" ) - - def is_distinct_from(self, a: str, b: str) -> str: - return f"{a} IS NULL AND NOT {b} IS NULL OR {b} IS NULL OR {a}!={b}" diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index afd52ba8..985020d7 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -2,7 +2,7 @@ import logging from .database_types import Timestamp, TimestampTZ, Decimal, Float, Text, FractionalType, TemporalType, DbPath -from .base import ConnectError, Database, import_helper, CHECKSUM_MASK, ThreadLocalInterpreter +from .base import BaseDialect, ConnectError, Database, import_helper, CHECKSUM_MASK, ThreadLocalInterpreter @import_helper("snowflake") @@ -14,7 +14,35 @@ def import_snowflake(): return snowflake, serialization, default_backend +class Dialect(BaseDialect): + name = "Snowflake" + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN USING TEXT {query}" + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))" + else: + timestamp = f"cast({value} as timestamp({coltype.precision}))" + + return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") + + def quote(self, s: str): + return f'"{s}"' + + def md5_as_int(self, s: str) -> str: + return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK})" + + def to_string(self, s: str): + return f"cast({s} as string)" + + class Snowflake(Database): + dialect = Dialect() TYPE_CLASSES = { # Timestamps "TIMESTAMP_NTZ": Timestamp, @@ -65,36 +93,13 @@ def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): "Uses the standard SQL cursor interface" return self._query_conn(self._conn, sql_code) - def quote(self, s: str): - return f'"{s}"' - - def md5_to_int(self, s: str) -> str: - return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK})" - - def to_string(self, s: str): - return f"cast({s} as string)" - def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) return super().select_table_schema((schema, table)) - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))" - else: - timestamp = f"cast({value} as timestamp({coltype.precision}))" - - return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") - @property def is_autocommit(self) -> bool: return True - def explain_as_text(self, query: str) -> str: - return f"EXPLAIN USING TEXT {query}" - def query_table_unique_columns(self, path: DbPath) -> List[str]: return [] diff --git a/data_diff/databases/trino.py b/data_diff/databases/trino.py index 73ef4a97..a7b0ef8c 100644 --- a/data_diff/databases/trino.py +++ b/data_diff/databases/trino.py @@ -1,5 +1,5 @@ from .database_types import TemporalType, ColType_UUID -from .presto import Presto +from .presto import Presto, Dialect from .base import import_helper from .base import TIMESTAMP_PRECISION_POS @@ -11,11 +11,8 @@ def import_trino(): return trino -class Trino(Presto): - def __init__(self, **kw): - trino = import_trino() - - self._conn = trino.dbapi.connect(**kw) +class Dialect(Dialect): + name = "Trino" def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -29,3 +26,12 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: return f"TRIM({value})" + + +class Trino(Presto): + dialect = Dialect() + + def __init__(self, **kw): + trino = import_trino() + + self._conn = trino.dbapi.connect(**kw) diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index 6b486555..e50eec60 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -5,6 +5,7 @@ CHECKSUM_HEXDIGITS, MD5_HEXDIGITS, TIMESTAMP_PRECISION_POS, + BaseDialect, ConnectError, DbPath, ColType, @@ -12,7 +13,16 @@ ThreadedDatabase, import_helper, ) -from .database_types import Decimal, Float, FractionalType, Integer, TemporalType, Text, Timestamp, TimestampTZ +from .database_types import ( + Decimal, + Float, + FractionalType, + Integer, + TemporalType, + Text, + Timestamp, + TimestampTZ, +) @import_helper("vertica") @@ -22,7 +32,43 @@ def import_vertica(): return vertica_python +class Dialect(BaseDialect): + name = "Vertica" + + def quote(self, s: str): + return f'"{s}"' + + def concat(self, items: List[str]) -> str: + return " || ".join(items) + + def md5_as_int(self, s: str) -> str: + return f"CAST(HEX_TO_INTEGER(SUBSTRING(MD5({s}), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS})) AS NUMERIC(38, 0))" + + def to_string(self, s: str) -> str: + return f"CAST({s} AS VARCHAR)" + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + return f"TO_CHAR({value}::TIMESTAMP({coltype.precision}), 'YYYY-MM-DD HH24:MI:SS.US')" + + timestamp6 = f"TO_CHAR({value}::TIMESTAMP(6), 'YYYY-MM-DD HH24:MI:SS.US')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"CAST({value} AS NUMERIC(38, {coltype.precision}))") + + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + # Trim doesn't work on CHAR type + return f"TRIM(CAST({value} AS VARCHAR))" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"not ({a} <=> {b})" + + class Vertica(ThreadedDatabase): + dialect = Dialect() default_schema = "public" TYPE_CLASSES = { @@ -95,34 +141,3 @@ def select_table_schema(self, path: DbPath) -> str: "FROM V_CATALOG.COLUMNS " f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) - - def quote(self, s: str): - return f'"{s}"' - - def concat(self, items: List[str]) -> str: - return " || ".join(items) - - def md5_to_int(self, s: str) -> str: - return f"CAST(HEX_TO_INTEGER(SUBSTRING(MD5({s}), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS})) AS NUMERIC(38, 0))" - - def to_string(self, s: str) -> str: - return f"CAST({s} AS VARCHAR)" - - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - return f"TO_CHAR({value}::TIMESTAMP({coltype.precision}), 'YYYY-MM-DD HH24:MI:SS.US')" - - timestamp6 = f"TO_CHAR({value}::TIMESTAMP(6), 'YYYY-MM-DD HH24:MI:SS.US')" - return ( - f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"CAST({value} AS NUMERIC(38, {coltype.precision}))") - - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - # Trim doesn't work on CHAR type - return f"TRIM(CAST({value} AS VARCHAR))" - - def is_distinct_from(self, a: str, b: str) -> str: - return f"not ({a} <=> {b})" diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 66d62783..13f33193 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -141,14 +141,14 @@ class Concat(ExprNode): def compile(self, c: Compiler) -> str: # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL - items = [f"coalesce({c.compile(c.database.to_string(c.compile(expr)))}, '')" for expr in self.exprs] + items = [f"coalesce({c.compile(c.dialect.to_string(c.compile(expr)))}, '')" for expr in self.exprs] assert items if len(items) == 1: return items[0] if self.sep: items = list(join_iter(f"'{self.sep}'", items)) - return c.database.concat(items) + return c.dialect.concat(items) @dataclass @@ -239,7 +239,7 @@ class IsDistinctFrom(ExprNode, LazyOps): type = bool def compile(self, c: Compiler) -> str: - return c.database.is_distinct_from(c.compile(self.a), c.compile(self.b)) + return c.dialect.is_distinct_from(c.compile(self.a), c.compile(self.b)) @dataclass(eq=False, order=False) @@ -481,7 +481,7 @@ def compile(self, parent_c: Compiler) -> str: select += " ORDER BY " + ", ".join(map(c.compile, self.order_by_exprs)) if self.limit_expr is not None: - select += " " + c.database.offset_limit(0, self.limit_expr) + select += " " + c.dialect.offset_limit(0, self.limit_expr) if parent_c.in_select: select = f"({select}) {c.new_unique_name()}" @@ -605,7 +605,7 @@ class Random(ExprNode): type = float def compile(self, c: Compiler) -> str: - return c.database.random() + return c.dialect.random() @dataclass @@ -616,7 +616,7 @@ def compile(self, c: Compiler) -> str: raise NotImplementedError() def compile_for_insert(self, c: Compiler): - return c.database.constant_values(self.rows) + return c.dialect.constant_values(self.rows) @dataclass @@ -626,7 +626,7 @@ class Explain(ExprNode): type = str def compile(self, c: Compiler) -> str: - return c.database.explain_as_text(c.compile(self.select)) + return c.dialect.explain_as_text(c.compile(self.select)) # DDL @@ -648,10 +648,10 @@ def compile(self, c: Compiler) -> str: if self.source_table: return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}" - schema = ", ".join(f"{c.database.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items()) + schema = ", ".join(f"{c.dialect.quote(k)} {c.dialect.type_repr(v)}" for k, v in self.path.schema.items()) pks = ( ", PRIMARY KEY (%s)" % ", ".join(self.primary_keys) - if self.primary_keys and c.database.SUPPORTS_PRIMARY_KEY + if self.primary_keys and c.dialect.SUPPORTS_PRIMARY_KEY else "" ) return f"CREATE TABLE {ne}{c.compile(self.path)}({schema}{pks})" @@ -698,6 +698,7 @@ class Commit(Statement): def compile(self, c: Compiler) -> str: return "COMMIT" if not c.database.is_autocommit else SKIP + @dataclass class Param(ExprNode, ITable): """A value placeholder, to be specified at compilation time using the `cv_params` context variable.""" @@ -711,4 +712,3 @@ def source_table(self): def compile(self, c: Compiler) -> str: params = cv_params.get() return c._compile(params[self.name]) - diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index e9a66bed..0a4d1d6f 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -6,7 +6,7 @@ from runtype import dataclass from data_diff.utils import ArithString -from data_diff.databases.database_types import AbstractDialect, DbPath +from data_diff.databases.database_types import AbstractDatabase, AbstractDialect, DbPath import contextvars @@ -15,7 +15,7 @@ @dataclass class Compiler: - database: AbstractDialect + database: AbstractDatabase params: dict = {} in_select: bool = False # Compilation runtime flag in_join: bool = False # Compilation runtime flag @@ -26,6 +26,10 @@ class Compiler: _counter: List = [0] + @property + def dialect(self) -> AbstractDialect: + return self.database.dialect + def compile(self, elem, params=None) -> str: if params: cv_params.set(params) @@ -47,7 +51,7 @@ def _compile(self, elem) -> str: elif isinstance(elem, int): return str(elem) elif isinstance(elem, datetime): - return self.database.timestamp_value(elem) + return self.dialect.timestamp_value(elem) elif isinstance(elem, bytes): return f"b'{elem.decode()}'" elif isinstance(elem, ArithString): @@ -66,7 +70,7 @@ def add_table_context(self, *tables: Sequence, **kw): return self.replace(_table_context=self._table_context + list(tables), **kw) def quote(self, s: str): - return self.database.quote(s) + return self.dialect.quote(s) class Compilable(ABC): diff --git a/data_diff/queries/extras.py b/data_diff/queries/extras.py index bcd426df..32d31ce9 100644 --- a/data_diff/queries/extras.py +++ b/data_diff/queries/extras.py @@ -17,7 +17,7 @@ class NormalizeAsString(ExprNode): def compile(self, c: Compiler) -> str: expr = c.compile(self.expr) - return c.database.normalize_value_by_type(expr, self.expr_type or self.expr.type) + return c.dialect.normalize_value_by_type(expr, self.expr_type or self.expr.type) @dataclass @@ -58,5 +58,5 @@ def compile(self, c: Compiler): # No need to coalesce - safe to assume that key cannot be null (expr,) = self.exprs expr = c.compile(expr) - md5 = c.database.md5_to_int(expr) + md5 = c.dialect.md5_as_int(expr) return f"sum({md5})" diff --git a/tests/common.py b/tests/common.py index cd974e34..fccd5ddc 100644 --- a/tests/common.py +++ b/tests/common.py @@ -135,8 +135,8 @@ def setUp(self): self.table_src_path = self.connection.parse_table_name(self.table_src_name) self.table_dst_path = self.connection.parse_table_name(self.table_dst_name) - self.table_src = ".".join(map(self.connection.quote, self.table_src_path)) - self.table_dst = ".".join(map(self.connection.quote, self.table_dst_path)) + self.table_src = ".".join(map(self.connection.dialect.quote, self.table_src_path)) + self.table_dst = ".".join(map(self.connection.dialect.quote, self.table_dst_path)) drop_table(self.connection, self.table_src_path) drop_table(self.connection, self.table_dst_path) diff --git a/tests/test_database.py b/tests/test_database.py index a7e34d1d..d309a4ed 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -11,9 +11,9 @@ def setUp(self): def test_connect_to_db(self): self.assertEqual(1, self.mysql.query("SELECT 1", int)) - def test_md5_to_int(self): + def test_md5_as_int(self): str = "hello world" - query_fragment = self.mysql.md5_to_int("'{0}'".format(str)) + query_fragment = self.mysql.dialect.md5_as_int("'{0}'".format(str)) query = f"SELECT {query_fragment}" self.assertEqual(str_to_checksum(str), self.mysql.query(query, int)) diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 1c36ff7f..c8ed69aa 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -633,8 +633,8 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego self.src_table_path = src_table_path = src_conn.parse_table_name(src_table_name) self.dst_table_path = dst_table_path = dst_conn.parse_table_name(dst_table_name) - self.src_table = src_table = ".".join(map(src_conn.quote, src_table_path)) - self.dst_table = dst_table = ".".join(map(dst_conn.quote, dst_table_path)) + self.src_table = src_table = ".".join(map(src_conn.dialect.quote, src_table_path)) + self.dst_table = dst_table = ".".join(map(dst_conn.dialect.quote, dst_table_path)) start = time.monotonic() if not BENCHMARK: diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 22ed217d..3f70fa09 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -275,7 +275,7 @@ def test_null_pks(self): self.assertRaises(ValueError, list, x) -@test_each_database_in_list(d for d in TEST_DATABASES if d.SUPPORTS_PRIMARY_KEY and d.SUPPORTS_UNIQUE_CONSTAINT) +@test_each_database_in_list(d for d in TEST_DATABASES if d.dialect.SUPPORTS_PRIMARY_KEY and d.SUPPORTS_UNIQUE_CONSTAINT) class TestUniqueConstraint(TestPerDatabase): def setUp(self): super().setUp() diff --git a/tests/test_query.py b/tests/test_query.py index d02e9745..559d8bcd 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import List, Optional import unittest -from data_diff.databases.database_types import AbstractDialect, CaseInsensitiveDict, CaseSensitiveDict +from data_diff.databases.database_types import AbstractDatabase, AbstractDialect, CaseInsensitiveDict, CaseSensitiveDict from data_diff.queries import this, table, Compiler, outerjoin, cte from data_diff.queries.ast_classes import Random @@ -12,6 +12,8 @@ def normalize_spaces(s: str): class MockDialect(AbstractDialect): + name = "MockDialect" + def quote(self, s: str) -> str: return s @@ -39,12 +41,25 @@ def timestamp_value(self, t: datetime) -> str: return f"timestamp '{t}'" +class MockDatabase(AbstractDatabase): + dialect = MockDialect() + + _query = NotImplemented + query_table_schema = NotImplemented + select_table_schema = NotImplemented + _process_table_schema = NotImplemented + parse_table_name = NotImplemented + close = NotImplemented + _normalize_table_path = NotImplemented + is_autocommit = NotImplemented + + class TestQuery(unittest.TestCase): def setUp(self): pass def test_basic(self): - c = Compiler(MockDialect()) + c = Compiler(MockDatabase()) t = table("point") t2 = t.select(x=this.x + 1, y=t["y"] + this.x) @@ -57,7 +72,7 @@ def test_basic(self): assert c.compile(t) == "SELECT x, y FROM point" def test_outerjoin(self): - c = Compiler(MockDialect()) + c = Compiler(MockDatabase()) a = table("a") b = table("b") @@ -82,7 +97,7 @@ def test_outerjoin(self): # t.group_by(keys=[this.x], values=[this.py]) def test_schema(self): - c = Compiler(MockDialect()) + c = Compiler(MockDatabase()) schema = dict(id="int", comment="varchar") # test table @@ -108,7 +123,7 @@ def test_schema(self): self.assertRaises(KeyError, j.__getitem__, "ysum") def test_commutable_select(self): - # c = Compiler(MockDialect()) + # c = Compiler(MockDatabase()) t = table("a") q1 = t.select("a").where("b") @@ -116,7 +131,7 @@ def test_commutable_select(self): assert q1 == q2, (q1, q2) def test_cte(self): - c = Compiler(MockDialect()) + c = Compiler(MockDatabase()) t = table("a") @@ -128,14 +143,14 @@ def test_cte(self): assert normalize_spaces(c.compile(t3)) == expected # nested cte - c = Compiler(MockDialect()) + c = Compiler(MockDatabase()) t4 = cte(t3).select(this.x) expected = "WITH tmp1 AS (SELECT x FROM a), tmp2 AS (SELECT x FROM tmp1) SELECT x FROM tmp2" assert normalize_spaces(c.compile(t4)) == expected # parameterized cte - c = Compiler(MockDialect()) + c = Compiler(MockDatabase()) t2 = cte(t.select(this.x), params=["y"]) t3 = t2.select(this.y) @@ -143,14 +158,14 @@ def test_cte(self): assert normalize_spaces(c.compile(t3)) == expected def test_funcs(self): - c = Compiler(MockDialect()) + c = Compiler(MockDatabase()) t = table("a") q = c.compile(t.order_by(Random()).limit(10)) assert q == "SELECT * FROM a ORDER BY random() limit 10" def test_union(self): - c = Compiler(MockDialect()) + c = Compiler(MockDatabase()) a = table("a").select("x") b = table("b").select("y") From 4b3a1dbbecc3a7c3ef64c7d9deec60e7a3af9c63 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 1 Nov 2022 18:03:49 -0300 Subject: [PATCH 79/93] Continued refactoring (parse_type, TYPE_CLASSES, etc.) --- data_diff/config.py | 4 +- data_diff/databases/base.py | 106 +++++++++++++------------- data_diff/databases/bigquery.py | 30 ++++---- data_diff/databases/clickhouse.py | 40 +++++----- data_diff/databases/database_types.py | 22 ++++++ data_diff/databases/databricks.py | 41 +++++----- data_diff/databases/mysql.py | 34 ++++----- data_diff/databases/oracle.py | 66 ++++++++-------- data_diff/databases/postgresql.py | 55 ++++++------- data_diff/databases/presto.py | 84 ++++++++++---------- data_diff/databases/redshift.py | 14 ++-- data_diff/databases/snowflake.py | 24 +++--- data_diff/databases/vertica.py | 71 +++++++++-------- data_diff/joindiff_tables.py | 4 +- tests/test_query.py | 4 + 15 files changed, 315 insertions(+), 284 deletions(-) diff --git a/data_diff/config.py b/data_diff/config.py index 18c31e6e..9a6b6d54 100644 --- a/data_diff/config.py +++ b/data_diff/config.py @@ -39,7 +39,9 @@ def _apply_config(config: Dict[str, Any], run_name: str, kw: Dict[str, Any]): try: args = run_args.pop(index) except KeyError: - raise ConfigParseError(f"Could not find source #{index}: Expecting a key of '{index}' containing '.database' and '.table'.") + raise ConfigParseError( + f"Could not find source #{index}: Expecting a key of '{index}' containing '.database' and '.table'." + ) for attr in ("database", "table"): if attr not in args: raise ConfigParseError(f"Running 'run.{run_name}': Connection #{index} is missing attribute '{attr}'.") diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 0a907c6b..0a5a2fa6 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -104,6 +104,7 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal class BaseDialect(AbstractDialect, AbstractMixin_MD5, AbstractMixin_NormalizeValue): SUPPORTS_PRIMARY_KEY = False + TYPE_CLASSES: Dict[str, type] = {} def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): if offset: @@ -160,6 +161,56 @@ def type_repr(self, t) -> str: datetime: "TIMESTAMP", }[t] + def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: + return self.TYPE_CLASSES.get(type_repr) + + def parse_type( + self, + table_path: DbPath, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, + ) -> ColType: + """ """ + + cls = self._parse_type_repr(type_repr) + if not cls: + return UnknownColType(type_repr) + + if issubclass(cls, TemporalType): + return cls( + precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, + rounds=self.ROUNDS_ON_PREC_LOSS, + ) + + elif issubclass(cls, Integer): + return cls() + + elif issubclass(cls, Decimal): + if numeric_scale is None: + numeric_scale = 0 # Needed for Oracle. + return cls(precision=numeric_scale) + + elif issubclass(cls, Float): + # assert numeric_scale is None + return cls( + precision=self._convert_db_precision_to_digits( + numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION + ) + ) + + elif issubclass(cls, (Text, Native_UUID)): + return cls() + + raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.") + + def _convert_db_precision_to_digits(self, p: int) -> int: + """Convert from binary precision, used by floats, to decimal precision.""" + # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format + return math.floor(math.log(2**p, 10)) + class Database(AbstractDatabase): """Base abstract class for databases. @@ -169,7 +220,6 @@ class Database(AbstractDatabase): Instanciated using :meth:`~data_diff.connect` """ - TYPE_CLASSES: Dict[str, type] = {} default_schema: str = None dialect: AbstractDialect = None @@ -232,56 +282,6 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list): def enable_interactive(self): self._interactive = True - def _convert_db_precision_to_digits(self, p: int) -> int: - """Convert from binary precision, used by floats, to decimal precision.""" - # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format - return math.floor(math.log(2**p, 10)) - - def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: - return self.TYPE_CLASSES.get(type_repr) - - def _parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - numeric_scale: int = None, - ) -> ColType: - """ """ - - cls = self._parse_type_repr(type_repr) - if not cls: - return UnknownColType(type_repr) - - if issubclass(cls, TemporalType): - return cls( - precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, - rounds=self.ROUNDS_ON_PREC_LOSS, - ) - - elif issubclass(cls, Integer): - return cls() - - elif issubclass(cls, Decimal): - if numeric_scale is None: - numeric_scale = 0 # Needed for Oracle. - return cls(precision=numeric_scale) - - elif issubclass(cls, Float): - # assert numeric_scale is None - return cls( - precision=self._convert_db_precision_to_digits( - numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION - ) - ) - - elif issubclass(cls, (Text, Native_UUID)): - return cls() - - raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.") - def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) @@ -320,7 +320,9 @@ def _process_table_schema( ): accept = {i.lower() for i in filter_columns} - col_dict = {row[0]: self._parse_type(path, *row) for name, row in raw_schema.items() if name.lower() in accept} + col_dict = { + row[0]: self.dialect.parse_type(path, *row) for name, row in raw_schema.items() if name.lower() in accept + } self._refine_coltypes(path, col_dict, where) diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 58577585..0aa7670a 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -13,6 +13,21 @@ def import_bigquery(): class Dialect(BaseDialect): name = "BigQuery" + ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation + TYPE_CLASSES = { + # Dates + "TIMESTAMP": Timestamp, + "DATETIME": Datetime, + # Numbers + "INT64": Integer, + "INT32": Integer, + "NUMERIC": Decimal, + "BIGNUMERIC": Decimal, + "FLOAT64": Float, + "FLOAT32": Float, + # Text + "STRING": Text, + } def random(self) -> str: return "RAND()" @@ -53,21 +68,6 @@ def type_repr(self, t) -> str: class BigQuery(Database): dialect = Dialect() - TYPE_CLASSES = { - # Dates - "TIMESTAMP": Timestamp, - "DATETIME": Datetime, - # Numbers - "INT64": Integer, - "INT32": Integer, - "NUMERIC": Decimal, - "BIGNUMERIC": Decimal, - "FLOAT64": Float, - "FLOAT32": Float, - # Text - "STRING": Text, - } - ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation def __init__(self, project, *, dataset, **kw): bigquery = import_bigquery() diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 40657585..3f59cbd3 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -31,6 +31,7 @@ def import_clickhouse(): class Dialect(BaseDialect): name = "Clickhouse" + ROUNDS_ON_PREC_LOSS = False def normalize_number(self, value: str, coltype: FractionalType) -> str: # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. @@ -98,6 +99,25 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: value = f"formatDateTime({value}, '%Y-%m-%d %H:%M:%S') || '.' || {self.to_string(fractional)}" return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')" + def _convert_db_precision_to_digits(self, p: int) -> int: + # Done the same as for PostgreSQL but need to rewrite in another way + # because it does not help for float with a big integer part. + return super()._convert_db_precision_to_digits(p) - 2 + + def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: + nullable_prefix = "Nullable(" + if type_repr.startswith(nullable_prefix): + type_repr = type_repr[len(nullable_prefix) :].rstrip(")") + + if type_repr.startswith("Decimal"): + type_repr = "Decimal" + elif type_repr.startswith("FixedString"): + type_repr = "FixedString" + elif type_repr.startswith("DateTime64"): + type_repr = "DateTime64" + + return self.TYPE_CLASSES.get(type_repr) + class Clickhouse(ThreadedDatabase): dialect = Dialect() @@ -123,7 +143,6 @@ class Clickhouse(ThreadedDatabase): "DateTime": Timestamp, "DateTime64": Timestamp, } - ROUNDS_ON_PREC_LOSS = False def __init__(self, *, thread_count: int, **kw): super().__init__(thread_count=thread_count) @@ -148,25 +167,6 @@ def cursor(self, cursor_factory=None): except clickhouse.OperationError as e: raise ConnectError(*e.args) from e - def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: - nullable_prefix = "Nullable(" - if type_repr.startswith(nullable_prefix): - type_repr = type_repr[len(nullable_prefix) :].rstrip(")") - - if type_repr.startswith("Decimal"): - type_repr = "Decimal" - elif type_repr.startswith("FixedString"): - type_repr = "FixedString" - elif type_repr.startswith("DateTime64"): - type_repr = "DateTime64" - - return self.TYPE_CLASSES.get(type_repr) - @property def is_autocommit(self) -> bool: return True - - def _convert_db_precision_to_digits(self, p: int) -> int: - # Done the same as for PostgreSQL but need to rewrite in another way - # because it does not help for float with a big integer part. - return super()._convert_db_precision_to_digits(p) - 2 diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index f03861d3..296ad475 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -145,6 +145,16 @@ class AbstractDialect(ABC): name: str + @property + @abstractmethod + def name(self) -> str: + "Name of the dialect" + + @property + @abstractmethod + def ROUNDS_ON_PREC_LOSS(self) -> bool: + "True if db rounds real values when losing precision, False if it truncates." + @abstractmethod def quote(self, s: str): "Quote SQL name" @@ -185,6 +195,18 @@ def timestamp_value(self, t: datetime) -> str: "Provide SQL for the given timestamp value" ... + @abstractmethod + def parse_type( + self, + table_path: DbPath, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, + ) -> ColType: + "Parse type info as returned by the database" + class AbstractMixin_NormalizeValue(ABC): @abstractmethod diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 496ead10..7942b53a 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -25,6 +25,21 @@ def import_databricks(): class Dialect(BaseDialect): name = "Databricks" + ROUNDS_ON_PREC_LOSS = True + TYPE_CLASSES = { + # Numbers + "INT": Integer, + "SMALLINT": Integer, + "TINYINT": Integer, + "BIGINT": Integer, + "FLOAT": Float, + "DOUBLE": Float, + "DECIMAL": Decimal, + # Timestamps + "TIMESTAMP": Timestamp, + # Text + "STRING": Text, + } def quote(self, s: str): return f"`{s}`" @@ -48,25 +63,13 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: def normalize_number(self, value: str, coltype: NumericType) -> str: return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") + def _convert_db_precision_to_digits(self, p: int) -> int: + # Subtracting 1 due to wierd precision issues + return max(super()._convert_db_precision_to_digits(p) - 1, 0) + class Databricks(Database): dialect = Dialect() - TYPE_CLASSES = { - # Numbers - "INT": Integer, - "SMALLINT": Integer, - "TINYINT": Integer, - "BIGINT": Integer, - "FLOAT": Float, - "DOUBLE": Float, - "DECIMAL": Decimal, - # Timestamps - "TIMESTAMP": Timestamp, - # Text - "STRING": Text, - } - - ROUNDS_ON_PREC_LOSS = True def __init__( self, @@ -93,10 +96,6 @@ def _query(self, sql_code: str) -> list: "Uses the standard SQL cursor interface" return self._query_conn(self._conn, sql_code) - def _convert_db_precision_to_digits(self, p: int) -> int: - # Subtracting 1 due to wierd precision issues - return max(super()._convert_db_precision_to_digits(p) - 1, 0) - def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: # Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL. # https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html @@ -145,7 +144,7 @@ def _process_table_schema( resulted_rows.append(row) - col_dict: Dict[str, ColType] = {row[0]: self._parse_type(path, *row) for row in resulted_rows} + col_dict: Dict[str, ColType] = {row[0]: self.dialect.parse_type(path, *row) for row in resulted_rows} self._refine_coltypes(path, col_dict, where) return col_dict diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 2f0484c0..8f8e1730 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -22,7 +22,24 @@ def import_mysql(): class Dialect(BaseDialect): name = "MySQL" + ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True + TYPE_CLASSES = { + # Dates + "datetime": Datetime, + "timestamp": Timestamp, + # Numbers + "double": Float, + "float": Float, + "decimal": Decimal, + "int": Integer, + "bigint": Integer, + # Text + "varchar": Text, + "char": Text, + "varbinary": Text, + "binary": Text, + } def quote(self, s: str): return f"`{s}`" @@ -66,23 +83,6 @@ def explain_as_text(self, query: str) -> str: class MySQL(ThreadedDatabase): dialect = Dialect() - TYPE_CLASSES = { - # Dates - "datetime": Datetime, - "timestamp": Timestamp, - # Numbers - "double": Float, - "float": Float, - "decimal": Decimal, - "int": Integer, - "bigint": Integer, - # Text - "varchar": Text, - "char": Text, - "varbinary": Text, - "binary": Text, - } - ROUNDS_ON_PREC_LOSS = True SUPPORTS_ALPHANUMS = False SUPPORTS_UNIQUE_CONSTAINT = True diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 7449c0dc..a1ed07ec 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -30,6 +30,16 @@ def import_oracle(): class Dialect(BaseDialect): name = "Oracle" SUPPORTS_PRIMARY_KEY = True + TYPE_CLASSES: Dict[str, type] = { + "NUMBER": Decimal, + "FLOAT": Float, + # Text + "CHAR": Text, + "NCHAR": Text, + "NVARCHAR2": Text, + "VARCHAR2": Text, + } + ROUNDS_ON_PREC_LOSS = True def md5_as_int(self, s: str) -> str: # standard_hash is faster than DBMS_CRYPTO.Hash @@ -98,19 +108,32 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: def explain_as_text(self, query: str) -> str: raise NotImplementedError("Explain not yet implemented in Oracle") + def parse_type( + self, + table_path: DbPath, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, + ) -> ColType: + regexps = { + r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp, + r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ, + r"TIMESTAMP\((\d)\)": Timestamp, + } + + for m, t_cls in match_regexps(regexps, type_repr): + precision = int(m.group(1)) + return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) + + return super()._parse_type( + table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale + ) + class Oracle(ThreadedDatabase): dialect = Dialect() - TYPE_CLASSES: Dict[str, type] = { - "NUMBER": Decimal, - "FLOAT": Float, - # Text - "CHAR": Text, - "NCHAR": Text, - "NVARCHAR2": Text, - "VARCHAR2": Text, - } - ROUNDS_ON_PREC_LOSS = True def __init__(self, *, host, database, thread_count, **kw): self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) @@ -142,26 +165,3 @@ def select_table_schema(self, path: DbPath) -> str: f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale" f" FROM ALL_TAB_COLUMNS WHERE table_name = '{table.upper()}' AND owner = '{schema.upper()}'" ) - - def _parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - numeric_scale: int = None, - ) -> ColType: - regexps = { - r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp, - r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ, - r"TIMESTAMP\((\d)\)": Timestamp, - } - - for m, t_cls in match_regexps(regexps, type_repr): - precision = int(m.group(1)) - return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) - - return super()._parse_type( - table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale - ) diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index a044ba3a..27df1273 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -24,10 +24,32 @@ def import_postgresql(): return psycopg2 -class Dialect(BaseDialect): +class PostgresqlDialect(BaseDialect): name = "PostgreSQL" + ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True + TYPE_CLASSES = { + # Timestamps + "timestamp with time zone": TimestampTZ, + "timestamp without time zone": Timestamp, + "timestamp": Timestamp, + # Numbers + "double precision": Float, + "real": Float, + "decimal": Decimal, + "integer": Integer, + "numeric": Decimal, + "bigint": Integer, + # Text + "character": Text, + "character varying": Text, + "varchar": Text, + "text": Text, + # UUID + "uuid": Native_UUID, + } + def quote(self, s: str): return f'"{s}"' @@ -49,30 +71,13 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: def normalize_number(self, value: str, coltype: FractionalType) -> str: return self.to_string(f"{value}::decimal(38, {coltype.precision})") + def _convert_db_precision_to_digits(self, p: int) -> int: + # Subtracting 2 due to wierd precision issues in PostgreSQL + return super()._convert_db_precision_to_digits(p) - 2 + class PostgreSQL(ThreadedDatabase): - dialect = Dialect() - TYPE_CLASSES = { - # Timestamps - "timestamp with time zone": TimestampTZ, - "timestamp without time zone": Timestamp, - "timestamp": Timestamp, - # Numbers - "double precision": Float, - "real": Float, - "decimal": Decimal, - "integer": Integer, - "numeric": Decimal, - "bigint": Integer, - # Text - "character": Text, - "character varying": Text, - "varchar": Text, - "text": Text, - # UUID - "uuid": Native_UUID, - } - ROUNDS_ON_PREC_LOSS = True + dialect = PostgresqlDialect() SUPPORTS_UNIQUE_CONSTAINT = True default_schema = "public" @@ -94,7 +99,3 @@ def create_connection(self): return c except pg.OperationalError as e: raise ConnectError(*e.args) from e - - def _convert_db_precision_to_digits(self, p: int) -> int: - # Subtracting 2 due to wierd precision issues in PostgreSQL - return super()._convert_db_precision_to_digits(p) - 2 diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index b1749893..de54f5b5 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -43,6 +43,20 @@ def import_presto(): class Dialect(BaseDialect): name = "Presto" + ROUNDS_ON_PREC_LOSS = True + TYPE_CLASSES = { + # Timestamps + "timestamp with time zone": TimestampTZ, + "timestamp without time zone": Timestamp, + "timestamp": Timestamp, + # Numbers + "integer": Integer, + "bigint": Integer, + "real": Float, + "double": Float, + # Text + "varchar": Text, + } def explain_as_text(self, query: str) -> str: return f"EXPLAIN (FORMAT TEXT) {query}" @@ -81,24 +95,38 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: def normalize_number(self, value: str, coltype: FractionalType) -> str: return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))") + def parse_type( + self, + table_path: DbPath, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, + ) -> ColType: + timestamp_regexps = { + r"timestamp\((\d)\)": Timestamp, + r"timestamp\((\d)\) with time zone": TimestampTZ, + } + for m, t_cls in match_regexps(timestamp_regexps, type_repr): + precision = int(m.group(1)) + return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) + + number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal} + for m, n_cls in match_regexps(number_regexps, type_repr): + _prec, scale = map(int, m.groups()) + return n_cls(scale) + + string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text} + for m, n_cls in match_regexps(string_regexps, type_repr): + return n_cls() + + return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) + class Presto(Database): dialect = Dialect() default_schema = "public" - TYPE_CLASSES = { - # Timestamps - "timestamp with time zone": TimestampTZ, - "timestamp without time zone": Timestamp, - "timestamp": Timestamp, - # Numbers - "integer": Integer, - "bigint": Integer, - "real": Float, - "double": Float, - # Text - "varchar": Text, - } - ROUNDS_ON_PREC_LOSS = True def __init__(self, **kw): prestodb = import_presto() @@ -137,34 +165,6 @@ def select_table_schema(self, path: DbPath) -> str: f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) - def _parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - numeric_scale: int = None, - ) -> ColType: - timestamp_regexps = { - r"timestamp\((\d)\)": Timestamp, - r"timestamp\((\d)\) with time zone": TimestampTZ, - } - for m, t_cls in match_regexps(timestamp_regexps, type_repr): - precision = int(m.group(1)) - return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) - - number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal} - for m, n_cls in match_regexps(number_regexps, type_repr): - _prec, scale = map(int, m.groups()) - return n_cls(scale) - - string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text} - for m, n_cls in match_regexps(string_regexps, type_repr): - return n_cls() - - return super()._parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) - @property def is_autocommit(self) -> bool: return False diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index 7d274b34..8113df2e 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -1,10 +1,15 @@ from typing import List from .database_types import Float, TemporalType, FractionalType, DbPath -from .postgresql import PostgreSQL, MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS, Dialect +from .postgresql import PostgreSQL, MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS, PostgresqlDialect -class Dialect(Dialect): +class Dialect(PostgresqlDialect): name = "Redshift" + TYPE_CLASSES = { + **PostgresqlDialect.TYPE_CLASSES, + "double": Float, + "real": Float, + } def md5_as_int(self, s: str) -> str: return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" @@ -42,11 +47,6 @@ def is_distinct_from(self, a: str, b: str) -> str: class Redshift(PostgreSQL): dialect = Dialect() - TYPE_CLASSES = { - **PostgreSQL.TYPE_CLASSES, - "double": Float, - "real": Float, - } def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 985020d7..5ab5705b 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -16,6 +16,18 @@ def import_snowflake(): class Dialect(BaseDialect): name = "Snowflake" + ROUNDS_ON_PREC_LOSS = False + TYPE_CLASSES = { + # Timestamps + "TIMESTAMP_NTZ": Timestamp, + "TIMESTAMP_LTZ": Timestamp, + "TIMESTAMP_TZ": TimestampTZ, + # Numbers + "NUMBER": Decimal, + "FLOAT": Float, + # Text + "TEXT": Text, + } def explain_as_text(self, query: str) -> str: return f"EXPLAIN USING TEXT {query}" @@ -43,18 +55,6 @@ def to_string(self, s: str): class Snowflake(Database): dialect = Dialect() - TYPE_CLASSES = { - # Timestamps - "TIMESTAMP_NTZ": Timestamp, - "TIMESTAMP_LTZ": Timestamp, - "TIMESTAMP_TZ": TimestampTZ, - # Numbers - "NUMBER": Decimal, - "FLOAT": Float, - # Text - "TEXT": Text, - } - ROUNDS_ON_PREC_LOSS = False def __init__(self, *, schema: str, **kw): snowflake, serialization, default_backend = import_snowflake() diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index e50eec60..7852800a 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -34,6 +34,20 @@ def import_vertica(): class Dialect(BaseDialect): name = "Vertica" + ROUNDS_ON_PREC_LOSS = True + + TYPE_CLASSES = { + # Timestamps + "timestamp": Timestamp, + "timestamptz": TimestampTZ, + # Numbers + "numeric": Decimal, + "int": Integer, + "float": Float, + # Text + "char": Text, + "varchar": Text, + } def quote(self, s: str): return f'"{s}"' @@ -66,41 +80,7 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: def is_distinct_from(self, a: str, b: str) -> str: return f"not ({a} <=> {b})" - -class Vertica(ThreadedDatabase): - dialect = Dialect() - default_schema = "public" - - TYPE_CLASSES = { - # Timestamps - "timestamp": Timestamp, - "timestamptz": TimestampTZ, - # Numbers - "numeric": Decimal, - "int": Integer, - "float": Float, - # Text - "char": Text, - "varchar": Text, - } - - ROUNDS_ON_PREC_LOSS = True - - def __init__(self, *, thread_count, **kw): - self._args = kw - self._args["AUTOCOMMIT"] = False - - super().__init__(thread_count=thread_count) - - def create_connection(self): - vertica = import_vertica() - try: - c = vertica.connect(**self._args) - return c - except vertica.errors.ConnectionError as e: - raise ConnectError(*e.args) from e - - def _parse_type( + def parse_type( self, table_path: DbPath, col_name: str, @@ -131,7 +111,26 @@ def _parse_type( for m, n_cls in match_regexps(string_regexps, type_repr): return n_cls() - return super()._parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) + return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) + + +class Vertica(ThreadedDatabase): + dialect = Dialect() + default_schema = "public" + + def __init__(self, *, thread_count, **kw): + self._args = kw + self._args["AUTOCOMMIT"] = False + + super().__init__(thread_count=thread_count) + + def create_connection(self): + vertica = import_vertica() + try: + c = vertica.connect(**self._args) + return c + except vertica.errors.ConnectionError as e: + raise ConnectError(*e.args) from e def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 95e1a4d5..62a70508 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -200,7 +200,9 @@ def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment): # Test duplicate keys for ts in [table1, table2]: - unique = ts.database.query_table_unique_columns(ts.table_path) if ts.database.SUPPORTS_UNIQUE_CONSTAINT else [] + unique = ( + ts.database.query_table_unique_columns(ts.table_path) if ts.database.SUPPORTS_UNIQUE_CONSTAINT else [] + ) t = ts.make_select() key_columns = ts.key_columns diff --git a/tests/test_query.py b/tests/test_query.py index 559d8bcd..36792d23 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -14,6 +14,8 @@ def normalize_spaces(s: str): class MockDialect(AbstractDialect): name = "MockDialect" + ROUNDS_ON_PREC_LOSS = False + def quote(self, s: str) -> str: return s @@ -40,6 +42,8 @@ def explain_as_text(self, query: str) -> str: def timestamp_value(self, t: datetime) -> str: return f"timestamp '{t}'" + parse_type = NotImplemented + class MockDatabase(AbstractDatabase): dialect = MockDialect() From 1b169fdb4cce64c2bf9a534f63e06e1844f10d2c Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 1 Nov 2022 19:55:54 -0300 Subject: [PATCH 80/93] Fix last commit for clickhouse --- data_diff/databases/clickhouse.py | 45 ++++++++++++++++--------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 3f59cbd3..3d6f9d34 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -32,6 +32,29 @@ def import_clickhouse(): class Dialect(BaseDialect): name = "Clickhouse" ROUNDS_ON_PREC_LOSS = False + TYPE_CLASSES = { + "Int8": Integer, + "Int16": Integer, + "Int32": Integer, + "Int64": Integer, + "Int128": Integer, + "Int256": Integer, + "UInt8": Integer, + "UInt16": Integer, + "UInt32": Integer, + "UInt64": Integer, + "UInt128": Integer, + "UInt256": Integer, + "Float32": Float, + "Float64": Float, + "Decimal": Decimal, + "UUID": Native_UUID, + "String": Text, + "FixedString": Text, + "DateTime": Timestamp, + "DateTime64": Timestamp, + } + def normalize_number(self, value: str, coltype: FractionalType) -> str: # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. @@ -121,28 +144,6 @@ def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: class Clickhouse(ThreadedDatabase): dialect = Dialect() - TYPE_CLASSES = { - "Int8": Integer, - "Int16": Integer, - "Int32": Integer, - "Int64": Integer, - "Int128": Integer, - "Int256": Integer, - "UInt8": Integer, - "UInt16": Integer, - "UInt32": Integer, - "UInt64": Integer, - "UInt128": Integer, - "UInt256": Integer, - "Float32": Float, - "Float64": Float, - "Decimal": Decimal, - "UUID": Native_UUID, - "String": Text, - "FixedString": Text, - "DateTime": Timestamp, - "DateTime64": Timestamp, - } def __init__(self, *, thread_count: int, **kw): super().__init__(thread_count=thread_count) From 86ecac43d1d3ea0b19ffba00dfe3641b817fca67 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 1 Nov 2022 22:17:51 -0300 Subject: [PATCH 81/93] CI:Attempt to fix for vertica, after container name change --- tests/waiting_for_stack_up.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/waiting_for_stack_up.sh b/tests/waiting_for_stack_up.sh index 02ca9cf0..762138de 100644 --- a/tests/waiting_for_stack_up.sh +++ b/tests/waiting_for_stack_up.sh @@ -5,7 +5,7 @@ if [ -n "$DATADIFF_VERTICA_URI" ] echo "Check Vertica DB running..." while true do - if docker logs vertica | tail -n 100 | grep -q -i "vertica is now running" + if docker logs dd-vertica | tail -n 100 | grep -q -i "vertica is now running" then echo "Vertica DB is ready"; break; From b1a14536893da3062feea1d6b4a249e8028c4b28 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 2 Nov 2022 11:49:12 -0300 Subject: [PATCH 82/93] Fix for Oracle --- data_diff/databases/oracle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index a1ed07ec..b68b216a 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -127,7 +127,7 @@ def parse_type( precision = int(m.group(1)) return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) - return super()._parse_type( + return super().parse_type( table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale ) From 683b247015f2c91d139431191bee3de583eec78b Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 2 Nov 2022 16:14:54 -0300 Subject: [PATCH 83/93] Refactor test_database_types --- tests/test_database_types.py | 46 ++++++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/tests/test_database_types.py b/tests/test_database_types.py index c8ed69aa..07b8891b 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -418,22 +418,38 @@ def __iter__(self): "uuid": UUID_Faker(N_SAMPLES), } +def _get_test_db_pairs(full): + if full: + for source_db in DATABASE_TYPES: + for target_db in DATABASE_TYPES: + yield source_db, target_db + else: + for db_cls in DATABASE_TYPES: + yield db_cls, db.PostgreSQL + yield db.PostgreSQL, db_cls + yield db_cls, db.Snowflake + yield db.Snowflake, db_cls + +def get_test_db_pairs(full=True): + active_pairs = {(db1, db2) for db1, db2 in _get_test_db_pairs(full) if db1 in CONN_STRINGS and db2 in CONN_STRINGS} + for db1, db2 in active_pairs: + yield db1, DATABASE_TYPES[db1], db2, DATABASE_TYPES[db2] + + type_pairs = [] -for source_db, source_type_categories in DATABASE_TYPES.items(): - for target_db, target_type_categories in DATABASE_TYPES.items(): - if CONN_STRINGS.get(source_db, False) and CONN_STRINGS.get(target_db, False): - for type_category, source_types in source_type_categories.items(): # int, datetime, .. - for source_type in source_types: - for target_type in target_type_categories[type_category]: - type_pairs.append( - ( - source_db, - target_db, - source_type, - target_type, - type_category, - ) - ) +for source_db, source_type_categories, target_db, target_type_categories in get_test_db_pairs(): + for type_category, source_types in source_type_categories.items(): # int, datetime, .. + for source_type in source_types: + for target_type in target_type_categories[type_category]: + type_pairs.append( + ( + source_db, + target_db, + source_type, + target_type, + type_category, + ) + ) def sanitize(name): From abf3d51ec0d9541d5c8d487a798b4ecb6f7c298f Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 2 Nov 2022 16:15:26 -0300 Subject: [PATCH 84/93] Tests: Fix MySQL tests (timezone issues) --- data_diff/databases/base.py | 3 +++ tests/test_cli.py | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 0a5a2fa6..672c4e0b 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -267,6 +267,9 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list): if res is None: # May happen due to sum() of 0 items return None return int(res) + elif res_type is datetime: + res = _one(_one(res)) + return res # XXX parse timestamp? elif res_type is tuple: assert len(res) == 1, (sql_code, res) return res[0] diff --git a/tests/test_cli.py b/tests/test_cli.py index 1c131f54..b63b1c7f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -37,7 +37,10 @@ def setUp(self) -> None: src_table = table(table_src_name, schema={"id": int, "datetime": datetime, "text_comment": str}) self.conn.query(src_table.create()) - self.now = now = arrow.get(datetime.now()) + + self.conn.query("SET @@session.time_zone='+00:00'") + db_time = self.conn.query("select now()", datetime) + self.now = now = arrow.get(db_time) rows = [ (now, "now"), From cc126592f0a4b518ffeed1d40c260c7e8d740b44 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 2 Nov 2022 17:07:28 -0300 Subject: [PATCH 85/93] Split tests into FULL_TESTS and regular --- .github/workflows/ci.yml | 1 + .github/workflows/ci_full.yml | 50 +++++++++++++++++++++++++++++++ data_diff/databases/clickhouse.py | 1 - data_diff/databases/oracle.py | 4 +-- tests/common.py | 3 ++ tests/test_database_types.py | 5 +++- 6 files changed, 59 insertions(+), 5 deletions(-) create mode 100644 .github/workflows/ci_full.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b5f8d548..21115c45 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -46,6 +46,7 @@ jobs: env: DATADIFF_SNOWFLAKE_URI: '${{ secrets.DATADIFF_SNOWFLAKE_URI }}' DATADIFF_PRESTO_URI: '${{ secrets.DATADIFF_PRESTO_URI }}' + DATADIFF_TRINO_URI: '${{ secrets.DATADIFF_TRINO_URI }}' DATADIFF_CLICKHOUSE_URI: 'clickhouse://clickhouse:Password1@localhost:9000/clickhouse' DATADIFF_VERTICA_URI: 'vertica://vertica:Password1@localhost:5433/vertica' run: | diff --git a/.github/workflows/ci_full.yml b/.github/workflows/ci_full.yml new file mode 100644 index 00000000..42b1e5fb --- /dev/null +++ b/.github/workflows/ci_full.yml @@ -0,0 +1,50 @@ +name: CI + +on: + push: + paths: + - '**.py' + - '.github/workflows/**' + - '!dev/**' + pull_request: + branches: [ master ] + + workflow_dispatch: + +jobs: + unit_tests: + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: + - "3.10" + + name: Check Python ${{ matrix.python-version }} on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + + - name: Build the stack + run: docker-compose up -d mysql postgres presto trino clickhouse vertica + + - name: Install Poetry + run: pip install poetry + + - name: Install package + run: "poetry install" + + - name: Run unit tests + env: + DATADIFF_SNOWFLAKE_URI: '${{ secrets.DATADIFF_SNOWFLAKE_URI }}' + DATADIFF_PRESTO_URI: '${{ secrets.DATADIFF_PRESTO_URI }}' + DATADIFF_CLICKHOUSE_URI: 'clickhouse://clickhouse:Password1@localhost:9000/clickhouse' + DATADIFF_VERTICA_URI: 'vertica://vertica:Password1@localhost:5433/vertica' + run: | + chmod +x tests/waiting_for_stack_up.sh + ./tests/waiting_for_stack_up.sh && FULL_TESTS=1 poetry run unittest-parallel -j 16 diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 3d6f9d34..b5f2f577 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -55,7 +55,6 @@ class Dialect(BaseDialect): "DateTime64": Timestamp, } - def normalize_number(self, value: str, coltype: FractionalType) -> str: # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. # For example: diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index b68b216a..64127e9a 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -127,9 +127,7 @@ def parse_type( precision = int(m.group(1)) return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) - return super().parse_type( - table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale - ) + return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) class Oracle(ThreadedDatabase): diff --git a/tests/common.py b/tests/common.py index fccd5ddc..aa10f64d 100644 --- a/tests/common.py +++ b/tests/common.py @@ -36,6 +36,7 @@ N_SAMPLES = int(os.environ.get("N_SAMPLES", DEFAULT_N_SAMPLES)) BENCHMARK = os.environ.get("BENCHMARK", False) N_THREADS = int(os.environ.get("N_THREADS", 1)) +FULL_TESTS = bool(os.environ.get("FULL_TESTS", False)) # Should we run the full db<->db test suite? def get_git_revision_short_hash() -> str: @@ -94,6 +95,8 @@ def _print_used_dbs(): logging.info(f"Testing databases: {', '.join(used)}") if unused: logging.info(f"Connection not configured; skipping tests for: {', '.join(unused)}") + if FULL_TESTS: + logging.info("Full tests enabled (every db<->db). May take very long when many dbs are involved.") _print_used_dbs() diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 07b8891b..d2f65162 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -25,6 +25,7 @@ N_THREADS, BENCHMARK, GIT_REVISION, + FULL_TESTS, get_conn, random_table_suffix, ) @@ -418,6 +419,7 @@ def __iter__(self): "uuid": UUID_Faker(N_SAMPLES), } + def _get_test_db_pairs(full): if full: for source_db in DATABASE_TYPES: @@ -430,6 +432,7 @@ def _get_test_db_pairs(full): yield db_cls, db.Snowflake yield db.Snowflake, db_cls + def get_test_db_pairs(full=True): active_pairs = {(db1, db2) for db1, db2 in _get_test_db_pairs(full) if db1 in CONN_STRINGS and db2 in CONN_STRINGS} for db1, db2 in active_pairs: @@ -437,7 +440,7 @@ def get_test_db_pairs(full=True): type_pairs = [] -for source_db, source_type_categories, target_db, target_type_categories in get_test_db_pairs(): +for source_db, source_type_categories, target_db, target_type_categories in get_test_db_pairs(FULL_TESTS): for type_category, source_types in source_type_categories.items(): # int, datetime, .. for source_type in source_types: for target_type in target_type_categories[type_category]: From b0c8f606fde9767aeba174799e536075c60c0bcd Mon Sep 17 00:00:00 2001 From: Ilia Pinchuk Date: Wed, 2 Nov 2022 00:11:05 +0600 Subject: [PATCH 86/93] pass a catalog name to connection to find tables not only in hive_metastore --- data_diff/databases/databricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 7942b53a..0d993cdf 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -83,7 +83,7 @@ def __init__( databricks = import_databricks() self._conn = databricks.sql.connect( - server_hostname=server_hostname, http_path=http_path, access_token=access_token + server_hostname=server_hostname, http_path=http_path, access_token=access_token, catalog=catalog ) logging.getLogger("databricks.sql").setLevel(logging.WARNING) From b82b3ed1d0f3f8965ab60e6666b67463a9c70478 Mon Sep 17 00:00:00 2001 From: Ilia Pinchuk Date: Wed, 2 Nov 2022 00:11:23 +0600 Subject: [PATCH 87/93] fix schema parsing --- data_diff/databases/databricks.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 0d993cdf..8327f90a 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -108,7 +108,7 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: if not rows: raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") - d = {r.COLUMN_NAME: r for r in rows} + d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows} assert len(d) == len(rows) return d @@ -120,27 +120,26 @@ def _process_table_schema( resulted_rows = [] for row in rows: - row_type = "DECIMAL" if row.DATA_TYPE == 3 else row.TYPE_NAME + row_type = "DECIMAL" if row[1].startswith("DECIMAL") else row[1] type_cls = self.TYPE_CLASSES.get(row_type, UnknownColType) if issubclass(type_cls, Integer): - row = (row.COLUMN_NAME, row_type, None, None, 0) + row = (row[0], row_type, None, None, 0) elif issubclass(type_cls, Float): - numeric_precision = self._convert_db_precision_to_digits(row.DECIMAL_DIGITS) - row = (row.COLUMN_NAME, row_type, None, numeric_precision, None) + numeric_precision = self._convert_db_precision_to_digits(row[2]) + row = (row[0], row_type, None, numeric_precision, None) elif issubclass(type_cls, Decimal): - # TYPE_NAME has a format DECIMAL(x,y) - items = row.TYPE_NAME[8:].rstrip(")").split(",") + items = row[1][8:].rstrip(")").split(",") numeric_precision, numeric_scale = int(items[0]), int(items[1]) - row = (row.COLUMN_NAME, row_type, None, numeric_precision, numeric_scale) + row = (row[0], row_type, None, numeric_precision, numeric_scale) elif issubclass(type_cls, Timestamp): - row = (row.COLUMN_NAME, row_type, row.DECIMAL_DIGITS, None, None) + row = (row[0], row_type, row[2], None, None) else: - row = (row.COLUMN_NAME, row_type, None, None, None) + row = (row[0], row_type, None, None, None) resulted_rows.append(row) From 32a491447df03e21aecbb97f32a7540433779a86 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 2 Nov 2022 17:41:51 -0300 Subject: [PATCH 88/93] Tests: Even better balance using TEST_ACROSS_ALL_DBS --- .github/workflows/ci.yml | 2 +- .github/workflows/ci_full.yml | 2 +- tests/common.py | 8 +++++--- tests/test_database_types.py | 16 +++++++++------- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 21115c45..0aa3642d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -51,4 +51,4 @@ jobs: DATADIFF_VERTICA_URI: 'vertica://vertica:Password1@localhost:5433/vertica' run: | chmod +x tests/waiting_for_stack_up.sh - ./tests/waiting_for_stack_up.sh && poetry run unittest-parallel -j 16 + ./tests/waiting_for_stack_up.sh && TEST_ACROSS_ALL_DBS=0 poetry run unittest-parallel -j 16 diff --git a/.github/workflows/ci_full.yml b/.github/workflows/ci_full.yml index 42b1e5fb..a6a75806 100644 --- a/.github/workflows/ci_full.yml +++ b/.github/workflows/ci_full.yml @@ -47,4 +47,4 @@ jobs: DATADIFF_VERTICA_URI: 'vertica://vertica:Password1@localhost:5433/vertica' run: | chmod +x tests/waiting_for_stack_up.sh - ./tests/waiting_for_stack_up.sh && FULL_TESTS=1 poetry run unittest-parallel -j 16 + ./tests/waiting_for_stack_up.sh && TEST_ACROSS_ALL_DBS=full poetry run unittest-parallel -j 16 diff --git a/tests/common.py b/tests/common.py index aa10f64d..c1ae30a0 100644 --- a/tests/common.py +++ b/tests/common.py @@ -36,7 +36,7 @@ N_SAMPLES = int(os.environ.get("N_SAMPLES", DEFAULT_N_SAMPLES)) BENCHMARK = os.environ.get("BENCHMARK", False) N_THREADS = int(os.environ.get("N_THREADS", 1)) -FULL_TESTS = bool(os.environ.get("FULL_TESTS", False)) # Should we run the full db<->db test suite? +TEST_ACROSS_ALL_DBS = os.environ.get("TEST_ACROSS_ALL_DBS", True) # Should we run the full db<->db test suite? def get_git_revision_short_hash() -> str: @@ -95,8 +95,10 @@ def _print_used_dbs(): logging.info(f"Testing databases: {', '.join(used)}") if unused: logging.info(f"Connection not configured; skipping tests for: {', '.join(unused)}") - if FULL_TESTS: - logging.info("Full tests enabled (every db<->db). May take very long when many dbs are involved.") + if TEST_ACROSS_ALL_DBS: + logging.info( + f"Full tests enabled (every db<->db). May take very long when many dbs are involved. ={TEST_ACROSS_ALL_DBS}" + ) _print_used_dbs() diff --git a/tests/test_database_types.py b/tests/test_database_types.py index d2f65162..250b4537 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -25,7 +25,7 @@ N_THREADS, BENCHMARK, GIT_REVISION, - FULL_TESTS, + TEST_ACROSS_ALL_DBS, get_conn, random_table_suffix, ) @@ -420,27 +420,29 @@ def __iter__(self): } -def _get_test_db_pairs(full): - if full: +def _get_test_db_pairs(): + if str(TEST_ACROSS_ALL_DBS).lower() == "full": for source_db in DATABASE_TYPES: for target_db in DATABASE_TYPES: yield source_db, target_db - else: + elif int(TEST_ACROSS_ALL_DBS): for db_cls in DATABASE_TYPES: yield db_cls, db.PostgreSQL yield db.PostgreSQL, db_cls yield db_cls, db.Snowflake yield db.Snowflake, db_cls + else: + yield db.PostgreSQL, db.PostgreSQL -def get_test_db_pairs(full=True): - active_pairs = {(db1, db2) for db1, db2 in _get_test_db_pairs(full) if db1 in CONN_STRINGS and db2 in CONN_STRINGS} +def get_test_db_pairs(): + active_pairs = {(db1, db2) for db1, db2 in _get_test_db_pairs() if db1 in CONN_STRINGS and db2 in CONN_STRINGS} for db1, db2 in active_pairs: yield db1, DATABASE_TYPES[db1], db2, DATABASE_TYPES[db2] type_pairs = [] -for source_db, source_type_categories, target_db, target_type_categories in get_test_db_pairs(FULL_TESTS): +for source_db, source_type_categories, target_db, target_type_categories in get_test_db_pairs(): for type_category, source_types in source_type_categories.items(): # int, datetime, .. for source_type in source_types: for target_type in target_type_categories[type_category]: From f216ecb887959534a9da8b406c968f2ee4b36cd7 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 2 Nov 2022 17:51:27 -0300 Subject: [PATCH 89/93] Better names for CI --- .github/workflows/ci.yml | 2 +- .github/workflows/ci_full.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0aa3642d..7c42e356 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,4 +1,4 @@ -name: CI +name: CI-COVER-VERSIONS on: push: diff --git a/.github/workflows/ci_full.yml b/.github/workflows/ci_full.yml index a6a75806..0de7da52 100644 --- a/.github/workflows/ci_full.yml +++ b/.github/workflows/ci_full.yml @@ -1,4 +1,4 @@ -name: CI +name: CI-COVER-DATABASES on: push: From 1790a38fd7e8a597098d40a1327f0ece2bc4b754 Mon Sep 17 00:00:00 2001 From: Ilia Pinchuk Date: Wed, 2 Nov 2022 17:04:33 +0600 Subject: [PATCH 90/93] support multithreading for databricks The databricks connector is not thread-safe so we should inherit ThreadedDatabase class --- data_diff/databases/databricks.py | 59 +++++++++++++++---------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 8327f90a..c38704b5 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -13,7 +13,7 @@ ColType, UnknownColType, ) -from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, BaseDialect, Database, import_helper, parse_table_name +from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, BaseDialect, ThreadedDatabase, import_helper, parse_table_name @import_helper(text="You can install it using 'pip install databricks-sql-connector'") @@ -68,43 +68,45 @@ def _convert_db_precision_to_digits(self, p: int) -> int: return max(super()._convert_db_precision_to_digits(p) - 1, 0) -class Databricks(Database): +class Databricks(ThreadedDatabase): dialect = Dialect() - def __init__( - self, - http_path: str, - access_token: str, - server_hostname: str, - catalog: str = "hive_metastore", - schema: str = "default", - **kwargs, - ): - databricks = import_databricks() - - self._conn = databricks.sql.connect( - server_hostname=server_hostname, http_path=http_path, access_token=access_token, catalog=catalog - ) - + def __init__(self, *, thread_count, **kw): logging.getLogger("databricks.sql").setLevel(logging.WARNING) - self.catalog = catalog - self.default_schema = schema - self.kwargs = kwargs + self._args = kw + self.default_schema = kw.get('schema', 'hive_metastore') + super().__init__(thread_count=thread_count) - def _query(self, sql_code: str) -> list: - "Uses the standard SQL cursor interface" - return self._query_conn(self._conn, sql_code) + def create_connection(self): + databricks = import_databricks() + + try: + return databricks.sql.connect( + server_hostname=self._args['server_hostname'], + http_path=self._args['http_path'], + access_token=self._args['access_token'], + catalog=self._args['catalog'], + ) + except databricks.sql.exc.Error as e: + raise ConnectionError(*e.args) from e def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: # Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL. # https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html # So, to obtain information about schema, we should use another approach. + conn = self.create_connection() + schema, table = self._normalize_table_path(path) - with self._conn.cursor() as cursor: - cursor.columns(catalog_name=self.catalog, schema_name=schema, table_name=table) - rows = cursor.fetchall() + with conn.cursor() as cursor: + cursor.columns(catalog_name=self._args['catalog'], schema_name=schema, table_name=table) + try: + rows = cursor.fetchall() + except: + rows = None + finally: + conn.close() if not rows: raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") @@ -121,7 +123,7 @@ def _process_table_schema( resulted_rows = [] for row in rows: row_type = "DECIMAL" if row[1].startswith("DECIMAL") else row[1] - type_cls = self.TYPE_CLASSES.get(row_type, UnknownColType) + type_cls = self.dialect.TYPE_CLASSES.get(row_type, UnknownColType) if issubclass(type_cls, Integer): row = (row[0], row_type, None, None, 0) @@ -152,9 +154,6 @@ def parse_table_name(self, name: str) -> DbPath: path = parse_table_name(name) return self._normalize_table_path(path) - def close(self): - self._conn.close() - @property def is_autocommit(self) -> bool: return True From a8c6efce7906a3608cad1d237eadee9700f61bda Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 4 Nov 2022 17:46:39 -0300 Subject: [PATCH 91/93] Bugfix in alphanums (reported by Guarav Singh) --- data_diff/utils.py | 12 ++++++------ tests/test_diff_tables.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/data_diff/utils.py b/data_diff/utils.py index de011d02..a11c4142 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -66,7 +66,7 @@ def numberToAlphanum(num: int, base: str = alphanums) -> str: return "".join(base[i] for i in digits[::-1]) -def alphanumToNumber(alphanum: str, base: str) -> int: +def alphanumToNumber(alphanum: str, base: str = alphanums) -> int: num = 0 for c in alphanum: num = num * len(base) + base.index(c) @@ -82,8 +82,8 @@ def justify_alphanums(s1: str, s2: str): def alphanums_to_numbers(s1: str, s2: str): s1, s2 = justify_alphanums(s1, s2) - n1 = alphanumToNumber(s1, alphanums) - n2 = alphanumToNumber(s2, alphanums) + n1 = alphanumToNumber(s1) + n2 = alphanumToNumber(s2) return n1, n2 @@ -121,9 +121,9 @@ def __add__(self, other: "Union[ArithAlphanumeric, int]") -> "ArithAlphanumeric" if isinstance(other, int): if other != 1: raise NotImplementedError("not implemented for arbitrary numbers") - lastchar = self._str[-1] if self._str else alphanums[0] - s = self._str[:-1] + alphanums[alphanums.index(lastchar) + other] - return self.new(s) + num = alphanumToNumber(self._str) + return self.new(numberToAlphanum(num + 1)) + return NotImplemented def range(self, other: "ArithAlphanumeric", count: int): diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index c87b23bf..a632638a 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -414,7 +414,7 @@ def setUp(self): super().setUp() self.src_table = src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) - self.new_alphanum = "aBcDeFgHiJ" + self.new_alphanum = "aBcDeFgHiz" values = [] for i in range(0, 10000, 1000): From 9c93229be75e652f90597ab11ef422815ad84370 Mon Sep 17 00:00:00 2001 From: Ilia Pinchuk Date: Wed, 2 Nov 2022 17:29:42 +0600 Subject: [PATCH 92/93] fix float value precision calculation --- data_diff/databases/databricks.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index c38704b5..79c46fc7 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -1,3 +1,4 @@ +import math from typing import Dict, Sequence import logging @@ -61,11 +62,14 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')" def normalize_number(self, value: str, coltype: NumericType) -> str: - return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") + value = f"cast({value} as decimal(38, {coltype.precision}))" + if coltype.precision > 0: + value = f"format_number({value}, {coltype.precision})" + return f"replace({self.to_string(value)}, ',', '')" def _convert_db_precision_to_digits(self, p: int) -> int: - # Subtracting 1 due to wierd precision issues - return max(super()._convert_db_precision_to_digits(p) - 1, 0) + # Subtracting 2 due to wierd precision issues + return max(super()._convert_db_precision_to_digits(p) - 2, 0) class Databricks(ThreadedDatabase): @@ -75,7 +79,7 @@ def __init__(self, *, thread_count, **kw): logging.getLogger("databricks.sql").setLevel(logging.WARNING) self._args = kw - self.default_schema = kw.get('schema', 'hive_metastore') + self.default_schema = kw.get("schema", "hive_metastore") super().__init__(thread_count=thread_count) def create_connection(self): @@ -83,11 +87,11 @@ def create_connection(self): try: return databricks.sql.connect( - server_hostname=self._args['server_hostname'], - http_path=self._args['http_path'], - access_token=self._args['access_token'], - catalog=self._args['catalog'], - ) + server_hostname=self._args["server_hostname"], + http_path=self._args["http_path"], + access_token=self._args["access_token"], + catalog=self._args["catalog"], + ) except databricks.sql.exc.Error as e: raise ConnectionError(*e.args) from e @@ -100,11 +104,9 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: schema, table = self._normalize_table_path(path) with conn.cursor() as cursor: - cursor.columns(catalog_name=self._args['catalog'], schema_name=schema, table_name=table) + cursor.columns(catalog_name=self._args["catalog"], schema_name=schema, table_name=table) try: rows = cursor.fetchall() - except: - rows = None finally: conn.close() if not rows: @@ -129,7 +131,7 @@ def _process_table_schema( row = (row[0], row_type, None, None, 0) elif issubclass(type_cls, Float): - numeric_precision = self._convert_db_precision_to_digits(row[2]) + numeric_precision = math.ceil(row[2] / math.log(2, 10)) row = (row[0], row_type, None, numeric_precision, None) elif issubclass(type_cls, Decimal): From ae45945b688f8e39a0bc94d712175a527d5a50a3 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 7 Nov 2022 10:46:56 -0300 Subject: [PATCH 93/93] Version bump - Release candidate 2 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 212d177d..5c90d2d3 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "data-diff" -version = "0.3.0rc1" +version = "0.3.0rc2" description = "Command-line tool and Python library to efficiently diff rows across two different databases." authors = ["Datafold "] license = "MIT"