diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 21115c45..7c42e356 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,4 +1,4 @@ -name: CI +name: CI-COVER-VERSIONS on: push: @@ -51,4 +51,4 @@ jobs: DATADIFF_VERTICA_URI: 'vertica://vertica:Password1@localhost:5433/vertica' run: | chmod +x tests/waiting_for_stack_up.sh - ./tests/waiting_for_stack_up.sh && poetry run unittest-parallel -j 16 + ./tests/waiting_for_stack_up.sh && TEST_ACROSS_ALL_DBS=0 poetry run unittest-parallel -j 16 diff --git a/.github/workflows/ci_full.yml b/.github/workflows/ci_full.yml new file mode 100644 index 00000000..0de7da52 --- /dev/null +++ b/.github/workflows/ci_full.yml @@ -0,0 +1,50 @@ +name: CI-COVER-DATABASES + +on: + push: + paths: + - '**.py' + - '.github/workflows/**' + - '!dev/**' + pull_request: + branches: [ master ] + + workflow_dispatch: + +jobs: + unit_tests: + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: + - "3.10" + + name: Check Python ${{ matrix.python-version }} on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + + - name: Build the stack + run: docker-compose up -d mysql postgres presto trino clickhouse vertica + + - name: Install Poetry + run: pip install poetry + + - name: Install package + run: "poetry install" + + - name: Run unit tests + env: + DATADIFF_SNOWFLAKE_URI: '${{ secrets.DATADIFF_SNOWFLAKE_URI }}' + DATADIFF_PRESTO_URI: '${{ secrets.DATADIFF_PRESTO_URI }}' + DATADIFF_CLICKHOUSE_URI: 'clickhouse://clickhouse:Password1@localhost:9000/clickhouse' + DATADIFF_VERTICA_URI: 'vertica://vertica:Password1@localhost:5433/vertica' + run: | + chmod +x tests/waiting_for_stack_up.sh + ./tests/waiting_for_stack_up.sh && TEST_ACROSS_ALL_DBS=full poetry run unittest-parallel -j 16 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6a6418d9..9f8c0b10 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -34,7 +34,7 @@ The same goes for other technical requests, like missing features, or gaps in th See [issues](/datafold/data-diff/issues/). -For questions, and non-technical discussions, see [discussions](/datafold/data-diff/discussions). +For questions, and non-technical discussions, see [discussions](https://github.com/datafold/data-diff/discussions). ### Contributing code @@ -79,3 +79,86 @@ New databases should be added as a new module in the `data-diff/databases/` fold If possible, please also add the database setup to `docker-compose.yml`, so that we can run and test it for ourselves. If you do, also update the CI (`ci.yml`). Guide to implementing a new database driver: https://data-diff.readthedocs.io/en/latest/new-database-driver-guide.html + +## Development Setup + +The development setup centers around using `docker-compose` to boot up various +databases, and then inserting data into them. + +For Mac for performance of Docker, we suggest enabling in the UI: + +* Use new Virtualization Framework +* Enable VirtioFS accelerated directory sharing + +**1. Install Data Diff** + +When developing/debugging, it's recommended to install dependencies and run it +directly with `poetry` rather than go through the package. + +``` +$ brew install mysql postgresql # MacOS dependencies for C bindings +$ apt-get install libpq-dev libmysqlclient-dev # Debian dependencies +$ pip install poetry # Python dependency isolation tool +$ poetry install # Install dependencies +``` +**2. Start Databases** + +[Install **docker-compose**][docker-compose] if you haven't already. + +```shell-session +$ docker-compose up -d mysql postgres # run mysql and postgres dbs in background +``` + +[docker-compose]: https://docs.docker.com/compose/install/ + +**3. Run Unit Tests** + +There are more than 1000 tests for all the different type and database +combinations, so we recommend using `unittest-parallel` that's installed as a +development dependency. + +```shell-session +$ poetry run unittest-parallel -j 16 # run all tests +$ poetry run python -m unittest -k # run individual test +``` + +**4. Seed the Database(s) (optional)** + +First, download the CSVs of seeding data: + +```shell-session +$ curl https://datafold-public.s3.us-west-2.amazonaws.com/1m.csv -o dev/ratings.csv +# For a larger data-set (but takes 25x longer to import): +# - curl https://datafold-public.s3.us-west-2.amazonaws.com/25m.csv -o dev/ratings.csv +``` + +Now you can insert it into the testing database(s): + +```shell-session +# It's optional to seed more than one to run data-diff(1) against. +$ poetry run preql -f dev/prepare_db.pql mysql://mysql:Password1@127.0.0.1:3306/mysql +$ poetry run preql -f dev/prepare_db.pql postgresql://postgres:Password1@127.0.0.1:5432/postgres +# Cloud databases +$ poetry run preql -f dev/prepare_db.pql snowflake:// +$ poetry run preql -f dev/prepare_db.pql mssql:// +$ poetry run preql -f dev/prepare_db.pql bigquery:/// +``` + +**5. Run **data-diff** against seeded database (optional)** + +```bash +poetry run python3 -m data_diff postgresql://postgres:Password1@localhost/postgres rating postgresql://postgres:Password1@localhost/postgres rating_del1 --verbose +``` + +**6. Run benchmarks (optional)** + +```shell-session +$ dev/benchmark.sh # runs benchmarks and puts results in benchmark_.csv +$ poetry run python3 dev/graph.py # create graphs from benchmark_*.csv files +``` + +You can adjust how many rows we benchmark with by passing `N_SAMPLES` to `dev/benchmark.sh`: + +```shell-session +$ N_SAMPLES=100000000 dev/benchmark.sh # 100m which is our canonical target +``` diff --git a/README.md b/README.md index df0d27c5..c7410736 100644 --- a/README.md +++ b/README.md @@ -1,170 +1,55 @@ +

+ Datafold +

+ # **data-diff** -**data-diff is in shape to be run in production, but also under development. If -you run into issues or bugs, please [open an issue](https://github.com/datafold/data-diff/issues/new/choose) and we'll help you out ASAP! You can -also find us in `#tools-data-diff` in the [Locally Optimistic Slack][slack].** - -**We'd love to hear about your experience using data-diff, and learn more your use cases. [Reach out to product team share any product feedback or feature requests!](https://calendly.com/jp-toor/customer-interview-oss)** - -πŸ’ΈπŸ’Έ **Looking for paid contributors!** πŸ’ΈπŸ’Έ If you're up for making money working on awesome open-source tools, we're looking for developers with a deep understanding of databases and solid Python knowledge. [**Apply here!**](https://docs.google.com/forms/d/e/1FAIpQLScEa5tc9CM0uNsb3WigqRFq92OZENkThM04nIs7ZVl_bwsGMw/viewform) - ----- - -**data-diff** is a command-line tool and Python library to efficiently diff -rows across two different databases. - -* ⇄ Verifies across [many different databases][dbs] (e.g. PostgreSQL -> Snowflake) -* πŸ” Outputs [diff of rows](#example-command-and-output) in detail -* 🚨 Simple CLI/API to create monitoring and alerts -* πŸ” Bridges column types of different formats and levels of precision (e.g. Double ⇆ Float ⇆ Decimal) -* πŸ”₯ Verify 25M+ rows in <10s, and 1B+ rows in ~5min. -* ♾️ Works for tables with 10s of billions of rows - -**data-diff** splits the table into smaller segments, then checksums each -segment in both databases. When the checksums for a segment aren't equal, it -will further divide that segment into yet smaller segments, checksumming those -until it gets to the differing row(s). See [Technical Explanation][tech-explain] for more -details. - -This approach has performance within an order of magnitude of `count(*)` when -there are few/no changes, but is able to output each differing row! By pushing -the compute into the databases, it's _much_ faster than querying for and -comparing every row. - -![Performance for 100M rows](https://user-images.githubusercontent.com/97400/175182987-a3900d4e-c097-4732-a4e9-19a40fac8cdc.png) - -**†:** The implementation for downloading all rows that `data-diff` and -`count(*)` is compared to is not optimal. It is a single Python multi-threaded -process. The performance is fairly driver-specific, e.g. PostgreSQL's performs 10x -better than MySQL. - -## Table of Contents - -- [**data-diff**](#data-diff) - - [Table of Contents](#table-of-contents) - - [Common use-cases](#common-use-cases) - - [Example Command and Output](#example-command-and-output) - - [Supported Databases](#supported-databases) -- [How to install](#how-to-install) - - [Install drivers](#install-drivers) -- [How to use](#how-to-use) - - [How to use from the command-line](#how-to-use-from-the-command-line) - - [How to use from Python](#how-to-use-from-python) -- [Technical Explanation](#technical-explanation) - - [Performance Considerations](#performance-considerations) -- [Anonymous Tracking](#anonymous-tracking) -- [Development Setup](#development-setup) -- [License](#license) - -## Common use-cases - -* **Verify data migrations.** Verify that all data was copied when doing a - critical data migration. For example, migrating from Heroku PostgreSQL to Amazon RDS. -* **Verifying data pipelines.** Moving data from a relational database to a - warehouse/data lake with Fivetran, Airbyte, Debezium, or some other pipeline. -* **Alerting and maintaining data integrity SLOs.** You can create and monitor - your SLO of e.g. 99.999% data integrity, and alert your team when data is - missing. -* **Debugging complex data pipelines.** When data gets lost in pipelines that - may span a half-dozen systems, without verifying each intermediate datastore - it's extremely difficult to track down where a row got lost. -* **Detecting hard deletes for an `updated_at`-based pipeline**. If you're - copying data to your warehouse based on an `updated_at`-style column, then - you'll miss hard-deletes that **data-diff** can find for you. -* **Make your replication self-healing.** You can use **data-diff** to - self-heal by using the diff output to write/update rows in the target - database. - -## Example Command and Output - -Below we run a comparison with the CLI for 25M rows in PostgreSQL where the -right-hand table is missing single row with `id=12500048`: +## What is `data-diff`? +data-diff is a **free, open-source tool** that enables data professionals to detect differences in values between any two tables. It's fast, easy to use, and reliable. Even at massive scale. -``` -$ data-diff \ - postgresql://user:password@localhost/database rating \ - postgresql://user:password@localhost/database rating_del1 \ - --bisection-threshold 100000 \ # for readability, try default first - --bisection-factor 6 \ # for readability, try default first - --update-column timestamp \ - --verbose - - # Consider running with --interactive the first time. - # Runs `EXPLAIN` for you to verify the queries are using indexes. - # --interactive -[10:15:00] INFO - Diffing tables | segments: 6, bisection threshold: 100000. -[10:15:00] INFO - . Diffing segment 1/6, key-range: 1..4166683, size: 4166682 -[10:15:03] INFO - . Diffing segment 2/6, key-range: 4166683..8333365, size: 4166682 -[10:15:06] INFO - . Diffing segment 3/6, key-range: 8333365..12500047, size: 4166682 -[10:15:09] INFO - . Diffing segment 4/6, key-range: 12500047..16666729, size: 4166682 -[10:15:12] INFO - . . Diffing segment 1/6, key-range: 12500047..13194494, size: 694447 -[10:15:13] INFO - . . . Diffing segment 1/6, key-range: 12500047..12615788, size: 115741 -[10:15:13] INFO - . . . . Diffing segment 1/6, key-range: 12500047..12519337, size: 19290 -[10:15:13] INFO - . . . . Diff found 1 different rows. -[10:15:13] INFO - . . . . Diffing segment 2/6, key-range: 12519337..12538627, size: 19290 -[10:15:13] INFO - . . . . Diffing segment 3/6, key-range: 12538627..12557917, size: 19290 -[10:15:13] INFO - . . . . Diffing segment 4/6, key-range: 12557917..12577207, size: 19290 -[10:15:13] INFO - . . . . Diffing segment 5/6, key-range: 12577207..12596497, size: 19290 -[10:15:13] INFO - . . . . Diffing segment 6/6, key-range: 12596497..12615788, size: 19291 -[10:15:13] INFO - . . . Diffing segment 2/6, key-range: 12615788..12731529, size: 115741 -[10:15:13] INFO - . . . Diffing segment 3/6, key-range: 12731529..12847270, size: 115741 -[10:15:13] INFO - . . . Diffing segment 4/6, key-range: 12847270..12963011, size: 115741 -[10:15:14] INFO - . . . Diffing segment 5/6, key-range: 12963011..13078752, size: 115741 -[10:15:14] INFO - . . . Diffing segment 6/6, key-range: 13078752..13194494, size: 115742 -[10:15:14] INFO - . . Diffing segment 2/6, key-range: 13194494..13888941, size: 694447 -[10:15:14] INFO - . . Diffing segment 3/6, key-range: 13888941..14583388, size: 694447 -[10:15:15] INFO - . . Diffing segment 4/6, key-range: 14583388..15277835, size: 694447 -[10:15:15] INFO - . . Diffing segment 5/6, key-range: 15277835..15972282, size: 694447 -[10:15:15] INFO - . . Diffing segment 6/6, key-range: 15972282..16666729, size: 694447 -+ (12500048, 1268104625) -[10:15:16] INFO - . Diffing segment 5/6, key-range: 16666729..20833411, size: 4166682 -[10:15:19] INFO - . Diffing segment 6/6, key-range: 20833411..25000096, size: 4166685 -``` +_Are you a developer with a deep understanding of databases and solid Python knowledge? [We're hiring!](https://www.datafold.com/careers)_ + +## Documentation -## Supported Databases +[**Our detailed documentation**](https://docs.datafold.com/os_diff/about) has everything you need to start diffing. -| Database | Connection string | Status | -|---------------|-------------------------------------------------------------------------------------------------------------------------------------|--------| -| PostgreSQL >=10 | `postgresql://:@:5432/` | πŸ’š | -| MySQL | `mysql://:@:5432/` | πŸ’š | -| Snowflake | `"snowflake://[:]@//?warehouse=&role=[&authenticator=externalbrowser]"` | πŸ’š | -| Oracle | `oracle://:@/database` | πŸ’› | -| BigQuery | `bigquery:///` | πŸ’› | -| Redshift | `redshift://:@:5439/` | πŸ’› | -| Presto | `presto://:@:8080/` | πŸ’› | -| Databricks | `databricks://:@//` | πŸ’› | -| Trino | `trino://:@:8080/` | πŸ’› | -| Clickhouse | `clickhouse://:@:9000/` | πŸ’› | -| Vertica | `vertica://:@:5433/` | πŸ’› | -| ElasticSearch | | πŸ“ | -| Planetscale | | πŸ“ | -| Pinot | | πŸ“ | -| Druid | | πŸ“ | -| Kafka | | πŸ“ | +## Use cases -* πŸ’š: Implemented and thoroughly tested. -* πŸ’›: Implemented, but not thoroughly tested yet. -* ⏳: Implementation in progress. -* πŸ“: Implementation planned. Contributions welcome. +### Diff Tables Between Databases +#### Quickly identify issues when moving data between databases -If a database is not on the list, we'd still love to support it. Open an issue -to discuss it. +

+ diff2 +

-Note: Because URLs allow many special characters, and may collide with the syntax of your command-line, -it's recommended to surround them with quotes. Alternatively, you may provide them in a TOML file via the `--config` option. +### Diff Tables Within a Database (available in pre-release) +#### Improve code reviews by identifying data problems you don't have tests for +

+ + Intro to Diff + +

+  +  -# How to install +## Get started -Requires Python 3.7+ with pip. +### Installation -```pip install data-diff``` +#### First, install `data-diff` using `pip`. + +``` +pip install data-diff +``` -## Install drivers +To try out bleeding-edge features, including materialization of results in your data warehouse: -To connect to a database, we need to have its driver installed, in the form of a Python library. +``` +pip install data-diff --pre +``` -While you may install them manually, we offer an easy way to install them along with data-diff*: +#### Then, install one or more driver(s) specific to the database(s) you want to connect to. - `pip install 'data-diff[mysql]'` @@ -184,429 +69,75 @@ While you may install them manually, we offer an easy way to install them along - For BigQuery, see: https://pypi.org/project/google-cloud-bigquery/ +_Some drivers have dependencies that cannot be installed using `pip` and still need to be installed manually._ -Users can also install several drivers at once: - -```pip install 'data-diff[mysql,postgresql,snowflake]'``` - -_* Some drivers have dependencies that cannot be installed using `pip` and still need to be installed manually._ - - -### Install Psycopg2 - -In order to run Postgresql, you'll need `psycopg2`. This Python package requires some additional dependencies described in their [documentation](https://www.psycopg.org/docs/install.html#build-prerequisites). -An easy solution is to install [psycopg2-binary](https://www.psycopg.org/docs/install.html#quick-install) by running: - -```pip install psycopg2-binary``` - -Which comes with a pre-compiled binary and does not require additonal prerequisites. However, note that for production use it is adviced to use `psycopg2`. - - -# How to use - -## How to use from the command-line - -Usage: `data-diff DB1_URI TABLE1_NAME DB2_URI TABLE2_NAME [OPTIONS]` - -See the [example command](#example-command-and-output) and the [sample -connection strings](#supported-databases). - -Note that for some databases, the arguments that you enter in the command line -may be case-sensitive. This is the case for the Snowflake schema and table names. - -Options: - - - `--help` - Show help message and exit. - - `-k` or `--key-column` - Name of the primary key column - - `-t` or `--update-column` - Name of updated_at/last_updated column - - `-c` or `--columns` - Names of extra columns to compare. Can be used more than once in the same command. - Accepts a name or a pattern like in SQL. - Example: `-c col% -c another_col -c %foorb.r%` - - `-l` or `--limit` - Maximum number of differences to find (limits maximum bandwidth and runtime) - - `-s` or `--stats` - Print stats instead of a detailed diff - - `-d` or `--debug` - Print debug info - - `-v` or `--verbose` - Print extra info - - `-i` or `--interactive` - Confirm queries, implies `--debug` - - `--json` - Print JSONL output for machine readability - - `--min-age` - Considers only rows older than specified. Useful for specifying replication lag. - Example: `--min-age=5min` ignores rows from the last 5 minutes. - Valid units: `d, days, h, hours, min, minutes, mon, months, s, seconds, w, weeks, y, years` - - `--max-age` - Considers only rows younger than specified. See `--min-age`. - - `--bisection-factor` - Segments per iteration. When set to 2, it performs binary search. - - `--bisection-threshold` - Minimal bisection threshold. i.e. maximum size of pages to diff locally. - - `-j` or `--threads` - Number of worker threads to use per database. Default=1. - - `-w`, `--where` - An additional 'where' expression to restrict the search space. - - `--conf`, `--run` - Specify the run and configuration from a TOML file. (see below) - - `--no-tracking` - data-diff sends home anonymous usage data. Use this to disable it. - - -### How to use with a configuration file - -Data-diff lets you load the configuration for a run from a TOML file. - -Reasons to use a configuration file: - -- Convenience - Set-up the parameters for diffs that need to run often - -- Easier and more readable - you can define the database connection settings as config values, instead of in a URI. - -- Gives you fine-grained control over the settings switches, without requiring any Python code. - -Use `--conf` to specify that path to the configuration file. data-diff will load the settings from `run.default`, if it's defined. - -Then you can, optionally, use `--run` to choose to load the settings of a specific run, and override the settings `run.default`. (all runs extend `run.default`, like inheritance). - -Finally, CLI switches have the final say, and will override the settings defined by the configuration file, and the current run. - -Example TOML file: - -```toml -# Specify the connection params to the test database. -[database.test_postgresql] -driver = "postgresql" -user = "postgres" -password = "Password1" - -# Specify the default run params -[run.default] -update_column = "timestamp" -verbose = true - -# Specify params for a run 'test_diff'. -[run.test_diff] -verbose = false -# Source 1 ("left") -1.database = "test_postgresql" # Use options from database.test_postgresql -1.table = "rating" -# Source 2 ("right") -2.database = "postgresql://postgres:Password1@/" # Use URI like in the CLI -2.table = "rating_del1" -``` - -In this example, running `data-diff --conf myconfig.toml --run test_diff` will compare between `rating` and `rating_del1`. -It will use the `timestamp` column as the update column, as specified in `run.default`. However, it won't be verbose, since that -flag is overwritten to `false`. - -Running it with `data-diff --conf myconfig.toml --run test_diff -v` will set verbose back to `true`. - - -## How to use from Python - -API reference: [https://data-diff.readthedocs.io/en/latest/](https://data-diff.readthedocs.io/en/latest/) - -Example: - -```python -# Optional: Set logging to display the progress of the diff -import logging -logging.basicConfig(level=logging.INFO) - -from data_diff import connect_to_table, diff_tables - -table1 = connect_to_table("postgresql:///", "table_name", "id") -table2 = connect_to_table("mysql:///", "table_name", "id") - -for different_row in diff_tables(table1, table2): - plus_or_minus, columns = different_row - print(plus_or_minus, columns) -``` - -Run `help(diff_tables)` or [read the docs](https://data-diff.readthedocs.io/en/latest/) to learn about the different options. - -# Technical Explanation - -In this section we'll be doing a walk-through of exactly how **data-diff** -works, and how to tune `--bisection-factor` and `--bisection-threshold`. - -Let's consider a scenario with an `orders` table with 1M rows. Fivetran is -replicating it contionously from PostgreSQL to Snowflake: - -``` -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ PostgreSQL β”‚ β”‚ Snowflake β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ table with β”‚ -β”‚ table with β”œβ”€β”€β”€ replication β”œβ”€β”€β”€β”€β”€β”€β–Άβ”‚ ?maybe? all β”‚ -β”‚lots of rows!β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ the same β”‚ -β”‚ β”‚ β”‚ rows. β”‚ -β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -``` - -In order to check whether the two tables are the same, **data-diff** splits -the table into `--bisection-factor=10` segments. - -We also have to choose which columns we want to checksum. In our case, we care -about the primary key, `--key-column=id` and the update column -`--update-column=updated_at`. `updated_at` is updated every time the row is, and -we have an index on it. - -**data-diff** starts by querying both databases for the `min(id)` and `max(id)` -of the table. Then it splits the table into `--bisection-factor=10` segments of -`1M/10 = 100K` keys each: - -``` -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ PostgreSQL β”‚ β”‚ Snowflake β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ id=1..100k β”‚ β”‚ id=1..100k β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ id=100k..200k β”‚ β”‚ id=100k..200k β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ id=200k..300k β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Άβ”‚ id=200k..300k β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ id=300k..400k β”‚ β”‚ id=300k..400k β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ ... β”‚ β”‚ ... β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ 900k..100k β”‚ β”‚ 900k..100k β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–²β”€β”€β”˜ β””β–²β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ - ┃ ┃ - ┃ ┃ - ┃ checksum queries ┃ - ┃ ┃ - β”Œβ”€β”»β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”»β”€β”€β”€β”€β” - β”‚ data-diff β”‚ - β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -``` - -Now **data-diff** will start running `--threads=1` queries in parallel that -checksum each segment. The queries for checksumming each segment will look -something like this, depending on the database: - -```sql -SELECT count(*), - sum(cast(conv(substring(md5(concat(cast(id as char), cast(timestamp as char))), 18), 16, 10) as unsigned)) -FROM `rating_del1` -WHERE (id >= 1) AND (id < 100000) -``` - -This keeps the amount of data that has to be transferred between the databases -to a minimum, making it very performant! Additionally, if you have an index on -`updated_at` (highly recommended) then the query will be fast as the database -only has to do a partial index scan between `id=1..100k`. - -If you are not sure whether the queries are using an index, you can run it with -`--interactive`. This puts **data-diff** in interactive mode where it shows an -`EXPLAIN` before executing each query, requiring confirmation to proceed. - -After running the checksum queries on both sides, we see that all segments -are the same except `id=100k..200k`: - -``` -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ PostgreSQL β”‚ β”‚ Snowflake β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ checksum=0102 β”‚ β”‚ checksum=0102 β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ mismatch! β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ checksum=ffff ◀──────────────▢ checksum=aaab β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ checksum=abab β”‚ β”‚ checksum=abab β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ checksum=f0f0 β”‚ β”‚ checksum=f0f0 β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ ... β”‚ β”‚ ... β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ checksum=9494 β”‚ β”‚ checksum=9494 β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -``` - -Now **data-diff** will do exactly as it just did for the _whole table_ for only -this segment: Split it into `--bisection-factor` segments. - -However, this time, because each segment has `100k/10=10k` entries, which is -less than the `--bisection-threshold` it will pull down every row in the segment -and compare them in memory in **data-diff**. - -``` -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ PostgreSQL β”‚ β”‚ Snowflake β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ id=100k..110k β”‚ β”‚ id=100k..110k β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ id=110k..120k β”‚ β”‚ id=110k..120k β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ id=120k..130k β”‚ β”‚ id=120k..130k β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ id=130k..140k β”‚ β”‚ id=130k..140k β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ ... β”‚ β”‚ ... β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ 190k..200k β”‚ β”‚ 190k..200k β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -``` +### Run your first diff -Finally **data-diff** will output the `(id, updated_at)` for each row that was different: +Once you've installed `data-diff`, you can run it from the command line. ``` -(122001, 1653672821) +data-diff DB1_URI TABLE1_NAME DB2_URI TABLE2_NAME [OPTIONS] ``` -If you pass `--stats` you'll see e.g. what % of rows were different. - -## Performance Considerations - -* Ensure that you have indexes on the columns you are comparing. Preferably a - compound index. You can run with `--interactive` to see an `EXPLAIN` for the - queries. -* Consider increasing the number of simultaneous threads executing - queries per database with `--threads`. For databases that limit concurrency - per query, e.g. PostgreSQL/MySQL, this can improve performance dramatically. -* If you are only interested in _whether_ something changed, pass `--limit 1`. - This can be useful if changes are very rare. This is often faster than doing a - `count(*)`, for the reason mentioned above. -* If the table is _very_ large, consider a larger `--bisection-factor`. Explained in - the [technical explanation][tech-explain]. Otherwise you may run into timeouts. -* If there are a lot of changes, consider a larger `--bisection-threshold`. - Explained in the [technical explanation][tech-explain]. -* If there are very large gaps in your key column, e.g. 10s of millions of - continuous rows missing, then **data-diff** may perform poorly doing lots of - queries for ranges of rows that do not exist (see [technical - explanation][tech-explain]). We have ideas on how to tackle this issue, which we have - yet to implement. If you're experiencing this effect, please open an issue and we - will prioritize it. -* The fewer columns you verify (passed with `--columns`), the faster - **data-diff** will be. On one extreme you can verify every column, on the - other you can verify _only_ `updated_at`, if you trust it enough. You can also - _only_ verify `id` if you're interested in only presence, e.g. to detect - missing hard deletes. You can do also do a hybrid where you verify - `updated_at` and the most critical value, e.g a money value in `amount` but - not verify a large serialized column like `json_settings`. -* We have ideas for making **data-diff** even faster that - we haven't implemented yet: faster checksums by reducing type-casts - and using a faster hash than MD5, dynamic adaptation of - `bisection_factor`/`threads`/`bisection_threshold` (especially with large key - gaps), and improvements to bypass Python/driver performance limitations when - comparing huge amounts of rows locally (i.e. for very high `bisection_threshold` values). - -# Usage Analytics - -data-diff collects anonymous usage data to help our team improve the tool and to apply development efforts to where our users need them most. - -We capture two events, one when the data-diff run starts and one when it is finished. No user data or potentially sensitive information is or ever will be collected. The captured data is limited to: - -- Operating System and Python version - -- Types of databases used (postgresql, mysql, etc.) +Be sure to read [the docs](https://docs.datafold.com/os_diff/how_to_use) for detailed instructions how to build one of these commands depending on your database setup. -- Sizes of tables diffed, run time, and diff row count (numbers only) +#### Code Example: Diff Tables Between Databases +Here's an example command for your copy/pasting, taken from the screenshot above when we diffed data between Snowflake and Postgres. -- Error message, if any, truncated to the first 20 characters. - -- A persistent UUID to indentify the session, stored in `~/.datadiff.toml` - -If you do not wish to participate, the tracking can be easily disabled with one of the following methods: - -* In the CLI, use the `--no-tracking` flag. - -* In the config file, set `no_tracking = true` (for example, under `[run.default]`) - -* If you're using the Python API: - -```python -import data_diff -data_diff.disable_tracking() # Call this first, before making any API calls - -# Connect and diff your tables without any tracking ``` - - -# Development Setup - -The development setup centers around using `docker-compose` to boot up various -databases, and then inserting data into them. - -For Mac for performance of Docker, we suggest enabling in the UI: - -* Use new Virtualization Framework -* Enable VirtioFS accelerated directory sharing - -**1. Install Data Diff** - -When developing/debugging, it's recommended to install dependencies and run it -directly with `poetry` rather than go through the package. - +data-diff \ + postgresql://:''@localhost:5432/ \ + \ + "snowflake://:@//?warehouse=&role=" \ +
\ + -k activity_id \ + -c activity \ + -w "event_timestamp < '2022-10-10'" ``` -$ brew install mysql postgresql # MacOS dependencies for C bindings -$ apt-get install libpq-dev libmysqlclient-dev # Debian dependencies -$ pip install poetry # Python dependency isolation tool -$ poetry install # Install dependencies -``` -**2. Start Databases** +#### Code Example: Diff Tables Within a Database (available in pre-release) -[Install **docker-compose**][docker-compose] if you haven't already. +Here's a code example from [the video](https://www.loom.com/share/682e4b7d74e84eb4824b983311f0a3b2), where we compare data between two Snowflake tables within one database. -```shell-session -$ docker-compose up -d mysql postgres # run mysql and postgres dbs in background ``` - -[docker-compose]: https://docs.docker.com/compose/install/ - -**3. Run Unit Tests** - -There are more than 1000 tests for all the different type and database -combinations, so we recommend using `unittest-parallel` that's installed as a -development dependency. - -```shell-session -$ poetry run unittest-parallel -j 16 # run all tests -$ poetry run python -m unittest -k # run individual test +data-diff \ + "snowflake://:@//?warehouse=&role=" \ + . \ + -k org_id \ + -c created_at -c is_internal \ + -w "org_id != 1 and org_id < 2000" \ + -m test_results_%t \ + --materialize-all-rows \ + --table-write-limit 10000 ``` -**4. Seed the Database(s) (optional)** - -First, download the CSVs of seeding data: +In both code examples, I've used `<>` carrots to represent values that **should be replaced with your values** in the database connection strings. For the flags (`-k`, `-c`, etc.), I opted for "real" values (`org_id`, `is_internal`) to give you a more realistic view of what your command will look like. -```shell-session -$ curl https://datafold-public.s3.us-west-2.amazonaws.com/1m.csv -o dev/ratings.csv +### We're here to help! -# For a larger data-set (but takes 25x longer to import): -# - curl https://datafold-public.s3.us-west-2.amazonaws.com/25m.csv -o dev/ratings.csv -``` +We know that in some cases, the data-diff command can become long and dense. And maybe you're new to the command line. -Now you can insert it into the testing database(s): +* We're here to help [on slack](https://locallyoptimistic.slack.com/archives/C03HUNGQV0S) if you have ANY questions as you use `data-diff` in your workflow. +* You can also post a question in [GitHub Discussions](https://github.com/datafold/data-diff/discussions). -```shell-session -# It's optional to seed more than one to run data-diff(1) against. -$ poetry run preql -f dev/prepare_db.pql mysql://mysql:Password1@127.0.0.1:3306/mysql -$ poetry run preql -f dev/prepare_db.pql postgresql://postgres:Password1@127.0.0.1:5432/postgres -# Cloud databases -$ poetry run preql -f dev/prepare_db.pql snowflake:// -$ poetry run preql -f dev/prepare_db.pql mssql:// -$ poetry run preql -f dev/prepare_db.pql bigquery:/// -``` +To get a Slack invite - [click here](https://locallyoptimistic.com/community/) -**5. Run **data-diff** against seeded database (optional)** +## How to Use -```bash -poetry run python3 -m data_diff postgresql://postgres:Password1@localhost/postgres rating postgresql://postgres:Password1@localhost/postgres rating_del1 --verbose -``` +* [How to use from the shell (or: command-line)](https://docs.datafold.com/os_diff/how_to_use#how-to-use-from-the-command-line) +* [How to use from Python](https://docs.datafold.com/os_diff/how_to_use#how-to-use-from-python) +* [Usage Analytics & Data Privacy](https://docs.datafold.com/os_diff/usage_analytics_data_privacy) -**6. Run benchmarks (optional)** +## How to Contribute +* Feel free to open an issue or contribute to the project by working on an existing issue. +* Please read the [contributing guidelines](https://github.com/leoebfolsom/data-diff/blob/master/CONTRIBUTING.md) to get started. -```shell-session -$ dev/benchmark.sh # runs benchmarks and puts results in benchmark_.csv -$ poetry run python3 dev/graph.py # create graphs from benchmark_*.csv files -``` - -You can adjust how many rows we benchmark with by passing `N_SAMPLES` to `dev/benchmark.sh`: - -```shell-session -$ N_SAMPLES=100000000 dev/benchmark.sh # 100m which is our canonical target -``` +## Technical Explanation -# License +Check out this [technical explanation](https://docs.datafold.com/os_diff/technical_explanation) of how data-diff works. -[MIT License](https://github.com/datafold/data-diff/blob/master/LICENSE) +## License -[dbs]: #supported-databases -[tech-explain]: #technical-explanation -[perf]: #performance-considerations -[slack]: https://locallyoptimistic.com/community/ +This project is licensed under the terms of the [MIT License](https://github.com/datafold/data-diff/blob/master/LICENSE). diff --git a/data_diff/__init__.py b/data_diff/__init__.py index bc5677c8..20c6b57d 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -1,33 +1,41 @@ -from typing import Tuple, Iterator, Optional, Union +from typing import Sequence, Tuple, Iterator, Optional, Union from .tracking import disable_tracking from .databases.connect import connect from .databases.database_types import DbKey, DbTime, DbPath -from .diff_tables import TableSegment, TableDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR +from .diff_tables import Algorithm +from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR +from .joindiff_tables import JoinDiffer +from .table_segment import TableSegment def connect_to_table( db_info: Union[str, dict], table_name: Union[DbPath, str], - key_column: str = "id", + key_columns: str = ("id",), thread_count: Optional[int] = 1, **kwargs, -): +) -> TableSegment: """Connects to the given database, and creates a TableSegment instance Parameters: db_info: Either a URI string, or a dict of connection options. table_name: Name of the table as a string, or a tuple that signifies the path. - key_column: Name of the key column - thread_count: Number of threads for this connection (only if using a threadpooled implementation) + key_columns: Names of the key columns + thread_count: Number of threads for this connection (only if using a threadpooled db implementation) + + See Also: + :meth:`connect` """ + if isinstance(key_columns, str): + key_columns = (key_columns,) db = connect(db_info, thread_count=thread_count) if isinstance(table_name, str): table_name = db.parse_table_name(table_name) - return TableSegment(db, table_name, key_column, **kwargs) + return TableSegment(db, table_name, key_columns, **kwargs) def diff_tables( @@ -35,7 +43,7 @@ def diff_tables( table2: TableSegment, *, # Name of the key column, which uniquely identifies each row (usually id) - key_column: str = None, + key_columns: Sequence[str] = None, # Name of updated column, which signals that rows changed (usually updated_at or last_update) update_column: str = None, # Extra columns to compare @@ -46,31 +54,63 @@ def diff_tables( # Start/end update_column values, used to restrict the segment min_update: DbTime = None, max_update: DbTime = None, - # Into how many segments to bisect per iteration + # Algorithm + algorithm: Algorithm = Algorithm.HASHDIFF, + # Into how many segments to bisect per iteration (hashdiff only) bisection_factor: int = DEFAULT_BISECTION_FACTOR, - # When should we stop bisecting and compare locally (in row count) + # When should we stop bisecting and compare locally (in row count; hashdiff only) bisection_threshold: int = DEFAULT_BISECTION_THRESHOLD, # Enable/disable threaded diffing. Needed to take advantage of database threads. threaded: bool = True, # Maximum size of each threadpool. None = auto. Only relevant when threaded is True. # There may be many pools, so number of actual threads can be a lot higher. max_threadpool_size: Optional[int] = 1, - # Enable/disable debug prints - debug: bool = False, ) -> Iterator: - """Efficiently finds the diff between table1 and table2. + """Finds the diff between table1 and table2. + + Parameters: + key_columns (Tuple[str, ...]): Name of the key column, which uniquely identifies each row (usually id) + update_column (str, optional): Name of updated column, which signals that rows changed. + Usually updated_at or last_update. Used by `min_update` and `max_update`. + extra_columns (Tuple[str, ...], optional): Extra columns to compare + min_key (:data:`DbKey`, optional): Lowest key value, used to restrict the segment + max_key (:data:`DbKey`, optional): Highest key value, used to restrict the segment + min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment + max_update (:data:`DbTime`, optional): Highest update_column value, used to restrict the segment + algorithm (:class:`Algorithm`): Which diffing algorithm to use (`HASHDIFF` or `JOINDIFF`) + bisection_factor (int): Into how many segments to bisect per iteration. (Used when algorithm is `HASHDIFF`) + bisection_threshold (Number): Minimal row count of segment to bisect, otherwise download + and compare locally. (Used when algorithm is `HASHDIFF`). + threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. + max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. + Only relevant when `threaded` is ``True``. + There may be many pools, so number of actual threads can be a lot higher. + + Note: + The following parameters are used to override the corresponding attributes of the given :class:`TableSegment` instances: + `key_columns`, `update_column`, `extra_columns`, `min_key`, `max_key`. + If different values are needed per table, it's possible to omit them here, and instead set + them directly when creating each :class:`TableSegment`. Example: >>> table1 = connect_to_table('postgresql:///', 'Rating', 'id') >>> list(diff_tables(table1, table1)) [] + See Also: + :class:`TableSegment` + :class:`HashDiffer` + :class:`JoinDiffer` + """ + if isinstance(key_columns, str): + key_columns = (key_columns,) + tables = [table1, table2] override_attrs = { k: v for k, v in dict( - key_column=key_column, + key_columns=key_columns, update_column=update_column, extra_columns=extra_columns, min_key=min_key, @@ -83,11 +123,20 @@ def diff_tables( segments = [t.new(**override_attrs) for t in tables] if override_attrs else tables - differ = TableDiffer( - bisection_factor=bisection_factor, - bisection_threshold=bisection_threshold, - debug=debug, - threaded=threaded, - max_threadpool_size=max_threadpool_size, - ) + algorithm = Algorithm(algorithm) + if algorithm == Algorithm.HASHDIFF: + differ = HashDiffer( + bisection_factor=bisection_factor, + bisection_threshold=bisection_threshold, + threaded=threaded, + max_threadpool_size=max_threadpool_size, + ) + elif algorithm == Algorithm.JOINDIFF: + differ = JoinDiffer( + threaded=threaded, + max_threadpool_size=max_threadpool_size, + ) + else: + raise ValueError(f"Unknown algorithm: {algorithm}") + return differ.diff_tables(*segments) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index bccd132f..0ad6de11 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -4,13 +4,15 @@ import json import logging from itertools import islice +from typing import Optional import rich import click - -from .utils import remove_password_from_url, safezip, match_like -from .diff_tables import TableDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR +from .utils import eval_name_template, remove_password_from_url, safezip, match_like +from .diff_tables import Algorithm +from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR +from .joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer from .table_segment import TableSegment from .databases.database_types import create_schema from .databases.connect import connect @@ -43,13 +45,45 @@ def _get_schema(pair): return db.query_table_schema(table_path) -@click.command() +def diff_schemas(schema1, schema2, columns): + logging.info("Diffing schemas...") + attrs = "name", "type", "datetime_precision", "numeric_precision", "numeric_scale" + for c in columns: + if c is None: # Skip for convenience + continue + diffs = [] + for attr, v1, v2 in safezip(attrs, schema1[c], schema2[c]): + if v1 != v2: + diffs.append(f"{attr}:({v1} != {v2})") + if diffs: + logging.warning(f"Schema mismatch in column '{c}': {', '.join(diffs)}") + + +class MyHelpFormatter(click.HelpFormatter): + def __init__(self, **kwargs): + super().__init__(self, **kwargs) + self.indent_increment = 6 + + def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -> None: + self.write("data-diff - efficiently diff rows across database tables.\n\n") + self.write("Usage:\n") + self.write(f" * In-db diff: {prog} [OPTIONS]\n") + self.write(f" * Cross-db diff: {prog} [OPTIONS]\n") + self.write(f" * Using config: {prog} --conf PATH [--run NAME] [OPTIONS]\n") + + +click.Context.formatter_class = MyHelpFormatter + + +@click.command(no_args_is_help=True) @click.argument("database1", required=False) @click.argument("table1", required=False) @click.argument("database2", required=False) @click.argument("table2", required=False) -@click.option("-k", "--key-column", default=None, help="Name of primary key column. Default='id'.") -@click.option("-t", "--update-column", default=None, help="Name of updated_at/last_updated column") +@click.option( + "-k", "--key-columns", default=[], multiple=True, help="Names of primary key columns. Default='id'.", metavar="NAME" +) +@click.option("-t", "--update-column", default=None, help="Name of updated_at/last_updated column", metavar="NAME") @click.option( "-c", "--columns", @@ -58,13 +92,27 @@ def _get_schema(pair): help="Names of extra columns to compare." "Can be used more than once in the same command. " "Accepts a name or a pattern like in SQL. Example: -c col% -c another_col", + metavar="NAME", +) +@click.option("-l", "--limit", default=None, help="Maximum number of differences to find", metavar="NUM") +@click.option( + "--bisection-factor", + default=None, + help=f"Segments per iteration. Default={DEFAULT_BISECTION_FACTOR}.", + metavar="NUM", ) -@click.option("-l", "--limit", default=None, help="Maximum number of differences to find") -@click.option("--bisection-factor", default=None, help=f"Segments per iteration. Default={DEFAULT_BISECTION_FACTOR}.") @click.option( "--bisection-threshold", default=None, help=f"Minimal bisection threshold. Below it, data-diff will download the data and compare it locally. Default={DEFAULT_BISECTION_THRESHOLD}.", + metavar="NUM", +) +@click.option( + "-m", + "--materialize", + default=None, + metavar="TABLE_NAME", + help="(joindiff only) Materialize the diff results into a new table in the database. If a table exists by that name, it will be replaced.", ) @click.option( "--min-age", @@ -72,8 +120,11 @@ def _get_schema(pair): help="Considers only rows older than specified. Useful for specifying replication lag." "Example: --min-age=5min ignores rows from the last 5 minutes. " f"\nValid units: {UNITS_STR}", + metavar="AGE", +) +@click.option( + "--max-age", default=None, help="Considers only rows younger than specified. See --min-age.", metavar="AGE" ) -@click.option("--max-age", default=None, help="Considers only rows younger than specified. See --min-age.") @click.option("-s", "--stats", is_flag=True, help="Print stats instead of a detailed diff") @click.option("-d", "--debug", is_flag=True, help="Print debug info") @click.option("--json", "json_output", is_flag=True, help="Print JSONL output for machine readability") @@ -85,6 +136,27 @@ def _get_schema(pair): is_flag=True, help="Column names are treated as case-sensitive. Otherwise, data-diff corrects their case according to schema.", ) +@click.option( + "--assume-unique-key", + is_flag=True, + help="Skip validating the uniqueness of the key column during joindiff, which is costly in non-cloud dbs.", +) +@click.option( + "--sample-exclusive-rows", + is_flag=True, + help="Sample several rows that only appear in one of the tables, but not the other. (joindiff only)", +) +@click.option( + "--materialize-all-rows", + is_flag=True, + help="Materialize every row, even if they are the same, instead of just the differing rows. (joindiff only)", +) +@click.option( + "--table-write-limit", + default=TABLE_WRITE_LIMIT, + help=f"Maximum number of rows to write when creating materialized or sample tables, per thread. Default={TABLE_WRITE_LIMIT}", + metavar="COUNT", +) @click.option( "-j", "--threads", @@ -92,21 +164,39 @@ def _get_schema(pair): help="Number of worker threads to use per database. Default=1. " "A higher number will increase performance, but take more capacity from your database. " "'serial' guarantees a single-threaded execution of the algorithm (useful for debugging).", + metavar="COUNT", +) +@click.option( + "-w", "--where", default=None, help="An additional 'where' expression to restrict the search space.", metavar="EXPR" ) -@click.option("-w", "--where", default=None, help="An additional 'where' expression to restrict the search space.") +@click.option("-a", "--algorithm", default=Algorithm.AUTO.value, type=click.Choice([i.value for i in Algorithm])) @click.option( "--conf", default=None, help="Path to a configuration.toml file, to provide a default configuration, and a list of possible runs.", + metavar="PATH", ) @click.option( "--run", default=None, help="Name of run-configuration to run. If used, CLI arguments for database and table must be omitted.", + metavar="NAME", ) def main(conf, run, **kw): + indb_syntax = False + if kw["table2"] is None and kw["database2"]: + # Use the "database table table" form + kw["table2"] = kw["database2"] + kw["database2"] = kw["database1"] + indb_syntax = True + if conf: kw = apply_config_from_file(conf, run, kw) + + kw["algorithm"] = Algorithm(kw["algorithm"]) + if kw["algorithm"] == Algorithm.AUTO: + kw["algorithm"] = Algorithm.JOINDIFF if indb_syntax else Algorithm.HASHDIFF + return _main(**kw) @@ -115,10 +205,11 @@ def _main( table1, database2, table2, - key_column, + key_columns, update_column, columns, limit, + algorithm, bisection_factor, bisection_threshold, min_age, @@ -132,6 +223,11 @@ def _main( case_sensitive, json_output, where, + assume_unique_key, + sample_exclusive_rows, + materialize_all_rows, + table_write_limit, + materialize, threads1=None, threads2=None, __conf__=None, @@ -158,7 +254,7 @@ def _main( logging.error("Cannot specify a limit when using the -s/--stats switch") return - key_column = key_column or "id" + key_columns = key_columns or ("id",) bisection_factor = DEFAULT_BISECTION_FACTOR if bisection_factor is None else int(bisection_factor) bisection_threshold = DEFAULT_BISECTION_THRESHOLD if bisection_threshold is None else int(bisection_threshold) @@ -192,13 +288,6 @@ def _main( logging.error(f"Error while parsing age expression: {e}") return - differ = TableDiffer( - bisection_factor=bisection_factor, - bisection_threshold=bisection_threshold, - threaded=threaded, - max_threadpool_size=threads and threads * 2, - ) - if database1 is None or database2 is None: logging.error( f"Error: Databases not specified. Got {database1} and {database2}. Use --help for more information." @@ -207,7 +296,10 @@ def _main( try: db1 = connect(database1, threads1 or threads) - db2 = connect(database2, threads2 or threads) + if database1 == database2: + db2 = db1 + else: + db2 = connect(database2, threads2 or threads) except Exception as e: logging.error(e) return @@ -218,6 +310,25 @@ def _main( for db in dbs: db.enable_interactive() + if algorithm == Algorithm.JOINDIFF: + differ = JoinDiffer( + threaded=threaded, + max_threadpool_size=threads and threads * 2, + validate_unique_key=not assume_unique_key, + sample_exclusive_rows=sample_exclusive_rows, + materialize_all_rows=materialize_all_rows, + table_write_limit=table_write_limit, + materialize_to_table=materialize and db1.parse_table_name(eval_name_template(materialize)), + ) + else: + assert algorithm == Algorithm.HASHDIFF + differ = HashDiffer( + bisection_factor=bisection_factor, + bisection_threshold=bisection_threshold, + threaded=threaded, + max_threadpool_size=threads and threads * 2, + ) + table_names = table1, table2 table_paths = [db.parse_table_name(t) for db, t in safezip(dbs, table_names)] @@ -237,16 +348,27 @@ def _main( m1 = None if any(match_like(c, schema1.keys())) else f"{db1}/{table1}" m2 = None if any(match_like(c, schema2.keys())) else f"{db2}/{table2}" not_matched = ", ".join(m for m in [m1, m2] if m) - raise ValueError(f"Column {c} not found in: {not_matched}") + raise ValueError(f"Column '{c}' not found in: {not_matched}") expanded_columns |= match - columns = tuple(expanded_columns - {key_column, update_column}) + columns = tuple(expanded_columns - {*key_columns, update_column}) + + if db1 is db2: + diff_schemas( + schema1, + schema2, + ( + *key_columns, + update_column, + *columns, + ), + ) - logging.info(f"Diffing columns: key={key_column} update={update_column} extra={columns}") + logging.info(f"Diffing using columns: key={key_columns} update={update_column} extra={columns}") segments = [ - TableSegment(db, table_path, key_column, update_column, columns, **options)._with_raw_schema(raw_schema) + TableSegment(db, table_path, key_columns, update_column, columns, **options)._with_raw_schema(raw_schema) for db, table_path, raw_schema in safezip(dbs, table_paths, schemas) ] @@ -270,12 +392,17 @@ def _main( "different_+": plus, "different_-": minus, "total": max_table_count, + "stats": differ.stats, } print(json.dumps(json_output)) else: print(f"Diff-Total: {len(diff)} changed rows out of {max_table_count}") print(f"Diff-Percent: {percent:.14f}%") print(f"Diff-Split: +{plus} -{minus}") + if differ.stats: + print("Extra-Info:") + for k, v in differ.stats.items(): + print(f" {k} = {v}") else: for op, values in diff_iter: color = COLOR_SCHEME[op] @@ -284,7 +411,7 @@ def _main( jsonl = json.dumps([op, list(values)]) rich.print(f"[{color}]{jsonl}[/{color}]") else: - text = f"{op} {', '.join(values)}" + text = f"{op} {', '.join(map(str, values))}" rich.print(f"[{color}]{text}[/{color}]") sys.stdout.flush() diff --git a/data_diff/config.py b/data_diff/config.py index ad7c972d..9a6b6d54 100644 --- a/data_diff/config.py +++ b/data_diff/config.py @@ -26,17 +26,22 @@ def _apply_config(config: Dict[str, Any], run_name: str, kw: Dict[str, Any]): else: run_name = "default" - if 'database1' in kw: - for attr in ('table1', 'database2', 'table2'): + if kw.get("database1") is not None: + for attr in ("table1", "database2", "table2"): if kw[attr] is None: raise ValueError(f"Specified database1 but not {attr}. Must specify all 4 arguments, or niether.") for index in "12": - run_args[index] = {attr: kw.pop(f"{attr}{index}") for attr in ('database', 'table')} + run_args[index] = {attr: kw.pop(f"{attr}{index}") for attr in ("database", "table")} # Process databases + tables for index in "12": - args = run_args.pop(index, {}) + try: + args = run_args.pop(index) + except KeyError: + raise ConfigParseError( + f"Could not find source #{index}: Expecting a key of '{index}' containing '.database' and '.table'." + ) for attr in ("database", "table"): if attr not in args: raise ConfigParseError(f"Running 'run.{run_name}': Connection #{index} is missing attribute '{attr}'.") diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index b114937a..672c4e0b 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -1,15 +1,21 @@ +from datetime import datetime import math import sys import logging -from typing import Dict, Tuple, Optional, Sequence, Type, List -from functools import wraps +from typing import Any, Callable, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union +from functools import partial, wraps from concurrent.futures import ThreadPoolExecutor import threading from abc import abstractmethod +from uuid import UUID from data_diff.utils import is_uuid, safezip +from data_diff.queries import Expr, Compiler, table, Select, SKIP, Explain from .database_types import ( AbstractDatabase, + AbstractDialect, + AbstractMixin_MD5, + AbstractMixin_NormalizeValue, ColType, Integer, Decimal, @@ -18,14 +24,13 @@ Native_UUID, String_UUID, String_Alphanum, - String_FixedAlphanum, String_VaryingAlphanum, TemporalType, UnknownColType, Text, DbTime, + DbPath, ) -from data_diff.sql import DbPath, SqlOrStr, Compiler, Explain, Select, TableName logger = logging.getLogger("database") @@ -64,73 +69,102 @@ def _one(seq): return x -def _query_conn(conn, sql_code: str) -> list: - c = conn.cursor() - c.execute(sql_code) - if sql_code.lower().startswith("select"): - return c.fetchall() +class ThreadLocalInterpreter: + """An interpeter used to execute a sequence of queries within the same thread. + Useful for cursor-sensitive operations, such as creating a temporary table. + """ -class Database(AbstractDatabase): - """Base abstract class for databases. + def __init__(self, compiler: Compiler, gen: Generator): + self.gen = gen + self.compiler = compiler - Used for providing connection code and implementation specific SQL utilities. + def apply_queries(self, callback: Callable[[str], Any]): + q: Expr = next(self.gen) + while True: + sql = self.compiler.compile(q) + logger.debug("Running SQL (%s-TL): %s", self.compiler.database.name, sql) + try: + try: + res = callback(sql) if sql is not SKIP else SKIP + except Exception as e: + q = self.gen.throw(type(e), e) + else: + q = self.gen.send(res) + except StopIteration: + break - Instanciated using :meth:`~data_diff.connect` - """ +def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocalInterpreter]) -> list: + if isinstance(sql_code, ThreadLocalInterpreter): + return sql_code.apply_queries(callback) + else: + return callback(sql_code) + + +class BaseDialect(AbstractDialect, AbstractMixin_MD5, AbstractMixin_NormalizeValue): + SUPPORTS_PRIMARY_KEY = False TYPE_CLASSES: Dict[str, type] = {} - default_schema: str = None - SUPPORTS_ALPHANUMS = True - @property - def name(self): - return type(self).__name__ + def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): + if offset: + raise NotImplementedError("No support for OFFSET in query") - def query(self, sql_ast: SqlOrStr, res_type: type): - "Query the given SQL code/AST, and attempt to convert the result to type 'res_type'" + return f"LIMIT {limit}" - compiler = Compiler(self) - sql_code = compiler.compile(sql_ast) - logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code) - if getattr(self, "_interactive", False) and isinstance(sql_ast, Select): - explained_sql = compiler.compile(Explain(sql_ast)) - logger.info("EXPLAIN for SQL SELECT") - logger.info(self._query(explained_sql)) - answer = input("Continue? [y/n] ") - if not answer.lower() in ["y", "yes"]: - sys.exit(1) + def concat(self, items: List[str]) -> str: + assert len(items) > 1 + joined_exprs = ", ".join(items) + return f"concat({joined_exprs})" - res = self._query(sql_code) - if res_type is int: - res = _one(_one(res)) - if res is None: # May happen due to sum() of 0 items - return None - return int(res) - elif res_type is tuple: - assert len(res) == 1, (sql_code, res) - return res[0] - elif getattr(res_type, "__origin__", None) is list and len(res_type.__args__) == 1: - if res_type.__args__ in ((int,), (str,)): - return [_one(row) for row in res] - elif res_type.__args__ == (Tuple,): - return [tuple(row) for row in res] - else: - raise ValueError(res_type) - return res + def is_distinct_from(self, a: str, b: str) -> str: + return f"{a} is distinct from {b}" - def enable_interactive(self): - self._interactive = True + def timestamp_value(self, t: DbTime) -> str: + return f"'{t.isoformat()}'" - def _convert_db_precision_to_digits(self, p: int) -> int: - """Convert from binary precision, used by floats, to decimal precision.""" - # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format - return math.floor(math.log(2**p, 10)) + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + if isinstance(coltype, String_UUID): + return f"TRIM({value})" + return self.to_string(value) + + def random(self) -> str: + return "RANDOM()" + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN {query}" + + def _constant_value(self, v): + if v is None: + return "NULL" + elif isinstance(v, str): + return f"'{v}'" + elif isinstance(v, datetime): + # TODO use self.timestamp_value + return f"timestamp '{v}'" + elif isinstance(v, UUID): + return f"'{v}'" + return repr(v) + + def constant_values(self, rows) -> str: + values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows) + return f"VALUES {values}" + + def type_repr(self, t) -> str: + if isinstance(t, str): + return t + return { + int: "INT", + str: "VARCHAR", + bool: "BOOLEAN", + float: "FLOAT", + datetime: "TIMESTAMP", + }[t] def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: return self.TYPE_CLASSES.get(type_repr) - def _parse_type( + def parse_type( self, table_path: DbPath, col_name: str, @@ -172,6 +206,85 @@ def _parse_type( raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.") + def _convert_db_precision_to_digits(self, p: int) -> int: + """Convert from binary precision, used by floats, to decimal precision.""" + # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format + return math.floor(math.log(2**p, 10)) + + +class Database(AbstractDatabase): + """Base abstract class for databases. + + Used for providing connection code and implementation specific SQL utilities. + + Instanciated using :meth:`~data_diff.connect` + """ + + default_schema: str = None + dialect: AbstractDialect = None + + SUPPORTS_ALPHANUMS = True + SUPPORTS_UNIQUE_CONSTAINT = False + + _interactive = False + + @property + def name(self): + return type(self).__name__ + + def query(self, sql_ast: Union[Expr, Generator], res_type: type = list): + "Query the given SQL code/AST, and attempt to convert the result to type 'res_type'" + + compiler = Compiler(self) + if isinstance(sql_ast, Generator): + sql_code = ThreadLocalInterpreter(compiler, sql_ast) + elif isinstance(sql_ast, list): + for i in sql_ast[:-1]: + self.query(i) + return self.query(sql_ast[-1], res_type) + else: + sql_code = compiler.compile(sql_ast) + if sql_code is SKIP: + return SKIP + + logger.debug("Running SQL (%s): %s", self.name, sql_code) + + if self._interactive and isinstance(sql_ast, Select): + explained_sql = compiler.compile(Explain(sql_ast)) + explain = self._query(explained_sql) + for row in explain: + # Most returned a 1-tuple. Presto returns a string + if isinstance(row, tuple): + (row,) = row + logger.debug("EXPLAIN: %s", row) + answer = input("Continue? [y/n] ") + if answer.lower() not in ["y", "yes"]: + sys.exit(1) + + res = self._query(sql_code) + if res_type is int: + res = _one(_one(res)) + if res is None: # May happen due to sum() of 0 items + return None + return int(res) + elif res_type is datetime: + res = _one(_one(res)) + return res # XXX parse timestamp? + elif res_type is tuple: + assert len(res) == 1, (sql_code, res) + return res[0] + elif getattr(res_type, "__origin__", None) is list and len(res_type.__args__) == 1: + if res_type.__args__ in ((int,), (str,)): + return [_one(row) for row in res] + elif res_type.__args__ in [(Tuple,), (tuple,)]: + return [tuple(row) for row in res] + else: + raise ValueError(res_type) + return res + + def enable_interactive(self): + self._interactive = True + def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) @@ -190,19 +303,36 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: assert len(d) == len(rows) return d + def select_table_unique_columns(self, path: DbPath) -> str: + schema, table = self._normalize_table_path(path) + + return ( + "SELECT column_name " + "FROM information_schema.key_column_usage " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + ) + + def query_table_unique_columns(self, path: DbPath) -> List[str]: + if not self.SUPPORTS_UNIQUE_CONSTAINT: + raise NotImplementedError("This database doesn't support 'unique' constraints") + res = self.query(self.select_table_unique_columns(path), List[str]) + return list(res) + def _process_table_schema( self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None ): accept = {i.lower() for i in filter_columns} - col_dict = {row[0]: self._parse_type(path, *row) for name, row in raw_schema.items() if name.lower() in accept} + col_dict = { + row[0]: self.dialect.parse_type(path, *row) for name, row in raw_schema.items() if name.lower() in accept + } self._refine_coltypes(path, col_dict, where) # Return a dict of form {name: type} after normalization return col_dict - def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], where: str = None): + def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], where: str = None, sample_size=32): """Refine the types in the column dict, by querying the database for a sample of their values 'where' restricts the rows to be sampled. @@ -212,8 +342,8 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe if not text_columns: return - fields = [self.normalize_uuid(c, String_UUID()) for c in text_columns] - samples_by_row = self.query(Select(fields, TableName(table_path), limit=16, where=where and [where]), list) + fields = [self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID()) for c in text_columns] + samples_by_row = self.query(table(*table_path).select(*fields).where(where or SKIP).limit(sample_size), list) if not samples_by_row: raise ValueError(f"Table {table_path} is empty.") @@ -241,13 +371,7 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe ) else: assert col_name in col_dict - lens = set(map(len, alphanum_samples)) - if len(lens) > 1: - col_dict[col_name] = String_VaryingAlphanum() - else: - (length,) = lens - col_dict[col_name] = String_FixedAlphanum(length=length) - continue + col_dict[col_name] = String_VaryingAlphanum() # @lru_cache() # def get_table_schema(self, path: DbPath) -> Dict[str, ColType]: @@ -265,27 +389,21 @@ def _normalize_table_path(self, path: DbPath) -> DbPath: def parse_table_name(self, name: str) -> DbPath: return parse_table_name(name) - def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): - if offset: - raise NotImplementedError("No support for OFFSET in query") - - return f"LIMIT {limit}" - - def concat(self, l: List[str]) -> str: - assert len(l) > 1 - joined_exprs = ", ".join(l) - return f"concat({joined_exprs})" - - def is_distinct_from(self, a: str, b: str) -> str: - return f"{a} is distinct from {b}" - - def timestamp_value(self, t: DbTime) -> str: - return f"'{t.isoformat()}'" + def _query_cursor(self, c, sql_code: str): + assert isinstance(sql_code, str), sql_code + try: + c.execute(sql_code) + if sql_code.lower().startswith(("select", "explain", "show")): + return c.fetchall() + except Exception as e: + # logger.exception(e) + # logger.error(f'Caused by SQL: {sql_code}') + raise - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - if isinstance(coltype, String_UUID): - return f"TRIM({value})" - return self.to_string(value) + def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> list: + c = conn.cursor() + callback = partial(self._query_cursor, c) + return apply_query(callback, sql_code) class ThreadedDatabase(Database): @@ -307,15 +425,15 @@ def set_conn(self): except ModuleNotFoundError as e: self._init_error = e - def _query(self, sql_code: str): + def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): r = self._queue.submit(self._query_in_worker, sql_code) return r.result() - def _query_in_worker(self, sql_code: str): + def _query_in_worker(self, sql_code: Union[str, ThreadLocalInterpreter]): "This method runs in a worker thread" if self._init_error: raise self._init_error - return _query_conn(self.thread_local.conn, sql_code) + return self._query_conn(self.thread_local.conn, sql_code) @abstractmethod def create_connection(self): @@ -324,6 +442,10 @@ def create_connection(self): def close(self): self._queue.shutdown() + @property + def is_autocommit(self) -> bool: + return False + CHECKSUM_HEXDIGITS = 15 # Must be 15 or lower MD5_HEXDIGITS = 32 diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 411ae795..0aa7670a 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -1,6 +1,7 @@ -from .database_types import * -from .base import Database, import_helper, parse_table_name, ConnectError -from .base import TIMESTAMP_PRECISION_POS +from typing import List, Union +from .database_types import Timestamp, Datetime, Integer, Decimal, Float, Text, DbPath, FractionalType, TemporalType +from .base import BaseDialect, Database, import_helper, parse_table_name, ConnectError, apply_query +from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter @import_helper(text="Please install BigQuery and configure your google-cloud access.") @@ -10,7 +11,9 @@ def import_bigquery(): return bigquery -class BigQuery(Database): +class Dialect(BaseDialect): + name = "BigQuery" + ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation TYPE_CLASSES = { # Dates "TIMESTAMP": Timestamp, @@ -25,7 +28,46 @@ class BigQuery(Database): # Text "STRING": Text, } - ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation + + def random(self) -> str: + return "RAND()" + + def quote(self, s: str): + return f"`{s}`" + + def md5_as_int(self, s: str) -> str: + return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)" + + def to_string(self, s: str): + return f"cast({s} as string)" + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" + return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" + + if coltype.precision == 0: + return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000, {value})" + elif coltype.precision == 6: + return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" + + timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return f"format('%.{coltype.precision}f', {value})" + + def type_repr(self, t) -> str: + try: + return {str: "STRING", float: "FLOAT64"}[t] + except KeyError: + return super().type_repr(t) + + +class BigQuery(Database): + dialect = Dialect() def __init__(self, project, *, dataset, **kw): bigquery = import_bigquery() @@ -36,18 +78,12 @@ def __init__(self, project, *, dataset, **kw): self.default_schema = dataset - def quote(self, s: str): - return f"`{s}`" - - def md5_to_int(self, s: str) -> str: - return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)" - def _normalize_returned_value(self, value): if isinstance(value, bytes): return value.decode() return value - def _query(self, sql_code: str): + def _query_atom(self, sql_code: str): from google.cloud import bigquery try: @@ -60,8 +96,8 @@ def _query(self, sql_code: str): res = [tuple(self._normalize_returned_value(v) for v in row.values()) for row in res] return res - def to_string(self, s: str): - return f"cast({s} as string)" + def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): + return apply_query(self._query_atom, sql_code) def close(self): self._client.close() @@ -74,24 +110,13 @@ def select_table_schema(self, path: DbPath) -> str: f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" - return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" - - if coltype.precision == 0: - return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000, {value})" - elif coltype.precision == 6: - return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" - - timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" - return ( - f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return f"format('%.{coltype.precision}f', {value})" + def query_table_unique_columns(self, path: DbPath) -> List[str]: + return [] def parse_table_name(self, name: str) -> DbPath: path = parse_table_name(name) return self._normalize_table_path(path) + + @property + def is_autocommit(self) -> bool: + return True diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index f0f7a5ad..b5f2f577 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -4,11 +4,22 @@ MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS, + BaseDialect, ThreadedDatabase, import_helper, ConnectError, ) -from .database_types import ColType, Decimal, Float, Integer, FractionalType, Native_UUID, TemporalType, Text, Timestamp +from .database_types import ( + ColType, + Decimal, + Float, + Integer, + FractionalType, + Native_UUID, + TemporalType, + Text, + Timestamp, +) @import_helper("clickhouse") @@ -18,7 +29,9 @@ def import_clickhouse(): return clickhouse_driver -class Clickhouse(ThreadedDatabase): +class Dialect(BaseDialect): + name = "Clickhouse" + ROUNDS_ON_PREC_LOSS = False TYPE_CLASSES = { "Int8": Integer, "Int16": Integer, @@ -41,70 +54,6 @@ class Clickhouse(ThreadedDatabase): "DateTime": Timestamp, "DateTime64": Timestamp, } - ROUNDS_ON_PREC_LOSS = False - - def __init__(self, *, thread_count: int, **kw): - super().__init__(thread_count=thread_count) - - self._args = kw - # In Clickhouse database and schema are the same - self.default_schema = kw["database"] - - def create_connection(self): - clickhouse = import_clickhouse() - - class SingleConnection(clickhouse.dbapi.connection.Connection): - """Not thread-safe connection to Clickhouse""" - - def cursor(self, cursor_factory=None): - if not len(self.cursors): - _ = super().cursor() - return self.cursors[0] - - try: - return SingleConnection(**self._args) - except clickhouse.OperationError as e: - raise ConnectError(*e.args) from e - - def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: - nullable_prefix = "Nullable(" - if type_repr.startswith(nullable_prefix): - type_repr = type_repr[len(nullable_prefix) :].rstrip(")") - - if type_repr.startswith("Decimal"): - type_repr = "Decimal" - elif type_repr.startswith("FixedString"): - type_repr = "FixedString" - elif type_repr.startswith("DateTime64"): - type_repr = "DateTime64" - - return self.TYPE_CLASSES.get(type_repr) - - def quote(self, s: str) -> str: - return f'"{s}"' - - def md5_to_int(self, s: str) -> str: - substr_idx = 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS - return f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx})))))" - - def to_string(self, s: str) -> str: - return f"toString({s})" - - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - prec = coltype.precision - if coltype.rounds: - timestamp = f"toDateTime64(round(toUnixTimestamp64Micro(toDateTime64({value}, 6)) / 1000000, {prec}), 6)" - return self.to_string(timestamp) - - fractional = f"toUnixTimestamp64Micro(toDateTime64({value}, {prec})) % 1000000" - fractional = f"lpad({self.to_string(fractional)}, 6, '0')" - value = f"formatDateTime({value}, '%Y-%m-%d %H:%M:%S') || '.' || {self.to_string(fractional)}" - return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')" - - def _convert_db_precision_to_digits(self, p: int) -> int: - # Done the same as for PostgreSQL but need to rewrite in another way - # because it does not help for float with a big integer part. - return super()._convert_db_precision_to_digits(p) - 2 def normalize_number(self, value: str, coltype: FractionalType) -> str: # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. @@ -150,3 +99,74 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: ) """ return value + + def quote(self, s: str) -> str: + return f'"{s}"' + + def md5_as_int(self, s: str) -> str: + substr_idx = 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS + return f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx})))))" + + def to_string(self, s: str) -> str: + return f"toString({s})" + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + prec = coltype.precision + if coltype.rounds: + timestamp = f"toDateTime64(round(toUnixTimestamp64Micro(toDateTime64({value}, 6)) / 1000000, {prec}), 6)" + return self.to_string(timestamp) + + fractional = f"toUnixTimestamp64Micro(toDateTime64({value}, {prec})) % 1000000" + fractional = f"lpad({self.to_string(fractional)}, 6, '0')" + value = f"formatDateTime({value}, '%Y-%m-%d %H:%M:%S') || '.' || {self.to_string(fractional)}" + return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')" + + def _convert_db_precision_to_digits(self, p: int) -> int: + # Done the same as for PostgreSQL but need to rewrite in another way + # because it does not help for float with a big integer part. + return super()._convert_db_precision_to_digits(p) - 2 + + def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: + nullable_prefix = "Nullable(" + if type_repr.startswith(nullable_prefix): + type_repr = type_repr[len(nullable_prefix) :].rstrip(")") + + if type_repr.startswith("Decimal"): + type_repr = "Decimal" + elif type_repr.startswith("FixedString"): + type_repr = "FixedString" + elif type_repr.startswith("DateTime64"): + type_repr = "DateTime64" + + return self.TYPE_CLASSES.get(type_repr) + + +class Clickhouse(ThreadedDatabase): + dialect = Dialect() + + def __init__(self, *, thread_count: int, **kw): + super().__init__(thread_count=thread_count) + + self._args = kw + # In Clickhouse database and schema are the same + self.default_schema = kw["database"] + + def create_connection(self): + clickhouse = import_clickhouse() + + class SingleConnection(clickhouse.dbapi.connection.Connection): + """Not thread-safe connection to Clickhouse""" + + def cursor(self, cursor_factory=None): + if not len(self.cursors): + _ = super().cursor() + return self.cursors[0] + + try: + return SingleConnection(**self._args) + except clickhouse.OperationError as e: + raise ConnectError(*e.args) from e + + @property + def is_autocommit(self) -> bool: + return True diff --git a/data_diff/databases/connect.py b/data_diff/databases/connect.py index 94cb52d2..8468a734 100644 --- a/data_diff/databases/connect.py +++ b/data_diff/databases/connect.py @@ -184,6 +184,8 @@ def connect(db_conf: Union[str, dict], thread_count: Optional[int] = 1) -> Datab Configuration can be given either as a URI string, or as a dict of {option: value}. + The dictionary configuration uses the same keys as the TOML 'database' definition given with --conf. + thread_count determines the max number of worker threads per database, if relevant. None means no limit. @@ -205,6 +207,12 @@ def connect(db_conf: Union[str, dict], thread_count: Optional[int] = 1) -> Datab - trino - clickhouse - vertica + + Example: + >>> connect("mysql://localhost/db") + + >>> connect({"driver": "mysql", "host": "localhost", "database": "db"}) + """ if isinstance(db_conf, str): return connect_to_uri(db_conf, thread_count) diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index e93e380e..296ad475 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -140,22 +140,29 @@ class UnknownColType(ColType): supported = False -class AbstractDatabase(ABC): +class AbstractDialect(ABC): + """Dialect-dependent query expressions""" + name: str + @property @abstractmethod - def quote(self, s: str): - "Quote SQL name (implementation specific)" - ... + def name(self) -> str: + "Name of the dialect" + @property @abstractmethod - def to_string(self, s: str) -> str: - "Provide SQL for casting a column to string" + def ROUNDS_ON_PREC_LOSS(self) -> bool: + "True if db rounds real values when losing precision, False if it truncates." + + @abstractmethod + def quote(self, s: str): + "Quote SQL name" ... @abstractmethod - def concat(self, l: List[str]) -> str: - "Provide SQL for concatenating a bunch of column into a string" + def concat(self, items: List[str]) -> str: + "Provide SQL for concatenating a bunch of columns into a string" ... @abstractmethod @@ -164,14 +171,14 @@ def is_distinct_from(self, a: str, b: str) -> str: ... @abstractmethod - def timestamp_value(self, t: DbTime) -> str: - "Provide SQL for the given timestamp value" + def to_string(self, s: str) -> str: + # TODO rewrite using cast_to(x, str) + "Provide SQL for casting a column to string" ... @abstractmethod - def md5_to_int(self, s: str) -> str: - "Provide SQL for computing md5 and returning an int" - ... + def random(self) -> str: + "Provide SQL for generating a random number betweein 0..1" @abstractmethod def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): @@ -179,45 +186,29 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None ... @abstractmethod - def _query(self, sql_code: str) -> list: - "Send query to database and return result" + def explain_as_text(self, query: str) -> str: + "Provide SQL for explaining a query, returned as table(varchar)" ... @abstractmethod - def select_table_schema(self, path: DbPath) -> str: - "Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)" - ... - - @abstractmethod - def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: - """Query the table for its schema for table in 'path', and return {column: tuple} - where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?) - """ - ... - - @abstractmethod - def _process_table_schema( - self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None - ): - """Process the result of query_table_schema(). - - Done in a separate step, to minimize the amount of processed columns. - Needed because processing each column may: - * throw errors and warnings - * query the database to sample values - - """ - - @abstractmethod - def parse_table_name(self, name: str) -> DbPath: - "Parse the given table name into a DbPath" + def timestamp_value(self, t: datetime) -> str: + "Provide SQL for the given timestamp value" ... @abstractmethod - def close(self): - "Close connection(s) to the database instance. Querying will stop functioning." - ... - + def parse_type( + self, + table_path: DbPath, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, + ) -> ColType: + "Parse type info as returned by the database" + + +class AbstractMixin_NormalizeValue(ABC): @abstractmethod def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: """Creates an SQL expression, that converts 'value' to a normalized timestamp. @@ -283,10 +274,76 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: return self.normalize_uuid(value, coltype) return self.to_string(value) + +class AbstractMixin_MD5(ABC): + """Dialect-dependent query expressions, that are specific to data-diff""" + + @abstractmethod + def md5_as_int(self, s: str) -> str: + "Provide SQL for computing md5 and returning an int" + ... + + +class AbstractDatabase: + @abstractmethod + def _query(self, sql_code: str) -> list: + "Send query to database and return result" + ... + + @abstractmethod + def select_table_schema(self, path: DbPath) -> str: + "Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)" + ... + + @abstractmethod + def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: + """Query the table for its schema for table in 'path', and return {column: tuple} + where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?) + """ + ... + + @abstractmethod + def select_table_unique_columns(self, path: DbPath) -> str: + "Provide SQL for selecting the names of unique columns in the table" + ... + + @abstractmethod + def query_table_unique_columns(self, path: DbPath) -> List[str]: + """Query the table for its unique columns for table in 'path', and return {column}""" + ... + + @abstractmethod + def _process_table_schema( + self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None + ): + """Process the result of query_table_schema(). + + Done in a separate step, to minimize the amount of processed columns. + Needed because processing each column may: + * throw errors and warnings + * query the database to sample values + + """ + + @abstractmethod + def parse_table_name(self, name: str) -> DbPath: + "Parse the given table name into a DbPath" + ... + + @abstractmethod + def close(self): + "Close connection(s) to the database instance. Querying will stop functioning." + ... + @abstractmethod def _normalize_table_path(self, path: DbPath) -> DbPath: ... + @property + @abstractmethod + def is_autocommit(self) -> bool: + ... + Schema = CaseAwareMapping diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index b0ee9fa5..79c46fc7 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -1,7 +1,20 @@ +import math +from typing import Dict, Sequence import logging -from .database_types import * -from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Database, import_helper, _query_conn, parse_table_name +from .database_types import ( + Integer, + Float, + Decimal, + Timestamp, + Text, + TemporalType, + NumericType, + DbPath, + ColType, + UnknownColType, +) +from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, BaseDialect, ThreadedDatabase, import_helper, parse_table_name @import_helper(text="You can install it using 'pip install databricks-sql-connector'") @@ -11,7 +24,9 @@ def import_databricks(): return databricks -class Databricks(Database): +class Dialect(BaseDialect): + name = "Databricks" + ROUNDS_ON_PREC_LOSS = True TYPE_CLASSES = { # Numbers "INT": Integer, @@ -27,59 +42,77 @@ class Databricks(Database): "STRING": Text, } - ROUNDS_ON_PREC_LOSS = True - - def __init__( - self, - http_path: str, - access_token: str, - server_hostname: str, - catalog: str = "hive_metastore", - schema: str = "default", - **kwargs, - ): - databricks = import_databricks() - - self._conn = databricks.sql.connect( - server_hostname=server_hostname, http_path=http_path, access_token=access_token - ) - - logging.getLogger("databricks.sql").setLevel(logging.WARNING) - - self.catalog = catalog - self.default_schema = schema - self.kwargs = kwargs - - def _query(self, sql_code: str) -> list: - "Uses the standard SQL cursor interface" - return _query_conn(self._conn, sql_code) - def quote(self, s: str): return f"`{s}`" - def md5_to_int(self, s: str) -> str: + def md5_as_int(self, s: str) -> str: return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0))" def to_string(self, s: str) -> str: return f"cast({s} as string)" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + """Databricks timestamp contains no more than 6 digits in precision""" + + if coltype.rounds: + timestamp = f"cast(round(unix_micros({value}) / 1000000, {coltype.precision}) * 1000000 as bigint)" + return f"date_format(timestamp_micros({timestamp}), 'yyyy-MM-dd HH:mm:ss.SSSSSS')" + + precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision) + return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')" + + def normalize_number(self, value: str, coltype: NumericType) -> str: + value = f"cast({value} as decimal(38, {coltype.precision}))" + if coltype.precision > 0: + value = f"format_number({value}, {coltype.precision})" + return f"replace({self.to_string(value)}, ',', '')" + def _convert_db_precision_to_digits(self, p: int) -> int: - # Subtracting 1 due to wierd precision issues - return max(super()._convert_db_precision_to_digits(p) - 1, 0) + # Subtracting 2 due to wierd precision issues + return max(super()._convert_db_precision_to_digits(p) - 2, 0) + + +class Databricks(ThreadedDatabase): + dialect = Dialect() + + def __init__(self, *, thread_count, **kw): + logging.getLogger("databricks.sql").setLevel(logging.WARNING) + + self._args = kw + self.default_schema = kw.get("schema", "hive_metastore") + super().__init__(thread_count=thread_count) + + def create_connection(self): + databricks = import_databricks() + + try: + return databricks.sql.connect( + server_hostname=self._args["server_hostname"], + http_path=self._args["http_path"], + access_token=self._args["access_token"], + catalog=self._args["catalog"], + ) + except databricks.sql.exc.Error as e: + raise ConnectionError(*e.args) from e def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: # Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL. # https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html # So, to obtain information about schema, we should use another approach. + conn = self.create_connection() + schema, table = self._normalize_table_path(path) - with self._conn.cursor() as cursor: - cursor.columns(catalog_name=self.catalog, schema_name=schema, table_name=table) - rows = cursor.fetchall() + with conn.cursor() as cursor: + cursor.columns(catalog_name=self._args["catalog"], schema_name=schema, table_name=table) + try: + rows = cursor.fetchall() + finally: + conn.close() if not rows: raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") - d = {r.COLUMN_NAME: r for r in rows} + d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows} assert len(d) == len(rows) return d @@ -91,51 +124,38 @@ def _process_table_schema( resulted_rows = [] for row in rows: - row_type = "DECIMAL" if row.DATA_TYPE == 3 else row.TYPE_NAME - type_cls = self.TYPE_CLASSES.get(row_type, UnknownColType) + row_type = "DECIMAL" if row[1].startswith("DECIMAL") else row[1] + type_cls = self.dialect.TYPE_CLASSES.get(row_type, UnknownColType) if issubclass(type_cls, Integer): - row = (row.COLUMN_NAME, row_type, None, None, 0) + row = (row[0], row_type, None, None, 0) elif issubclass(type_cls, Float): - numeric_precision = self._convert_db_precision_to_digits(row.DECIMAL_DIGITS) - row = (row.COLUMN_NAME, row_type, None, numeric_precision, None) + numeric_precision = math.ceil(row[2] / math.log(2, 10)) + row = (row[0], row_type, None, numeric_precision, None) elif issubclass(type_cls, Decimal): - # TYPE_NAME has a format DECIMAL(x,y) - items = row.TYPE_NAME[8:].rstrip(")").split(",") + items = row[1][8:].rstrip(")").split(",") numeric_precision, numeric_scale = int(items[0]), int(items[1]) - row = (row.COLUMN_NAME, row_type, None, numeric_precision, numeric_scale) + row = (row[0], row_type, None, numeric_precision, numeric_scale) elif issubclass(type_cls, Timestamp): - row = (row.COLUMN_NAME, row_type, row.DECIMAL_DIGITS, None, None) + row = (row[0], row_type, row[2], None, None) else: - row = (row.COLUMN_NAME, row_type, None, None, None) + row = (row[0], row_type, None, None, None) resulted_rows.append(row) - col_dict: Dict[str, ColType] = {row[0]: self._parse_type(path, *row) for row in resulted_rows} + col_dict: Dict[str, ColType] = {row[0]: self.dialect.parse_type(path, *row) for row in resulted_rows} self._refine_coltypes(path, col_dict, where) return col_dict - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - """Databricks timestamp contains no more than 6 digits in precision""" - - if coltype.rounds: - timestamp = f"cast(round(unix_micros({value}) / 1000000, {coltype.precision}) * 1000000 as bigint)" - return f"date_format(timestamp_micros({timestamp}), 'yyyy-MM-dd HH:mm:ss.SSSSSS')" - - precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision) - return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')" - - def normalize_number(self, value: str, coltype: NumericType) -> str: - return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") - def parse_table_name(self, name: str) -> DbPath: path = parse_table_name(name) return self._normalize_table_path(path) - def close(self): - self._conn.close() + @property + def is_autocommit(self) -> bool: + return True diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index 9029ff14..8d394e3c 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -17,7 +17,7 @@ # def quote(self, s: str): # return f"[{s}]" -# def md5_to_int(self, s: str) -> str: +# def md5_as_int(self, s: str) -> str: # return f"CONVERT(decimal(38,0), CONVERT(bigint, HashBytes('MD5', {s}), 2))" # # return f"CONVERT(bigint, (CHECKSUM({s})))" diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 7e89b184..8f8e1730 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -1,5 +1,15 @@ -from .database_types import * -from .base import ThreadedDatabase, import_helper, ConnectError +from .database_types import ( + Datetime, + Timestamp, + Float, + Decimal, + Integer, + Text, + TemporalType, + FractionalType, + ColType_UUID, +) +from .base import ThreadedDatabase, import_helper, ConnectError, BaseDialect from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS @@ -10,7 +20,10 @@ def import_mysql(): return mysql.connector -class MySQL(ThreadedDatabase): +class Dialect(BaseDialect): + name = "MySQL" + ROUNDS_ON_PREC_LOSS = True + SUPPORTS_PRIMARY_KEY = True TYPE_CLASSES = { # Dates "datetime": Datetime, @@ -27,32 +40,11 @@ class MySQL(ThreadedDatabase): "varbinary": Text, "binary": Text, } - ROUNDS_ON_PREC_LOSS = True - SUPPORTS_ALPHANUMS = False - - def __init__(self, *, thread_count, **kw): - self._args = kw - - super().__init__(thread_count=thread_count) - - # In MySQL schema and database are synonymous - self.default_schema = kw["database"] - - def create_connection(self): - mysql = import_mysql() - try: - return mysql.connect(charset="utf8", use_unicode=True, **self._args) - except mysql.Error as e: - if e.errno == mysql.errorcode.ER_ACCESS_DENIED_ERROR: - raise ConnectError("Bad user name or password") from e - elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR: - raise ConnectError("Database does not exist") from e - raise ConnectError(*e.args) from e def quote(self, s: str): return f"`{s}`" - def md5_to_int(self, s: str) -> str: + def md5_as_int(self, s: str) -> str: return f"cast(conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as unsigned)" def to_string(self, s: str): @@ -73,3 +65,42 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: def is_distinct_from(self, a: str, b: str) -> str: return f"not ({a} <=> {b})" + + def random(self) -> str: + return "RAND()" + + def type_repr(self, t) -> str: + try: + return { + str: "VARCHAR(1024)", + }[t] + except KeyError: + return super().type_repr(t) + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN FORMAT=TREE {query}" + + +class MySQL(ThreadedDatabase): + dialect = Dialect() + SUPPORTS_ALPHANUMS = False + SUPPORTS_UNIQUE_CONSTAINT = True + + def __init__(self, *, thread_count, **kw): + self._args = kw + + super().__init__(thread_count=thread_count) + + # In MySQL schema and database are synonymous + self.default_schema = kw["database"] + + def create_connection(self): + mysql = import_mysql() + try: + return mysql.connect(charset="utf8", use_unicode=True, **self._args) + except mysql.Error as e: + if e.errno == mysql.errorcode.ER_ACCESS_DENIED_ERROR: + raise ConnectError("Bad user name or password") from e + elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR: + raise ConnectError("Database does not exist") from e + raise ConnectError(*e.args) from e diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 76387010..64127e9a 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -1,7 +1,20 @@ -from ..utils import match_regexps +from typing import Dict, List, Optional -from .database_types import * -from .base import ThreadedDatabase, import_helper, ConnectError, QueryError +from ..utils import match_regexps +from .database_types import ( + Decimal, + Float, + Text, + DbPath, + TemporalType, + ColType, + DbTime, + ColType_UUID, + Timestamp, + TimestampTZ, + FractionalType, +) +from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, QueryError from .base import TIMESTAMP_PRECISION_POS SESSION_TIME_ZONE = None # Changed by the tests @@ -14,7 +27,9 @@ def import_oracle(): return cx_Oracle -class Oracle(ThreadedDatabase): +class Dialect(BaseDialect): + name = "Oracle" + SUPPORTS_PRIMARY_KEY = True TYPE_CLASSES: Dict[str, type] = { "NUMBER": Decimal, "FLOAT": Float, @@ -26,30 +41,7 @@ class Oracle(ThreadedDatabase): } ROUNDS_ON_PREC_LOSS = True - def __init__(self, *, host, database, thread_count, **kw): - self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) - - self.default_schema = kw.get("user") - - super().__init__(thread_count=thread_count) - - def create_connection(self): - self._oracle = import_oracle() - try: - c = self._oracle.connect(**self.kwargs) - if SESSION_TIME_ZONE: - c.cursor().execute(f"ALTER SESSION SET TIME_ZONE = '{SESSION_TIME_ZONE}'") - return c - except Exception as e: - raise ConnectError(*e.args) from e - - def _query(self, sql_code: str): - try: - return super()._query(sql_code) - except self._oracle.DatabaseError as e: - raise QueryError(e) - - def md5_to_int(self, s: str) -> str: + def md5_as_int(self, s: str) -> str: # standard_hash is faster than DBMS_CRYPTO.Hash # TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ? return f"to_number(substr(standard_hash({s}, 'MD5'), 18), 'xxxxxxxxxxxxxxx')" @@ -60,12 +52,40 @@ def quote(self, s: str): def to_string(self, s: str): return f"cast({s} as varchar(1024))" - def select_table_schema(self, path: DbPath) -> str: - schema, table = self._normalize_table_path(path) + def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): + if offset: + raise NotImplementedError("No support for OFFSET in query") - return ( - f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale" - f" FROM ALL_TAB_COLUMNS WHERE table_name = '{table.upper()}' AND owner = '{schema.upper()}'" + return f"FETCH NEXT {limit} ROWS ONLY" + + def concat(self, items: List[str]) -> str: + joined_exprs = " || ".join(items) + return f"({joined_exprs})" + + def timestamp_value(self, t: DbTime) -> str: + return "timestamp '%s'" % t.isoformat(" ") + + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + # Cast is necessary for correct MD5 (trimming not enough) + return f"CAST(TRIM({value}) AS VARCHAR(36))" + + def random(self) -> str: + return "dbms_random.value" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"DECODE({a}, {b}, 1, 0) = 0" + + def type_repr(self, t) -> str: + try: + return { + str: "VARCHAR(1024)", + }[t] + except KeyError: + return super().type_repr(t) + + def constant_values(self, rows) -> str: + return " UNION ALL ".join( + "SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows ) def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: @@ -85,7 +105,10 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: format_str += "0." + "9" * (coltype.precision - 1) + "0" return f"to_char({value}, '{format_str}')" - def _parse_type( + def explain_as_text(self, query: str) -> str: + raise NotImplementedError("Explain not yet implemented in Oracle") + + def parse_type( self, table_path: DbPath, col_name: str, @@ -104,23 +127,39 @@ def _parse_type( precision = int(m.group(1)) return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) - return super()._parse_type( - table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale - ) + return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) - def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): - if offset: - raise NotImplementedError("No support for OFFSET in query") - return f"FETCH NEXT {limit} ROWS ONLY" +class Oracle(ThreadedDatabase): + dialect = Dialect() - def concat(self, l: List[str]) -> str: - joined_exprs = " || ".join(l) - return f"({joined_exprs})" + def __init__(self, *, host, database, thread_count, **kw): + self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) - def timestamp_value(self, t: DbTime) -> str: - return "timestamp '%s'" % t.isoformat(" ") + self.default_schema = kw.get("user") - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - # Cast is necessary for correct MD5 (trimming not enough) - return f"CAST(TRIM({value}) AS VARCHAR(36))" + super().__init__(thread_count=thread_count) + + def create_connection(self): + self._oracle = import_oracle() + try: + c = self._oracle.connect(**self.kwargs) + if SESSION_TIME_ZONE: + c.cursor().execute(f"ALTER SESSION SET TIME_ZONE = '{SESSION_TIME_ZONE}'") + return c + except Exception as e: + raise ConnectError(*e.args) from e + + def _query_cursor(self, c, sql_code: str): + try: + return super()._query_cursor(c, sql_code) + except self._oracle.DatabaseError as e: + raise QueryError(e) + + def select_table_schema(self, path: DbPath) -> str: + schema, table = self._normalize_table_path(path) + + return ( + f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale" + f" FROM ALL_TAB_COLUMNS WHERE table_name = '{table.upper()}' AND owner = '{schema.upper()}'" + ) diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index d65ac7de..27df1273 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -1,5 +1,15 @@ -from .database_types import * -from .base import ThreadedDatabase, import_helper, ConnectError +from .database_types import ( + Timestamp, + TimestampTZ, + Float, + Decimal, + Integer, + TemporalType, + Native_UUID, + Text, + FractionalType, +) +from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, _CHECKSUM_BITSIZE, TIMESTAMP_PRECISION_POS SESSION_TIME_ZONE = None # Changed by the tests @@ -14,7 +24,11 @@ def import_postgresql(): return psycopg2 -class PostgreSQL(ThreadedDatabase): +class PostgresqlDialect(BaseDialect): + name = "PostgreSQL" + ROUNDS_ON_PREC_LOSS = True + SUPPORTS_PRIMARY_KEY = True + TYPE_CLASSES = { # Timestamps "timestamp with time zone": TimestampTZ, @@ -35,36 +49,11 @@ class PostgreSQL(ThreadedDatabase): # UUID "uuid": Native_UUID, } - ROUNDS_ON_PREC_LOSS = True - - default_schema = "public" - - def __init__(self, *, thread_count, **kw): - self._args = kw - - super().__init__(thread_count=thread_count) - - def _convert_db_precision_to_digits(self, p: int) -> int: - # Subtracting 2 due to wierd precision issues in PostgreSQL - return super()._convert_db_precision_to_digits(p) - 2 - - def create_connection(self): - if not self._args: - self._args["host"] = None # psycopg2 requires 1+ arguments - - pg = import_postgresql() - try: - c = pg.connect(**self._args) - if SESSION_TIME_ZONE: - c.cursor().execute(f"SET TIME ZONE '{SESSION_TIME_ZONE}'") - return c - except pg.OperationalError as e: - raise ConnectError(*e.args) from e def quote(self, s: str): return f'"{s}"' - def md5_to_int(self, s: str) -> str: + def md5_as_int(self, s: str) -> str: return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint" def to_string(self, s: str): @@ -81,3 +70,32 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: def normalize_number(self, value: str, coltype: FractionalType) -> str: return self.to_string(f"{value}::decimal(38, {coltype.precision})") + + def _convert_db_precision_to_digits(self, p: int) -> int: + # Subtracting 2 due to wierd precision issues in PostgreSQL + return super()._convert_db_precision_to_digits(p) - 2 + + +class PostgreSQL(ThreadedDatabase): + dialect = PostgresqlDialect() + SUPPORTS_UNIQUE_CONSTAINT = True + + default_schema = "public" + + def __init__(self, *, thread_count, **kw): + self._args = kw + + super().__init__(thread_count=thread_count) + + def create_connection(self): + if not self._args: + self._args["host"] = None # psycopg2 requires 1+ arguments + + pg = import_postgresql() + try: + c = pg.connect(**self._args) + if SESSION_TIME_ZONE: + c.cursor().execute(f"SET TIME ZONE '{SESSION_TIME_ZONE}'") + return c + except pg.OperationalError as e: + raise ConnectError(*e.args) from e diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index 5ee98770..de54f5b5 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -1,9 +1,23 @@ +from functools import partial import re -from ..utils import match_regexps - -from .database_types import * -from .base import Database, import_helper +from data_diff.utils import match_regexps + +from .database_types import ( + Timestamp, + TimestampTZ, + Integer, + Float, + Text, + FractionalType, + DbPath, + DbTime, + Decimal, + ColType, + ColType_UUID, + TemporalType, +) +from .base import BaseDialect, Database, import_helper, ThreadLocalInterpreter from .base import ( MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, @@ -11,6 +25,15 @@ ) +def query_cursor(c, sql_code): + c.execute(sql_code) + if sql_code.lower().startswith("select"): + return c.fetchall() + # Required for the query to actually run 🀯 + if re.match(r"(insert|create|truncate|drop|explain)", sql_code, re.IGNORECASE): + return c.fetchone() + + @import_helper("presto") def import_presto(): import prestodb @@ -18,8 +41,9 @@ def import_presto(): return prestodb -class Presto(Database): - default_schema = "public" +class Dialect(BaseDialect): + name = "Presto" + ROUNDS_ON_PREC_LOSS = True TYPE_CLASSES = { # Timestamps "timestamp with time zone": TimestampTZ, @@ -33,48 +57,34 @@ class Presto(Database): # Text "varchar": Text, } - ROUNDS_ON_PREC_LOSS = True - def __init__(self, **kw): - prestodb = import_presto() + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN (FORMAT TEXT) {query}" - if kw.get("schema"): - self.default_schema = kw.get("schema") + def type_repr(self, t) -> str: + try: + return {float: "REAL"}[t] + except KeyError: + return super().type_repr(t) - if kw.get("auth") == "basic": # if auth=basic, add basic authenticator for Presto - kw["auth"] = prestodb.auth.BasicAuthentication(kw.pop("user"), kw.pop("password")) - - if "cert" in kw: # if a certificate was specified in URI, verify session with cert - cert = kw.pop("cert") - self._conn = prestodb.dbapi.connect(**kw) - self._conn._http_session.verify = cert - else: - self._conn = prestodb.dbapi.connect(**kw) + def timestamp_value(self, t: DbTime) -> str: + return f"timestamp '{t.isoformat(' ')}'" def quote(self, s: str): return f'"{s}"' - def md5_to_int(self, s: str) -> str: + def md5_as_int(self, s: str) -> str: return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))" def to_string(self, s: str): return f"cast({s} as varchar)" - def _query(self, sql_code: str) -> list: - "Uses the standard SQL cursor interface" - c = self._conn.cursor() - c.execute(sql_code) - if sql_code.lower().startswith("select"): - return c.fetchall() - # Required for the query to actually run 🀯 - if re.match(r"(insert|create|truncate|drop)", sql_code, re.IGNORECASE): - return c.fetchone() - - def close(self): - self._conn.close() + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + # Trim doesn't work on CHAR type + return f"TRIM(CAST({value} AS VARCHAR))" def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - # TODO + # TODO rounds if coltype.rounds: s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" else: @@ -85,22 +95,14 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: def normalize_number(self, value: str, coltype: FractionalType) -> str: return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))") - def select_table_schema(self, path: DbPath) -> str: - schema, table = self._normalize_table_path(path) - - return ( - "SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision " - "FROM INFORMATION_SCHEMA.COLUMNS " - f"WHERE table_name = '{table}' AND table_schema = '{schema}'" - ) - - def _parse_type( + def parse_type( self, table_path: DbPath, col_name: str, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, + numeric_scale: int = None, ) -> ColType: timestamp_regexps = { r"timestamp\((\d)\)": Timestamp, @@ -119,8 +121,50 @@ def _parse_type( for m, n_cls in match_regexps(string_regexps, type_repr): return n_cls() - return super()._parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) + return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - # Trim doesn't work on CHAR type - return f"TRIM(CAST({value} AS VARCHAR))" + +class Presto(Database): + dialect = Dialect() + default_schema = "public" + + def __init__(self, **kw): + prestodb = import_presto() + + if kw.get("schema"): + self.default_schema = kw.get("schema") + + if kw.get("auth") == "basic": # if auth=basic, add basic authenticator for Presto + kw["auth"] = prestodb.auth.BasicAuthentication(kw.pop("user"), kw.pop("password")) + + if "cert" in kw: # if a certificate was specified in URI, verify session with cert + cert = kw.pop("cert") + self._conn = prestodb.dbapi.connect(**kw) + self._conn._http_session.verify = cert + else: + self._conn = prestodb.dbapi.connect(**kw) + + def _query(self, sql_code: str) -> list: + "Uses the standard SQL cursor interface" + c = self._conn.cursor() + + if isinstance(sql_code, ThreadLocalInterpreter): + return sql_code.apply_queries(partial(query_cursor, c)) + + return query_cursor(c, sql_code) + + def close(self): + self._conn.close() + + def select_table_schema(self, path: DbPath) -> str: + schema, table = self._normalize_table_path(path) + + return ( + "SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision, NULL as numeric_scale " + "FROM INFORMATION_SCHEMA.COLUMNS " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + ) + + @property + def is_autocommit(self) -> bool: + return False diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index a512c123..8113df2e 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -1,15 +1,17 @@ -from .database_types import * -from .postgresql import PostgreSQL, MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS +from typing import List +from .database_types import Float, TemporalType, FractionalType, DbPath +from .postgresql import PostgreSQL, MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS, PostgresqlDialect -class Redshift(PostgreSQL): +class Dialect(PostgresqlDialect): + name = "Redshift" TYPE_CLASSES = { - **PostgreSQL.TYPE_CLASSES, + **PostgresqlDialect.TYPE_CLASSES, "double": Float, "real": Float, } - def md5_to_int(self, s: str) -> str: + def md5_as_int(self, s: str) -> str: return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: @@ -35,10 +37,17 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: def normalize_number(self, value: str, coltype: FractionalType) -> str: return self.to_string(f"{value}::decimal(38,{coltype.precision})") - def concat(self, l: List[str]) -> str: - joined_exprs = " || ".join(l) + def concat(self, items: List[str]) -> str: + joined_exprs = " || ".join(items) return f"({joined_exprs})" + def is_distinct_from(self, a: str, b: str) -> str: + return f"{a} IS NULL AND NOT {b} IS NULL OR {b} IS NULL OR {a}!={b}" + + +class Redshift(PostgreSQL): + dialect = Dialect() + def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 9b03d833..5ab5705b 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,7 +1,8 @@ +from typing import Union, List import logging -from .database_types import * -from .base import ConnectError, Database, import_helper, _query_conn, CHECKSUM_MASK +from .database_types import Timestamp, TimestampTZ, Decimal, Float, Text, FractionalType, TemporalType, DbPath +from .base import BaseDialect, ConnectError, Database, import_helper, CHECKSUM_MASK, ThreadLocalInterpreter @import_helper("snowflake") @@ -13,7 +14,9 @@ def import_snowflake(): return snowflake, serialization, default_backend -class Snowflake(Database): +class Dialect(BaseDialect): + name = "Snowflake" + ROUNDS_ON_PREC_LOSS = False TYPE_CLASSES = { # Timestamps "TIMESTAMP_NTZ": Timestamp, @@ -25,7 +28,33 @@ class Snowflake(Database): # Text "TEXT": Text, } - ROUNDS_ON_PREC_LOSS = False + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN USING TEXT {query}" + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))" + else: + timestamp = f"cast({value} as timestamp({coltype.precision}))" + + return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") + + def quote(self, s: str): + return f'"{s}"' + + def md5_as_int(self, s: str) -> str: + return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK})" + + def to_string(self, s: str): + return f"cast({s} as string)" + + +class Snowflake(Database): + dialect = Dialect() def __init__(self, *, schema: str, **kw): snowflake, serialization, default_backend = import_snowflake() @@ -60,30 +89,17 @@ def __init__(self, *, schema: str, **kw): def close(self): self._conn.close() - def _query(self, sql_code: str) -> list: + def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): "Uses the standard SQL cursor interface" - return _query_conn(self._conn, sql_code) - - def quote(self, s: str): - return f'"{s}"' - - def md5_to_int(self, s: str) -> str: - return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK})" - - def to_string(self, s: str): - return f"cast({s} as string)" + return self._query_conn(self._conn, sql_code) def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) return super().select_table_schema((schema, table)) - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))" - else: - timestamp = f"cast({value} as timestamp({coltype.precision}))" - - return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" + @property + def is_autocommit(self) -> bool: + return True - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") + def query_table_unique_columns(self, path: DbPath) -> List[str]: + return [] diff --git a/data_diff/databases/trino.py b/data_diff/databases/trino.py index c3e3e581..a7b0ef8c 100644 --- a/data_diff/databases/trino.py +++ b/data_diff/databases/trino.py @@ -1,5 +1,5 @@ -from .database_types import * -from .presto import Presto +from .database_types import TemporalType, ColType_UUID +from .presto import Presto, Dialect from .base import import_helper from .base import TIMESTAMP_PRECISION_POS @@ -11,11 +11,8 @@ def import_trino(): return trino -class Trino(Presto): - def __init__(self, **kw): - trino = import_trino() - - self._conn = trino.dbapi.connect(**kw) +class Dialect(Dialect): + name = "Trino" def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -29,3 +26,12 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: return f"TRIM({value})" + + +class Trino(Presto): + dialect = Dialect() + + def __init__(self, **kw): + trino = import_trino() + + self._conn = trino.dbapi.connect(**kw) diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index 78a52363..7852800a 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -5,6 +5,7 @@ CHECKSUM_HEXDIGITS, MD5_HEXDIGITS, TIMESTAMP_PRECISION_POS, + BaseDialect, ConnectError, DbPath, ColType, @@ -12,7 +13,16 @@ ThreadedDatabase, import_helper, ) -from .database_types import Decimal, Float, FractionalType, Integer, TemporalType, Text, Timestamp, TimestampTZ +from .database_types import ( + Decimal, + Float, + FractionalType, + Integer, + TemporalType, + Text, + Timestamp, + TimestampTZ, +) @import_helper("vertica") @@ -22,8 +32,9 @@ def import_vertica(): return vertica_python -class Vertica(ThreadedDatabase): - default_schema = "public" +class Dialect(BaseDialect): + name = "Vertica" + ROUNDS_ON_PREC_LOSS = True TYPE_CLASSES = { # Timestamps @@ -38,23 +49,38 @@ class Vertica(ThreadedDatabase): "varchar": Text, } - ROUNDS_ON_PREC_LOSS = True + def quote(self, s: str): + return f'"{s}"' - def __init__(self, *, thread_count, **kw): - self._args = kw - self._args["AUTOCOMMIT"] = False + def concat(self, items: List[str]) -> str: + return " || ".join(items) - super().__init__(thread_count=thread_count) + def md5_as_int(self, s: str) -> str: + return f"CAST(HEX_TO_INTEGER(SUBSTRING(MD5({s}), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS})) AS NUMERIC(38, 0))" - def create_connection(self): - vertica = import_vertica() - try: - c = vertica.connect(**self._args) - return c - except vertica.errors.ConnectionError as e: - raise ConnectError(*e.args) from e + def to_string(self, s: str) -> str: + return f"CAST({s} AS VARCHAR)" - def _parse_type( + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + return f"TO_CHAR({value}::TIMESTAMP({coltype.precision}), 'YYYY-MM-DD HH24:MI:SS.US')" + + timestamp6 = f"TO_CHAR({value}::TIMESTAMP(6), 'YYYY-MM-DD HH24:MI:SS.US')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + return self.to_string(f"CAST({value} AS NUMERIC(38, {coltype.precision}))") + + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + # Trim doesn't work on CHAR type + return f"TRIM(CAST({value} AS VARCHAR))" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"not ({a} <=> {b})" + + def parse_type( self, table_path: DbPath, col_name: str, @@ -85,41 +111,32 @@ def _parse_type( for m, n_cls in match_regexps(string_regexps, type_repr): return n_cls() - return super()._parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) + return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) - def select_table_schema(self, path: DbPath) -> str: - schema, table = self._normalize_table_path(path) - return ( - "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " - "FROM V_CATALOG.COLUMNS " - f"WHERE table_name = '{table}' AND table_schema = '{schema}'" - ) - - def quote(self, s: str): - return f'"{s}"' +class Vertica(ThreadedDatabase): + dialect = Dialect() + default_schema = "public" - def concat(self, l: List[str]) -> str: - return " || ".join(l) + def __init__(self, *, thread_count, **kw): + self._args = kw + self._args["AUTOCOMMIT"] = False - def md5_to_int(self, s: str) -> str: - return f"CAST(HEX_TO_INTEGER(SUBSTRING(MD5({s}), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS})) AS NUMERIC(38, 0))" + super().__init__(thread_count=thread_count) - def to_string(self, s: str) -> str: - return f"CAST({s} AS VARCHAR)" + def create_connection(self): + vertica = import_vertica() + try: + c = vertica.connect(**self._args) + return c + except vertica.errors.ConnectionError as e: + raise ConnectError(*e.args) from e - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - return f"TO_CHAR({value}::TIMESTAMP({coltype.precision}), 'YYYY-MM-DD HH24:MI:SS.US')" + def select_table_schema(self, path: DbPath) -> str: + schema, table = self._normalize_table_path(path) - timestamp6 = f"TO_CHAR({value}::TIMESTAMP(6), 'YYYY-MM-DD HH24:MI:SS.US')" return ( - f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " + "FROM V_CATALOG.COLUMNS " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) - - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"CAST({value} AS NUMERIC(38, {coltype.precision}))") - - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - # Trim doesn't work on CHAR type - return f"TRIM(CAST({value} AS VARCHAR))" diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 5efee0c7..bf30cd9a 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -1,73 +1,83 @@ """Provides classes for performing a table diff """ +import re import time -import os -from numbers import Number -from operator import attrgetter, methodcaller -from collections import defaultdict +from abc import ABC, abstractmethod +from enum import Enum +from contextlib import contextmanager +from operator import methodcaller from typing import Tuple, Iterator, Optional -import logging from concurrent.futures import ThreadPoolExecutor, as_completed from runtype import dataclass -from .utils import safezip, run_as_daemon +from .utils import run_as_daemon, safezip, getLogger from .thread_utils import ThreadedYielder -from .databases.database_types import IKey, NumericType, PrecisionType, StringType, ColType_UUID from .table_segment import TableSegment from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled +from .databases.database_types import IKey -logger = logging.getLogger("diff_tables") +logger = getLogger(__name__) -BENCHMARK = os.environ.get("BENCHMARK", False) -DEFAULT_BISECTION_THRESHOLD = 1024 * 16 -DEFAULT_BISECTION_FACTOR = 32 +class Algorithm(Enum): + AUTO = "auto" + JOINDIFF = "joindiff" + HASHDIFF = "hashdiff" -def diff_sets(a: set, b: set) -> Iterator: - s1 = set(a) - s2 = set(b) - d = defaultdict(list) - # The first item is always the key (see TableDiffer._relevant_columns) - for i in s1 - s2: - d[i[0]].append(("-", i)) - for i in s2 - s1: - d[i[0]].append(("+", i)) - - for _k, v in sorted(d.items(), key=lambda i: i[0]): - yield from v +DiffResult = Iterator[Tuple[str, tuple]] # Iterator[Tuple[Literal["+", "-"], tuple]] -DiffResult = Iterator[Tuple[str, tuple]] # Iterator[Tuple[Literal["+", "-"], tuple]] +def truncate_error(error: str): + first_line = error.split("\n", 1)[0] + return re.sub("'(.*?)'", "'***'", first_line) @dataclass -class TableDiffer: - """Finds the diff between two SQL tables +class ThreadBase: + "Provides utility methods for optional threading" - The algorithm uses hashing to quickly check if the tables are different, and then applies a - bisection search recursively to find the differences efficiently. + threaded: bool = True + max_threadpool_size: Optional[int] = 1 - Works best for comparing tables that are mostly the same, with minor discrepencies. + def _thread_map(self, func, iterable): + if not self.threaded: + return map(func, iterable) - Parameters: - bisection_factor (int): Into how many segments to bisect per iteration. - bisection_threshold (Number): When should we stop bisecting and compare locally (in row count). - threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. - max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``. - There may be many pools, so number of actual threads can be a lot higher. - """ + with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool: + return task_pool.map(func, iterable) - bisection_factor: int = DEFAULT_BISECTION_FACTOR - bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests - threaded: bool = True - max_threadpool_size: Optional[int] = 1 + def _threaded_call(self, func, iterable): + "Calls a method for each object in iterable." + return list(self._thread_map(methodcaller(func), iterable)) - # Enable/disable debug prints - debug: bool = False + def _thread_as_completed(self, func, iterable): + if not self.threaded: + yield from map(func, iterable) + return + + with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool: + futures = [task_pool.submit(func, item) for item in iterable] + for future in as_completed(futures): + yield future.result() + def _threaded_call_as_completed(self, func, iterable): + "Calls a method for each object in iterable. Returned in order of completion." + return self._thread_as_completed(methodcaller(func), iterable) + + @contextmanager + def _run_in_background(self, *funcs): + with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool: + futures = [task_pool.submit(f) for f in funcs if f is not None] + yield futures + for f in futures: + f.result() + + +class TableDiffer(ThreadBase, ABC): + bisection_factor = 32 stats: dict = {} def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: @@ -78,19 +88,15 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: table2 (TableSegment): The "after" table to compare. Or: target table Returns: - An iterator that yield pair-tuples, representing the diff. Items can be either - ('-', columns) for items in table1 but not in table2 - ('+', columns) for items in table2 but not in table1 - Where `columns` is a tuple of values for the involved columns, i.e. (id, ...extra) + An iterator that yield pair-tuples, representing the diff. Items can be either - + ('-', row) for items in table1 but not in table2. + ('+', row) for items in table2 but not in table1. + Where `row` is a tuple of values, corresponding to the diffed columns. """ - # Validate options - if self.bisection_factor >= self.bisection_threshold: - raise ValueError("Incorrect param values (bisection factor must be lower than threshold)") - if self.bisection_factor < 2: - raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)") if is_tracking_enabled(): options = dict(self) + options["differ_name"] = type(self).__name__ event_json = create_start_event_json(options) run_as_daemon(send_event_json, event_json) @@ -103,44 +109,7 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: table1, table2 = self._threaded_call("with_schema", [table1, table2]) self._validate_and_adjust_columns(table1, table2) - key_type = table1._schema[table1.key_column] - key_type2 = table2._schema[table2.key_column] - if not isinstance(key_type, IKey): - raise NotImplementedError(f"Cannot use column of type {key_type} as a key") - if not isinstance(key_type2, IKey): - raise NotImplementedError(f"Cannot use column of type {key_type2} as a key") - assert key_type.python_type is key_type2.python_type - - # Query min/max values - key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2]) - - # Start with the first completed value, so we don't waste time waiting - min_key1, max_key1 = self._parse_key_range_result(key_type, next(key_ranges)) - - table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)] - - logger.info( - f"Diffing tables | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}. " - f"key-range: {table1.min_key}..{table2.max_key}, " - f"size: table1 <= {table1.approximate_size()}, table2 <= {table2.approximate_size()}" - ) - - ti = ThreadedYielder(self.max_threadpool_size) - # Bisect (split) the table into segments, and diff them recursively. - ti.submit(self._bisect_and_diff_tables, ti, table1, table2) - - # Now we check for the second min-max, to diff the portions we "missed". - min_key2, max_key2 = self._parse_key_range_result(key_type, next(key_ranges)) - - if min_key2 < min_key1: - pre_tables = [t.new(min_key=min_key2, max_key=min_key1) for t in (table1, table2)] - ti.submit(self._bisect_and_diff_tables, ti, *pre_tables) - - if max_key2 > max_key1: - post_tables = [t.new(min_key=max_key1, max_key=max_key2) for t in (table1, table2)] - ti.submit(self._bisect_and_diff_tables, ti, *post_tables) - - yield from ti + yield from self._diff_tables(table1, table2) except BaseException as e: # Catch KeyboardInterrupt too error = e @@ -150,7 +119,7 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: table1_count = self.stats.get("table1_count") table2_count = self.stats.get("table2_count") diff_count = self.stats.get("diff_count") - err_message = str(error)[:20] # Truncate possibly sensitive information. + err_message = truncate_error(repr(error)) event_json = create_end_event_json( error is None, runtime, @@ -166,97 +135,87 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: if error: raise error - def _parse_key_range_result(self, key_type, key_range): - mn, mx = key_range - cls = key_type.make_value - # We add 1 because our ranges are exclusive of the end (like in Python) - try: - return cls(mn), cls(mx) + 1 - except (TypeError, ValueError) as e: - raise type(e)(f"Cannot apply {key_type} to {mn}, {mx}.") from e + def _validate_and_adjust_columns(self, table1: TableSegment, table2: TableSegment) -> DiffResult: + pass - def _validate_and_adjust_columns(self, table1, table2): - for c1, c2 in safezip(table1._relevant_columns, table2._relevant_columns): - if c1 not in table1._schema: - raise ValueError(f"Column '{c1}' not found in schema for table {table1}") - if c2 not in table2._schema: - raise ValueError(f"Column '{c2}' not found in schema for table {table2}") + def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: + return self._bisect_and_diff_tables(table1, table2) - # Update schemas to minimal mutual precision - col1 = table1._schema[c1] - col2 = table2._schema[c2] - if isinstance(col1, PrecisionType): - if not isinstance(col2, PrecisionType): - raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") + @abstractmethod + def _diff_segments( + self, + ti: ThreadedYielder, + table1: TableSegment, + table2: TableSegment, + max_rows: int, + level=0, + segment_index=None, + segment_count=None, + ): + ... - lowest = min(col1, col2, key=attrgetter("precision")) + def _bisect_and_diff_tables(self, table1, table2): + if len(table1.key_columns) > 1: + raise NotImplementedError("Composite key not supported yet!") + if len(table2.key_columns) > 1: + raise NotImplementedError("Composite key not supported yet!") + (key1,) = table1.key_columns + (key2,) = table2.key_columns - if col1.precision != col2.precision: - logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") + key_type = table1._schema[key1] + key_type2 = table2._schema[key2] + if not isinstance(key_type, IKey): + raise NotImplementedError(f"Cannot use column of type {key_type} as a key") + if not isinstance(key_type2, IKey): + raise NotImplementedError(f"Cannot use column of type {key_type2} as a key") + assert key_type.python_type is key_type2.python_type - table1._schema[c1] = col1.replace(precision=lowest.precision, rounds=lowest.rounds) - table2._schema[c2] = col2.replace(precision=lowest.precision, rounds=lowest.rounds) + # Query min/max values + key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2]) - elif isinstance(col1, NumericType): - if not isinstance(col2, NumericType): - raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") + # Start with the first completed value, so we don't waste time waiting + min_key1, max_key1 = self._parse_key_range_result(key_type, next(key_ranges)) - lowest = min(col1, col2, key=attrgetter("precision")) + table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)] - if col1.precision != col2.precision: - logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") + logger.info( + f"Diffing segments at key-range: {table1.min_key}..{table2.max_key}. " + f"size: table1 <= {table1.approximate_size()}, table2 <= {table2.approximate_size()}" + ) + + ti = ThreadedYielder(self.max_threadpool_size) + # Bisect (split) the table into segments, and diff them recursively. + ti.submit(self._bisect_and_diff_segments, ti, table1, table2) - table1._schema[c1] = col1.replace(precision=lowest.precision) - table2._schema[c2] = col2.replace(precision=lowest.precision) + # Now we check for the second min-max, to diff the portions we "missed". + min_key2, max_key2 = self._parse_key_range_result(key_type, next(key_ranges)) - elif isinstance(col1, ColType_UUID): - if not isinstance(col2, ColType_UUID): - raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") + if min_key2 < min_key1: + pre_tables = [t.new(min_key=min_key2, max_key=min_key1) for t in (table1, table2)] + ti.submit(self._bisect_and_diff_segments, ti, *pre_tables) - elif isinstance(col1, StringType): - if not isinstance(col2, StringType): - raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") + if max_key2 > max_key1: + post_tables = [t.new(min_key=max_key1, max_key=max_key2) for t in (table1, table2)] + ti.submit(self._bisect_and_diff_segments, ti, *post_tables) - for t in [table1, table2]: - for c in t._relevant_columns: - ctype = t._schema[c] - if not ctype.supported: - logger.warning( - f"[{t.database.name}] Column '{c}' of type '{ctype}' has no compatibility handling. " - "If encoding/formatting differs between databases, it may result in false positives." - ) + return ti + + def _parse_key_range_result(self, key_type, key_range): + mn, mx = key_range + cls = key_type.make_value + # We add 1 because our ranges are exclusive of the end (like in Python) + try: + return cls(mn), cls(mx) + 1 + except (TypeError, ValueError) as e: + raise type(e)(f"Cannot apply {key_type} to '{mn}', '{mx}'.") from e - def _bisect_and_diff_tables( + def _bisect_and_diff_segments( self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None ): assert table1.is_bounded and table2.is_bounded - max_space_size = max(table1.approximate_size(), table2.approximate_size()) - if max_rows is None: - # We can be sure that row_count <= max_rows iff the table key is unique - max_rows = max_space_size - - # If count is below the threshold, just download and compare the columns locally - # This saves time, as bisection speed is limited by ping and query performance. - if max_rows < self.bisection_threshold or max_space_size < self.bisection_factor * 2: - rows1, rows2 = self._threaded_call("get_values", [table1, table2]) - diff = list(diff_sets(rows1, rows2)) - - # Initial bisection_threshold larger than count. Normally we always - # checksum and count segments, even if we get the values. At the - # first level, however, that won't be true. - if level == 0: - self.stats["table1_count"] = len(rows1) - self.stats["table2_count"] = len(rows2) - - self.stats["diff_count"] += len(diff) - - logger.info(". " * level + f"Diff found {len(diff)} different rows.") - self.stats["rows_downloaded"] = self.stats.get("rows_downloaded", 0) + max(len(rows1), len(rows2)) - return diff - # Choose evenly spaced checkpoints (according to min_key and max_key) - biggest_table = max(table1, table2, key=methodcaller('approximate_size')) + biggest_table = max(table1, table2, key=methodcaller("approximate_size")) checkpoints = biggest_table.choose_checkpoints(self.bisection_factor - 1) # Create new instances of TableSegment between each checkpoint @@ -265,70 +224,4 @@ def _bisect_and_diff_tables( # Recursively compare each pair of corresponding segments between table1 and table2 for i, (t1, t2) in enumerate(safezip(segmented1, segmented2)): - ti.submit(self._diff_tables, ti, t1, t2, max_rows, level + 1, i + 1, len(segmented1), priority=level) - - def _diff_tables( - self, - ti: ThreadedYielder, - table1: TableSegment, - table2: TableSegment, - max_rows: int, - level=0, - segment_index=None, - segment_count=None, - ): - logger.info( - ". " * level + f"Diffing segment {segment_index}/{segment_count}, " - f"key-range: {table1.min_key}..{table2.max_key}, " - f"size <= {max_rows}" - ) - - # When benchmarking, we want the ability to skip checksumming. This - # allows us to download all rows for comparison in performance. By - # default, data-diff will checksum the section first (when it's below - # the threshold) and _then_ download it. - if BENCHMARK: - if max_rows < self.bisection_threshold: - return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max_rows) - - (count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2]) - - if count1 == 0 and count2 == 0: - # logger.warning( - # f"Uneven distribution of keys detected in segment {table1.min_key}..{table2.max_key}. (big gaps in the key column). " - # "For better performance, we recommend to increase the bisection-threshold." - # ) - assert checksum1 is None and checksum2 is None - return - - if level == 1: - self.stats["table1_count"] = self.stats.get("table1_count", 0) + count1 - self.stats["table2_count"] = self.stats.get("table2_count", 0) + count2 - - if checksum1 != checksum2: - return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max(count1, count2)) - - def _thread_map(self, func, iterable): - if not self.threaded: - return map(func, iterable) - - with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool: - return task_pool.map(func, iterable) - - def _threaded_call(self, func, iterable): - "Calls a method for each object in iterable." - return list(self._thread_map(methodcaller(func), iterable)) - - def _thread_as_completed(self, func, iterable): - if not self.threaded: - yield from map(func, iterable) - return - - with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool: - futures = [task_pool.submit(func, item) for item in iterable] - for future in as_completed(futures): - yield future.result() - - def _threaded_call_as_completed(self, func, iterable): - "Calls a method for each object in iterable. Returned in order of completion." - return self._thread_as_completed(methodcaller(func), iterable) + ti.submit(self._diff_segments, ti, t1, t2, max_rows, level + 1, i + 1, len(segmented1), priority=level) diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py new file mode 100644 index 00000000..38e6fee5 --- /dev/null +++ b/data_diff/hashdiff_tables.py @@ -0,0 +1,193 @@ +import os +from numbers import Number +import logging +from collections import defaultdict +from typing import Iterator +from operator import attrgetter + +from runtype import dataclass + +from .utils import safezip +from .thread_utils import ThreadedYielder +from .databases.database_types import ColType_UUID, NumericType, PrecisionType, StringType +from .table_segment import TableSegment + +from .diff_tables import TableDiffer + +BENCHMARK = os.environ.get("BENCHMARK", False) + +DEFAULT_BISECTION_THRESHOLD = 1024 * 16 +DEFAULT_BISECTION_FACTOR = 32 + +logger = logging.getLogger("hashdiff_tables") + + +def diff_sets(a: set, b: set) -> Iterator: + s1 = set(a) + s2 = set(b) + d = defaultdict(list) + + # The first item is always the key (see TableDiffer.relevant_columns) + for i in s1 - s2: + d[i[0]].append(("-", i)) + for i in s2 - s1: + d[i[0]].append(("+", i)) + + for _k, v in sorted(d.items(), key=lambda i: i[0]): + yield from v + + +@dataclass +class HashDiffer(TableDiffer): + """Finds the diff between two SQL tables + + The algorithm uses hashing to quickly check if the tables are different, and then applies a + bisection search recursively to find the differences efficiently. + + Works best for comparing tables that are mostly the same, with minor discrepencies. + + Parameters: + bisection_factor (int): Into how many segments to bisect per iteration. + bisection_threshold (Number): When should we stop bisecting and compare locally (in row count). + threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. + max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. + Only relevant when `threaded` is ``True``. + There may be many pools, so number of actual threads can be a lot higher. + """ + + bisection_factor: int = DEFAULT_BISECTION_FACTOR + bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests + + stats: dict = {} + + def __post_init__(self): + # Validate options + if self.bisection_factor >= self.bisection_threshold: + raise ValueError("Incorrect param values (bisection factor must be lower than threshold)") + if self.bisection_factor < 2: + raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)") + + def _validate_and_adjust_columns(self, table1, table2): + for c1, c2 in safezip(table1.relevant_columns, table2.relevant_columns): + if c1 not in table1._schema: + raise ValueError(f"Column '{c1}' not found in schema for table {table1}") + if c2 not in table2._schema: + raise ValueError(f"Column '{c2}' not found in schema for table {table2}") + + # Update schemas to minimal mutual precision + col1 = table1._schema[c1] + col2 = table2._schema[c2] + if isinstance(col1, PrecisionType): + if not isinstance(col2, PrecisionType): + raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") + + lowest = min(col1, col2, key=attrgetter("precision")) + + if col1.precision != col2.precision: + logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") + + table1._schema[c1] = col1.replace(precision=lowest.precision, rounds=lowest.rounds) + table2._schema[c2] = col2.replace(precision=lowest.precision, rounds=lowest.rounds) + + elif isinstance(col1, NumericType): + if not isinstance(col2, NumericType): + raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") + + lowest = min(col1, col2, key=attrgetter("precision")) + + if col1.precision != col2.precision: + logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") + + table1._schema[c1] = col1.replace(precision=lowest.precision) + table2._schema[c2] = col2.replace(precision=lowest.precision) + + elif isinstance(col1, ColType_UUID): + if not isinstance(col2, ColType_UUID): + raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") + + elif isinstance(col1, StringType): + if not isinstance(col2, StringType): + raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") + + for t in [table1, table2]: + for c in t.relevant_columns: + ctype = t._schema[c] + if not ctype.supported: + logger.warning( + f"[{t.database.name}] Column '{c}' of type '{ctype}' has no compatibility handling. " + "If encoding/formatting differs between databases, it may result in false positives." + ) + + def _diff_segments( + self, + ti: ThreadedYielder, + table1: TableSegment, + table2: TableSegment, + max_rows: int, + level=0, + segment_index=None, + segment_count=None, + ): + logger.info( + ". " * level + f"Diffing segment {segment_index}/{segment_count}, " + f"key-range: {table1.min_key}..{table2.max_key}, " + f"size <= {max_rows}" + ) + + # When benchmarking, we want the ability to skip checksumming. This + # allows us to download all rows for comparison in performance. By + # default, data-diff will checksum the section first (when it's below + # the threshold) and _then_ download it. + if BENCHMARK: + if max_rows < self.bisection_threshold: + return self._bisect_and_diff_segments(ti, table1, table2, level=level, max_rows=max_rows) + + (count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2]) + + if count1 == 0 and count2 == 0: + logger.debug( + "Uneven distribution of keys detected in segment %s..%s (big gaps in the key column). " + "For better performance, we recommend to increase the bisection-threshold.", + table1.min_key, + table1.max_key, + ) + assert checksum1 is None and checksum2 is None + return + + if level == 1: + self.stats["table1_count"] = self.stats.get("table1_count", 0) + count1 + self.stats["table2_count"] = self.stats.get("table2_count", 0) + count2 + + if checksum1 != checksum2: + return self._bisect_and_diff_segments(ti, table1, table2, level=level, max_rows=max(count1, count2)) + + def _bisect_and_diff_segments( + self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None + ): + assert table1.is_bounded and table2.is_bounded + + max_space_size = max(table1.approximate_size(), table2.approximate_size()) + if max_rows is None: + # We can be sure that row_count <= max_rows iff the table key is unique + max_rows = max_space_size + + # If count is below the threshold, just download and compare the columns locally + # This saves time, as bisection speed is limited by ping and query performance. + if max_rows < self.bisection_threshold or max_space_size < self.bisection_factor * 2: + rows1, rows2 = self._threaded_call("get_values", [table1, table2]) + diff = list(diff_sets(rows1, rows2)) + + # Initial bisection_threshold larger than count. Normally we always + # checksum and count segments, even if we get the values. At the + # first level, however, that won't be true. + if level == 0: + self.stats["table1_count"] = len(rows1) + self.stats["table2_count"] = len(rows2) + + self.stats["diff_count"] += len(diff) + + logger.info(". " * level + f"Diff found {len(diff)} different rows.") + self.stats["rows_downloaded"] = self.stats.get("rows_downloaded", 0) + max(len(rows1), len(rows2)) + return diff + + return super()._bisect_and_diff_segments(ti, table1, table2, level, max_rows) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py new file mode 100644 index 00000000..62a70508 --- /dev/null +++ b/data_diff/joindiff_tables.py @@ -0,0 +1,325 @@ +"""Provides classes for performing a table diff using JOIN + +""" + +from decimal import Decimal +from functools import partial +import logging +from typing import List + +from runtype import dataclass + +from .databases.database_types import DbPath, NumericType +from .query_utils import append_to_table, drop_table + + +from .utils import safezip +from .databases.base import Database +from .databases import MySQL, BigQuery, Presto, Oracle, Snowflake +from .table_segment import TableSegment +from .diff_tables import TableDiffer, DiffResult +from .thread_utils import ThreadedYielder + +from .queries import table, sum_, min_, max_, avg +from .queries.api import and_, if_, or_, outerjoin, leftjoin, rightjoin, this, ITable +from .queries.ast_classes import Concat, Count, Expr, Random, TablePath +from .queries.compiler import Compiler +from .queries.extras import NormalizeAsString + +logger = logging.getLogger("joindiff_tables") + +TABLE_WRITE_LIMIT = 1000 + + +def merge_dicts(dicts): + i = iter(dicts) + try: + res = next(i) + except StopIteration: + return {} + + for d in i: + res.update(d) + return res + + +def sample(table_expr): + return table_expr.order_by(Random()).limit(10) + + +def create_temp_table(c: Compiler, path: TablePath, expr: Expr) -> str: + db = c.database + if isinstance(db, BigQuery): + return f"create table {c.compile(path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}" + elif isinstance(db, Presto): + return f"create table {c.compile(path)} as {c.compile(expr)}" + elif isinstance(db, Oracle): + return f"create global temporary table {c.compile(path)} as {c.compile(expr)}" + else: + return f"create temporary table {c.compile(path)} as {c.compile(expr)}" + + +def bool_to_int(x): + return if_(x, 1, 0) + + +def _outerjoin(db: Database, a: ITable, b: ITable, keys1: List[str], keys2: List[str], select_fields: dict) -> ITable: + on = [a[k1] == b[k2] for k1, k2 in safezip(keys1, keys2)] + + is_exclusive_a = and_(b[k] == None for k in keys2) + is_exclusive_b = and_(a[k] == None for k in keys1) + if isinstance(db, Oracle): + is_exclusive_a = bool_to_int(is_exclusive_a) + is_exclusive_b = bool_to_int(is_exclusive_b) + + if isinstance(db, MySQL): + # No outer join + l = leftjoin(a, b).on(*on).select(is_exclusive_a=is_exclusive_a, is_exclusive_b=False, **select_fields) + r = rightjoin(a, b).on(*on).select(is_exclusive_a=False, is_exclusive_b=is_exclusive_b, **select_fields) + return l.union(r) + + return outerjoin(a, b).on(*on).select(is_exclusive_a=is_exclusive_a, is_exclusive_b=is_exclusive_b, **select_fields) + + +def _slice_tuple(t, *sizes): + i = 0 + for size in sizes: + yield t[i : i + size] + i += size + assert i == len(t) + + +def json_friendly_value(v): + if isinstance(v, Decimal): + return float(v) + return v + + +@dataclass +class JoinDiffer(TableDiffer): + """Finds the diff between two SQL tables in the same database, using JOINs. + + The algorithm uses an OUTER JOIN (or equivalent) with extra checks and statistics. + The two tables must reside in the same database, and their primary keys must be unique and not null. + + All parameters are optional. + + Parameters: + threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. + max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. + Only relevant when `threaded` is ``True``. + There may be many pools, so number of actual threads can be a lot higher. + validate_unique_key (bool): Enable/disable validating that the key columns are unique. + Single query, and can't be threaded, so it's very slow on non-cloud dbs. + Future versions will detect UNIQUE constraints in the schema. + sample_exclusive_rows (bool): Enable/disable sampling of exclusive rows. Creates a temporary table. + materialize_to_table (DbPath, optional): Path of new table to write diff results to. Disabled if not provided. + table_write_limit (int): Maximum number of rows to write when materializing, per thread. + """ + + validate_unique_key: bool = True + sample_exclusive_rows: bool = True + materialize_to_table: DbPath = None + materialize_all_rows: bool = False + table_write_limit: int = TABLE_WRITE_LIMIT + stats: dict = {} + + def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: + db = table1.database + + if table1.database is not table2.database: + raise ValueError("Join-diff only works when both tables are in the same database") + + table1, table2 = self._threaded_call("with_schema", [table1, table2]) + + bg_funcs = [partial(self._test_duplicate_keys, table1, table2)] if self.validate_unique_key else [] + if self.materialize_to_table: + drop_table(db, self.materialize_to_table) + + with self._run_in_background(*bg_funcs): + + if isinstance(db, (Snowflake, BigQuery)): + # Don't segment the table; let the database handling parallelization + yield from self._diff_segments(None, table1, table2, None) + else: + yield from self._bisect_and_diff_tables(table1, table2) + logger.info("Diffing complete") + + def _diff_segments( + self, + ti: ThreadedYielder, + table1: TableSegment, + table2: TableSegment, + max_rows: int, + level=0, + segment_index=None, + segment_count=None, + ): + assert table1.database is table2.database + + if segment_index or table1.min_key or max_rows: + logger.info( + ". " * level + f"Diffing segment {segment_index}/{segment_count}, " + f"key-range: {table1.min_key}..{table2.max_key}, " + f"size <= {max_rows}" + ) + + db = table1.database + diff_rows, a_cols, b_cols, is_diff_cols, all_rows = self._create_outer_join(table1, table2) + + with self._run_in_background( + partial(self._collect_stats, 1, table1), + partial(self._collect_stats, 2, table2), + partial(self._test_null_keys, table1, table2), + partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols), + partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols), + partial( + self._materialize_diff, + db, + all_rows if self.materialize_all_rows else diff_rows, + segment_index=segment_index, + ) + if self.materialize_to_table + else None, + ): + + logger.debug("Querying for different rows") + for is_xa, is_xb, *x in db.query(diff_rows, list): + if is_xa and is_xb: + # Can't both be exclusive, meaning a pk is NULL + # This can happen if the explicit null test didn't finish running yet + raise ValueError("NULL values in one or more primary keys") + _is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols)) + if not is_xb: + yield "-", tuple(a_row) + if not is_xa: + yield "+", tuple(b_row) + + def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment): + logger.debug("Testing for duplicate keys") + + # Test duplicate keys + for ts in [table1, table2]: + unique = ( + ts.database.query_table_unique_columns(ts.table_path) if ts.database.SUPPORTS_UNIQUE_CONSTAINT else [] + ) + + t = ts.make_select() + key_columns = ts.key_columns + + unvalidated = list(set(key_columns) - set(unique)) + if unvalidated: + # Validate that there are no duplicate keys + self.stats["validated_unique_keys"] = self.stats.get("validated_unique_keys", []) + [unvalidated] + q = t.select(total=Count(), total_distinct=Count(Concat(this[unvalidated]), distinct=True)) + total, total_distinct = ts.database.query(q, tuple) + if total != total_distinct: + raise ValueError("Duplicate primary keys") + + def _test_null_keys(self, table1, table2): + logger.debug("Testing for null keys") + + # Test null keys + for ts in [table1, table2]: + t = ts.make_select() + key_columns = ts.key_columns + + q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns)) + nulls = ts.database.query(q, list) + if nulls: + raise ValueError("NULL values in one or more primary keys") + + def _collect_stats(self, i, table_seg: TableSegment): + logger.info(f"Collecting stats for table #{i}") + db = table_seg.database + + # Metrics + col_exprs = merge_dicts( + { + f"sum_{c}": sum_(this[c]), + f"avg_{c}": avg(this[c]), + f"min_{c}": min_(this[c]), + f"max_{c}": max_(this[c]), + } + for c in table_seg.relevant_columns + if isinstance(table_seg._schema[c], NumericType) + ) + col_exprs["count"] = Count() + + res = db.query(table_seg.make_select().select(**col_exprs), tuple) + res = dict(zip([f"table{i}_{n}" for n in col_exprs], map(json_friendly_value, res))) + for k, v in res.items(): + self.stats[k] = self.stats.get(k, 0) + (v or 0) + + logger.debug("Done collecting stats for table #%s", i) + + def _create_outer_join(self, table1, table2): + db = table1.database + if db is not table2.database: + raise ValueError("Joindiff only applies to tables within the same database") + + keys1 = table1.key_columns + keys2 = table2.key_columns + if len(keys1) != len(keys2): + raise ValueError("The provided key columns are of a different count") + + cols1 = table1.relevant_columns + cols2 = table2.relevant_columns + if len(cols1) != len(cols2): + raise ValueError("The provided columns are of a different count") + + a = table1.make_select() + b = table2.make_select() + + is_diff_cols = {f"is_diff_{c1}": bool_to_int(a[c1].is_distinct_from(b[c2])) for c1, c2 in safezip(cols1, cols2)} + + a_cols = {f"table1_{c}": NormalizeAsString(a[c]) for c in cols1} + b_cols = {f"table2_{c}": NormalizeAsString(b[c]) for c in cols2} + + all_rows = _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols}) + diff_rows = all_rows.where(or_(this[c] == 1 for c in is_diff_cols)) + return diff_rows, a_cols, b_cols, is_diff_cols, all_rows + + def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): + logger.info("Counting differences per column") + is_diff_cols_counts = db.query(diff_rows.select(sum_(this[c]) for c in is_diff_cols), tuple) + diff_counts = {} + for name, count in safezip(cols, is_diff_cols_counts): + diff_counts[name] = diff_counts.get(name, 0) + (count or 0) + self.stats["diff_counts"] = diff_counts + + def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): + if isinstance(db, Oracle): + exclusive_rows_query = diff_rows.where((this.is_exclusive_a == 1) | (this.is_exclusive_b == 1)) + else: + exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b) + + if not self.sample_exclusive_rows: + logger.info("Counting exclusive rows") + self.stats["exclusive_count"] = db.query(exclusive_rows_query.count(), int) + return + + logger.info("Counting and sampling exclusive rows") + + def exclusive_rows(expr): + c = Compiler(db) + name = c.new_unique_table_name("temp_table") + exclusive_rows = table(name, schema=expr.source_table.schema) + yield create_temp_table(c, exclusive_rows, expr.limit(self.table_write_limit)) + + count = yield exclusive_rows.count() + self.stats["exclusive_count"] = self.stats.get("exclusive_count", 0) + count[0][0] + sample_rows = yield sample(exclusive_rows.select(*this[list(a_cols)], *this[list(b_cols)])) + self.stats["exclusive_sample"] = self.stats.get("exclusive_sample", []) + sample_rows + + # Only drops if create table succeeded (meaning, the table didn't already exist) + yield exclusive_rows.drop() + + # Run as a sequence of thread-local queries (compiled into a ThreadLocalInterpreter) + db.query(exclusive_rows(exclusive_rows_query), None) + + def _materialize_diff(self, db, diff_rows, segment_index=None): + assert self.materialize_to_table + + append_to_table(db, self.materialize_to_table, diff_rows.limit(self.table_write_limit)) + logger.info("Materialized diff to table '%s'.", ".".join(self.materialize_to_table)) diff --git a/data_diff/queries/__init__.py b/data_diff/queries/__init__.py new file mode 100644 index 00000000..172e73e4 --- /dev/null +++ b/data_diff/queries/__init__.py @@ -0,0 +1,4 @@ +from .compiler import Compiler +from .api import this, join, outerjoin, table, SKIP, sum_, avg, min_, max_, cte, commit +from .ast_classes import Expr, ExprNode, Select, Count, BinOp, Explain, In +from .extras import Checksum, NormalizeAsString, ApplyFuncAndNormalizeAsString diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py new file mode 100644 index 00000000..797fafa5 --- /dev/null +++ b/data_diff/queries/api.py @@ -0,0 +1,80 @@ +from typing import Optional + +from data_diff.utils import CaseAwareMapping, CaseSensitiveDict +from .ast_classes import * +from .base import args_as_tuple + + +this = This() + + +def join(*tables: ITable): + "Joins each table into a 'struct'" + return Join(tables) + + +def leftjoin(*tables: ITable): + "Left-joins each table into a 'struct'" + return Join(tables, "LEFT") + + +def rightjoin(*tables: ITable): + "Right-joins each table into a 'struct'" + return Join(tables, "RIGHT") + + +def outerjoin(*tables: ITable): + "Outer-joins each table into a 'struct'" + return Join(tables, "FULL OUTER") + + +def cte(expr: Expr, *, name: Optional[str] = None, params: Sequence[str] = None): + return Cte(expr, name, params) + + +def table(*path: str, schema: Union[dict, CaseAwareMapping] = None) -> TablePath: + if len(path) == 1 and isinstance(path[0], tuple): + (path,) = path + if not all(isinstance(i, str) for i in path): + raise TypeError(f"All elements of table path must be of type 'str'. Got: {path}") + if schema and not isinstance(schema, CaseAwareMapping): + assert isinstance(schema, dict) + schema = CaseSensitiveDict(schema) + return TablePath(path, schema) + + +def or_(*exprs: Expr): + exprs = args_as_tuple(exprs) + if len(exprs) == 1: + return exprs[0] + return BinBoolOp("OR", exprs) + + +def and_(*exprs: Expr): + exprs = args_as_tuple(exprs) + if len(exprs) == 1: + return exprs[0] + return BinBoolOp("AND", exprs) + + +def sum_(expr: Expr): + return Func("sum", [expr]) + + +def avg(expr: Expr): + return Func("avg", [expr]) + + +def min_(expr: Expr): + return Func("min", [expr]) + + +def max_(expr: Expr): + return Func("max", [expr]) + + +def if_(cond: Expr, then: Expr, else_: Optional[Expr] = None): + return CaseWhen([(cond, then)], else_=else_) + + +commit = Commit() diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py new file mode 100644 index 00000000..13f33193 --- /dev/null +++ b/data_diff/queries/ast_classes.py @@ -0,0 +1,714 @@ +from dataclasses import field +from datetime import datetime +from typing import Any, Generator, List, Optional, Sequence, Tuple, Union + +from runtype import dataclass + +from data_diff.utils import ArithString, join_iter + +from .compiler import Compilable, Compiler, cv_params +from .base import SKIP, CompileError, DbPath, Schema, args_as_tuple + + +class ExprNode(Compilable): + type: Any = None + + def _dfs_values(self): + yield self + for k, vs in dict(self).items(): # __dict__ provided by runtype.dataclass + if k == "source_table": + # Skip data-sources, we're only interested in data-parameters + continue + if not isinstance(vs, (list, tuple)): + vs = [vs] + for v in vs: + if isinstance(v, ExprNode): + yield from v._dfs_values() + + def cast_to(self, to): + return Cast(self, to) + + +Expr = Union[ExprNode, str, bool, int, datetime, ArithString, None] + + +def _expr_type(e: Expr) -> type: + if isinstance(e, ExprNode): + return e.type + return type(e) + + +@dataclass +class Alias(ExprNode): + expr: Expr + name: str + + def compile(self, c: Compiler) -> str: + return f"{c.compile(self.expr)} AS {c.quote(self.name)}" + + @property + def type(self): + return _expr_type(self.expr) + + +def _drop_skips(exprs): + return [e for e in exprs if e is not SKIP] + + +def _drop_skips_dict(exprs_dict): + return {k: v for k, v in exprs_dict.items() if v is not SKIP} + + +class ITable: + source_table: Any + schema: Schema = None + + def select(self, *exprs, **named_exprs): + exprs = args_as_tuple(exprs) + exprs = _drop_skips(exprs) + named_exprs = _drop_skips_dict(named_exprs) + exprs += _named_exprs_as_aliases(named_exprs) + resolve_names(self.source_table, exprs) + return Select.make(self, columns=exprs) + + def where(self, *exprs): + exprs = args_as_tuple(exprs) + exprs = _drop_skips(exprs) + if not exprs: + return self + + resolve_names(self.source_table, exprs) + return Select.make(self, where_exprs=exprs, _concat=True) + + def order_by(self, *exprs): + exprs = _drop_skips(exprs) + if not exprs: + return self + + resolve_names(self.source_table, exprs) + return Select.make(self, order_by_exprs=exprs) + + def limit(self, limit: int): + if limit is SKIP: + return self + + return Select.make(self, limit_expr=limit) + + def at(self, *exprs): + # TODO + exprs = _drop_skips(exprs) + if not exprs: + return self + + raise NotImplementedError() + + def join(self, target): + return Join(self, target) + + def group_by(self, *, keys=None, values=None): + # TODO + assert keys or values + raise NotImplementedError() + + def with_schema(self): + # TODO + raise NotImplementedError() + + def _get_column(self, name: str): + if self.schema: + name = self.schema.get_key(name) # Get the actual name. Might be case-insensitive. + return Column(self, name) + + # def __getattr__(self, column): + # return self._get_column(column) + + def __getitem__(self, column): + if not isinstance(column, str): + raise TypeError() + return self._get_column(column) + + def count(self): + return Select(self, [Count()]) + + def union(self, other: "ITable"): + return SetUnion(self, other) + + +@dataclass +class Concat(ExprNode): + exprs: list + sep: str = None + + def compile(self, c: Compiler) -> str: + # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL + items = [f"coalesce({c.compile(c.dialect.to_string(c.compile(expr)))}, '')" for expr in self.exprs] + assert items + if len(items) == 1: + return items[0] + + if self.sep: + items = list(join_iter(f"'{self.sep}'", items)) + return c.dialect.concat(items) + + +@dataclass +class Count(ExprNode): + expr: Expr = "*" + distinct: bool = False + + type = int + + def compile(self, c: Compiler) -> str: + expr = c.compile(self.expr) + if self.distinct: + return f"count(distinct {expr})" + + return f"count({expr})" + + +@dataclass +class Func(ExprNode): + name: str + args: Sequence[Expr] + + def compile(self, c: Compiler) -> str: + args = ", ".join(c.compile(e) for e in self.args) + return f"{self.name}({args})" + + +@dataclass +class CaseWhen(ExprNode): + cases: Sequence[Tuple[Expr, Expr]] + else_: Expr = None + + def compile(self, c: Compiler) -> str: + assert self.cases + when_thens = " ".join(f"WHEN {c.compile(when)} THEN {c.compile(then)}" for when, then in self.cases) + else_ = (" ELSE " + c.compile(self.else_)) if self.else_ is not None else "" + return f"CASE {when_thens}{else_} END" + + @property + def type(self): + when_types = {_expr_type(w) for _c, w in self.cases} + if self.else_: + when_types |= _expr_type(self.else_) + if len(when_types) > 1: + raise RuntimeError(f"Non-matching types in when: {when_types}") + (t,) = when_types + return t + + +class LazyOps: + def __add__(self, other): + return BinOp("+", [self, other]) + + def __gt__(self, other): + return BinBoolOp(">", [self, other]) + + def __ge__(self, other): + return BinBoolOp(">=", [self, other]) + + def __eq__(self, other): + if other is None: + return BinBoolOp("IS", [self, None]) + return BinBoolOp("=", [self, other]) + + def __lt__(self, other): + return BinBoolOp("<", [self, other]) + + def __le__(self, other): + return BinBoolOp("<=", [self, other]) + + def __or__(self, other): + return BinBoolOp("OR", [self, other]) + + def __and__(self, other): + return BinBoolOp("AND", [self, other]) + + def is_distinct_from(self, other): + return IsDistinctFrom(self, other) + + def sum(self): + return Func("SUM", [self]) + + +@dataclass(eq=False, order=False) +class IsDistinctFrom(ExprNode, LazyOps): + a: Expr + b: Expr + type = bool + + def compile(self, c: Compiler) -> str: + return c.dialect.is_distinct_from(c.compile(self.a), c.compile(self.b)) + + +@dataclass(eq=False, order=False) +class BinOp(ExprNode, LazyOps): + op: str + args: Sequence[Expr] + + def compile(self, c: Compiler) -> str: + expr = f" {self.op} ".join(c.compile(a) for a in self.args) + return f"({expr})" + + @property + def type(self): + types = {_expr_type(i) for i in self.args} + if len(types) > 1: + raise TypeError(f"Expected all args to have the same type, got {types}") + (t,) = types + return t + + +class BinBoolOp(BinOp): + type = bool + + +@dataclass(eq=False, order=False) +class Column(ExprNode, LazyOps): + source_table: ITable + name: str + + @property + def type(self): + if self.source_table.schema is None: + raise RuntimeError(f"Schema required for table {self.source_table}") + return self.source_table.schema[self.name] + + def compile(self, c: Compiler) -> str: + if c._table_context: + if len(c._table_context) > 1: + aliases = [ + t for t in c._table_context if isinstance(t, TableAlias) and t.source_table is self.source_table + ] + if not aliases: + return c.quote(self.name) + elif len(aliases) > 1: + raise CompileError(f"Too many aliases for column {self.name}") + (alias,) = aliases + + return f"{c.quote(alias.name)}.{c.quote(self.name)}" + + return c.quote(self.name) + + +@dataclass +class TablePath(ExprNode, ITable): + path: DbPath + schema: Optional[Schema] = field(default=None, repr=False) + + @property + def source_table(self): + return self + + def compile(self, c: Compiler) -> str: + path = self.path # c.database._normalize_table_path(self.name) + return ".".join(map(c.quote, path)) + + # Statement shorthands + def create(self, source_table: ITable = None, *, if_not_exists=False, primary_keys=None): + + if source_table is None and not self.schema: + raise ValueError("Either schema or source table needed to create table") + if isinstance(source_table, TablePath): + source_table = source_table.select() + return CreateTable(self, source_table, if_not_exists=if_not_exists, primary_keys=primary_keys) + + def drop(self, if_exists=False): + return DropTable(self, if_exists=if_exists) + + def truncate(self): + return TruncateTable(self) + + def insert_rows(self, rows, *, columns=None): + rows = list(rows) + return InsertToTable(self, ConstantTable(rows), columns=columns) + + def insert_row(self, *values, columns=None): + return InsertToTable(self, ConstantTable([values]), columns=columns) + + def insert_expr(self, expr: Expr): + if isinstance(expr, TablePath): + expr = expr.select() + return InsertToTable(self, expr) + + +@dataclass +class TableAlias(ExprNode, ITable): + source_table: ITable + name: str + + def compile(self, c: Compiler) -> str: + return f"{c.compile(self.source_table)} {c.quote(self.name)}" + + +@dataclass +class Join(ExprNode, ITable): + source_tables: Sequence[ITable] + op: str = None + on_exprs: Sequence[Expr] = None + columns: Sequence[Expr] = None + + @property + def source_table(self): + return self # TODO is this right? + + @property + def schema(self): + assert self.columns # TODO Implement SELECT * + s = self.source_tables[0].schema # XXX + return type(s)({c.name: c.type for c in self.columns}) + + def on(self, *exprs): + if len(exprs) == 1: + (e,) = exprs + if isinstance(e, Generator): + exprs = tuple(e) + + exprs = _drop_skips(exprs) + if not exprs: + return self + + return self.replace(on_exprs=(self.on_exprs or []) + exprs) + + def select(self, *exprs, **named_exprs): + if self.columns is not None: + # join-select already applied + return super().select(*exprs, **named_exprs) + + exprs = _drop_skips(exprs) + named_exprs = _drop_skips_dict(named_exprs) + exprs += _named_exprs_as_aliases(named_exprs) + # resolve_names(self.source_table, exprs) + # TODO Ensure exprs <= self.columns ? + return self.replace(columns=exprs) + + def compile(self, parent_c: Compiler) -> str: + tables = [ + t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in self.source_tables + ] + c = parent_c.add_table_context(*tables, in_join=True, in_select=False) + op = " JOIN " if self.op is None else f" {self.op} JOIN " + joined = op.join(c.compile(t) for t in tables) + + if self.on_exprs: + on = " AND ".join(c.compile(e) for e in self.on_exprs) + res = f"{joined} ON {on}" + else: + res = joined + + columns = "*" if self.columns is None else ", ".join(map(c.compile, self.columns)) + select = f"SELECT {columns} FROM {res}" + + if parent_c.in_select: + select = f"({select}) {c.new_unique_name()}" + elif parent_c.in_join: + select = f"({select})" + return select + + +class GroupBy(ITable): + def having(self): + raise NotImplementedError() + + +@dataclass +class SetUnion(ExprNode, ITable): + table1: ITable + table2: ITable + + @property + def source_table(self): + return self # TODO is this right? + + @property + def type(self): + return self.table1.type + + @property + def schema(self): + s1 = self.table1.schema + s2 = self.table2.schema + assert len(s1) == len(s2) + return s1 + + def compile(self, parent_c: Compiler) -> str: + c = parent_c.replace(in_select=False) + union = f"{c.compile(self.table1)} UNION {c.compile(self.table2)}" + if parent_c.in_select: + union = f"({union}) {c.new_unique_name()}" + elif parent_c.in_join: + union = f"({union})" + return union + + +@dataclass +class Select(ExprNode, ITable): + table: Expr = None + columns: Sequence[Expr] = None + where_exprs: Sequence[Expr] = None + order_by_exprs: Sequence[Expr] = None + group_by_exprs: Sequence[Expr] = None + limit_expr: int = None + + @property + def schema(self): + s = self.table.schema + if s is None or self.columns is None: + return s + return type(s)({c.name: c.type for c in self.columns}) + + @property + def source_table(self): + return self + + def compile(self, parent_c: Compiler) -> str: + c = parent_c.replace(in_select=True) # .add_table_context(self.table) + + columns = ", ".join(map(c.compile, self.columns)) if self.columns else "*" + select = f"SELECT {columns}" + + if self.table: + select += " FROM " + c.compile(self.table) + + if self.where_exprs: + select += " WHERE " + " AND ".join(map(c.compile, self.where_exprs)) + + if self.group_by_exprs: + select += " GROUP BY " + ", ".join(map(c.compile, self.group_by_exprs)) + + if self.order_by_exprs: + select += " ORDER BY " + ", ".join(map(c.compile, self.order_by_exprs)) + + if self.limit_expr is not None: + select += " " + c.dialect.offset_limit(0, self.limit_expr) + + if parent_c.in_select: + select = f"({select}) {c.new_unique_name()}" + elif parent_c.in_join: + select = f"({select})" + return select + + @classmethod + def make(cls, table: ITable, _concat: bool = False, **kwargs): + if not isinstance(table, cls): + return cls(table, **kwargs) + + # Fill in missing attributes, instead of creating a new instance. + for k, v in kwargs.items(): + if getattr(table, k) is not None: + if _concat: + kwargs[k] = getattr(table, k) + v + else: + raise ValueError("...") + + return table.replace(**kwargs) + + +@dataclass +class Cte(ExprNode, ITable): + source_table: Expr + name: str = None + params: Sequence[str] = None + + def compile(self, parent_c: Compiler) -> str: + c = parent_c.replace(_table_context=[], in_select=False) + compiled = c.compile(self.source_table) + + name = self.name or parent_c.new_unique_name() + name_params = f"{name}({', '.join(self.params)})" if self.params else name + parent_c._subqueries[name_params] = compiled + + return name + + @property + def schema(self): + # TODO add cte to schema + return self.source_table.schema + + +def _named_exprs_as_aliases(named_exprs): + return [Alias(expr, name) for name, expr in named_exprs.items()] + + +def resolve_names(source_table, exprs): + i = 0 + for expr in exprs: + # Iterate recursively and update _ResolveColumn with the right expression + if isinstance(expr, ExprNode): + for v in expr._dfs_values(): + if isinstance(v, _ResolveColumn): + v.resolve(source_table._get_column(v.resolve_name)) + i += 1 + + +@dataclass(frozen=False, eq=False, order=False) +class _ResolveColumn(ExprNode, LazyOps): + resolve_name: str + resolved: Expr = None + + def resolve(self, expr: Expr): + if self.resolved is not None: + raise RuntimeError("Already resolved!") + self.resolved = expr + + def _get_resolved(self) -> Expr: + if self.resolved is None: + raise RuntimeError(f"Column not resolved: {self.resolve_name}") + return self.resolved + + def compile(self, c: Compiler) -> str: + return self._get_resolved().compile(c) + + @property + def type(self): + return self._get_resolved().type + + @property + def name(self): + return self._get_resolved().name + + +class This: + def __getattr__(self, name): + return _ResolveColumn(name) + + def __getitem__(self, name): + if isinstance(name, (list, tuple)): + return [_ResolveColumn(n) for n in name] + return _ResolveColumn(name) + + +@dataclass +class In(ExprNode): + expr: Expr + list: Sequence[Expr] + + type = bool + + def compile(self, c: Compiler): + elems = ", ".join(map(c.compile, self.list)) + return f"({c.compile(self.expr)} IN ({elems}))" + + +@dataclass +class Cast(ExprNode): + expr: Expr + target_type: Expr + + def compile(self, c: Compiler) -> str: + return f"cast({c.compile(self.expr)} as {c.compile(self.target_type)})" + + +@dataclass +class Random(ExprNode): + type = float + + def compile(self, c: Compiler) -> str: + return c.dialect.random() + + +@dataclass +class ConstantTable(ExprNode): + rows: Sequence[Sequence] + + def compile(self, c: Compiler) -> str: + raise NotImplementedError() + + def compile_for_insert(self, c: Compiler): + return c.dialect.constant_values(self.rows) + + +@dataclass +class Explain(ExprNode): + select: Select + + type = str + + def compile(self, c: Compiler) -> str: + return c.dialect.explain_as_text(c.compile(self.select)) + + +# DDL + + +class Statement(Compilable): + type = None + + +@dataclass +class CreateTable(Statement): + path: TablePath + source_table: Expr = None + if_not_exists: bool = False + primary_keys: List[str] = None + + def compile(self, c: Compiler) -> str: + ne = "IF NOT EXISTS " if self.if_not_exists else "" + if self.source_table: + return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}" + + schema = ", ".join(f"{c.dialect.quote(k)} {c.dialect.type_repr(v)}" for k, v in self.path.schema.items()) + pks = ( + ", PRIMARY KEY (%s)" % ", ".join(self.primary_keys) + if self.primary_keys and c.dialect.SUPPORTS_PRIMARY_KEY + else "" + ) + return f"CREATE TABLE {ne}{c.compile(self.path)}({schema}{pks})" + + +@dataclass +class DropTable(Statement): + path: TablePath + if_exists: bool = False + + def compile(self, c: Compiler) -> str: + ie = "IF EXISTS " if self.if_exists else "" + return f"DROP TABLE {ie}{c.compile(self.path)}" + + +@dataclass +class TruncateTable(Statement): + path: TablePath + + def compile(self, c: Compiler) -> str: + return f"TRUNCATE TABLE {c.compile(self.path)}" + + +@dataclass +class InsertToTable(Statement): + # TODO Support insert for only some columns + path: TablePath + expr: Expr + columns: List[str] = None + + def compile(self, c: Compiler) -> str: + if isinstance(self.expr, ConstantTable): + expr = self.expr.compile_for_insert(c) + else: + expr = c.compile(self.expr) + + columns = f"(%s)" % ", ".join(map(c.quote, self.columns)) if self.columns is not None else "" + + return f"INSERT INTO {c.compile(self.path)}{columns} {expr}" + + +@dataclass +class Commit(Statement): + def compile(self, c: Compiler) -> str: + return "COMMIT" if not c.database.is_autocommit else SKIP + + +@dataclass +class Param(ExprNode, ITable): + """A value placeholder, to be specified at compilation time using the `cv_params` context variable.""" + + name: str + + @property + def source_table(self): + return self + + def compile(self, c: Compiler) -> str: + params = cv_params.get() + return c._compile(params[self.name]) diff --git a/data_diff/queries/base.py b/data_diff/queries/base.py new file mode 100644 index 00000000..7b0d96cb --- /dev/null +++ b/data_diff/queries/base.py @@ -0,0 +1,23 @@ +from typing import Generator + +from data_diff.databases.database_types import DbPath, DbKey, Schema + + +class _SKIP: + def __repr__(self): + return "SKIP" + + +SKIP = _SKIP() + + +class CompileError(Exception): + pass + + +def args_as_tuple(exprs): + if len(exprs) == 1: + (e,) = exprs + if isinstance(e, Generator): + return tuple(e) + return exprs diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py new file mode 100644 index 00000000..0a4d1d6f --- /dev/null +++ b/data_diff/queries/compiler.py @@ -0,0 +1,79 @@ +import random +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, Sequence, List + +from runtype import dataclass + +from data_diff.utils import ArithString +from data_diff.databases.database_types import AbstractDatabase, AbstractDialect, DbPath + +import contextvars + +cv_params = contextvars.ContextVar("params") + + +@dataclass +class Compiler: + database: AbstractDatabase + params: dict = {} + in_select: bool = False # Compilation runtime flag + in_join: bool = False # Compilation runtime flag + + _table_context: List = [] # List[ITable] + _subqueries: Dict[str, Any] = {} # XXX not thread-safe + root: bool = True + + _counter: List = [0] + + @property + def dialect(self) -> AbstractDialect: + return self.database.dialect + + def compile(self, elem, params=None) -> str: + if params: + cv_params.set(params) + + res = self._compile(elem) + if self.root and self._subqueries: + subq = ", ".join(f"\n {k} AS ({v})" for k, v in self._subqueries.items()) + self._subqueries.clear() + return f"WITH {subq}\n{res}" + return res + + def _compile(self, elem) -> str: + if elem is None: + return "NULL" + elif isinstance(elem, Compilable): + return elem.compile(self.replace(root=False)) + elif isinstance(elem, str): + return elem + elif isinstance(elem, int): + return str(elem) + elif isinstance(elem, datetime): + return self.dialect.timestamp_value(elem) + elif isinstance(elem, bytes): + return f"b'{elem.decode()}'" + elif isinstance(elem, ArithString): + return f"'{elem}'" + assert False, elem + + def new_unique_name(self, prefix="tmp"): + self._counter[0] += 1 + return f"{prefix}{self._counter[0]}" + + def new_unique_table_name(self, prefix="tmp") -> DbPath: + self._counter[0] += 1 + return self.database.parse_table_name(f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}") + + def add_table_context(self, *tables: Sequence, **kw): + return self.replace(_table_context=self._table_context + list(tables), **kw) + + def quote(self, s: str): + return self.dialect.quote(s) + + +class Compilable(ABC): + @abstractmethod + def compile(self, c: Compiler) -> str: + ... diff --git a/data_diff/queries/extras.py b/data_diff/queries/extras.py new file mode 100644 index 00000000..32d31ce9 --- /dev/null +++ b/data_diff/queries/extras.py @@ -0,0 +1,62 @@ +"Useful AST classes that don't quite fall within the scope of regular SQL" + +from typing import Callable, Sequence +from runtype import dataclass + +from data_diff.databases.database_types import ColType, Native_UUID + +from .compiler import Compiler +from .ast_classes import Expr, ExprNode, Concat + + +@dataclass +class NormalizeAsString(ExprNode): + expr: ExprNode + expr_type: ColType = None + type = str + + def compile(self, c: Compiler) -> str: + expr = c.compile(self.expr) + return c.dialect.normalize_value_by_type(expr, self.expr_type or self.expr.type) + + +@dataclass +class ApplyFuncAndNormalizeAsString(ExprNode): + expr: ExprNode + apply_func: Callable = None + + def compile(self, c: Compiler) -> str: + expr = self.expr + expr_type = expr.type + + if isinstance(expr_type, Native_UUID): + # Normalize first, apply template after (for uuids) + # Needed because min/max(uuid) fails in postgresql + expr = NormalizeAsString(expr, expr_type) + if self.apply_func is not None: + expr = self.apply_func(expr) # Apply template using Python's string formatting + + else: + # Apply template before normalizing (for ints) + if self.apply_func is not None: + expr = self.apply_func(expr) # Apply template using Python's string formatting + expr = NormalizeAsString(expr, expr_type) + + return c.compile(expr) + + +@dataclass +class Checksum(ExprNode): + exprs: Sequence[Expr] + + def compile(self, c: Compiler): + if len(self.exprs) > 1: + exprs = [f"coalesce({c.compile(expr)}, '')" for expr in self.exprs] + # exprs = [c.compile(e) for e in exprs] + expr = Concat(exprs, "|") + else: + # No need to coalesce - safe to assume that key cannot be null + (expr,) = self.exprs + expr = c.compile(expr) + md5 = c.dialect.md5_as_int(expr) + return f"sum({md5})" diff --git a/data_diff/query_utils.py b/data_diff/query_utils.py new file mode 100644 index 00000000..825dbdc3 --- /dev/null +++ b/data_diff/query_utils.py @@ -0,0 +1,57 @@ +"Module for query utilities that didn't make it into the query-builder (yet)" + +from contextlib import suppress + +from data_diff.databases.database_types import DbPath +from data_diff.databases.base import QueryError + +from .databases import Oracle +from .queries import table, commit, Expr + + +def _drop_table_oracle(name: DbPath): + t = table(name) + # Experience shows double drop is necessary + with suppress(QueryError): + yield t.drop() + yield t.drop() + yield commit + + +def _drop_table(name: DbPath): + t = table(name) + yield t.drop(if_exists=True) + yield commit + + +def drop_table(db, tbl): + if isinstance(db, Oracle): + db.query(_drop_table_oracle(tbl)) + else: + db.query(_drop_table(tbl)) + + +def _append_to_table_oracle(path: DbPath, expr: Expr): + """See append_to_table""" + assert expr.schema, expr + t = table(path, schema=expr.schema) + with suppress(QueryError): + yield t.create() # uses expr.schema + yield commit + yield t.insert_expr(expr) + yield commit + + +def _append_to_table(path: DbPath, expr: Expr): + """Append to table""" + assert expr.schema, expr + t = table(path, schema=expr.schema) + yield t.create(if_not_exists=True) # uses expr.schema + yield commit + yield t.insert_expr(expr) + yield commit + + +def append_to_table(db, path, expr): + f = _append_to_table_oracle if isinstance(db, Oracle) else _append_to_table + db.query(f(path, expr)) diff --git a/data_diff/sql.py b/data_diff/sql.py deleted file mode 100644 index 46332797..00000000 --- a/data_diff/sql.py +++ /dev/null @@ -1,198 +0,0 @@ -"""Provides classes for a pseudo-SQL AST that compiles to SQL code -""" - -from typing import Sequence, Union, Optional -from datetime import datetime - -from runtype import dataclass - -from .utils import join_iter, ArithString - -from .databases.database_types import AbstractDatabase, DbPath - - -class Sql: - pass - - -SqlOrStr = Union[Sql, str] - -CONCAT_SEP = "|" - - -@dataclass -class Compiler: - """Provides a set of utility methods for compiling SQL - - For internal use. - """ - - database: AbstractDatabase - in_select: bool = False # Compilation - - def quote(self, s: str): - return self.database.quote(s) - - def compile(self, elem): - if isinstance(elem, Sql): - return elem.compile(self) - elif isinstance(elem, str): - return elem - elif isinstance(elem, int): - return str(elem) - assert False - - -@dataclass -class TableName(Sql): - name: DbPath - - def compile(self, c: Compiler): - path = c.database._normalize_table_path(self.name) - return ".".join(map(c.quote, path)) - - -@dataclass -class ColumnName(Sql): - name: str - - def compile(self, c: Compiler): - return c.quote(self.name) - - -@dataclass -class Value(Sql): - value: object # Primitive - - def compile(self, c: Compiler): - if isinstance(self.value, bytes): - return f"b'{self.value.decode()}'" - elif isinstance(self.value, str): - return f"'{self.value}'" % self.value - elif isinstance(self.value, ArithString): - return f"'{self.value}'" - return str(self.value) - - -@dataclass -class Select(Sql): - columns: Sequence[SqlOrStr] - table: SqlOrStr = None - where: Sequence[SqlOrStr] = None - order_by: Sequence[SqlOrStr] = None - group_by: Sequence[SqlOrStr] = None - limit: int = None - - def compile(self, parent_c: Compiler): - c = parent_c.replace(in_select=True) - columns = ", ".join(map(c.compile, self.columns)) - select = f"SELECT {columns}" - - if self.table: - select += " FROM " + c.compile(self.table) - - if self.where: - select += " WHERE " + " AND ".join(map(c.compile, self.where)) - - if self.group_by: - select += " GROUP BY " + ", ".join(map(c.compile, self.group_by)) - - if self.order_by: - select += " ORDER BY " + ", ".join(map(c.compile, self.order_by)) - - if self.limit is not None: - select += " " + c.database.offset_limit(0, self.limit) - - if parent_c.in_select: - select = "(%s)" % select - return select - - -@dataclass -class Enum(Sql): - table: DbPath - order_by: SqlOrStr - - def compile(self, c: Compiler): - table = ".".join(map(c.quote, self.table)) - order = c.compile(self.order_by) - return f"(SELECT *, (row_number() over (ORDER BY {order})) as idx FROM {table} ORDER BY {order}) tmp" - - -@dataclass -class Checksum(Sql): - exprs: Sequence[SqlOrStr] - - def compile(self, c: Compiler): - if len(self.exprs) > 1: - compiled_exprs = [f"coalesce({c.compile(expr)}, '')" for expr in self.exprs] - separated = list(join_iter(f"'|'", compiled_exprs)) - expr = c.database.concat(separated) - else: - # No need to coalesce - safe to assume that key cannot be null - (expr,) = self.exprs - expr = c.compile(expr) - md5 = c.database.md5_to_int(expr) - return f"sum({md5})" - - -@dataclass -class Compare(Sql): - op: str - a: SqlOrStr - b: SqlOrStr - - def compile(self, c: Compiler): - return f"({c.compile(self.a)} {self.op} {c.compile(self.b)})" - - -@dataclass -class In(Sql): - expr: SqlOrStr - list: Sequence # List[SqlOrStr] - - def compile(self, c: Compiler): - elems = ", ".join(map(c.compile, self.list)) - return f"({c.compile(self.expr)} IN ({elems}))" - - -@dataclass -class Count(Sql): - column: Optional[SqlOrStr] = None - - def compile(self, c: Compiler): - if self.column: - return f"count({c.compile(self.column)})" - return "count(*)" - - -@dataclass -class Min(Sql): - column: SqlOrStr - - def compile(self, c: Compiler): - return f"min({c.compile(self.column)})" - - -@dataclass -class Max(Sql): - column: SqlOrStr - - def compile(self, c: Compiler): - return f"max({c.compile(self.column)})" - - -@dataclass -class Time(Sql): - time: datetime - - def compile(self, c: Compiler): - return c.database.timestamp_value(self.time) - - -@dataclass -class Explain(Sql): - sql: Select - - def compile(self, c: Compiler): - return f"EXPLAIN {c.compile(self.sql)}" diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 8b95458f..cddbe9f5 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -4,15 +4,15 @@ from runtype import dataclass -from .utils import ArithString, split_space, ArithAlphanumeric - +from .utils import ArithString, split_space from .databases.base import Database -from .databases.database_types import DbPath, DbKey, DbTime, Native_UUID, Schema, create_schema -from .sql import Select, Checksum, Compare, Count, TableName, Time, Value +from .databases.database_types import DbPath, DbKey, DbTime, Schema, create_schema +from .queries import Count, Checksum, SKIP, table, this, Expr, min_, max_ +from .queries.extras import ApplyFuncAndNormalizeAsString, NormalizeAsString logger = logging.getLogger("table_segment") -RECOMMENDED_CHECKSUM_DURATION = 10 +RECOMMENDED_CHECKSUM_DURATION = 20 @dataclass @@ -22,11 +22,12 @@ class TableSegment: Parameters: database (Database): Database instance. See :meth:`connect` table_path (:data:`DbPath`): Path to table in form of a tuple. e.g. `('my_dataset', 'table_name')` - key_column (str): Name of the key column, which uniquely identifies each row (usually id) - update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update) + key_columns (Tuple[str]): Name of the key column, which uniquely identifies each row (usually id) + update_column (str, optional): Name of updated column, which signals that rows changed. + Usually updated_at or last_update. Used by `min_update` and `max_update`. extra_columns (Tuple[str, ...], optional): Extra columns to compare - min_key (:data:`DbKey`, optional): Lowest key_column value, used to restrict the segment - max_key (:data:`DbKey`, optional): Highest key_column value, used to restrict the segment + min_key (:data:`DbKey`, optional): Lowest key value, used to restrict the segment + max_key (:data:`DbKey`, optional): Highest key value, used to restrict the segment min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment max_update (:data:`DbTime`, optional): Highest update_column value, used to restrict the segment where (str, optional): An additional 'where' expression to restrict the search space. @@ -40,7 +41,7 @@ class TableSegment: table_path: DbPath # Columns - key_column: str + key_columns: Tuple[str, ...] update_column: str = None extra_columns: Tuple[str, ...] = () @@ -66,40 +67,8 @@ def __post_init__(self): f"Error: min_update expected to be smaller than max_update! ({self.min_update} >= {self.max_update})" ) - @property - def _update_column(self): - return self._quote_column(self.update_column) - - def _quote_column(self, c: str) -> str: - if self._schema: - c = self._schema.get_key(c) # Get the actual name. Might be case-insensitive. - return self.database.quote(c) - - def _normalize_column(self, name: str, template: str = None) -> str: - if not self._schema: - raise RuntimeError( - "Cannot compile query when the schema is unknown. Please use TableSegment.with_schema()." - ) - - col_type = self._schema[name] - col = self._quote_column(name) - - if isinstance(col_type, Native_UUID): - # Normalize first, apply template after (for uuids) - # Needed because min/max(uuid) fails in postgresql - col = self.database.normalize_value_by_type(col, col_type) - if template is not None: - col = template % col # Apply template using Python's string formatting - return col - - # Apply template before normalizing (for ints) - if template is not None: - col = template % col # Apply template using Python's string formatting - - return self.database.normalize_value_by_type(col, col_type) - def _with_raw_schema(self, raw_schema: dict) -> "TableSegment": - schema = self.database._process_table_schema(self.table_path, raw_schema, self._relevant_columns, self.where) + schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self.where) return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive)) def with_schema(self) -> "TableSegment": @@ -111,41 +80,38 @@ def with_schema(self) -> "TableSegment": def _make_key_range(self): if self.min_key is not None: - yield Compare("<=", Value(self.min_key), self._quote_column(self.key_column)) + assert len(self.key_columns) == 1 + (k,) = self.key_columns + yield self.min_key <= this[k] if self.max_key is not None: - yield Compare("<", self._quote_column(self.key_column), Value(self.max_key)) + assert len(self.key_columns) == 1 + (k,) = self.key_columns + yield this[k] < self.max_key def _make_update_range(self): if self.min_update is not None: - yield Compare("<=", Time(self.min_update), self._update_column) + yield self.min_update <= this[self.update_column] if self.max_update is not None: - yield Compare("<", self._update_column, Time(self.max_update)) - - def _make_select(self, *, table=None, columns=None, where=None, group_by=None, order_by=None): - if columns is None: - columns = [self._normalize_column(self.key_column)] - where = [ - *self._make_key_range(), - *self._make_update_range(), - *([] if where is None else [where]), - *([] if self.where is None else [self.where]), - ] - order_by = None if order_by is None else [order_by] - return Select( - table=table or TableName(self.table_path), - where=where, - columns=columns, - group_by=group_by, - order_by=order_by, - ) + yield this[self.update_column] < self.max_update + + @property + def source_table(self): + return table(*self.table_path, schema=self._schema) + + def make_select(self): + return self.source_table.where(*self._make_key_range(), *self._make_update_range(), self.where or SKIP) def get_values(self) -> list: "Download all the relevant values of the segment from the database" - select = self._make_select(columns=self._relevant_columns_repr) + select = self.make_select().select(*self._relevant_columns_repr) return self.database.query(select, List[Tuple]) def choose_checkpoints(self, count: int) -> List[DbKey]: "Suggests a bunch of evenly-spaced checkpoints to split by (not including start, end)" + + if self.max_key - self.min_key <= count: + count = 1 + assert self.is_bounded if isinstance(self.min_key, ArithString): assert type(self.min_key) is type(self.max_key) @@ -176,33 +142,33 @@ def new(self, **kwargs) -> "TableSegment": return self.replace(**kwargs) @property - def _relevant_columns(self) -> List[str]: + def relevant_columns(self) -> List[str]: extras = list(self.extra_columns) if self.update_column and self.update_column not in extras: extras = [self.update_column] + extras - return [self.key_column] + extras + return list(self.key_columns) + extras @property - def _relevant_columns_repr(self) -> List[str]: - return [self._normalize_column(c) for c in self._relevant_columns] + def _relevant_columns_repr(self) -> List[Expr]: + return [NormalizeAsString(this[c]) for c in self.relevant_columns] def count(self) -> Tuple[int, int]: """Count how many rows are in the segment, in one pass.""" - return self.database.query(self._make_select(columns=[Count()]), int) + return self.database.query(self.make_select().select(Count()), int) def count_and_checksum(self) -> Tuple[int, int]: """Count and checksum the rows in the segment, in one pass.""" start = time.monotonic() - count, checksum = self.database.query( - self._make_select(columns=[Count(), Checksum(self._relevant_columns_repr)]), tuple - ) + q = self.make_select().select(Count(), Checksum(self._relevant_columns_repr)) + count, checksum = self.database.query(q, tuple) duration = time.monotonic() - start if duration > RECOMMENDED_CHECKSUM_DURATION: logger.warning( - f"Checksum is taking longer than expected ({duration:.2f}s). " - "We recommend increasing --bisection-factor or decreasing --threads." + "Checksum is taking longer than expected (%.2f). " + "We recommend increasing --bisection-factor or decreasing --threads.", + duration, ) if count: @@ -212,11 +178,10 @@ def count_and_checksum(self) -> Tuple[int, int]: def query_key_range(self) -> Tuple[int, int]: """Query database for minimum and maximum key. This is used for setting the initial bounds.""" # Normalizes the result (needed for UUIDs) after the min/max computation - select = self._make_select( - columns=[ - self._normalize_column(self.key_column, "min(%s)"), - self._normalize_column(self.key_column, "max(%s)"), - ] + (k,) = self.key_columns + select = self.make_select().select( + ApplyFuncAndNormalizeAsString(this[k], min_), + ApplyFuncAndNormalizeAsString(this[k], max_), ) min_key, max_key = self.database.query(select, tuple) diff --git a/data_diff/thread_utils.py b/data_diff/thread_utils.py index 1e0d26b8..1be94ad4 100644 --- a/data_diff/thread_utils.py +++ b/data_diff/thread_utils.py @@ -1,9 +1,9 @@ import itertools -from concurrent.futures.thread import _WorkItem from queue import PriorityQueue from collections import deque from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor +from concurrent.futures.thread import _WorkItem from time import sleep from typing import Callable, Iterator, Optional diff --git a/data_diff/utils.py b/data_diff/utils.py index 5911f8f8..a11c4142 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -1,20 +1,24 @@ +import logging import re import math -from typing import Iterable, Tuple, Union, Any, Sequence, Dict -from typing import TypeVar, Generic -from abc import ABC, abstractmethod +from typing import Iterable, Iterator, MutableMapping, Union, Any, Sequence, Dict +from typing import TypeVar +from abc import abstractmethod from urllib.parse import urlparse from uuid import UUID import operator import string import threading +from datetime import datetime alphanums = " -" + string.digits + string.ascii_uppercase + "_" + string.ascii_lowercase def safezip(*args): "zip but makes sure all sequences are the same length" - assert len(set(map(len, args))) == 1 + lens = list(map(len, args)) + if len(set(lens)) != 1: + raise ValueError(f"Mismatching lengths in arguments to safezip: {lens}") return zip(*args) @@ -62,7 +66,7 @@ def numberToAlphanum(num: int, base: str = alphanums) -> str: return "".join(base[i] for i in digits[::-1]) -def alphanumToNumber(alphanum: str, base: str) -> int: +def alphanumToNumber(alphanum: str, base: str = alphanums) -> int: num = 0 for c in alphanum: num = num * len(base) + base.index(c) @@ -78,8 +82,8 @@ def justify_alphanums(s1: str, s2: str): def alphanums_to_numbers(s1: str, s2: str): s1, s2 = justify_alphanums(s1, s2) - n1 = alphanumToNumber(s1, alphanums) - n2 = alphanumToNumber(s2, alphanums) + n1 = alphanumToNumber(s1) + n2 = alphanumToNumber(s2) return n1, n2 @@ -117,9 +121,9 @@ def __add__(self, other: "Union[ArithAlphanumeric, int]") -> "ArithAlphanumeric" if isinstance(other, int): if other != 1: raise NotImplementedError("not implemented for arbitrary numbers") - lastchar = self._str[-1] if self._str else alphanums[0] - s = self._str[:-1] + alphanums[alphanums.index(lastchar) + other] - return self.new(s) + num = alphanumToNumber(self._str) + return self.new(numberToAlphanum(num + 1)) + return NotImplemented def range(self, other: "ArithAlphanumeric", count: int): @@ -188,7 +192,10 @@ def remove_password_from_url(url: str, replace_with: str = "***") -> str: def join_iter(joiner: Any, iterable: Iterable) -> Iterable: it = iter(iterable) - yield next(it) + try: + yield next(it) + except StopIteration: + return for i in it: yield joiner yield i @@ -197,48 +204,39 @@ def join_iter(joiner: Any, iterable: Iterable) -> Iterable: V = TypeVar("V") -class CaseAwareMapping(ABC, Generic[V]): +class CaseAwareMapping(MutableMapping[str, V]): @abstractmethod def get_key(self, key: str) -> str: ... - @abstractmethod - def __getitem__(self, key: str) -> V: - ... - - @abstractmethod - def __setitem__(self, key: str, value: V): - ... - - @abstractmethod - def __contains__(self, key: str) -> bool: - ... - class CaseInsensitiveDict(CaseAwareMapping): def __init__(self, initial): self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()} - def get_key(self, key: str) -> str: - return self._dict[key.lower()][0] - def __getitem__(self, key: str) -> V: return self._dict[key.lower()][1] + def __iter__(self) -> Iterator[V]: + return iter(self._dict) + + def __len__(self) -> int: + return len(self._dict) + def __setitem__(self, key: str, value): k = key.lower() if k in self._dict: key = self._dict[k][0] self._dict[k] = key, value - def __contains__(self, key): - return key.lower() in self._dict + def __delitem__(self, key: str): + del self._dict[key.lower()] - def keys(self) -> Iterable[str]: - return self._dict.keys() + def get_key(self, key: str) -> str: + return self._dict[key.lower()][0] - def items(self) -> Iterable[Tuple[str, V]]: - return ((k, v[1]) for k, v in self._dict.items()) + def __repr__(self) -> str: + return repr(dict(self.items())) class CaseSensitiveDict(dict, CaseAwareMapping): @@ -285,3 +283,14 @@ def run_as_daemon(threadfunc, *args): th.daemon = True th.start() return th + + +def getLogger(name): + return logging.getLogger(name.rsplit(".", 1)[-1]) + + +def eval_name_template(name): + def get_timestamp(_match): + return datetime.now().isoformat("_", "seconds").replace(":", "_") + + return re.sub("%t", get_timestamp, name) diff --git a/docker-compose.yml b/docker-compose.yml index b23c3b1e..60bab061 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,7 +2,7 @@ version: "3.8" services: postgres: - container_name: postgresql + container_name: dd-postgresql image: postgres:14.1-alpine # work_mem: less tmp files # maintenance_work_mem: improve table-level op perf @@ -25,7 +25,7 @@ services: - local mysql: - container_name: mysql + container_name: dd-mysql image: mysql:oracle # fsync less aggressively for insertion perf for test setup command: > @@ -52,7 +52,7 @@ services: - local clickhouse: - container_name: clickhouse + container_name: dd-clickhouse image: clickhouse/clickhouse-server:21.12.3.32 restart: always volumes: @@ -76,6 +76,7 @@ services: # prestodb.dbapi.connect(host="127.0.0.1", user="presto").cursor().execute('SELECT * FROM system.runtime.nodes') presto: + container_name: dd-presto build: context: ./dev dockerfile: ./Dockerfile.prestosql.340 @@ -88,6 +89,7 @@ services: - local trino: + container_name: dd-trino image: 'trinodb/trino:389' hostname: trino ports: @@ -98,7 +100,7 @@ services: - local vertica: - container_name: vertica + container_name: dd-vertica image: vertica/vertica-ce:12.0.0-0 restart: always volumes: diff --git a/docs/conf.py b/docs/conf.py index ef75ecc0..dc58fb90 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -41,6 +41,7 @@ "recommonmark", "sphinx_markdown_tables", "sphinx_copybutton", + "enum_tools.autoenum", # 'sphinx_gallery.gen_gallery' ] diff --git a/docs/how-to-use.md b/docs/how-to-use.md new file mode 100644 index 00000000..b5a1f5bb --- /dev/null +++ b/docs/how-to-use.md @@ -0,0 +1,164 @@ +# How to use + +## How to use from the shell (or: command-line) + +Run the following command: + +```bash + # Same-DB diff, using outer join + $ data-diff DB TABLE1 TABLE2 [options] + + # Cross-DB diff, using hashes + $ data-diff DB1 TABLE1 DB2 TABLE2 [options] +``` + +Where DB is either a database URL that's compatible with SQLAlchemy, or the name of a database specified in a configuration file. + +We recommend using a configuration file, with the ``--conf`` switch, to keep the command simple and managable. + +For a list of example URLs, see [list of supported databases](supported-databases.md). + +Note: Because URLs allow many special characters, and may collide with the syntax of your command-line, +it's recommended to surround them with quotes. + +### Options + + - `--help` - Show help message and exit. + - `-k` or `--key-columns` - Name of the primary key column. If none provided, default is 'id'. + - `-t` or `--update-column` - Name of updated_at/last_updated column + - `-c` or `--columns` - Names of extra columns to compare. Can be used more than once in the same command. + Accepts a name or a pattern like in SQL. + Example: `-c col% -c another_col -c %foorb.r%` + - `-l` or `--limit` - Maximum number of differences to find (limits maximum bandwidth and runtime) + - `-s` or `--stats` - Print stats instead of a detailed diff + - `-d` or `--debug` - Print debug info + - `-v` or `--verbose` - Print extra info + - `-i` or `--interactive` - Confirm queries, implies `--debug` + - `--json` - Print JSONL output for machine readability + - `--min-age` - Considers only rows older than specified. Useful for specifying replication lag. + Example: `--min-age=5min` ignores rows from the last 5 minutes. + Valid units: `d, days, h, hours, min, minutes, mon, months, s, seconds, w, weeks, y, years` + - `--max-age` - Considers only rows younger than specified. See `--min-age`. + - `-j` or `--threads` - Number of worker threads to use per database. Default=1. + - `-w`, `--where` - An additional 'where' expression to restrict the search space. + - `--conf`, `--run` - Specify the run and configuration from a TOML file. (see below) + - `--no-tracking` - data-diff sends home anonymous usage data. Use this to disable it. + + **The following two options are not available when using the pre release In-DB feature:** + + - `--bisection-threshold` - Minimal size of segment to be split. Smaller segments will be downloaded and compared locally. + - `--bisection-factor` - Segments per iteration. When set to 2, it performs binary search. + +**In-DB commands, available in pre release only:** + - `-m`, `--materialize` - Materialize the diff results into a new table in the database. + If a table exists by that name, it will be replaced. + Use `%t` in the name to place a timestamp. + Example: `-m test_mat_%t` + - `--assume-unique-key` - Skip validating the uniqueness of the key column during joindiff, which is costly in non-cloud dbs. + - `--sample-exclusive-rows` - Sample several rows that only appear in one of the tables, but not the other. Use with `-s`. + - `--materialize-all-rows` - Materialize every row, even if they are the same, instead of just the differing rows. + - `--table-write-limit` - Maximum number of rows to write when creating materialized or sample tables, per thread. Default=1000. + - `-a`, `--algorithm` `[auto|joindiff|hashdiff]` - Force algorithm choice + + + +### How to use with a configuration file + +Data-diff lets you load the configuration for a run from a TOML file. + +**Reasons to use a configuration file:** + +- Convenience: Set-up the parameters for diffs that need to run often + +- Easier and more readable: You can define the database connection settings as config values, instead of in a URI. + +- Gives you fine-grained control over the settings switches, without requiring any Python code. + +Use `--conf` to specify that path to the configuration file. data-diff will load the settings from `run.default`, if it's defined. + +Then you can, optionally, use `--run` to choose to load the settings of a specific run, and override the settings `run.default`. (all runs extend `run.default`, like inheritance). + +Finally, CLI switches have the final say, and will override the settings defined by the configuration file, and the current run. + +Example TOML file: + +```toml +# Specify the connection params to the test database. +[database.test_postgresql] +driver = "postgresql" +user = "postgres" +password = "Password1" + +# Specify the default run params +[run.default] +update_column = "timestamp" +verbose = true + +# Specify params for a run 'test_diff'. +[run.test_diff] +verbose = false +# Source 1 ("left") +1.database = "test_postgresql" # Use options from database.test_postgresql +1.table = "rating" +# Source 2 ("right") +2.database = "postgresql://postgres:Password1@/" # Use URI like in the CLI +2.table = "rating_del1" +``` + +In this example, running `data-diff --conf myconfig.toml --run test_diff` will compare between `rating` and `rating_del1`. +It will use the `timestamp` column as the update column, as specified in `run.default`. However, it won't be verbose, since that +flag is overwritten to `false`. + +Running it with `data-diff --conf myconfig.toml --run test_diff -v` will set verbose back to `true`. + + +## How to use from Python + +Import the `data_diff` module, and use the following functions: + +- `connect_to_table()` to connect to a specific table in the database + +- `diff_tables()` to diff those tables + + +Example: + +```python +# Optional: Set logging to display the progress of the diff +import logging +logging.basicConfig(level=logging.INFO) + +from data_diff import connect_to_table, diff_tables + +table1 = connect_to_table("postgresql:///", "table_name", "id") +table2 = connect_to_table("mysql:///", "table_name", "id") + +for different_row in diff_tables(table1, table2): + plus_or_minus, columns = different_row + print(plus_or_minus, columns) +``` + +Run `help(diff_tables)` or [read the docs](https://data-diff.readthedocs.io/en/latest/) to learn about the different options. + +## Usage Analytics & Data Privacy + +data-diff collects anonymous usage data to help our team improve the tool and to apply development efforts to where our users need them most. + +We capture two events: one when the data-diff run starts, and one when it is finished. No user data or potentially sensitive information is or ever will be collected. The captured data is limited to: + +- Operating System and Python version +- Types of databases used (postgresql, mysql, etc.) +- Sizes of tables diffed, run time, and diff row count (numbers only) +- Error message, if any, truncated to the first 20 characters. +- A persistent UUID to indentify the session, stored in `~/.datadiff.toml` + +If you do not wish to participate, the tracking can be easily disabled with one of the following methods: + +* In the CLI, use the `--no-tracking` flag. +* In the config file, set `no_tracking = true` (for example, under `[run.default]`) +* If you're using the Python API: +```python +import data_diff +data_diff.disable_tracking() # Call this first, before making any API calls +# Connect and diff your tables without any tracking +``` diff --git a/docs/index.rst b/docs/index.rst index af5f5c5d..7b78b66f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,7 +3,10 @@ :caption: Reference :hidden: + supported-databases + how-to-use python-api + technical-explanation new-database-driver-guide Introduction @@ -33,45 +36,36 @@ Requires Python 3.7+ with pip. pip install data-diff -or when you need extras like mysql and postgresql: +For installing with 3rd-party database connectors, use the following syntax: :: - pip install "data-diff[mysql,postgresql]" - - -How to use from Python ----------------------- - -.. code-block:: python + pip install "data-diff[db1,db2]" - # Optional: Set logging to display the progress of the diff - import logging - logging.basicConfig(level=logging.INFO) - - from data_diff import connect_to_table, diff_tables + e.g. + pip install "data-diff[mysql,postgresql]" - table1 = connect_to_table("postgresql:///", "table_name", "id") - table2 = connect_to_table("mysql:///", "table_name", "id") +Supported connectors: - for sign, columns in diff_tables(table1, table2): - print(sign, columns) +- mysql +- postgresql +- snowflake +- presto +- oracle +- trino +- clickhouse +- vertica - # Example output: - + ('4775622148347', '2022-06-05 16:57:32.000000') - - ('4775622312187', '2022-06-05 16:57:32.000000') - - ('4777375432955', '2022-06-07 16:57:36.000000') Resources --------- -- Source code (git): ``_ -- API Reference - - :doc:`python-api` -- Guides +- Users + - Source code (git): ``_ + - :doc:`supported-databases` + - :doc:`how-to-use` + - :doc:`python-api` + - :doc:`technical-explanation` +- Contributors - :doc:`new-database-driver-guide` -- Tutorials - - TODO - - diff --git a/docs/python-api.rst b/docs/python-api.rst index d2b18636..ada633d1 100644 --- a/docs/python-api.rst +++ b/docs/python-api.rst @@ -5,15 +5,26 @@ Python API Reference .. autofunction:: connect -.. autoclass:: TableDiffer +.. autofunction:: connect_to_table + +.. autofunction:: diff_tables + +.. autoclass:: HashDiffer + :members: __init__, diff_tables + +.. autoclass:: JoinDiffer :members: __init__, diff_tables .. autoclass:: TableSegment - :members: __init__, get_values, choose_checkpoints, segment_by_checkpoints, count, count_and_checksum, is_bounded, new + :members: __init__, get_values, choose_checkpoints, segment_by_checkpoints, count, count_and_checksum, is_bounded, new, with_schema .. autoclass:: data_diff.databases.database_types.AbstractDatabase :members: +.. autoclass:: data_diff.databases.database_types.AbstractDialect + :members: + .. autodata:: DbKey .. autodata:: DbTime .. autodata:: DbPath +.. autoenum:: Algorithm diff --git a/docs/requirements.txt b/docs/requirements.txt index 0d1d793a..252c7acb 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,6 +4,6 @@ sphinx_markdown_tables sphinx-copybutton sphinx-rtd-theme recommonmark +enum-tools[sphinx] -# Requirements. TODO Use poetry instead of this redundant list data_diff diff --git a/docs/supported-databases.md b/docs/supported-databases.md new file mode 100644 index 00000000..7cfef6ad --- /dev/null +++ b/docs/supported-databases.md @@ -0,0 +1,29 @@ +# List of supported databases + +| Database | Status | Connection string | +|---------------|-------------------------------------------------------------------------------------------------------------------------------------|--------| +| PostgreSQL >=10 | πŸ’š | `postgresql://:@:5432/` | +| MySQL | πŸ’š | `mysql://:@:5432/` | +| Snowflake | πŸ’š | `"snowflake://[:]@//?warehouse=&role=[&authenticator=externalbrowser]"` | +| BigQuery | πŸ’š | `bigquery:///` | +| Redshift | πŸ’š | `redshift://:@:5439/` | +| Oracle | πŸ’› | `oracle://:@/database` | +| Presto | πŸ’› | `presto://:@:8080/` | +| Databricks | πŸ’› | `databricks://:@//` | +| Trino | πŸ’› | `trino://:@:8080/` | +| Clickhouse | πŸ’› | `clickhouse://:@:9000/` | +| Vertica | πŸ’› | `vertica://:@:5433/` | +| ElasticSearch | πŸ“ | | +| Planetscale | πŸ“ | | +| Pinot | πŸ“ | | +| Druid | πŸ“ | | +| Kafka | πŸ“ | | +| DuckDB | πŸ“ | | +| SQLite | πŸ“ | | + +* πŸ’š: Implemented and thoroughly tested. +* πŸ’›: Implemented, but not thoroughly tested yet. +* ⏳: Implementation in progress. +* πŸ“: Implementation planned. Contributions welcome. + +Is your database not listed here? We accept pull-requests! diff --git a/docs/technical-explanation.md b/docs/technical-explanation.md new file mode 100644 index 00000000..572bd5eb --- /dev/null +++ b/docs/technical-explanation.md @@ -0,0 +1,191 @@ +# Technical explanation + +data-diff can diff tables within the same database, or across different databases. + +**Same-DB Diff:** +- Uses an outer-join to diff the rows as efficiently and accurately as possible. +- Supports materializing the diff results to a database table. +- Can also collect various extra statistics about the tables. + +**Cross-DB Diff:** Employs a divide and conquer algorithm based on hashing, optimized for few changes. + +The following is a technical explanation of the cross-db diff. + +### Overview + +data-diff splits the table into smaller segments, then checksums each segment in both databases. When the checksums for a segment aren't equal, it will further divide that segment into yet smaller segments, checksumming those until it gets to the differing row(s). + +This approach has performance within an order of magnitude of count(*) when there are few/no changes, but is able to output each differing row! By pushing the compute into the databases, it's much faster than querying for and comparing every row. + +![Performance for 100M rows](https://user-images.githubusercontent.com/97400/175182987-a3900d4e-c097-4732-a4e9-19a40fac8cdc.png) + +**†:** The implementation for downloading all rows that `data-diff` and +`count(*)` is compared to is not optimal. It is a single Python multi-threaded +process. The performance is fairly driver-specific, e.g. PostgreSQL's performs 10x +better than MySQL. + +### Deep Dive + +In this section we'll be doing a walk-through of exactly how **data-diff** +works, and how to tune `--bisection-factor` and `--bisection-threshold`. + +Let's consider a scenario with an `orders` table with 1M rows. Fivetran is +replicating it contionously from PostgreSQL to Snowflake: + +``` +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ PostgreSQL β”‚ β”‚ Snowflake β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ table with β”‚ +β”‚ table with β”œβ”€β”€β”€ replication β”œβ”€β”€β”€β”€β”€β”€β–Άβ”‚ ?maybe? all β”‚ +β”‚lots of rows!β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ the same β”‚ +β”‚ β”‚ β”‚ rows. β”‚ +β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +In order to check whether the two tables are the same, **data-diff** splits +the table into `--bisection-factor=10` segments. + +We also have to choose which columns we want to checksum. In our case, we care +about the primary key, `--key-column=id` and the update column +`--update-column=updated_at`. `updated_at` is updated every time the row is, and +we have an index on it. + +**data-diff** starts by querying both databases for the `min(id)` and `max(id)` +of the table. Then it splits the table into `--bisection-factor=10` segments of +`1M/10 = 100K` keys each: + +``` +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ PostgreSQL β”‚ β”‚ Snowflake β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ id=1..100k β”‚ β”‚ id=1..100k β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ id=100k..200k β”‚ β”‚ id=100k..200k β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ id=200k..300k β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Άβ”‚ id=200k..300k β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ id=300k..400k β”‚ β”‚ id=300k..400k β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ ... β”‚ β”‚ ... β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ 900k..100k β”‚ β”‚ 900k..100k β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–²β”€β”€β”˜ β””β–²β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + ┃ ┃ + ┃ ┃ + ┃ checksum queries ┃ + ┃ ┃ + β”Œβ”€β”»β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”»β”€β”€β”€β”€β” + β”‚ data-diff β”‚ + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +Now **data-diff** will start running `--threads=1` queries in parallel that +checksum each segment. The queries for checksumming each segment will look +something like this, depending on the database: + +```sql +SELECT count(*), + sum(cast(conv(substring(md5(concat(cast(id as char), cast(timestamp as char))), 18), 16, 10) as unsigned)) +FROM `rating_del1` +WHERE (id >= 1) AND (id < 100000) +``` + +This keeps the amount of data that has to be transferred between the databases +to a minimum, making it very performant! Additionally, if you have an index on +`updated_at` (highly recommended), then the query will be fast, as the database +only has to do a partial index scan between `id=1..100k`. + +If you are not sure whether the queries are using an index, you can run it with +`--interactive`. This puts **data-diff** in interactive mode, where it shows an +`EXPLAIN` before executing each query, requiring confirmation to proceed. + +After running the checksum queries on both sides, we see that all segments +are the same except `id=100k..200k`: + +``` +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ PostgreSQL β”‚ β”‚ Snowflake β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ checksum=0102 β”‚ β”‚ checksum=0102 β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ mismatch! β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ checksum=ffff ◀──────────────▢ checksum=aaab β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ checksum=abab β”‚ β”‚ checksum=abab β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ checksum=f0f0 β”‚ β”‚ checksum=f0f0 β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ ... β”‚ β”‚ ... β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ checksum=9494 β”‚ β”‚ checksum=9494 β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +Now **data-diff** will do exactly as it just did for the _whole table_ for only +this segment: Split it into `--bisection-factor` segments. + +However, this time, because each segment has `100k/10=10k` entries, which is +less than the `--bisection-threshold`, it will pull down every row in the segment +and compare them in memory in **data-diff**. + +``` +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ PostgreSQL β”‚ β”‚ Snowflake β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ id=100k..110k β”‚ β”‚ id=100k..110k β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ id=110k..120k β”‚ β”‚ id=110k..120k β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ id=120k..130k β”‚ β”‚ id=120k..130k β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ id=130k..140k β”‚ β”‚ id=130k..140k β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ ... β”‚ β”‚ ... β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ 190k..200k β”‚ β”‚ 190k..200k β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +Finally **data-diff** will output the `(id, updated_at)` for each row that was different: + +``` +(122001, 1653672821) +``` + +If you pass `--stats` you'll see stats such as the % of rows were different. + +### Performance Considerations + +* Ensure that you have indexes on the columns you are comparing. Preferably a + compound index. You can run with `--interactive` to see an `EXPLAIN` for the + queries. +* Consider increasing the number of simultaneous threads executing + queries per database with `--threads`. For databases that limit concurrency + per query, such as PostgreSQL/MySQL, this can improve performance dramatically. +* If you are only interested in _whether_ something changed, pass `--limit 1`. + This can be useful if changes are very rare. This is often faster than doing a + `count(*)`, for the reason mentioned above. +* If the table is _very_ large, consider a larger `--bisection-factor`. Otherwise, you may run into timeouts. +* If there are a lot of changes, consider a larger `--bisection-threshold`. +* If there are very large gaps in your key column (e.g., 10s of millions of + continuous rows missing), then **data-diff** may perform poorly, doing lots of + queries for ranges of rows that do not exist. We have ideas on how to tackle this issue, which we have yet to implement. If you're experiencing this effect, please open an issue, and we + will prioritize it. +* The fewer columns you verify (passed with `--columns`), the faster + **data-diff** will be. On one extreme, you can verify every column; on the + other, you can verify _only_ `updated_at`, if you trust it enough. You can also + _only_ verify `id` if you're interested in only presence, such as to detect + missing hard deletes. You can do also do a hybrid where you verify + `updated_at` and the most critical value, such as a money value in `amount`, but + not verify a large serialized column like `json_settings`. +* We have ideas for making **data-diff** even faster that + we haven't implemented yet: faster checksums by reducing type-casts + and using a faster hash than MD5, dynamic adaptation of + `bisection_factor`/`threads`/`bisection_threshold` (especially with large key + gaps), and improvements to bypass Python/driver performance limitations when + comparing huge amounts of rows locally (i.e. for very high `bisection_threshold` values). diff --git a/poetry.lock b/poetry.lock index 51ad249e..e8adc739 100644 --- a/poetry.lock +++ b/poetry.lock @@ -19,7 +19,7 @@ optional = false python-versions = "*" [[package]] -name = "backports.zoneinfo" +name = "backports-zoneinfo" version = "0.2.1" description = "Backport of the standard library zoneinfo module" category = "main" @@ -57,7 +57,7 @@ optional = false python-versions = ">=3.6.0" [package.extras] -unicode_backport = ["unicodedata2"] +unicode-backport = ["unicodedata2"] [[package]] name = "click" @@ -90,11 +90,11 @@ zstd = ["clickhouse-cityhash (>=1.0.2.1)", "zstd"] [[package]] name = "colorama" -version = "0.4.5" +version = "0.4.6" description = "Cross-platform colored terminal text." category = "main" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" [[package]] name = "commonmark" @@ -191,25 +191,25 @@ optional = false python-versions = "*" [package.extras] -atomic_cache = ["atomicwrites"] +atomic-cache = ["atomicwrites"] nearley = ["js2py"] regex = ["regex"] [[package]] name = "mysql-connector-python" -version = "8.0.30" +version = "8.0.29" description = "MySQL driver written in Python" category = "main" optional = false python-versions = "*" [package.dependencies] -protobuf = ">=3.11.0,<=3.20.1" +protobuf = ">=3.0.0" [package.extras] -compression = ["lz4 (>=2.1.6,<=3.1.3)", "zstandard (>=0.12.0,<=0.15.2)"] -dns-srv = ["dnspython (>=1.16.0,<=2.1.0)"] -gssapi = ["gssapi (>=1.6.9,<=1.7.3)"] +compression = ["lz4 (>=2.1.6)", "zstandard (>=0.12.0)"] +dns-srv = ["dnspython (>=1.16.0)"] +gssapi = ["gssapi (>=1.6.9)"] [[package]] name = "oscrypto" @@ -270,7 +270,7 @@ six = "*" [package.extras] all = ["google-auth", "requests-kerberos"] -google_auth = ["google-auth"] +google-auth = ["google-auth"] kerberos = ["requests-kerberos"] tests = ["google-auth", "httpretty", "pytest", "pytest-runner", "requests-kerberos"] @@ -287,15 +287,15 @@ wcwidth = "*" [[package]] name = "protobuf" -version = "3.20.1" -description = "Protocol Buffers" +version = "4.21.9" +description = "" category = "main" optional = false python-versions = ">=3.7" [[package]] name = "psycopg2" -version = "2.9.4" +version = "2.9.5" description = "psycopg2 - Python-PostgreSQL Database Adapter" category = "main" optional = false @@ -318,7 +318,7 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [[package]] -name = "Pygments" +name = "pygments" version = "2.13.0" description = "Pygments is a syntax highlighting package written in Python." category = "main" @@ -329,21 +329,21 @@ python-versions = ">=3.6" plugins = ["importlib-metadata"] [[package]] -name = "PyJWT" -version = "2.5.0" +name = "pyjwt" +version = "2.6.0" description = "JSON Web Token implementation in Python" category = "main" optional = false python-versions = ">=3.7" [package.extras] -crypto = ["cryptography (>=3.3.1)", "types-cryptography (>=3.3.21)"] -dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.3.1)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "types-cryptography (>=3.3.21)", "zope.interface"] +crypto = ["cryptography (>=3.4.0)"] +dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] [[package]] -name = "pyOpenSSL" +name = "pyopenssl" version = "22.0.0" description = "Python wrapper module around the OpenSSL library" category = "main" @@ -370,7 +370,7 @@ six = ">=1.5" [[package]] name = "pytz" -version = "2022.4" +version = "2022.5" description = "World timezone definitions, modern and historical" category = "main" optional = false @@ -404,7 +404,7 @@ urllib3 = ">=1.21.1,<1.27" [package.extras] socks = ["PySocks (>=1.5.6,!=1.5.7)"] -use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"] +use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "rich" @@ -432,7 +432,7 @@ python-versions = ">=3.6,<4.0" [[package]] name = "setuptools" -version = "65.4.1" +version = "65.5.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" category = "main" optional = false @@ -519,7 +519,7 @@ python-versions = ">=3.7" [[package]] name = "tzdata" -version = "2022.4" +version = "2022.6" description = "Provider of IANA time zone data" category = "main" optional = false @@ -544,7 +544,7 @@ test = ["pytest (>=4.3)", "pytest-mock (>=3.3)"] [[package]] name = "unittest-parallel" -version = "1.5.2" +version = "1.5.3" description = "Parallel unit test runner with coverage support" category = "dev" optional = false @@ -588,7 +588,7 @@ python-versions = "*" [[package]] name = "zipp" -version = "3.9.0" +version = "3.10.0" description = "Backport of pathlib-compatible object wrapper for zip files" category = "main" optional = false @@ -612,7 +612,7 @@ vertica = [] [metadata] lock-version = "1.1" python-versions = "^3.7" -content-hash = "6f68ef35366f62a4a47721baed5b2734e0a9015d7c4a49ff9a8410284acb6e71" +content-hash = "2ee6a778364480c8f72eda863926460037a8a7f580dc9d920388ec8c178ddb35" [metadata.files] arrow = [ @@ -623,7 +623,7 @@ asn1crypto = [ {file = "asn1crypto-1.5.1-py2.py3-none-any.whl", hash = "sha256:db4e40728b728508912cbb3d44f19ce188f218e9eba635821bb4b68564f8fd67"}, {file = "asn1crypto-1.5.1.tar.gz", hash = "sha256:13ae38502be632115abf8a24cbe5f4da52e3b5231990aff31123c805306ccb9c"}, ] -"backports.zoneinfo" = [ +backports-zoneinfo = [ {file = "backports.zoneinfo-0.2.1-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:da6013fd84a690242c310d77ddb8441a559e9cb3d3d59ebac9aca1a57b2e18bc"}, {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:89a48c0d158a3cc3f654da4c2de1ceba85263fafb861b98b59040a5086259722"}, {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:1c5742112073a563c81f786e77514969acb58649bcdf6cdf0b4ed31a348d4546"}, @@ -734,6 +734,19 @@ clickhouse-driver = [ {file = "clickhouse_driver-0.2.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8c776ab9592d456351ba2c4b05a8149761f36991033257d9c31ef6d26952dbe7"}, {file = "clickhouse_driver-0.2.4-cp310-cp310-win32.whl", hash = "sha256:bb1423d6daa8736aade0f7d31870c28ab2e7553a21cf923af6f1ff4a219c0ac9"}, {file = "clickhouse_driver-0.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:3955616073d030dc8cc7b0ef68ffe045510334137c1b5d11447347352d0dec28"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cd217a28e3821cbba49fc0ea87e0e6dde799e62af327d1f9c5f9480abd71e17c"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f4863a8eb369a36266f372e9eacc3fe222e4d31e1b2a7a2da759b521fffad1c"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c87ce674ebb2f5e38b68c36464538fffdea4f5432bb136cb6980489ae3c6dbe9"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:383690650fccaffa7f0e56d6fb0b00b9227b408fb3d92291a1f1ed66ce83df7c"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a68ac4633fd4cf265e619adeec1c0ee67ff1d9b5373c140b8400adcc4831c19"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5551056e5ab9e1dac67abbdf202b6e67590aa79a013a9da8ecbaec828e0790fe"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:93b601573779c5e8c2344cd983ef87de0fc8a31392bdc8571e79ed318f30dbbb"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:f228edc258f9ef6ee29b5618832b38479a0cfaa5bb837150ba62bbc1357a58cd"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:4b39c0f962a3664a72e0bfaa30929d0317b5e3427ff8b36d7889f8d31f4ff89e"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:47465598241cdf0b3a7810667c6104ada7b992a797768883ce30635a213568c3"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e9f31e7ccfd5cf526fd9db50ade94504007992922c8e556ba54e8ba637e9cca0"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-win32.whl", hash = "sha256:da0bcd41aeb50ec4316808c11d591aef60fe1b9de997df10ffcad9ab3cb0efa2"}, + {file = "clickhouse_driver-0.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:bb64ad0dfcc5ee158b01411e7a828bb3b20e4a2bc2f99da219acff0a9d18808c"}, {file = "clickhouse_driver-0.2.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b746e83652fbb89cb907adfdd69d6a7c7019bb5bbbdf68454bdcd16b09959a00"}, {file = "clickhouse_driver-0.2.4-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b9cf37f0b7165619d2e0188a71018300ed1a937ca518b01a5d168aec0a09add"}, {file = "clickhouse_driver-0.2.4-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:be4da882979c5a6d5631b5db6acb6407d76c73be6635ab0a76378b98aac8e5ab"}, @@ -793,8 +806,8 @@ clickhouse-driver = [ {file = "clickhouse_driver-0.2.4-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:821c7efff84fda8b68680140e07241991ab27296dc62213eb788efef4470fdd5"}, ] colorama = [ - {file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"}, - {file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"}, + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] commonmark = [ {file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"}, @@ -893,27 +906,24 @@ lark-parser = [ {file = "lark-parser-0.11.3.tar.gz", hash = "sha256:e29ca814a98bb0f81674617d878e5f611cb993c19ea47f22c80da3569425f9bd"}, ] mysql-connector-python = [ - {file = "mysql-connector-python-8.0.30.tar.gz", hash = "sha256:59a8592e154c874c299763bb8aa12c518384c364bcfd0d193e85c869ea81a895"}, - {file = "mysql_connector_python-8.0.30-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f1eb74eb30bb04ff314f5e19af5421d23b504e41d16ddcee2603b4100d18fd68"}, - {file = "mysql_connector_python-8.0.30-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:712cdfa97f35fec715e8d7aaa15ed9ce04f3cf71b3c177fcca273047040de9f2"}, - {file = "mysql_connector_python-8.0.30-cp310-cp310-manylinux1_i686.whl", hash = "sha256:ce23ca9c27e1f7b4707b3299ce515125f312736d86a7e5b2aa778484fa3ffa10"}, - {file = "mysql_connector_python-8.0.30-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:8876b1d51cae33cdfe7021d68206661e94dcd2666e5e14a743f8321e2b068e84"}, - {file = "mysql_connector_python-8.0.30-cp310-cp310-win_amd64.whl", hash = "sha256:41a04d1900e366bf6c2a645ead89ab9a567806d5ada7d417a3a31f170321dd14"}, - {file = "mysql_connector_python-8.0.30-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:7f771bd5cba3ade6d9f7a649e65d7c030f69f0e69980632b5cbbd3d19c39cee5"}, - {file = "mysql_connector_python-8.0.30-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:611c6945805216104575f7143ff6497c87396ce82d3257e6da7257b65406f13e"}, - {file = "mysql_connector_python-8.0.30-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:47deb8c3324db7eb2bfb720ec8084d547b1bce457672ea261bc21836024249db"}, - {file = "mysql_connector_python-8.0.30-cp37-cp37m-win_amd64.whl", hash = "sha256:234c6b156a1989bebca6eb564dc8f2e9d352f90a51bd228ccd68eb66fcd5fd7a"}, - {file = "mysql_connector_python-8.0.30-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8b7d50c221320b0e609dce9ca8801ab2f2a748dfee65cd76b1e4c6940757734a"}, - {file = "mysql_connector_python-8.0.30-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:d8f74c9388176635f75c01d47d0abc783a47e58d7f36d04fb6ee40ab6fb35c9b"}, - {file = "mysql_connector_python-8.0.30-cp38-cp38-manylinux1_i686.whl", hash = "sha256:1d9d3af14594aceda2c3096564b4c87ffac21e375806a802daeaf7adcd18d36b"}, - {file = "mysql_connector_python-8.0.30-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:f5d812245754d4759ebc8c075662fef65397e1e2a438a3c391eac9d545077b8b"}, - {file = "mysql_connector_python-8.0.30-cp38-cp38-win_amd64.whl", hash = "sha256:a130c5489861c7ff2990e5b503c37beb2fb7b32211b92f9107ad864ee90654c0"}, - {file = "mysql_connector_python-8.0.30-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:954a1fc2e9a811662c5b17cea24819c020ff9d56b2ff8e583dd0a233fb2399f6"}, - {file = "mysql_connector_python-8.0.30-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:62266d1b18cb4e286a05df0e1c99163a4955c82d41045305bcf0ab2aac107843"}, - {file = "mysql_connector_python-8.0.30-cp39-cp39-manylinux1_i686.whl", hash = "sha256:36e763f21e62b3c9623a264f2513ee11924ea1c9cc8640c115a279d3087064be"}, - {file = "mysql_connector_python-8.0.30-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:b5dc0f3295e404f93b674bfaff7589a9fbb8b5ae6c1c134112a1d1beb2f664b2"}, - {file = "mysql_connector_python-8.0.30-cp39-cp39-win_amd64.whl", hash = "sha256:33c4e567547a9a1868462fda8f2b19ea186a7b1afe498171dca39c0f3aa43a75"}, - {file = "mysql_connector_python-8.0.30-py2.py3-none-any.whl", hash = "sha256:f1d40cac9c786e292433716c1ade7a8968cbc3ea177026697b86a63188ddba34"}, + {file = "mysql-connector-python-8.0.29.tar.gz", hash = "sha256:29ec05ded856b4da4e47239f38489c03b31673ae0f46a090d0e4e29c670e6181"}, + {file = "mysql_connector_python-8.0.29-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:bed43ea3a11f8d4e7c2e3f20c891214e68b45451314f91fddf9ca701de7a53ac"}, + {file = "mysql_connector_python-8.0.29-cp310-cp310-manylinux1_i686.whl", hash = "sha256:6e2267ad75b37b5e1c480cde77cdc4f795427a54266ead30aabcdbf75ac70064"}, + {file = "mysql_connector_python-8.0.29-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:d5afb766b379111942d4260f29499f93355823c7241926471d843c9281fe477c"}, + {file = "mysql_connector_python-8.0.29-cp310-cp310-win_amd64.whl", hash = "sha256:4de5959e27038cbd11dfccb1afaa2fd258c013e59d3e15709dd1992086103050"}, + {file = "mysql_connector_python-8.0.29-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:895135cde57622edf48e1fce3beb4ed85f18332430d48f5c1d9630d49f7712b0"}, + {file = "mysql_connector_python-8.0.29-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:fdd262d8538aa504475f8860cfda939a297d3b213c8d15f7ceed52508aeb2aa3"}, + {file = "mysql_connector_python-8.0.29-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:89597c091c4f25b6e023cbbcd32be73affbb0b44256761fe3b8e1d4b14d14d02"}, + {file = "mysql_connector_python-8.0.29-cp37-cp37m-win_amd64.whl", hash = "sha256:ab0e9d9b5fc114b78dfa9c74e8bfa30b48fcfa17dbb9241ad6faada08a589900"}, + {file = "mysql_connector_python-8.0.29-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:245087999f081b389d66621f2abfe2463e3927f63c7c4c0f70ce0f82786ccb93"}, + {file = "mysql_connector_python-8.0.29-cp38-cp38-manylinux1_i686.whl", hash = "sha256:5eef51e48b22aadd633563bbdaf02112d98d954a4ead53f72fde283ea3f88152"}, + {file = "mysql_connector_python-8.0.29-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:b7dccd7f72f19c97b58428ebf8e709e24eb7e9b67a408af7e77b60efde44bea4"}, + {file = "mysql_connector_python-8.0.29-cp38-cp38-win_amd64.whl", hash = "sha256:7be3aeff73b85eab3af2a1e80c053a98cbcb99e142192e551ebd4c1e41ce2596"}, + {file = "mysql_connector_python-8.0.29-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:a7fd6a71df824f5a7d9a94060598d67b3a32eeccdc9837ee2cd98a44e2536cae"}, + {file = "mysql_connector_python-8.0.29-cp39-cp39-manylinux1_i686.whl", hash = "sha256:fd608c288f596c4c8767d9a8e90f129385bd19ee6e3adaf6974ad8012c6138b8"}, + {file = "mysql_connector_python-8.0.29-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:f353893481476a537cca7afd4e81e0ed84dd2173932b7f1721ab3e3351cbf324"}, + {file = "mysql_connector_python-8.0.29-cp39-cp39-win_amd64.whl", hash = "sha256:1bef2a4a2b529c6e9c46414100ab7032c252244e8a9e017d2b6a41bb9cea9312"}, + {file = "mysql_connector_python-8.0.29-py2.py3-none-any.whl", hash = "sha256:047420715bbb51d3cba78de446c8a6db4666459cd23e168568009c620a3f5b90"}, ] oscrypto = [ {file = "oscrypto-1.3.0-py2.py3-none-any.whl", hash = "sha256:2b2f1d2d42ec152ca90ccb5682f3e051fb55986e1b170ebde472b133713e7085"}, @@ -936,43 +946,33 @@ prompt-toolkit = [ {file = "prompt_toolkit-3.0.31.tar.gz", hash = "sha256:9ada952c9d1787f52ff6d5f3484d0b4df8952787c087edf6a1f7c2cb1ea88148"}, ] protobuf = [ - {file = "protobuf-3.20.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3cc797c9d15d7689ed507b165cd05913acb992d78b379f6014e013f9ecb20996"}, - {file = "protobuf-3.20.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ff8d8fa42675249bb456f5db06c00de6c2f4c27a065955917b28c4f15978b9c3"}, - {file = "protobuf-3.20.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:cd68be2559e2a3b84f517fb029ee611546f7812b1fdd0aa2ecc9bc6ec0e4fdde"}, - {file = "protobuf-3.20.1-cp310-cp310-win32.whl", hash = "sha256:9016d01c91e8e625141d24ec1b20fed584703e527d28512aa8c8707f105a683c"}, - {file = "protobuf-3.20.1-cp310-cp310-win_amd64.whl", hash = "sha256:32ca378605b41fd180dfe4e14d3226386d8d1b002ab31c969c366549e66a2bb7"}, - {file = "protobuf-3.20.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:9be73ad47579abc26c12024239d3540e6b765182a91dbc88e23658ab71767153"}, - {file = "protobuf-3.20.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:097c5d8a9808302fb0da7e20edf0b8d4703274d140fd25c5edabddcde43e081f"}, - {file = "protobuf-3.20.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e250a42f15bf9d5b09fe1b293bdba2801cd520a9f5ea2d7fb7536d4441811d20"}, - {file = "protobuf-3.20.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:cdee09140e1cd184ba9324ec1df410e7147242b94b5f8b0c64fc89e38a8ba531"}, - {file = "protobuf-3.20.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:af0ebadc74e281a517141daad9d0f2c5d93ab78e9d455113719a45a49da9db4e"}, - {file = "protobuf-3.20.1-cp37-cp37m-win32.whl", hash = "sha256:755f3aee41354ae395e104d62119cb223339a8f3276a0cd009ffabfcdd46bb0c"}, - {file = "protobuf-3.20.1-cp37-cp37m-win_amd64.whl", hash = "sha256:62f1b5c4cd6c5402b4e2d63804ba49a327e0c386c99b1675c8a0fefda23b2067"}, - {file = "protobuf-3.20.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:06059eb6953ff01e56a25cd02cca1a9649a75a7e65397b5b9b4e929ed71d10cf"}, - {file = "protobuf-3.20.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:cb29edb9eab15742d791e1025dd7b6a8f6fcb53802ad2f6e3adcb102051063ab"}, - {file = "protobuf-3.20.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:69ccfdf3657ba59569c64295b7d51325f91af586f8d5793b734260dfe2e94e2c"}, - {file = "protobuf-3.20.1-cp38-cp38-win32.whl", hash = "sha256:dd5789b2948ca702c17027c84c2accb552fc30f4622a98ab5c51fcfe8c50d3e7"}, - {file = "protobuf-3.20.1-cp38-cp38-win_amd64.whl", hash = "sha256:77053d28427a29987ca9caf7b72ccafee011257561259faba8dd308fda9a8739"}, - {file = "protobuf-3.20.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f50601512a3d23625d8a85b1638d914a0970f17920ff39cec63aaef80a93fb7"}, - {file = "protobuf-3.20.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:284f86a6207c897542d7e956eb243a36bb8f9564c1742b253462386e96c6b78f"}, - {file = "protobuf-3.20.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7403941f6d0992d40161aa8bb23e12575637008a5a02283a930addc0508982f9"}, - {file = "protobuf-3.20.1-cp39-cp39-win32.whl", hash = "sha256:db977c4ca738dd9ce508557d4fce0f5aebd105e158c725beec86feb1f6bc20d8"}, - {file = "protobuf-3.20.1-cp39-cp39-win_amd64.whl", hash = "sha256:7e371f10abe57cee5021797126c93479f59fccc9693dafd6bd5633ab67808a91"}, - {file = "protobuf-3.20.1-py2.py3-none-any.whl", hash = "sha256:adfc6cf69c7f8c50fd24c793964eef18f0ac321315439d94945820612849c388"}, - {file = "protobuf-3.20.1.tar.gz", hash = "sha256:adc31566d027f45efe3f44eeb5b1f329da43891634d61c75a5944e9be6dd42c9"}, + {file = "protobuf-4.21.9-cp310-abi3-win32.whl", hash = "sha256:6e0be9f09bf9b6cf497b27425487706fa48c6d1632ddd94dab1a5fe11a422392"}, + {file = "protobuf-4.21.9-cp310-abi3-win_amd64.whl", hash = "sha256:a7d0ea43949d45b836234f4ebb5ba0b22e7432d065394b532cdca8f98415e3cf"}, + {file = "protobuf-4.21.9-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:b5ab0b8918c136345ff045d4b3d5f719b505b7c8af45092d7f45e304f55e50a1"}, + {file = "protobuf-4.21.9-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:2c9c2ed7466ad565f18668aa4731c535511c5d9a40c6da39524bccf43e441719"}, + {file = "protobuf-4.21.9-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:e575c57dc8b5b2b2caa436c16d44ef6981f2235eb7179bfc847557886376d740"}, + {file = "protobuf-4.21.9-cp37-cp37m-win32.whl", hash = "sha256:9227c14010acd9ae7702d6467b4625b6fe853175a6b150e539b21d2b2f2b409c"}, + {file = "protobuf-4.21.9-cp37-cp37m-win_amd64.whl", hash = "sha256:a419cc95fca8694804709b8c4f2326266d29659b126a93befe210f5bbc772536"}, + {file = "protobuf-4.21.9-cp38-cp38-win32.whl", hash = "sha256:5b0834e61fb38f34ba8840d7dcb2e5a2f03de0c714e0293b3963b79db26de8ce"}, + {file = "protobuf-4.21.9-cp38-cp38-win_amd64.whl", hash = "sha256:84ea107016244dfc1eecae7684f7ce13c788b9a644cd3fca5b77871366556444"}, + {file = "protobuf-4.21.9-cp39-cp39-win32.whl", hash = "sha256:f9eae277dd240ae19bb06ff4e2346e771252b0e619421965504bd1b1bba7c5fa"}, + {file = "protobuf-4.21.9-cp39-cp39-win_amd64.whl", hash = "sha256:6e312e280fbe3c74ea9e080d9e6080b636798b5e3939242298b591064470b06b"}, + {file = "protobuf-4.21.9-py2.py3-none-any.whl", hash = "sha256:7eb8f2cc41a34e9c956c256e3ac766cf4e1a4c9c925dc757a41a01be3e852965"}, + {file = "protobuf-4.21.9-py3-none-any.whl", hash = "sha256:48e2cd6b88c6ed3d5877a3ea40df79d08374088e89bedc32557348848dff250b"}, + {file = "protobuf-4.21.9.tar.gz", hash = "sha256:61f21493d96d2a77f9ca84fefa105872550ab5ef71d21c458eb80edcf4885a99"}, ] psycopg2 = [ - {file = "psycopg2-2.9.4-cp310-cp310-win32.whl", hash = "sha256:8de6a9fc5f42fa52f559e65120dcd7502394692490c98fed1221acf0819d7797"}, - {file = "psycopg2-2.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:1da77c061bdaab450581458932ae5e469cc6e36e0d62f988376e9f513f11cb5c"}, - {file = "psycopg2-2.9.4-cp36-cp36m-win32.whl", hash = "sha256:a11946bad3557ca254f17357d5a4ed63bdca45163e7a7d2bfb8e695df069cc3a"}, - {file = "psycopg2-2.9.4-cp36-cp36m-win_amd64.whl", hash = "sha256:46361c054df612c3cc813fdb343733d56543fb93565cff0f8ace422e4da06acb"}, - {file = "psycopg2-2.9.4-cp37-cp37m-win32.whl", hash = "sha256:aafa96f2da0071d6dd0cbb7633406d99f414b40ab0f918c9d9af7df928a1accb"}, - {file = "psycopg2-2.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:aa184d551a767ad25df3b8d22a0a62ef2962e0e374c04f6cbd1204947f540d61"}, - {file = "psycopg2-2.9.4-cp38-cp38-win32.whl", hash = "sha256:839f9ea8f6098e39966d97fcb8d08548fbc57c523a1e27a1f0609addf40f777c"}, - {file = "psycopg2-2.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:c7fa041b4acb913f6968fce10169105af5200f296028251d817ab37847c30184"}, - {file = "psycopg2-2.9.4-cp39-cp39-win32.whl", hash = "sha256:07b90a24d5056687781ddaef0ea172fd951f2f7293f6ffdd03d4f5077801f426"}, - {file = "psycopg2-2.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:849bd868ae3369932127f0771c08d1109b254f08d48dc42493c3d1b87cb2d308"}, - {file = "psycopg2-2.9.4.tar.gz", hash = "sha256:d529926254e093a1b669f692a3aa50069bc71faf5b0ecd91686a78f62767d52f"}, + {file = "psycopg2-2.9.5-cp310-cp310-win32.whl", hash = "sha256:d3ef67e630b0de0779c42912fe2cbae3805ebaba30cda27fea2a3de650a9414f"}, + {file = "psycopg2-2.9.5-cp310-cp310-win_amd64.whl", hash = "sha256:4cb9936316d88bfab614666eb9e32995e794ed0f8f6b3b718666c22819c1d7ee"}, + {file = "psycopg2-2.9.5-cp36-cp36m-win32.whl", hash = "sha256:b9ac1b0d8ecc49e05e4e182694f418d27f3aedcfca854ebd6c05bb1cffa10d6d"}, + {file = "psycopg2-2.9.5-cp36-cp36m-win_amd64.whl", hash = "sha256:fc04dd5189b90d825509caa510f20d1d504761e78b8dfb95a0ede180f71d50e5"}, + {file = "psycopg2-2.9.5-cp37-cp37m-win32.whl", hash = "sha256:922cc5f0b98a5f2b1ff481f5551b95cd04580fd6f0c72d9b22e6c0145a4840e0"}, + {file = "psycopg2-2.9.5-cp37-cp37m-win_amd64.whl", hash = "sha256:1e5a38aa85bd660c53947bd28aeaafb6a97d70423606f1ccb044a03a1203fe4a"}, + {file = "psycopg2-2.9.5-cp38-cp38-win32.whl", hash = "sha256:f5b6320dbc3cf6cfb9f25308286f9f7ab464e65cfb105b64cc9c52831748ced2"}, + {file = "psycopg2-2.9.5-cp38-cp38-win_amd64.whl", hash = "sha256:1a5c7d7d577e0eabfcf15eb87d1e19314c8c4f0e722a301f98e0e3a65e238b4e"}, + {file = "psycopg2-2.9.5-cp39-cp39-win32.whl", hash = "sha256:322fd5fca0b1113677089d4ebd5222c964b1760e361f151cbb2706c4912112c5"}, + {file = "psycopg2-2.9.5-cp39-cp39-win_amd64.whl", hash = "sha256:190d51e8c1b25a47484e52a79638a8182451d6f6dff99f26ad9bd81e5359a0fa"}, + {file = "psycopg2-2.9.5.tar.gz", hash = "sha256:a5246d2e683a972e2187a8714b5c2cf8156c064629f9a9b1a873c1730d9e245a"}, ] pycparser = [ {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, @@ -1010,15 +1010,15 @@ pycryptodomex = [ {file = "pycryptodomex-3.15.0-pp36-pypy36_pp73-win32.whl", hash = "sha256:35a8f7afe1867118330e2e0e0bf759c409e28557fb1fc2fbb1c6c937297dbe9a"}, {file = "pycryptodomex-3.15.0.tar.gz", hash = "sha256:7341f1bb2dadb0d1a0047f34c3a58208a92423cdbd3244d998e4b28df5eac0ed"}, ] -Pygments = [ +pygments = [ {file = "Pygments-2.13.0-py3-none-any.whl", hash = "sha256:f643f331ab57ba3c9d89212ee4a2dabc6e94f117cf4eefde99a0574720d14c42"}, {file = "Pygments-2.13.0.tar.gz", hash = "sha256:56a8508ae95f98e2b9bdf93a6be5ae3f7d8af858b43e02c5a2ff083726be40c1"}, ] -PyJWT = [ - {file = "PyJWT-2.5.0-py3-none-any.whl", hash = "sha256:8d82e7087868e94dd8d7d418e5088ce64f7daab4b36db654cbaedb46f9d1ca80"}, - {file = "PyJWT-2.5.0.tar.gz", hash = "sha256:e77ab89480905d86998442ac5788f35333fa85f65047a534adc38edf3c88fc3b"}, +pyjwt = [ + {file = "PyJWT-2.6.0-py3-none-any.whl", hash = "sha256:d83c3d892a77bbb74d3e1a2cfa90afaadb60945205d1095d9221f04466f64c14"}, + {file = "PyJWT-2.6.0.tar.gz", hash = "sha256:69285c7e31fc44f68a1feb309e948e0df53259d579295e6cfe2b1792329f05fd"}, ] -pyOpenSSL = [ +pyopenssl = [ {file = "pyOpenSSL-22.0.0-py2.py3-none-any.whl", hash = "sha256:ea252b38c87425b64116f808355e8da644ef9b07e429398bfece610f893ee2e0"}, {file = "pyOpenSSL-22.0.0.tar.gz", hash = "sha256:660b1b1425aac4a1bea1d94168a85d99f0b3144c869dd4390d27629d0087f1bf"}, ] @@ -1027,8 +1027,8 @@ python-dateutil = [ {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, ] pytz = [ - {file = "pytz-2022.4-py2.py3-none-any.whl", hash = "sha256:2c0784747071402c6e99f0bafdb7da0fa22645f06554c7ae06bf6358897e9c91"}, - {file = "pytz-2022.4.tar.gz", hash = "sha256:48ce799d83b6f8aab2020e369b627446696619e79645419610b9facd909b3174"}, + {file = "pytz-2022.5-py2.py3-none-any.whl", hash = "sha256:335ab46900b1465e714b4fda4963d87363264eb662aab5e65da039c25f1f5b22"}, + {file = "pytz-2022.5.tar.gz", hash = "sha256:c4d88f472f54d615e9cd582a5004d1e5f624854a6a27a6211591c251f22a6914"}, ] pytz-deprecation-shim = [ {file = "pytz_deprecation_shim-0.1.0.post0-py2.py3-none-any.whl", hash = "sha256:8314c9692a636c8eb3bda879b9f119e350e93223ae83e70e80c31675a0fdc1a6"}, @@ -1047,8 +1047,8 @@ runtype = [ {file = "runtype-0.2.7.tar.gz", hash = "sha256:5a9e1212846b3e54d4ba29fd7db602af5544a2a4253d1f8d829087214a8766ad"}, ] setuptools = [ - {file = "setuptools-65.4.1-py3-none-any.whl", hash = "sha256:1b6bdc6161661409c5f21508763dc63ab20a9ac2f8ba20029aaaa7fdb9118012"}, - {file = "setuptools-65.4.1.tar.gz", hash = "sha256:3050e338e5871e70c72983072fe34f6032ae1cdeeeb67338199c2f74e083a80e"}, + {file = "setuptools-65.5.0-py3-none-any.whl", hash = "sha256:f62ea9da9ed6289bfe868cd6845968a2c854d1427f8548d52cae02a42b4f0356"}, + {file = "setuptools-65.5.0.tar.gz", hash = "sha256:512e5536220e38146176efb833d4a62aa726b7bbff82cfbc8ba9eaa3996e0b17"}, ] six = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, @@ -1089,15 +1089,16 @@ typing-extensions = [ {file = "typing_extensions-4.4.0.tar.gz", hash = "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa"}, ] tzdata = [ - {file = "tzdata-2022.4-py2.py3-none-any.whl", hash = "sha256:74da81ecf2b3887c94e53fc1d466d4362aaf8b26fc87cda18f22004544694583"}, - {file = "tzdata-2022.4.tar.gz", hash = "sha256:ada9133fbd561e6ec3d1674d3fba50251636e918aa97bd59d63735bef5a513bb"}, + {file = "tzdata-2022.6-py2.py3-none-any.whl", hash = "sha256:04a680bdc5b15750c39c12a448885a51134a27ec9af83667663f0b3a1bf3f342"}, + {file = "tzdata-2022.6.tar.gz", hash = "sha256:91f11db4503385928c15598c98573e3af07e7229181bee5375bd30f1695ddcae"}, ] tzlocal = [ {file = "tzlocal-4.2-py3-none-any.whl", hash = "sha256:89885494684c929d9191c57aa27502afc87a579be5cdd3225c77c463ea043745"}, {file = "tzlocal-4.2.tar.gz", hash = "sha256:ee5842fa3a795f023514ac2d801c4a81d1743bbe642e3940143326b3a00addd7"}, ] unittest-parallel = [ - {file = "unittest-parallel-1.5.2.tar.gz", hash = "sha256:42e82215862619ba7ce269db30eb63b878671ebb2ab9bfcead1fede43800b7ef"}, + {file = "unittest-parallel-1.5.3.tar.gz", hash = "sha256:32182bb2230371d651e6fc9795ddf52c134eb36f5064dc339fdbb5984a639517"}, + {file = "unittest_parallel-1.5.3-py3-none-any.whl", hash = "sha256:5670c9eca19450dedb493e9dad2ca4dcbbe12e04477d934ff6c92071d36bace7"}, ] urllib3 = [ {file = "urllib3-1.26.12-py2.py3-none-any.whl", hash = "sha256:b930dd878d5a8afb066a637fbb35144fe7901e3b209d1cd4f524bd0e9deee997"}, @@ -1112,6 +1113,6 @@ wcwidth = [ {file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"}, ] zipp = [ - {file = "zipp-3.9.0-py3-none-any.whl", hash = "sha256:972cfa31bc2fedd3fa838a51e9bc7e64b7fb725a8c00e7431554311f180e9980"}, - {file = "zipp-3.9.0.tar.gz", hash = "sha256:3a7af91c3db40ec72dd9d154ae18e008c69efe8ca88dde4f9a731bb82fe2f9eb"}, + {file = "zipp-3.10.0-py3-none-any.whl", hash = "sha256:4fcb6f278987a6605757302a6e40e896257570d11c51628968ccb2a47e80c6c1"}, + {file = "zipp-3.10.0.tar.gz", hash = "sha256:7a7262fd930bd3e36c50b9a64897aec3fafff3dfdeec9623ae22b40e93f99bb8"}, ] diff --git a/pyproject.toml b/pyproject.toml index 386dba0d..5c90d2d3 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "data-diff" -version = "0.2.8" +version = "0.3.0rc2" description = "Command-line tool and Python library to efficiently diff rows across two different databases." authors = ["Datafold "] license = "MIT" @@ -29,7 +29,7 @@ dsnparse = "*" click = "^8.1" rich = "*" toml = "^0.10.2" -mysql-connector-python = {version="*", optional=true} +mysql-connector-python = {version="8.0.29", optional=true} psycopg2 = {version="*", optional=true} snowflake-connector-python = {version="^2.7.2", optional=true} cryptography = {version="*", optional=true} @@ -38,6 +38,7 @@ presto-python-client = {version="*", optional=true} clickhouse-driver = {version="*", optional=true} [tool.poetry.dev-dependencies] +arrow = "^1.2.3" parameterized = "*" unittest-parallel = "*" preql = "^0.2.19" diff --git a/tests/common.py b/tests/common.py index 2aad4be1..c1ae30a0 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,13 +1,19 @@ -from contextlib import suppress import hashlib import os import string import random +from typing import Callable +import unittest +import logging +import subprocess + +from parameterized import parameterized_class from data_diff import databases as db from data_diff import tracking -import logging -import subprocess +from data_diff import connect +from data_diff.queries.api import table +from data_diff.query_utils import drop_table tracking.disable_tracking() @@ -30,6 +36,7 @@ N_SAMPLES = int(os.environ.get("N_SAMPLES", DEFAULT_N_SAMPLES)) BENCHMARK = os.environ.get("BENCHMARK", False) N_THREADS = int(os.environ.get("N_THREADS", 1)) +TEST_ACROSS_ALL_DBS = os.environ.get("TEST_ACROSS_ALL_DBS", True) # Should we run the full db<->db test suite? def get_git_revision_short_hash() -> str: @@ -43,6 +50,8 @@ def get_git_revision_short_hash() -> str: level = getattr(logging, os.environ["LOG_LEVEL"].upper()) logging.basicConfig(level=level) +logging.getLogger("hashdiff_tables").setLevel(level) +logging.getLogger("joindiff_tables").setLevel(level) logging.getLogger("diff_tables").setLevel(level) logging.getLogger("table_segment").setLevel(level) logging.getLogger("database").setLevel(level) @@ -70,6 +79,14 @@ def get_git_revision_short_hash() -> str: db.Vertica: TEST_VERTICA_CONN_STRING, } +_database_instances = {} + + +def get_conn(cls: type): + if cls not in _database_instances: + _database_instances[cls] = connect(CONN_STRINGS[cls], N_THREADS) + return _database_instances[cls] + def _print_used_dbs(): used = {k.__name__ for k, v in CONN_STRINGS.items() if v is not None} @@ -78,6 +95,10 @@ def _print_used_dbs(): logging.info(f"Testing databases: {', '.join(used)}") if unused: logging.info(f"Connection not configured; skipping tests for: {', '.join(unused)}") + if TEST_ACROSS_ALL_DBS: + logging.info( + f"Full tests enabled (every db<->db). May take very long when many dbs are involved. ={TEST_ACROSS_ALL_DBS}" + ) _print_used_dbs() @@ -104,12 +125,42 @@ def str_to_checksum(str: str): return int(md5[half_pos:], 16) -def _drop_table_if_exists(conn, table): - with suppress(db.QueryError): - if isinstance(conn, db.Oracle): - conn.query(f"DROP TABLE {table}", None) - conn.query(f"DROP TABLE {table}", None) - else: - conn.query(f"DROP TABLE IF EXISTS {table}", None) - if not isinstance(conn, (db.BigQuery, db.Databricks, db.Clickhouse)): - conn.query("COMMIT", None) +class TestPerDatabase(unittest.TestCase): + db_cls = None + + def setUp(self): + assert self.db_cls, self.db_cls + + self.connection = get_conn(self.db_cls) + + table_suffix = random_table_suffix() + self.table_src_name = f"src{table_suffix}" + self.table_dst_name = f"dst{table_suffix}" + + self.table_src_path = self.connection.parse_table_name(self.table_src_name) + self.table_dst_path = self.connection.parse_table_name(self.table_dst_name) + + self.table_src = ".".join(map(self.connection.dialect.quote, self.table_src_path)) + self.table_dst = ".".join(map(self.connection.dialect.quote, self.table_dst_path)) + + drop_table(self.connection, self.table_src_path) + drop_table(self.connection, self.table_dst_path) + + return super().setUp() + + def tearDown(self): + drop_table(self.connection, self.table_src_path) + drop_table(self.connection, self.table_dst_path) + + +def _parameterized_class_per_conn(test_databases): + test_databases = set(test_databases) + names = [(cls.__name__, cls) for cls in CONN_STRINGS if cls in test_databases] + return parameterized_class(("name", "db_cls"), names) + + +def test_each_database_in_list(databases) -> Callable: + def _test_per_database(cls): + return _parameterized_class_per_conn(databases)(cls) + + return _test_per_database diff --git a/tests/test_api.py b/tests/test_api.py index bac88c84..2c67b481 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,49 +1,50 @@ import unittest -import preql import arrow +from datetime import datetime from data_diff import diff_tables, connect_to_table +from data_diff.databases import MySQL +from data_diff.queries.api import table -from .common import TEST_MYSQL_CONN_STRING +from .common import TEST_MYSQL_CONN_STRING, get_conn + + +def _commit(conn): + conn.query("COMMIT", None) class TestApi(unittest.TestCase): def setUp(self) -> None: - self.preql = preql.Preql(TEST_MYSQL_CONN_STRING) - self.preql( - r""" - table test_api { - datetime: datetime - comment: string - } - commit() - - func add(date, comment) { - new test_api(date, comment) - } - """ - ) - self.now = now = arrow.get(self.preql.now()) - self.preql.add(now, "now") - self.preql.add(now, self.now.shift(seconds=-10)) - self.preql.add(now, self.now.shift(seconds=-7)) - self.preql.add(now, self.now.shift(seconds=-6)) - - self.preql( - r""" - const table test_api_2 = test_api - commit() - """ - ) - - self.preql.add(self.now.shift(seconds=-3), "3 seconds ago") - self.preql.commit() + self.conn = get_conn(MySQL) + table_src_name = "test_api" + table_dst_name = "test_api_2" + self.conn.query(f"drop table if exists {table_src_name}") + self.conn.query(f"drop table if exists {table_dst_name}") + + src_table = table(table_src_name, schema={"id": int, "datetime": datetime, "text_comment": str}) + self.conn.query(src_table.create()) + self.now = now = arrow.get() + + rows = [ + (now, "now"), + (self.now.shift(seconds=-10), "a"), + (self.now.shift(seconds=-7), "b"), + (self.now.shift(seconds=-6), "c"), + ] + + self.conn.query(src_table.insert_rows((i, ts.datetime, s) for i, (ts, s) in enumerate(rows))) + _commit(self.conn) + + self.conn.query(f"CREATE TABLE {table_dst_name} AS SELECT * FROM {table_src_name}") + _commit(self.conn) + + self.conn.query(src_table.insert_row(len(rows), self.now.shift(seconds=-3).datetime, "3 seconds ago")) + _commit(self.conn) def tearDown(self) -> None: - self.preql.run_statement("drop table if exists test_api") - self.preql.run_statement("drop table if exists test_api_2") - self.preql.commit() - self.preql.close() + self.conn.query("drop table if exists test_api") + self.conn.query("drop table if exists test_api_2") + _commit(self.conn) return super().tearDown() diff --git a/tests/test_cli.py b/tests/test_cli.py index 4e866680..b63b1c7f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,13 +1,19 @@ import logging import unittest -import preql import arrow import subprocess import sys +from datetime import datetime from data_diff import diff_tables, connect_to_table +from data_diff.databases import MySQL +from data_diff.queries import table -from .common import TEST_MYSQL_CONN_STRING +from .common import TEST_MYSQL_CONN_STRING, get_conn + + +def _commit(conn): + conn.query("COMMIT", None) def run_datadiff_cli(*args): @@ -23,41 +29,39 @@ def run_datadiff_cli(*args): class TestCLI(unittest.TestCase): def setUp(self) -> None: - self.preql = preql.Preql(TEST_MYSQL_CONN_STRING) - self.preql( - r""" - table test_cli { - datetime: datetime - comment: string - } - commit() - - func add(date, comment) { - new test_cli(date, comment) - } - """ - ) - self.now = now = arrow.get(self.preql.now()) - self.preql.add(now, "now") - self.preql.add(now, self.now.shift(seconds=-10)) - self.preql.add(now, self.now.shift(seconds=-7)) - self.preql.add(now, self.now.shift(seconds=-6)) - - self.preql( - r""" - const table test_cli_2 = test_cli - commit() - """ - ) + self.conn = get_conn(MySQL) + self.conn.query("drop table if exists test_cli") + self.conn.query("drop table if exists test_cli_2") + table_src_name = "test_cli" + table_dst_name = "test_cli_2" + + src_table = table(table_src_name, schema={"id": int, "datetime": datetime, "text_comment": str}) + self.conn.query(src_table.create()) + + self.conn.query("SET @@session.time_zone='+00:00'") + db_time = self.conn.query("select now()", datetime) + self.now = now = arrow.get(db_time) + + rows = [ + (now, "now"), + (self.now.shift(seconds=-10), "a"), + (self.now.shift(seconds=-7), "b"), + (self.now.shift(seconds=-6), "c"), + ] + + self.conn.query(src_table.insert_rows((i, ts.datetime, s) for i, (ts, s) in enumerate(rows))) + _commit(self.conn) + + self.conn.query(f"CREATE TABLE {table_dst_name} AS SELECT * FROM {table_src_name}") + _commit(self.conn) - self.preql.add(self.now.shift(seconds=-3), "3 seconds ago") - self.preql.commit() + self.conn.query(src_table.insert_row(len(rows), self.now.shift(seconds=-3).datetime, "3 seconds ago")) + _commit(self.conn) def tearDown(self) -> None: - self.preql.run_statement("drop table if exists test_cli") - self.preql.run_statement("drop table if exists test_cli_2") - self.preql.commit() - self.preql.close() + self.conn.query("drop table if exists test_cli") + self.conn.query("drop table if exists test_cli_2") + _commit(self.conn) return super().tearDown() diff --git a/tests/test_database.py b/tests/test_database.py index a7e34d1d..d309a4ed 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -11,9 +11,9 @@ def setUp(self): def test_connect_to_db(self): self.assertEqual(1, self.mysql.query("SELECT 1", int)) - def test_md5_to_int(self): + def test_md5_as_int(self): str = "hello world" - query_fragment = self.mysql.md5_to_int("'{0}'".format(str)) + query_fragment = self.mysql.dialect.md5_as_int("'{0}'".format(str)) query = f"SELECT {query_fragment}" self.assertEqual(str_to_checksum(str), self.mysql.query(query, int)) diff --git a/tests/test_database_types.py b/tests/test_database_types.py index ce273182..250b4537 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -1,3 +1,4 @@ +from contextlib import suppress import unittest import time import json @@ -14,16 +15,19 @@ from data_diff import databases as db from data_diff.databases import postgresql, oracle +from data_diff.query_utils import drop_table from data_diff.utils import number_to_human, accumulate -from data_diff.diff_tables import TableDiffer, TableSegment, DEFAULT_BISECTION_THRESHOLD +from data_diff.hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD +from data_diff.table_segment import TableSegment from .common import ( CONN_STRINGS, N_SAMPLES, N_THREADS, BENCHMARK, GIT_REVISION, + TEST_ACROSS_ALL_DBS, + get_conn, random_table_suffix, - _drop_table_if_exists, ) CONNS = None @@ -34,8 +38,8 @@ def init_conns(): if CONNS is not None: return - CONNS = {k: db.connect.connect(v, N_THREADS) for k, v in CONN_STRINGS.items()} - CONNS[db.MySQL].query("SET @@session.time_zone='+00:00'", None) + CONNS = {cls: get_conn(cls) for cls in CONN_STRINGS} + CONNS[db.MySQL].query("SET @@session.time_zone='+00:00'") oracle.SESSION_TIME_ZONE = postgresql.SESSION_TIME_ZONE = "UTC" @@ -175,6 +179,7 @@ def init_conns(): "numeric", "real", "double precision", + "Number(5, 2)", ], "uuid": [ "CHAR(100)", @@ -414,25 +419,42 @@ def __iter__(self): "uuid": UUID_Faker(N_SAMPLES), } + +def _get_test_db_pairs(): + if str(TEST_ACROSS_ALL_DBS).lower() == "full": + for source_db in DATABASE_TYPES: + for target_db in DATABASE_TYPES: + yield source_db, target_db + elif int(TEST_ACROSS_ALL_DBS): + for db_cls in DATABASE_TYPES: + yield db_cls, db.PostgreSQL + yield db.PostgreSQL, db_cls + yield db_cls, db.Snowflake + yield db.Snowflake, db_cls + else: + yield db.PostgreSQL, db.PostgreSQL + + +def get_test_db_pairs(): + active_pairs = {(db1, db2) for db1, db2 in _get_test_db_pairs() if db1 in CONN_STRINGS and db2 in CONN_STRINGS} + for db1, db2 in active_pairs: + yield db1, DATABASE_TYPES[db1], db2, DATABASE_TYPES[db2] + + type_pairs = [] -for source_db, source_type_categories in DATABASE_TYPES.items(): - for target_db, target_type_categories in DATABASE_TYPES.items(): - for ( - type_category, - source_types, - ) in source_type_categories.items(): # int, datetime, .. - for source_type in source_types: - for target_type in target_type_categories[type_category]: - if CONN_STRINGS.get(source_db, False) and CONN_STRINGS.get(target_db, False): - type_pairs.append( - ( - source_db, - target_db, - source_type, - target_type, - type_category, - ) - ) +for source_db, source_type_categories, target_db, target_type_categories in get_test_db_pairs(): + for type_category, source_types in source_type_categories.items(): # int, datetime, .. + for source_type in source_types: + for target_type in target_type_categories[type_category]: + type_pairs.append( + ( + source_db, + target_db, + source_type, + target_type, + type_category, + ) + ) def sanitize(name): @@ -466,6 +488,17 @@ def expand_params(testcase_func, param_num, param): return name +def _drop_table_if_exists(conn, tbl): + if isinstance(conn, db.Oracle): + with suppress(db.QueryError): + conn.query(f"DROP TABLE {tbl}", None) + conn.query(f"DROP TABLE {tbl}", None) + else: + conn.query(f"DROP TABLE IF EXISTS {tbl}", None) + if not isinstance(conn, (db.BigQuery, db.Databricks, db.Clickhouse)): + conn.query("COMMIT", None) + + def _insert_to_table(conn, table, values, type): current_n_rows = conn.query(f"SELECT COUNT(*) FROM {table}", int) if current_n_rows == N_SAMPLES: @@ -594,8 +627,8 @@ def setUp(self) -> None: def tearDown(self) -> None: if not BENCHMARK: - _drop_table_if_exists(self.src_conn, self.src_table) - _drop_table_if_exists(self.dst_conn, self.dst_table) + drop_table(self.src_conn, self.src_table_path) + drop_table(self.dst_conn, self.dst_table_path) return super().tearDown() @@ -619,14 +652,14 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego src_table_name = f"src_{self._testMethodName[11:]}{table_suffix}" dst_table_name = f"dst_{self._testMethodName[11:]}{table_suffix}" - src_table_path = src_conn.parse_table_name(src_table_name) - dst_table_path = dst_conn.parse_table_name(dst_table_name) - self.src_table = src_table = ".".join(map(src_conn.quote, src_table_path)) - self.dst_table = dst_table = ".".join(map(dst_conn.quote, dst_table_path)) + self.src_table_path = src_table_path = src_conn.parse_table_name(src_table_name) + self.dst_table_path = dst_table_path = dst_conn.parse_table_name(dst_table_name) + self.src_table = src_table = ".".join(map(src_conn.dialect.quote, src_table_path)) + self.dst_table = dst_table = ".".join(map(dst_conn.dialect.quote, dst_table_path)) start = time.monotonic() if not BENCHMARK: - _drop_table_if_exists(src_conn, src_table) + drop_table(src_conn, src_table_path) _create_table_with_indexes(src_conn, src_table, source_type) _insert_to_table(src_conn, src_table, enumerate(sample_values, 1), source_type) insertion_source_duration = time.monotonic() - start @@ -640,17 +673,17 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego start = time.monotonic() if not BENCHMARK: - _drop_table_if_exists(dst_conn, dst_table) + drop_table(dst_conn, dst_table_path) _create_table_with_indexes(dst_conn, dst_table, target_type) _insert_to_table(dst_conn, dst_table, values_in_source, target_type) insertion_target_duration = time.monotonic() - start if type_category == "uuid": - self.table = TableSegment(self.src_conn, src_table_path, "col", None, ("id",), case_sensitive=False) - self.table2 = TableSegment(self.dst_conn, dst_table_path, "col", None, ("id",), case_sensitive=False) + self.table = TableSegment(self.src_conn, src_table_path, ("col",), None, ("id",), case_sensitive=False) + self.table2 = TableSegment(self.dst_conn, dst_table_path, ("col",), None, ("id",), case_sensitive=False) else: - self.table = TableSegment(self.src_conn, src_table_path, "id", None, ("col",), case_sensitive=False) - self.table2 = TableSegment(self.dst_conn, dst_table_path, "id", None, ("col",), case_sensitive=False) + self.table = TableSegment(self.src_conn, src_table_path, ("id",), None, ("col",), case_sensitive=False) + self.table2 = TableSegment(self.dst_conn, dst_table_path, ("id",), None, ("col",), case_sensitive=False) start = time.monotonic() self.assertEqual(N_SAMPLES, self.table.count()) @@ -667,7 +700,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego ch_factor = min(max(int(N_SAMPLES / 250_000), 2), 128) if BENCHMARK else 2 ch_threshold = min(DEFAULT_BISECTION_THRESHOLD, int(N_SAMPLES / ch_factor)) if BENCHMARK else 3 ch_threads = N_THREADS - differ = TableDiffer( + differ = HashDiffer( bisection_threshold=ch_threshold, bisection_factor=ch_factor, max_threadpool_size=ch_threads, @@ -688,7 +721,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego dl_factor = max(int(N_SAMPLES / 100_000), 2) if BENCHMARK else 2 dl_threshold = int(N_SAMPLES / dl_factor) + 1 if BENCHMARK else math.inf dl_threads = N_THREADS - differ = TableDiffer( + differ = HashDiffer( bisection_threshold=dl_threshold, bisection_factor=dl_factor, max_threadpool_size=dl_threads ) start = time.monotonic() diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index a234833b..a632638a 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -1,87 +1,39 @@ -import datetime -import unittest +from datetime import datetime +from typing import Callable import uuid +import unittest -from parameterized import parameterized_class -import preql import arrow # comes with preql -from data_diff.databases.connect import connect -from data_diff.diff_tables import TableDiffer +from data_diff.queries import table, this, commit + +from data_diff.hashdiff_tables import HashDiffer from data_diff.table_segment import TableSegment, split_space from data_diff import databases as db from data_diff.utils import ArithAlphanumeric, numberToAlphanum -from .common import ( - TEST_MYSQL_CONN_STRING, - str_to_checksum, - random_table_suffix, - _drop_table_if_exists, - CONN_STRINGS, - N_THREADS, -) - -DATABASE_INSTANCES = None -DATABASE_URIS = {k.__name__: v for k, v in CONN_STRINGS.items()} - - -def init_instances(): - global DATABASE_INSTANCES - if DATABASE_INSTANCES is not None: - return - - DATABASE_INSTANCES = {k.__name__: connect(v, N_THREADS) for k, v in CONN_STRINGS.items()} - - -TEST_DATABASES = {x.__name__ for x in (db.MySQL, db.PostgreSQL, db.Oracle, db.Redshift, db.Snowflake, db.BigQuery)} - - -def _class_per_db_dec(filter_name=None): - names = [ - (name, name) - for name in DATABASE_URIS - if (name in TEST_DATABASES) and (filter_name is None or filter_name(name)) - ] - return parameterized_class(("name", "db_name"), names) - - -def test_per_database(cls): - return _class_per_db_dec()(cls) - - -def test_per_database__filter_name(filter_name): - def _test_per_database(cls): - return _class_per_db_dec(filter_name=filter_name)(cls) - - return _test_per_database - - -def _insert_row(conn, table, fields, values): - fields = ", ".join(map(str, fields)) - values = ", ".join(map(str, values)) - conn.query(f"INSERT INTO {table}({fields}) VALUES ({values})", None) - - -def _insert_rows(conn, table, fields, tuple_list): - for t in tuple_list: - _insert_row(conn, table, fields, t) - +from .common import str_to_checksum, test_each_database_in_list, TestPerDatabase -def _commit(conn): - if not isinstance(conn, db.BigQuery): - conn.query("COMMIT", None) +TEST_DATABASES = { + db.MySQL, + db.PostgreSQL, + db.Oracle, + db.Redshift, + db.Snowflake, + db.BigQuery, + db.Presto, + db.Trino, + db.Vertica, +} -def _get_text_type(conn): - if isinstance(conn, db.BigQuery): - return "STRING" - return "varchar(100)" +test_each_database: Callable = test_each_database_in_list(TEST_DATABASES) -def _get_float_type(conn): - if isinstance(conn, db.BigQuery): - return "FLOAT64" - return "float" +def _table_segment(database, table_path, key_columns, *args, **kw): + if isinstance(key_columns, str): + key_columns = (key_columns,) + return TableSegment(database, table_path, key_columns, *args, **kw) class TestUtils(unittest.TestCase): @@ -93,92 +45,45 @@ def test_split_space(self): assert len(r) == n, f"split_space({i}, {j+n}, {n}) = {(r)}" -class TestPerDatabase(unittest.TestCase): - db_name = None - with_preql = False - - preql = None - +@test_each_database +class TestDates(TestPerDatabase): def setUp(self): - assert self.db_name - init_instances() - - self.connection = DATABASE_INSTANCES[self.db_name] - if self.with_preql: - self.preql = preql.Preql(DATABASE_URIS[self.db_name]) - - table_suffix = random_table_suffix() - self.table_src_name = f"src{table_suffix}" - self.table_dst_name = f"dst{table_suffix}" - - self.table_src_path = self.connection.parse_table_name(self.table_src_name) - self.table_dst_path = self.connection.parse_table_name(self.table_dst_name) - - self.table_src = ".".join(map(self.connection.quote, self.table_src_path)) - self.table_dst = ".".join(map(self.connection.quote, self.table_dst_path)) - - _drop_table_if_exists(self.connection, self.table_src) - _drop_table_if_exists(self.connection, self.table_dst) - - return super().setUp() - - def tearDown(self): - if self.preql: - self.preql._interp.state.db.rollback() - self.preql.close() - - _drop_table_if_exists(self.connection, self.table_src) - _drop_table_if_exists(self.connection, self.table_dst) + super().setUp() + src_table = table(self.table_src_path, schema={"id": int, "datetime": datetime, "text_comment": str}) + self.connection.query(src_table.create()) + self.now = now = arrow.get() -@test_per_database -class TestDates(TestPerDatabase): - with_preql = True + rows = [ + (now.shift(days=-50), "50 days ago"), + (now.shift(hours=-3), "3 hours ago"), + (now.shift(minutes=-10), "10 mins ago"), + (now.shift(seconds=-1), "1 second ago"), + (now, "now"), + ] - def setUp(self): - super().setUp() - self.preql( - f""" - table {self.table_src_name} {{ - datetime: timestamp - text_comment: string - }} - commit() - - func add(date, text_comment) {{ - new {self.table_src_name}(date, text_comment) - }} - """ - ) - self.now = now = arrow.get(self.preql.now()) - self.preql.add(now.shift(days=-50), "50 days ago") - self.preql.add(now.shift(hours=-3), "3 hours ago") - self.preql.add(now.shift(minutes=-10), "10 mins ago") - self.preql.add(now.shift(seconds=-1), "1 second ago") - self.preql.add(now, "now") - - self.preql( - f""" - const table {self.table_dst_name} = {self.table_src_name} - commit() - """ + self.connection.query( + [ + src_table.insert_rows((i, ts.datetime, s) for i, (ts, s) in enumerate(rows)), + table(self.table_dst_path).create(src_table), + commit, + src_table.insert_row(len(rows), self.now.shift(seconds=-3).datetime, "3 seconds ago"), + commit, + ] ) - self.preql.add(self.now.shift(seconds=-3), "2 seconds ago") - self.preql.commit() - def test_init(self): - a = TableSegment( + a = _table_segment( self.connection, self.table_src_path, "id", "datetime", max_update=self.now.datetime, case_sensitive=False ) self.assertRaises( - ValueError, TableSegment, self.connection, self.table_src_path, "id", max_update=self.now.datetime + ValueError, _table_segment, self.connection, self.table_src_path, "id", max_update=self.now.datetime ) def test_basic(self): - differ = TableDiffer(10, 100) - a = TableSegment(self.connection, self.table_src_path, "id", "datetime", case_sensitive=False) - b = TableSegment(self.connection, self.table_dst_path, "id", "datetime", case_sensitive=False) + differ = HashDiffer(bisection_factor=10, bisection_threshold=100) + a = _table_segment(self.connection, self.table_src_path, "id", "datetime", case_sensitive=False) + b = _table_segment(self.connection, self.table_dst_path, "id", "datetime", case_sensitive=False) assert a.count() == 6 assert b.count() == 5 @@ -186,25 +91,33 @@ def test_basic(self): self.assertEqual(len(list(differ.diff_tables(a, b))), 1) def test_offset(self): - differ = TableDiffer(2, 10) - sec1 = self.now.shift(seconds=-1).datetime - a = TableSegment(self.connection, self.table_src_path, "id", "datetime", max_update=sec1, case_sensitive=False) - b = TableSegment(self.connection, self.table_dst_path, "id", "datetime", max_update=sec1, case_sensitive=False) - assert a.count() == 4 + differ = HashDiffer(bisection_factor=2, bisection_threshold=10) + sec1 = self.now.shift(seconds=-2).datetime + a = _table_segment( + self.connection, self.table_src_path, "id", "datetime", max_update=sec1, case_sensitive=False + ) + b = _table_segment( + self.connection, self.table_dst_path, "id", "datetime", max_update=sec1, case_sensitive=False + ) + assert a.count() == 4, a.count() assert b.count() == 3 assert not list(differ.diff_tables(a, a)) self.assertEqual(len(list(differ.diff_tables(a, b))), 1) - a = TableSegment(self.connection, self.table_src_path, "id", "datetime", min_update=sec1, case_sensitive=False) - b = TableSegment(self.connection, self.table_dst_path, "id", "datetime", min_update=sec1, case_sensitive=False) + a = _table_segment( + self.connection, self.table_src_path, "id", "datetime", min_update=sec1, case_sensitive=False + ) + b = _table_segment( + self.connection, self.table_dst_path, "id", "datetime", min_update=sec1, case_sensitive=False + ) assert a.count() == 2 assert b.count() == 2 assert not list(differ.diff_tables(a, b)) day1 = self.now.shift(days=-1).datetime - a = TableSegment( + a = _table_segment( self.connection, self.table_src_path, "id", @@ -213,7 +126,7 @@ def test_offset(self): max_update=sec1, case_sensitive=False, ) - b = TableSegment( + b = _table_segment( self.connection, self.table_dst_path, "id", @@ -228,29 +141,26 @@ def test_offset(self): self.assertEqual(len(list(differ.diff_tables(a, b))), 1) -@test_per_database +@test_each_database class TestDiffTables(TestPerDatabase): - with_preql = True - def setUp(self): super().setUp() - float_type = _get_float_type(self.connection) - - self.connection.query( - f"create table {self.table_src}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", - None, + self.src_table = table( + self.table_src_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}, ) - self.connection.query( - f"create table {self.table_dst}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", - None, + self.dst_table = table( + self.table_dst_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}, ) - _commit(self.connection) - self.table = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) - self.table2 = TableSegment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) + self.connection.query([self.src_table.create(), self.dst_table.create(), commit]) + + self.table = _table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) + self.table2 = _table_segment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) - self.differ = TableDiffer(3, 4) + self.differ = HashDiffer(bisection_factor=3, bisection_threshold=4) def test_properties_on_empty_table(self): table = self.table.with_schema() @@ -259,12 +169,12 @@ def test_properties_on_empty_table(self): def test_get_values(self): time = "2022-01-01 00:00:00.000000" - time_str = f"timestamp '{time}'" + time_obj = datetime.fromisoformat(time) cols = "id userid movieid rating timestamp".split() - _insert_row(self.connection, self.table_src, cols, [1, 1, 1, 9, time_str]) - _commit(self.connection) - id_ = self.connection.query(f"select id from {self.table_src}", int) + id_ = self.connection.query( + [self.src_table.insert_row(1, 1, 1, 9, time_obj, columns=cols), commit, self.src_table.select(this.id)], int + ) table = self.table.with_schema() @@ -274,12 +184,17 @@ def test_get_values(self): def test_diff_small_tables(self): time = "2022-01-01 00:00:00" - time_str = f"timestamp '{time}'" + time_obj = datetime.fromisoformat(time) cols = "id userid movieid rating timestamp".split() - _insert_rows(self.connection, self.table_src, cols, [[1, 1, 1, 9, time_str], [2, 2, 2, 9, time_str]]) - _insert_rows(self.connection, self.table_dst, cols, [[1, 1, 1, 9, time_str]]) - _commit(self.connection) + self.connection.query( + [ + self.src_table.insert_rows([[1, 1, 1, 9, time_obj], [2, 2, 2, 9, time_obj]], columns=cols), + self.dst_table.insert_rows([[1, 1, 1, 9, time_obj]], columns=cols), + commit, + ] + ) + diff = list(self.differ.diff_tables(self.table, self.table2)) expected = [("-", ("2", time + ".000000"))] self.assertEqual(expected, diff) @@ -287,47 +202,52 @@ def test_diff_small_tables(self): self.assertEqual(1, self.differ.stats["table2_count"]) def test_non_threaded(self): - differ = TableDiffer(3, 4, threaded=False) + differ = HashDiffer(bisection_factor=3, bisection_threshold=4, threaded=False) time = "2022-01-01 00:00:00" - time_str = f"timestamp '{time}'" + time_obj = datetime.fromisoformat(time) cols = "id userid movieid rating timestamp".split() - _insert_row(self.connection, self.table_src, cols, [1, 1, 1, 9, time_str]) - _insert_rows(self.connection, self.table_dst, cols, [[1, 1, 1, 9, time_str]]) - _commit(self.connection) + self.connection.query( + [ + self.src_table.insert_row(1, 1, 1, 9, time_obj, columns=cols), + self.dst_table.insert_row(1, 1, 1, 9, time_obj, columns=cols), + commit, + ] + ) + diff = list(differ.diff_tables(self.table, self.table2)) self.assertEqual(diff, []) def test_diff_table_above_bisection_threshold(self): time = "2022-01-01 00:00:00" - time_str = f"timestamp '{time}'" + time_obj = datetime.fromisoformat(time) cols = "id userid movieid rating timestamp".split() - _insert_rows( - self.connection, - self.table_src, - cols, - [ - [1, 1, 1, 9, time_str], - [2, 2, 2, 9, time_str], - [3, 3, 3, 9, time_str], - [4, 4, 4, 9, time_str], - [5, 5, 5, 9, time_str], - ], - ) - _insert_rows( - self.connection, - self.table_dst, - cols, + self.connection.query( [ - [1, 1, 1, 9, time_str], - [2, 2, 2, 9, time_str], - [3, 3, 3, 9, time_str], - [4, 4, 4, 9, time_str], - ], + self.src_table.insert_rows( + [ + [1, 1, 1, 9, time_obj], + [2, 2, 2, 9, time_obj], + [3, 3, 3, 9, time_obj], + [4, 4, 4, 9, time_obj], + [5, 5, 5, 9, time_obj], + ], + columns=cols, + ), + self.dst_table.insert_rows( + [ + [1, 1, 1, 9, time_obj], + [2, 2, 2, 9, time_obj], + [3, 3, 3, 9, time_obj], + [4, 4, 4, 9, time_obj], + ], + columns=cols, + ), + commit, + ] ) - _commit(self.connection) diff = list(self.differ.diff_tables(self.table, self.table2)) expected = [("-", ("5", time + ".000000"))] @@ -337,14 +257,18 @@ def test_diff_table_above_bisection_threshold(self): def test_return_empty_array_when_same(self): time = "2022-01-01 00:00:00" - time_str = f"timestamp '{time}'" + time_obj = datetime.fromisoformat(time) cols = "id userid movieid rating timestamp".split() - _insert_row(self.connection, self.table_src, cols, [1, 1, 1, 9, time_str]) - _insert_row(self.connection, self.table_dst, cols, [1, 1, 1, 9, time_str]) + self.connection.query( + [ + self.src_table.insert_row(1, 1, 1, 9, time_obj, columns=cols), + self.dst_table.insert_row(1, 1, 1, 9, time_obj, columns=cols), + commit, + ] + ) - self.preql.commit() diff = list(self.differ.diff_tables(self.table, self.table2)) self.assertEqual([], diff) @@ -352,39 +276,38 @@ def test_diff_sorted_by_key(self): time = "2022-01-01 00:00:00" time2 = "2021-01-01 00:00:00" - time_str = f"timestamp '{time}'" - time_str2 = f"timestamp '{time2}'" + time_obj = datetime.fromisoformat(time) + time_obj2 = datetime.fromisoformat(time2) cols = "id userid movieid rating timestamp".split() - _insert_rows( - self.connection, - self.table_src, - cols, - [ - [1, 1, 1, 9, time_str], - [2, 2, 2, 9, time_str2], - [3, 3, 3, 9, time_str], - [4, 4, 4, 9, time_str2], - [5, 5, 5, 9, time_str], - ], - ) - - _insert_rows( - self.connection, - self.table_dst, - cols, + self.connection.query( [ - [1, 1, 1, 9, time_str], - [2, 2, 2, 9, time_str], - [3, 3, 3, 9, time_str], - [4, 4, 4, 9, time_str], - [5, 5, 5, 9, time_str], - ], + self.src_table.insert_rows( + [ + [1, 1, 1, 9, time_obj], + [2, 2, 2, 9, time_obj2], + [3, 3, 3, 9, time_obj], + [4, 4, 4, 9, time_obj2], + [5, 5, 5, 9, time_obj], + ], + columns=cols, + ), + self.dst_table.insert_rows( + [ + [1, 1, 1, 9, time_obj], + [2, 2, 2, 9, time_obj], + [3, 3, 3, 9, time_obj], + [4, 4, 4, 9, time_obj], + [5, 5, 5, 9, time_obj], + ], + columns=cols, + ), + commit, + ] ) - _commit(self.connection) - differ = TableDiffer() + differ = HashDiffer() diff = list(differ.diff_tables(self.table, self.table2)) expected = [ ("-", ("2", time2 + ".000000")), @@ -395,105 +318,89 @@ def test_diff_sorted_by_key(self): self.assertEqual(expected, diff) -@test_per_database +@test_each_database class TestDiffTables2(TestPerDatabase): def test_diff_column_names(self): - float_type = _get_float_type(self.connection) - self.connection.query( - f"create table {self.table_src}(id int, rating {float_type}, timestamp timestamp)", - None, - ) - self.connection.query( - f"create table {self.table_dst}(id2 int, rating2 {float_type}, timestamp2 timestamp)", - None, - ) - _commit(self.connection) + self.src_table = table(self.table_src_path, schema={"id": int, "rating": float, "timestamp": datetime}) + self.dst_table = table(self.table_dst_path, schema={"id2": int, "rating2": float, "timestamp2": datetime}) + + self.connection.query([self.src_table.create(), self.dst_table.create(), commit]) time = "2022-01-01 00:00:00" time2 = "2021-01-01 00:00:00" - time_str = f"timestamp '{time}'" - time_str2 = f"timestamp '{time2}'" - _insert_rows( - self.connection, - self.table_src, - ["id", "rating", "timestamp"], - [ - [1, 9, time_str], - [2, 9, time_str2], - [3, 9, time_str], - [4, 9, time_str2], - [5, 9, time_str], - ], - ) + time_obj = datetime.fromisoformat(time) + time_obj2 = datetime.fromisoformat(time2) - _insert_rows( - self.connection, - self.table_dst, - ["id2", "rating2", "timestamp2"], + self.connection.query( [ - [1, 9, time_str], - [2, 9, time_str2], - [3, 9, time_str], - [4, 9, time_str2], - [5, 9, time_str], - ], + self.src_table.insert_rows( + [ + [1, 9, time_obj], + [2, 9, time_obj2], + [3, 9, time_obj], + [4, 9, time_obj2], + [5, 9, time_obj], + ], + columns=["id", "rating", "timestamp"], + ), + self.dst_table.insert_rows( + [ + [1, 9, time_obj], + [2, 9, time_obj2], + [3, 9, time_obj], + [4, 9, time_obj2], + [5, 9, time_obj], + ], + columns=["id2", "rating2", "timestamp2"], + ), + ] ) - table1 = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) - table2 = TableSegment(self.connection, self.table_dst_path, "id2", "timestamp2", case_sensitive=False) + table1 = _table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) + table2 = _table_segment(self.connection, self.table_dst_path, "id2", "timestamp2", case_sensitive=False) - differ = TableDiffer() + differ = HashDiffer() diff = list(differ.diff_tables(table1, table2)) assert diff == [] -@test_per_database +@test_each_database class TestUUIDs(TestPerDatabase): def setUp(self): super().setUp() - text_type = _get_text_type(self.connection) - - queries = [ - f"CREATE TABLE {self.table_src}(id {text_type}, text_comment {text_type})", - ] - for i in range(100): - queries.append(f"INSERT INTO {self.table_src} VALUES ('{uuid.uuid1(i)}', '{i}')") - - queries += [ - f"CREATE TABLE {self.table_dst} AS SELECT * FROM {self.table_src}", - ] + self.src_table = src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) self.new_uuid = uuid.uuid1(32132131) - queries.append(f"INSERT INTO {self.table_src} VALUES ('{self.new_uuid}', 'This one is different')") - - # TODO test unexpected values? - for query in queries: - self.connection.query(query, None) - - _commit(self.connection) + self.connection.query( + [ + src_table.create(), + src_table.insert_rows((uuid.uuid1(i), str(i)) for i in range(100)), + table(self.table_dst_path).create(src_table), + src_table.insert_row(self.new_uuid, "This one is different"), + commit, + ] + ) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_string_keys(self): - differ = TableDiffer() + differ = HashDiffer() diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, [("-", (str(self.new_uuid), "This one is different"))]) - self.connection.query( - f"INSERT INTO {self.table_src} VALUES ('unexpected', '<-- this bad value should not break us')", None - ) + self.connection.query(self.src_table.insert_row("unexpected", "<-- this bad value should not break us")) self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) def test_where_sampling(self): a = self.a.replace(where="1=1") - differ = TableDiffer() + differ = HashDiffer() diff = list(differ.diff_tables(a, self.b)) self.assertEqual(diff, [("-", (str(self.new_uuid), "This one is different"))]) @@ -501,90 +408,81 @@ def test_where_sampling(self): self.assertRaises(ValueError, list, differ.diff_tables(a_empty, self.b)) -@test_per_database__filter_name(lambda n: n != "MySQL") +@test_each_database_in_list(TEST_DATABASES - {db.MySQL}) class TestAlphanumericKeys(TestPerDatabase): def setUp(self): super().setUp() - text_type = _get_text_type(self.connection) + self.src_table = src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) + self.new_alphanum = "aBcDeFgHiz" - queries = [ - f"CREATE TABLE {self.table_src}(id {text_type}, text_comment {text_type})", - ] + values = [] for i in range(0, 10000, 1000): a = ArithAlphanumeric(numberToAlphanum(i), max_len=10) if not a and isinstance(self.connection, db.Oracle): # Skip empty string, because Oracle treats it as NULL .. continue - queries.append(f"INSERT INTO {self.table_src} VALUES ('{a}', '{i}')") - queries += [ - f"CREATE TABLE {self.table_dst} AS SELECT * FROM {self.table_src}", - ] - - self.new_alphanum = "aBcDeFgHiJ" - queries.append(f"INSERT INTO {self.table_src} VALUES ('{self.new_alphanum}', 'This one is different')") + values.append((str(a), str(i))) - # TODO test unexpected values? + queries = [ + src_table.create(), + src_table.insert_rows(values), + table(self.table_dst_path).create(src_table), + src_table.insert_row(self.new_alphanum, "This one is different"), + commit, + ] for query in queries: self.connection.query(query, None) - _commit(self.connection) - - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_alphanum_keys(self): - differ = TableDiffer(bisection_factor=2, bisection_threshold=3) + differ = HashDiffer(bisection_factor=2, bisection_threshold=3) diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, [("-", (str(self.new_alphanum), "This one is different"))]) - self.connection.query( - f"INSERT INTO {self.table_src} VALUES ('@@@', '<-- this bad value should not break us')", None - ) - _commit(self.connection) + self.connection.query([self.src_table.insert_row("@@@", "<-- this bad value should not break us"), commit]) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) self.assertRaises(NotImplementedError, list, differ.diff_tables(self.a, self.b)) -@test_per_database__filter_name(lambda n: n != "MySQL") +@test_each_database_in_list(TEST_DATABASES - {db.MySQL}) class TestVaryingAlphanumericKeys(TestPerDatabase): def setUp(self): super().setUp() - text_type = _get_text_type(self.connection) + self.src_table = src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) - queries = [ - f"CREATE TABLE {self.table_src}(id {text_type}, text_comment {text_type})", - ] + values = [] for i in range(0, 10000, 1000): a = ArithAlphanumeric(numberToAlphanum(i * i)) if not a and isinstance(self.connection, db.Oracle): # Skip empty string, because Oracle treats it as NULL .. continue - queries.append(f"INSERT INTO {self.table_src} VALUES ('{a}', '{i}')") - queries += [ - f"CREATE TABLE {self.table_dst} AS SELECT * FROM {self.table_src}", - ] + values.append((str(a), str(i))) self.new_alphanum = "aBcDeFgHiJ" - queries.append(f"INSERT INTO {self.table_src} VALUES ('{self.new_alphanum}', 'This one is different')") - - # TODO test unexpected values? - for query in queries: - self.connection.query(query, None) + queries = [ + src_table.create(), + src_table.insert_rows(values), + table(self.table_dst_path).create(src_table), + src_table.insert_row(self.new_alphanum, "This one is different"), + commit, + ] - _commit(self.connection) + self.connection.query(queries) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_varying_alphanum_keys(self): # Test the class itself @@ -596,113 +494,107 @@ def test_varying_alphanum_keys(self): for a in alphanums: assert a - a == 0 - # Test with the differ - differ = TableDiffer(threaded=False) + differ = HashDiffer() diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, [("-", (str(self.new_alphanum), "This one is different"))]) self.connection.query( - f"INSERT INTO {self.table_src} VALUES ('@@@', '<-- this bad value should not break us')", None + self.src_table.insert_row("@@@", "<-- this bad value should not break us"), + commit, ) - _commit(self.connection) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) self.assertRaises(NotImplementedError, list, differ.diff_tables(self.a, self.b)) -@test_per_database +@test_each_database class TestTableSegment(TestPerDatabase): def setUp(self) -> None: super().setUp() - self.table = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) - self.table2 = TableSegment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) + self.table = _table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) + self.table2 = _table_segment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) def test_table_segment(self): - early = datetime.datetime(2021, 1, 1, 0, 0) - late = datetime.datetime(2022, 1, 1, 0, 0) + early = datetime(2021, 1, 1, 0, 0) + late = datetime(2022, 1, 1, 0, 0) self.assertRaises(ValueError, self.table.replace, min_update=late, max_update=early) self.assertRaises(ValueError, self.table.replace, min_key=10, max_key=0) def test_case_awareness(self): - # create table - self.connection.query(f"create table {self.table_src}(id int, userid int, timestamp timestamp)", None) - _commit(self.connection) + src_table = table(self.table_src_path, schema={"id": int, "userid": int, "timestamp": datetime}) - # insert rows cols = "id userid timestamp".split() time = "2022-01-01 00:00:00.000000" - time_str = f"timestamp '{time}'" - _insert_rows(self.connection, self.table_src, cols, [[1, 9, time_str], [2, 2, time_str]]) - _commit(self.connection) + time_obj = datetime.fromisoformat(time) - res = tuple(self.table.replace(key_column="Id", case_sensitive=False).with_schema().query_key_range()) + self.connection.query( + [src_table.create(), src_table.insert_rows([[1, 9, time_obj], [2, 2, time_obj]], columns=cols), commit] + ) + + res = tuple(self.table.replace(key_columns=("Id",), case_sensitive=False).with_schema().query_key_range()) assert res == ("1", "2") self.assertRaises( - KeyError, self.table.replace(key_column="Id", case_sensitive=True).with_schema().query_key_range + KeyError, self.table.replace(key_columns=("Id",), case_sensitive=True).with_schema().query_key_range ) -@test_per_database +@test_each_database class TestTableUUID(TestPerDatabase): def setUp(self): super().setUp() - text_type = _get_text_type(self.connection) + src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) - queries = [ - f"CREATE TABLE {self.table_src}(id {text_type}, text_comment {text_type})", - ] + values = [] for i in range(10): uuid_value = uuid.uuid1(i) - queries.append(f"INSERT INTO {self.table_src} VALUES ('{uuid_value}', '{uuid_value}')") + values.append((uuid_value, uuid_value)) self.null_uuid = uuid.uuid1(32132131) - queries += [ - f"CREATE TABLE {self.table_dst} AS SELECT * FROM {self.table_src}", - f"INSERT INTO {self.table_src} VALUES ('{self.null_uuid}', NULL)", - ] - - for query in queries: - self.connection.query(query, None) - _commit(self.connection) + self.connection.query( + [ + src_table.create(), + src_table.insert_rows(values), + table(self.table_dst_path).create(src_table), + src_table.insert_row(self.null_uuid, None), + commit, + ] + ) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_uuid_column_with_nulls(self): - differ = TableDiffer() + differ = HashDiffer() diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, [("-", (str(self.null_uuid), None))]) -@test_per_database +@test_each_database class TestTableNullRowChecksum(TestPerDatabase): def setUp(self): super().setUp() - text_type = _get_text_type(self.connection) + src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) self.null_uuid = uuid.uuid1(1) - queries = [ - f"CREATE TABLE {self.table_src}(id {text_type}, text_comment {text_type})", - f"INSERT INTO {self.table_src} VALUES ('{uuid.uuid1(1)}', '1')", - f"CREATE TABLE {self.table_dst} AS SELECT * FROM {self.table_src}", - # Add a row where a column has NULL value - f"INSERT INTO {self.table_src} VALUES ('{self.null_uuid}', NULL)", - ] - - for query in queries: - self.connection.query(query, None) - - _commit(self.connection) + self.connection.query( + [ + src_table.create(), + src_table.insert_row(uuid.uuid1(1), "1"), + table(self.table_dst_path).create(src_table), + src_table.insert_row(self.null_uuid, None), # Add a row where a column has NULL value + commit, + ] + ) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_uuid_columns_with_nulls(self): """ @@ -725,44 +617,48 @@ def test_uuid_columns_with_nulls(self): diff results, but it's not. This test helps to detect such cases. """ - differ = TableDiffer(bisection_factor=2, bisection_threshold=3) + differ = HashDiffer(bisection_factor=2, bisection_threshold=3) diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, [("-", (str(self.null_uuid), None))]) -@test_per_database +@test_each_database class TestConcatMultipleColumnWithNulls(TestPerDatabase): def setUp(self): super().setUp() - text_type = _get_text_type(self.connection) + src_table = table(self.table_src_path, schema={"id": str, "c1": str, "c2": str}) + dst_table = table(self.table_dst_path, schema={"id": str, "c1": str, "c2": str}) - queries = [ - f"CREATE TABLE {self.table_src}(id {text_type}, c1 {text_type}, c2 {text_type})", - f"CREATE TABLE {self.table_dst}(id {text_type}, c1 {text_type}, c2 {text_type})", - ] + src_values = [] + dst_values = [] self.diffs = [] for i in range(0, 8): pk = uuid.uuid1(i) - table_src_c1_val = str(i) - table_dst_c1_val = str(i) + "-different" + src_row = (str(pk), str(i), None) + dst_row = (str(pk), str(i) + "-different", None) - queries.append(f"INSERT INTO {self.table_src} VALUES ('{pk}', '{table_src_c1_val}', NULL)") - queries.append(f"INSERT INTO {self.table_dst} VALUES ('{pk}', '{table_dst_c1_val}', NULL)") + src_values.append(src_row) + dst_values.append(dst_row) - self.diffs.append(("-", (str(pk), table_src_c1_val, None))) - self.diffs.append(("+", (str(pk), table_dst_c1_val, None))) + self.diffs.append(("-", src_row)) + self.diffs.append(("+", dst_row)) - for query in queries: - self.connection.query(query, None) - - _commit(self.connection) + self.connection.query( + [ + src_table.create(), + dst_table.create(), + src_table.insert_rows(src_values), + dst_table.insert_rows(dst_values), + commit, + ] + ) - self.a = TableSegment( + self.a = _table_segment( self.connection, self.table_src_path, "id", extra_columns=("c1", "c2"), case_sensitive=False ) - self.b = TableSegment( + self.b = _table_segment( self.connection, self.table_dst_path, "id", extra_columns=("c1", "c2"), case_sensitive=False ) @@ -789,49 +685,38 @@ def test_tables_are_different(self): value, it may lead that concat(pk_i, i, NULL) == concat(pk_i, i-diff, NULL). This test handle such cases. """ - differ = TableDiffer(bisection_factor=2, bisection_threshold=4) + differ = HashDiffer(bisection_factor=2, bisection_threshold=4) diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, self.diffs) -@test_per_database +@test_each_database class TestTableTableEmpty(TestPerDatabase): def setUp(self): super().setUp() - text_type = _get_text_type(self.connection) + self.src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) + self.dst_table = table(self.table_dst_path, schema={"id": str, "text_comment": str}) self.null_uuid = uuid.uuid1(1) - queries = [ - f"CREATE TABLE {self.table_src}(id {text_type}, text_comment {text_type})", - f"CREATE TABLE {self.table_dst}(id {text_type}, text_comment {text_type})", - ] - - self.diffs = [(uuid.uuid1(i), i) for i in range(100)] - for pk, value in self.diffs: - queries.append(f"INSERT INTO {self.table_src} VALUES ('{pk}', '{value}')") - - for query in queries: - self.connection.query(query, None) - _commit(self.connection) + self.diffs = [(uuid.uuid1(i), str(i)) for i in range(100)] - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_right_table_empty(self): - differ = TableDiffer() + self.connection.query( + [self.src_table.create(), self.dst_table.create(), self.src_table.insert_rows(self.diffs), commit] + ) + + differ = HashDiffer() self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) def test_left_table_empty(self): - queries = [ - f"INSERT INTO {self.table_dst} SELECT id, text_comment FROM {self.table_src}", - f"TRUNCATE TABLE {self.table_src}", - ] - for query in queries: - self.connection.query(query, None) - - _commit(self.connection) + self.connection.query( + [self.src_table.create(), self.dst_table.create(), self.dst_table.insert_rows(self.diffs), commit] + ) - differ = TableDiffer() + differ = HashDiffer() self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py new file mode 100644 index 00000000..3f70fa09 --- /dev/null +++ b/tests/test_joindiff.py @@ -0,0 +1,321 @@ +from typing import List +from datetime import datetime + +from data_diff.queries.ast_classes import TablePath +from data_diff.queries import table, commit +from data_diff.table_segment import TableSegment +from data_diff import databases as db +from data_diff.joindiff_tables import JoinDiffer + +from .test_diff_tables import TestPerDatabase + +from .common import ( + random_table_suffix, + test_each_database_in_list, +) + + +TEST_DATABASES = { + db.PostgreSQL, + db.MySQL, + db.Snowflake, + db.BigQuery, + db.Oracle, + db.Redshift, + db.Presto, + db.Trino, + db.Vertica, +} + +test_each_database = test_each_database_in_list(TEST_DATABASES) + + +@test_each_database_in_list({db.Snowflake, db.BigQuery}) +class TestCompositeKey(TestPerDatabase): + def setUp(self): + super().setUp() + + self.src_table = table( + self.table_src_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}, + ) + self.dst_table = table( + self.table_dst_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}, + ) + + self.connection.query([self.src_table.create(), self.dst_table.create(), commit]) + + self.differ = JoinDiffer() + + def test_composite_key(self): + time = "2022-01-01 00:00:00" + time_obj = datetime.fromisoformat(time) + + cols = "id userid movieid rating timestamp".split() + + self.connection.query( + [ + self.src_table.insert_rows([[1, 1, 1, 9, time_obj], [2, 2, 2, 9, time_obj]], columns=cols), + self.dst_table.insert_rows([[1, 1, 1, 9, time_obj], [2, 3, 2, 9, time_obj]], columns=cols), + commit, + ] + ) + + # Sanity + table1 = TableSegment( + self.connection, self.table_src_path, ("id",), "timestamp", ("userid",), case_sensitive=False + ) + table2 = TableSegment( + self.connection, self.table_dst_path, ("id",), "timestamp", ("userid",), case_sensitive=False + ) + diff = list(self.differ.diff_tables(table1, table2)) + assert len(diff) == 2 + assert self.differ.stats["exclusive_count"] == 0 + + # Test pks diffed, by checking exclusive_count + table1 = TableSegment(self.connection, self.table_src_path, ("id", "userid"), "timestamp", case_sensitive=False) + table2 = TableSegment(self.connection, self.table_dst_path, ("id", "userid"), "timestamp", case_sensitive=False) + diff = list(self.differ.diff_tables(table1, table2)) + assert len(diff) == 2 + assert self.differ.stats["exclusive_count"] == 2 + + +@test_each_database +class TestJoindiff(TestPerDatabase): + def setUp(self): + super().setUp() + + self.src_table = table( + self.table_src_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}, + ) + self.dst_table = table( + self.table_dst_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}, + ) + + self.connection.query([self.src_table.create(), self.dst_table.create(), commit]) + + self.table = TableSegment(self.connection, self.table_src_path, ("id",), "timestamp", case_sensitive=False) + self.table2 = TableSegment(self.connection, self.table_dst_path, ("id",), "timestamp", case_sensitive=False) + + self.differ = JoinDiffer() + + def test_diff_small_tables(self): + time = "2022-01-01 00:00:00" + time_obj = datetime.fromisoformat(time) + + cols = "id userid movieid rating timestamp".split() + + self.connection.query( + [ + self.src_table.insert_rows([[1, 1, 1, 9, time_obj], [2, 2, 2, 9, time_obj]], columns=cols), + self.dst_table.insert_rows([[1, 1, 1, 9, time_obj]], columns=cols), + commit, + ] + ) + + diff = list(self.differ.diff_tables(self.table, self.table2)) + expected_row = ("2", time + ".000000") + expected = [("-", expected_row)] + self.assertEqual(expected, diff) + self.assertEqual(2, self.differ.stats["table1_count"]) + self.assertEqual(1, self.differ.stats["table2_count"]) + self.assertEqual(3, self.differ.stats["table1_sum_id"]) + self.assertEqual(1, self.differ.stats["table2_sum_id"]) + + # Test materialize + materialize_path = self.connection.parse_table_name(f"test_mat_{random_table_suffix()}") + mdiffer = self.differ.replace(materialize_to_table=materialize_path) + diff = list(mdiffer.diff_tables(self.table, self.table2)) + self.assertEqual(expected, diff) + + t = TablePath(materialize_path) + rows = self.connection.query(t.select(), List[tuple]) + # is_xa, is_xb, is_diff1, is_diff2, row1, row2 + assert rows == [(1, 0, 1, 1) + expected_row + (None, None)], rows + self.connection.query(t.drop()) + + # Test materialize all rows + mdiffer = mdiffer.replace(materialize_all_rows=True) + diff = list(mdiffer.diff_tables(self.table, self.table2)) + self.assertEqual(expected, diff) + rows = self.connection.query(t.select(), List[tuple]) + assert len(rows) == 2, len(rows) + self.connection.query(t.drop()) + + def test_diff_table_above_bisection_threshold(self): + time = "2022-01-01 00:00:00" + time_obj = datetime.fromisoformat(time) + + cols = "id userid movieid rating timestamp".split() + + self.connection.query( + [ + self.src_table.insert_rows( + [ + [1, 1, 1, 9, time_obj], + [2, 2, 2, 9, time_obj], + [3, 3, 3, 9, time_obj], + [4, 4, 4, 9, time_obj], + [5, 5, 5, 9, time_obj], + ], + columns=cols, + ), + self.dst_table.insert_rows( + [ + [1, 1, 1, 9, time_obj], + [2, 2, 2, 9, time_obj], + [3, 3, 3, 9, time_obj], + [4, 4, 4, 9, time_obj], + ], + columns=cols, + ), + commit, + ] + ) + + diff = list(self.differ.diff_tables(self.table, self.table2)) + expected = [("-", ("5", time + ".000000"))] + self.assertEqual(expected, diff) + self.assertEqual(5, self.differ.stats["table1_count"]) + self.assertEqual(4, self.differ.stats["table2_count"]) + + def test_return_empty_array_when_same(self): + time = "2022-01-01 00:00:00" + time_obj = datetime.fromisoformat(time) + + cols = "id userid movieid rating timestamp".split() + + self.connection.query( + [ + self.src_table.insert_row(1, 1, 1, 9, time_obj, columns=cols), + self.dst_table.insert_row(1, 1, 1, 9, time_obj, columns=cols), + ] + ) + + diff = list(self.differ.diff_tables(self.table, self.table2)) + self.assertEqual([], diff) + + def test_diff_sorted_by_key(self): + time = "2022-01-01 00:00:00" + time2 = "2021-01-01 00:00:00" + + time_obj = datetime.fromisoformat(time) + time_obj2 = datetime.fromisoformat(time2) + + cols = "id userid movieid rating timestamp".split() + + self.connection.query( + [ + self.src_table.insert_rows( + [ + [1, 1, 1, 9, time_obj], + [2, 2, 2, 9, time_obj2], + [3, 3, 3, 9, time_obj], + [4, 4, 4, 9, time_obj2], + [5, 5, 5, 9, time_obj], + ], + columns=cols, + ), + self.dst_table.insert_rows( + [ + [1, 1, 1, 9, time_obj], + [2, 2, 2, 9, time_obj], + [3, 3, 3, 9, time_obj], + [4, 4, 4, 9, time_obj], + [5, 5, 5, 9, time_obj], + ], + columns=cols, + ), + commit, + ] + ) + + diff = list(self.differ.diff_tables(self.table, self.table2)) + expected = [ + ("-", ("2", time2 + ".000000")), + ("+", ("2", time + ".000000")), + ("-", ("4", time2 + ".000000")), + ("+", ("4", time + ".000000")), + ] + self.assertEqual(expected, diff) + + def test_dup_pks(self): + time = "2022-01-01 00:00:00" + time_obj = datetime.fromisoformat(time) + + cols = "id rating timestamp".split() + + self.connection.query( + [ + self.src_table.insert_rows([[1, 9, time_obj], [1, 10, time_obj]], columns=cols), + self.dst_table.insert_row(1, 9, time_obj, columns=cols), + ] + ) + + x = self.differ.diff_tables(self.table, self.table2) + self.assertRaises(ValueError, list, x) + + def test_null_pks(self): + time = "2022-01-01 00:00:00" + time_obj = datetime.fromisoformat(time) + + cols = "id rating timestamp".split() + + self.connection.query( + [ + self.src_table.insert_row(None, 9, time_obj, columns=cols), + self.dst_table.insert_row(1, 9, time_obj, columns=cols), + ] + ) + + x = self.differ.diff_tables(self.table, self.table2) + self.assertRaises(ValueError, list, x) + + +@test_each_database_in_list(d for d in TEST_DATABASES if d.dialect.SUPPORTS_PRIMARY_KEY and d.SUPPORTS_UNIQUE_CONSTAINT) +class TestUniqueConstraint(TestPerDatabase): + def setUp(self): + super().setUp() + + self.src_table = table( + self.table_src_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float}, + ) + self.dst_table = table( + self.table_dst_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float}, + ) + + self.connection.query( + [self.src_table.create(primary_keys=["id"]), self.dst_table.create(primary_keys=["id", "userid"]), commit] + ) + + self.differ = JoinDiffer() + + def test_unique_constraint(self): + self.connection.query( + [ + self.src_table.insert_rows([[1, 1, 1, 9], [2, 2, 2, 9]]), + self.dst_table.insert_rows([[1, 1, 1, 9], [2, 2, 2, 9]]), + commit, + ] + ) + + # Test no active validation + table = TableSegment(self.connection, self.table_src_path, ("id",), case_sensitive=False) + table2 = TableSegment(self.connection, self.table_dst_path, ("id",), case_sensitive=False) + + res = list(self.differ.diff_tables(table, table2)) + assert not res + assert "validated_unique_keys" not in self.differ.stats + + # Test active validation + table = TableSegment(self.connection, self.table_src_path, ("userid",), case_sensitive=False) + table2 = TableSegment(self.connection, self.table_dst_path, ("userid",), case_sensitive=False) + + res = list(self.differ.diff_tables(table, table2)) + assert not res + self.assertEqual(self.differ.stats["validated_unique_keys"], [["userid"]]) diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 140421a4..0c57d299 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -1,13 +1,13 @@ import unittest -from data_diff.databases.connect import connect -from data_diff import TableSegment, TableDiffer -from .common import TEST_POSTGRESQL_CONN_STRING, TEST_MYSQL_CONN_STRING, random_table_suffix +from data_diff import TableSegment, HashDiffer +from data_diff import databases as db +from .common import get_conn, random_table_suffix class TestUUID(unittest.TestCase): def setUp(self) -> None: - self.connection = connect(TEST_POSTGRESQL_CONN_STRING) + self.connection = get_conn(db.PostgreSQL) table_suffix = random_table_suffix() @@ -38,16 +38,16 @@ def test_uuid(self): for query in queries: self.connection.query(query, None) - a = TableSegment(self.connection, (self.table_src,), "id", "comment") - b = TableSegment(self.connection, (self.table_dst,), "id", "comment") + a = TableSegment(self.connection, (self.table_src,), ("id",), "comment") + b = TableSegment(self.connection, (self.table_dst,), ("id",), "comment") - differ = TableDiffer() + differ = HashDiffer() diff = list(differ.diff_tables(a, b)) uuid = diff[0][1][0] self.assertEqual(diff, [("-", (uuid, "This one is different"))]) # Compare with MySql - mysql_conn = connect(TEST_MYSQL_CONN_STRING) + mysql_conn = get_conn(db.MySQL) rows = self.connection.query(f"SELECT * FROM {self.table_src}", list) @@ -57,7 +57,7 @@ def test_uuid(self): mysql_conn.query(f"INSERT INTO {self.table_dst}(id, comment) VALUES ('{uuid}', '{comment}')", None) mysql_conn.query(f"COMMIT", None) - c = TableSegment(mysql_conn, (self.table_dst,), "id", "comment") + c = TableSegment(mysql_conn, (self.table_dst,), ("id",), "comment") diff = list(differ.diff_tables(a, c)) assert not diff, diff diff = list(differ.diff_tables(c, a)) diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 00000000..36792d23 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,177 @@ +from datetime import datetime +from typing import List, Optional +import unittest +from data_diff.databases.database_types import AbstractDatabase, AbstractDialect, CaseInsensitiveDict, CaseSensitiveDict + +from data_diff.queries import this, table, Compiler, outerjoin, cte +from data_diff.queries.ast_classes import Random + + +def normalize_spaces(s: str): + return " ".join(s.split()) + + +class MockDialect(AbstractDialect): + name = "MockDialect" + + ROUNDS_ON_PREC_LOSS = False + + def quote(self, s: str) -> str: + return s + + def concat(self, l: List[str]) -> str: + s = ", ".join(l) + return f"concat({s})" + + def to_string(self, s: str) -> str: + return f"cast({s} as varchar)" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"{a} is distinct from {b}" + + def random(self) -> str: + return "random()" + + def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): + x = offset and f"offset {offset}", limit and f"limit {limit}" + return " ".join(filter(None, x)) + + def explain_as_text(self, query: str) -> str: + return f"explain {query}" + + def timestamp_value(self, t: datetime) -> str: + return f"timestamp '{t}'" + + parse_type = NotImplemented + + +class MockDatabase(AbstractDatabase): + dialect = MockDialect() + + _query = NotImplemented + query_table_schema = NotImplemented + select_table_schema = NotImplemented + _process_table_schema = NotImplemented + parse_table_name = NotImplemented + close = NotImplemented + _normalize_table_path = NotImplemented + is_autocommit = NotImplemented + + +class TestQuery(unittest.TestCase): + def setUp(self): + pass + + def test_basic(self): + c = Compiler(MockDatabase()) + + t = table("point") + t2 = t.select(x=this.x + 1, y=t["y"] + this.x) + assert c.compile(t2) == "SELECT (x + 1) AS x, (y + x) AS y FROM point" + + t = table("point").where(this.x == 1, this.y == 2) + assert c.compile(t) == "SELECT * FROM point WHERE (x = 1) AND (y = 2)" + + t = table("point").select("x", "y") + assert c.compile(t) == "SELECT x, y FROM point" + + def test_outerjoin(self): + c = Compiler(MockDatabase()) + + a = table("a") + b = table("b") + keys = ["x", "y"] + cols = ["u", "v"] + + j = outerjoin(a, b).on(a[k] == b[k] for k in keys) + + self.assertEqual( + c.compile(j), "SELECT * FROM a tmp1 FULL OUTER JOIN b tmp2 ON (tmp1.x = tmp2.x) AND (tmp1.y = tmp2.y)" + ) + + # diffed = j.select("*", **{f"is_diff_col_{c}": a[c].is_distinct_from(b[c]) for c in cols}) + + # t = diffed.select( + # **{f"total_diff_col_{c}": diffed[f"is_diff_col_{c}"].sum() for c in cols}, + # total_diff=or_(diffed[f"is_diff_col_{c}"] for c in cols).sum(), + # ) + + # print(c.compile(t)) + + # t.group_by(keys=[this.x], values=[this.py]) + + def test_schema(self): + c = Compiler(MockDatabase()) + schema = dict(id="int", comment="varchar") + + # test table + t = table("a", schema=CaseInsensitiveDict(schema)) + q = t.select(this.Id, t["COMMENT"]) + assert c.compile(q) == "SELECT id, comment FROM a" + + t = table("a", schema=CaseSensitiveDict(schema)) + self.assertRaises(KeyError, t.__getitem__, "Id") + self.assertRaises(KeyError, t.select, this.Id) + + # test select + q = t.select(this.id) + self.assertRaises(KeyError, q.__getitem__, "comment") + + # test join + s = CaseInsensitiveDict({"x": int, "y": int}) + a = table("a", schema=s) + b = table("b", schema=s) + keys = ["x", "y"] + j = outerjoin(a, b).on(a[k] == b[k] for k in keys).select(a["x"], b["y"], xsum=a["x"] + b["x"]) + j["x"], j["y"], j["xsum"] + self.assertRaises(KeyError, j.__getitem__, "ysum") + + def test_commutable_select(self): + # c = Compiler(MockDatabase()) + + t = table("a") + q1 = t.select("a").where("b") + q2 = t.where("b").select("a") + assert q1 == q2, (q1, q2) + + def test_cte(self): + c = Compiler(MockDatabase()) + + t = table("a") + + # single cte + t2 = cte(t.select(this.x)) + t3 = t2.select(this.x) + + expected = "WITH tmp1 AS (SELECT x FROM a) SELECT x FROM tmp1" + assert normalize_spaces(c.compile(t3)) == expected + + # nested cte + c = Compiler(MockDatabase()) + t4 = cte(t3).select(this.x) + + expected = "WITH tmp1 AS (SELECT x FROM a), tmp2 AS (SELECT x FROM tmp1) SELECT x FROM tmp2" + assert normalize_spaces(c.compile(t4)) == expected + + # parameterized cte + c = Compiler(MockDatabase()) + t2 = cte(t.select(this.x), params=["y"]) + t3 = t2.select(this.y) + + expected = "WITH tmp1(y) AS (SELECT x FROM a) SELECT y FROM tmp1" + assert normalize_spaces(c.compile(t3)) == expected + + def test_funcs(self): + c = Compiler(MockDatabase()) + t = table("a") + + q = c.compile(t.order_by(Random()).limit(10)) + assert q == "SELECT * FROM a ORDER BY random() limit 10" + + def test_union(self): + c = Compiler(MockDatabase()) + a = table("a").select("x") + b = table("b").select("y") + + q = c.compile(a.union(b)) + assert q == "SELECT x FROM a UNION SELECT y FROM b" diff --git a/tests/test_sql.py b/tests/test_sql.py index bc4828c0..0e1e8d13 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -1,10 +1,10 @@ import unittest from data_diff.databases import connect_to_uri -from data_diff.sql import Checksum, Compare, Compiler, Count, Enum, Explain, In, Select, TableName - from .common import TEST_MYSQL_CONN_STRING +from data_diff.queries import Compiler, Count, Explain, Select, table, In, BinOp + class TestSQL(unittest.TestCase): def setUp(self): @@ -18,7 +18,7 @@ def test_compile_int(self): self.assertEqual("1", self.compiler.compile(1)) def test_compile_table_name(self): - self.assertEqual("`marine_mammals`.`walrus`", self.compiler.compile(TableName(("marine_mammals", "walrus")))) + self.assertEqual("`marine_mammals`.`walrus`", self.compiler.compile(table("marine_mammals", "walrus"))) def test_compile_select(self): expected_sql = "SELECT name FROM `marine_mammals`.`walrus`" @@ -26,23 +26,23 @@ def test_compile_select(self): expected_sql, self.compiler.compile( Select( + table("marine_mammals", "walrus"), ["name"], - TableName(("marine_mammals", "walrus")), ) ), ) - def test_enum(self): - expected_sql = "(SELECT *, (row_number() over (ORDER BY id)) as idx FROM `walrus` ORDER BY id) tmp" - self.assertEqual( - expected_sql, - self.compiler.compile( - Enum( - ("walrus",), - "id", - ) - ), - ) + # def test_enum(self): + # expected_sql = "(SELECT *, (row_number() over (ORDER BY id)) as idx FROM `walrus` ORDER BY id) tmp" + # self.assertEqual( + # expected_sql, + # self.compiler.compile( + # Enum( + # ("walrus",), + # "id", + # ) + # ), + # ) # def test_checksum(self): # expected_sql = "SELECT name, sum(cast(conv(substring(md5(concat(cast(id as char), cast(timestamp as char))), 18), 16, 10) as unsigned)) FROM `marine_mammals`.`walrus`" @@ -62,9 +62,9 @@ def test_compare(self): expected_sql, self.compiler.compile( Select( + table("marine_mammals", "walrus"), ["name"], - TableName(("marine_mammals", "walrus")), - [Compare("<=", "id", "1000"), Compare(">", "id", "1")], + [BinOp("<=", ["id", "1000"]), BinOp(">", ["id", "1"])], ) ), ) @@ -73,30 +73,28 @@ def test_in(self): expected_sql = "SELECT name FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" self.assertEqual( expected_sql, - self.compiler.compile(Select(["name"], TableName(("marine_mammals", "walrus")), [In("id", [1, 2, 3])])), + self.compiler.compile(Select(table("marine_mammals", "walrus"), ["name"], [In("id", [1, 2, 3])])), ) def test_count(self): expected_sql = "SELECT count(*) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" self.assertEqual( expected_sql, - self.compiler.compile(Select([Count()], TableName(("marine_mammals", "walrus")), [In("id", [1, 2, 3])])), + self.compiler.compile(Select(table("marine_mammals", "walrus"), [Count()], [In("id", [1, 2, 3])])), ) def test_count_with_column(self): expected_sql = "SELECT count(id) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" self.assertEqual( expected_sql, - self.compiler.compile( - Select([Count("id")], TableName(("marine_mammals", "walrus")), [In("id", [1, 2, 3])]) - ), + self.compiler.compile(Select(table("marine_mammals", "walrus"), [Count("id")], [In("id", [1, 2, 3])])), ) def test_explain(self): - expected_sql = "EXPLAIN SELECT count(id) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" + expected_sql = "EXPLAIN FORMAT=TREE SELECT count(id) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" self.assertEqual( expected_sql, self.compiler.compile( - Explain(Select([Count("id")], TableName(("marine_mammals", "walrus")), [In("id", [1, 2, 3])])) + Explain(Select(table("marine_mammals", "walrus"), [Count("id")], [In("id", [1, 2, 3])])) ), ) diff --git a/tests/waiting_for_stack_up.sh b/tests/waiting_for_stack_up.sh index 02ca9cf0..762138de 100644 --- a/tests/waiting_for_stack_up.sh +++ b/tests/waiting_for_stack_up.sh @@ -5,7 +5,7 @@ if [ -n "$DATADIFF_VERTICA_URI" ] echo "Check Vertica DB running..." while true do - if docker logs vertica | tail -n 100 | grep -q -i "vertica is now running" + if docker logs dd-vertica | tail -n 100 | grep -q -i "vertica is now running" then echo "Vertica DB is ready"; break;