diff --git a/.circleci/config.yml b/.circleci/config.yml
deleted file mode 100644
index 7a12d3c07d..0000000000
--- a/.circleci/config.yml
+++ /dev/null
@@ -1,115 +0,0 @@
-version: 2.1
-
-setup: true
-
-on_main_or_tag_filter: &on_main_or_tag_filter
- filters:
- branches:
- only: main
- tags:
- only: /^v\d+\.\d+\.\d+/
-
-on_tag_filter: &on_tag_filter
- filters:
- branches:
- ignore: /.*/
- tags:
- only: /^v\d+\.\d+\.\d+/
-
-orbs:
- path-filtering: circleci/path-filtering@1.2.0
-
-jobs:
- publish:
- docker:
- - image: cimg/python:3.10
- resource_class: small
- steps:
- - checkout
- - attach_workspace:
- at: web/client
- - run:
- name: Publish Python package
- command: make publish
- - run:
- name: Update pypirc
- command: ./.circleci/update-pypirc.sh
- - run:
- name: Publish Python Tests package
- command: unset TWINE_USERNAME TWINE_PASSWORD && make publish-tests
- gh-release:
- docker:
- - image: cimg/node:20.19.0
- resource_class: small
- steps:
- - run:
- name: Create release on GitHub
- command: |
- GITHUB_TOKEN="$GITHUB_TOKEN" \
- TARGET_TAG="$CIRCLE_TAG" \
- REPO_OWNER="$CIRCLE_PROJECT_USERNAME" \
- REPO_NAME="$CIRCLE_PROJECT_REPONAME" \
- CONTINUE_ON_ERROR="false" \
- npx https://github.com/TobikoData/circleci-gh-conventional-release
-
- ui-build:
- docker:
- - image: cimg/node:20.19.0
- resource_class: medium
- steps:
- - checkout
- - run:
- name: Install Dependencies
- command: |
- pnpm install
- - run:
- name: Build UI
- command: pnpm --prefix web/client run build
- - persist_to_workspace:
- root: web/client
- paths:
- - dist
- trigger_private_renovate:
- docker:
- - image: cimg/base:2021.11
- resource_class: small
- steps:
- - run:
- name: Trigger private renovate
- command: |
- curl --request POST \
- --url $TOBIKO_PRIVATE_CIRCLECI_URL \
- --header "Circle-Token: $TOBIKO_PRIVATE_CIRCLECI_KEY" \
- --header "content-type: application/json" \
- --data '{
- "branch":"main",
- "parameters":{
- "run_main_pr":false,
- "run_sqlmesh_commit":false,
- "run_renovate":true
- }
- }'
-
-workflows:
- setup-workflow:
- jobs:
- - path-filtering/filter:
- mapping: |
- web/client/.* client true
- (sqlmesh|tests|examples|web/server)/.* python true
- pytest.ini|setup.cfg|setup.py|pyproject.toml python true
- \.circleci/.*|Makefile|\.pre-commit-config\.yaml common true
- vscode/extensions/.* vscode true
- tag: "3.9"
- - gh-release:
- <<: *on_tag_filter
- - ui-build:
- <<: *on_main_or_tag_filter
- - publish:
- <<: *on_main_or_tag_filter
- requires:
- - ui-build
- - trigger_private_renovate:
- <<: *on_tag_filter
- requires:
- - publish
diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml
deleted file mode 100644
index 5a240b85e4..0000000000
--- a/.circleci/continue_config.yml
+++ /dev/null
@@ -1,327 +0,0 @@
-version: 2.1
-
-parameters:
- client:
- type: boolean
- default: false
- common:
- type: boolean
- default: false
- python:
- type: boolean
- default: false
-
-orbs:
- windows: circleci/windows@5.0
-
-commands:
- halt_unless_core:
- steps:
- - unless:
- condition:
- or:
- - << pipeline.parameters.common >>
- - << pipeline.parameters.python >>
- - equal: [main, << pipeline.git.branch >>]
- steps:
- - run: circleci-agent step halt
- halt_unless_client:
- steps:
- - unless:
- condition:
- or:
- - << pipeline.parameters.common >>
- - << pipeline.parameters.client >>
- - equal: [main, << pipeline.git.branch >>]
- steps:
- - run: circleci-agent step halt
-
-jobs:
- vscode_test:
- docker:
- - image: cimg/node:20.19.1-browsers
- resource_class: small
- steps:
- - checkout
- - run:
- name: Install Dependencies
- command: |
- pnpm install
- - run:
- name: Run VSCode extension CI
- command: |
- cd vscode/extension
- pnpm run ci
- doc_tests:
- docker:
- - image: cimg/python:3.10
- resource_class: small
- steps:
- - halt_unless_core
- - checkout
- - run:
- name: Install dependencies
- command: make install-dev install-doc
- - run:
- name: Run doc tests
- command: make doc-test
-
- style_and_cicd_tests:
- parameters:
- python_version:
- type: string
- docker:
- - image: cimg/python:<< parameters.python_version >>
- resource_class: large
- environment:
- PYTEST_XDIST_AUTO_NUM_WORKERS: 8
- steps:
- - halt_unless_core
- - checkout
- - run:
- name: Install OpenJDK
- command: sudo apt-get update && sudo apt-get install default-jdk
- - run:
- name: Install ODBC
- command: sudo apt-get install unixodbc-dev
- - run:
- name: Install SQLMesh dev dependencies
- command: make install-dev
- - run:
- name: Fix Git URL override
- command: git config --global --unset url."ssh://git@github.com".insteadOf
- - run:
- name: Run linters and code style checks
- command: make py-style
- - unless:
- condition:
- equal: ["3.9", << parameters.python_version >>]
- steps:
- - run:
- name: Exercise the benchmarks
- command: make benchmark-ci
- - run:
- name: Run cicd tests
- command: make cicd-test
- - store_test_results:
- path: test-results
-
- cicd_tests_windows:
- executor:
- name: windows/default
- size: large
- steps:
- - halt_unless_core
- - run:
- name: Enable symlinks in git config
- command: git config --global core.symlinks true
- - checkout
- - run:
- name: Install System Dependencies
- command: |
- choco install make which -y
- refreshenv
- - run:
- name: Install SQLMesh dev dependencies
- command: |
- python -m venv venv
- . ./venv/Scripts/activate
- python.exe -m pip install --upgrade pip
- make install-dev
- - run:
- name: Run fast unit tests
- command: |
- . ./venv/Scripts/activate
- which python
- python --version
- make fast-test
- - store_test_results:
- path: test-results
-
- migration_test:
- docker:
- - image: cimg/python:3.10
- resource_class: small
- environment:
- SQLMESH__DISABLE_ANONYMIZED_ANALYTICS: "1"
- steps:
- - halt_unless_core
- - checkout
- - run:
- name: Run the migration test - sushi
- command: ./.circleci/test_migration.sh sushi "--gateway duckdb_persistent"
- - run:
- name: Run the migration test - sushi_dbt
- command: ./.circleci/test_migration.sh sushi_dbt "--config migration_test_config"
-
- ui_style:
- docker:
- - image: cimg/node:20.19.0
- resource_class: small
- steps:
- - checkout
- - restore_cache:
- name: Restore pnpm Package Cache
- keys:
- - pnpm-packages-{{ checksum "pnpm-lock.yaml" }}
- - run:
- name: Install Dependencies
- command: |
- pnpm install
- - save_cache:
- name: Save pnpm Package Cache
- key: pnpm-packages-{{ checksum "pnpm-lock.yaml" }}
- paths:
- - .pnpm-store
- - run:
- name: Run linters and code style checks
- command: pnpm run lint
-
- ui_test:
- docker:
- - image: mcr.microsoft.com/playwright:v1.54.1-jammy
- resource_class: medium
- steps:
- - halt_unless_client
- - checkout
- - restore_cache:
- name: Restore pnpm Package Cache
- keys:
- - pnpm-packages-{{ checksum "pnpm-lock.yaml" }}
- - run:
- name: Install pnpm package manager
- command: |
- npm install --global corepack@latest
- corepack enable
- corepack prepare pnpm@latest-10 --activate
- pnpm config set store-dir .pnpm-store
- - run:
- name: Install Dependencies
- command: |
- pnpm install
- - save_cache:
- name: Save pnpm Package Cache
- key: pnpm-packages-{{ checksum "pnpm-lock.yaml" }}
- paths:
- - .pnpm-store
- - run:
- name: Run tests
- command: npm --prefix web/client run test
-
- engine_tests_docker:
- parameters:
- engine:
- type: string
- machine:
- image: ubuntu-2404:2024.05.1
- docker_layer_caching: true
- resource_class: large
- environment:
- SQLMESH__DISABLE_ANONYMIZED_ANALYTICS: "1"
- steps:
- - halt_unless_core
- - checkout
- - run:
- name: Install OS-level dependencies
- command: ./.circleci/install-prerequisites.sh "<< parameters.engine >>"
- - run:
- name: Run tests
- command: make << parameters.engine >>-test
- no_output_timeout: 20m
- - store_test_results:
- path: test-results
-
- engine_tests_cloud:
- parameters:
- engine:
- type: string
- docker:
- - image: cimg/python:3.12
- resource_class: medium
- environment:
- PYTEST_XDIST_AUTO_NUM_WORKERS: 4
- SQLMESH__DISABLE_ANONYMIZED_ANALYTICS: "1"
- steps:
- - halt_unless_core
- - checkout
- - run:
- name: Install OS-level dependencies
- command: ./.circleci/install-prerequisites.sh "<< parameters.engine >>"
- - run:
- name: Generate database name
- command: |
- UUID=`cat /proc/sys/kernel/random/uuid`
- TEST_DB_NAME="circleci_${UUID:0:8}"
- echo "export TEST_DB_NAME='$TEST_DB_NAME'" >> "$BASH_ENV"
- echo "export SNOWFLAKE_DATABASE='$TEST_DB_NAME'" >> "$BASH_ENV"
- echo "export DATABRICKS_CATALOG='$TEST_DB_NAME'" >> "$BASH_ENV"
- echo "export REDSHIFT_DATABASE='$TEST_DB_NAME'" >> "$BASH_ENV"
- echo "export GCP_POSTGRES_DATABASE='$TEST_DB_NAME'" >> "$BASH_ENV"
- echo "export FABRIC_DATABASE='$TEST_DB_NAME'" >> "$BASH_ENV"
- - run:
- name: Create test database
- command: ./.circleci/manage-test-db.sh << parameters.engine >> "$TEST_DB_NAME" up
- - run:
- name: Run tests
- command: |
- make << parameters.engine >>-test
- no_output_timeout: 20m
- - run:
- name: Tear down test database
- command: ./.circleci/manage-test-db.sh << parameters.engine >> "$TEST_DB_NAME" down
- when: always
- - store_test_results:
- path: test-results
-
-workflows:
- main_pr:
- jobs:
- - doc_tests
- - style_and_cicd_tests:
- matrix:
- parameters:
- python_version:
- - "3.9"
- - "3.10"
- - "3.11"
- - "3.12"
- - "3.13"
- - cicd_tests_windows
- - engine_tests_docker:
- name: engine_<< matrix.engine >>
- matrix:
- parameters:
- engine:
- - duckdb
- - postgres
- - mysql
- - mssql
- - trino
- - spark
- - clickhouse
- - risingwave
- - engine_tests_cloud:
- name: cloud_engine_<< matrix.engine >>
- context:
- - sqlmesh_cloud_database_integration
- requires:
- - engine_tests_docker
- matrix:
- parameters:
- engine:
- - snowflake
- - databricks
- - redshift
- - bigquery
- - clickhouse-cloud
- - athena
- - fabric
- - gcp-postgres
- filters:
- branches:
- only:
- - main
- - ui_style
- - ui_test
- - vscode_test
- - migration_test
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
new file mode 100644
index 0000000000..7585f0ce10
--- /dev/null
+++ b/.github/pull_request_template.md
@@ -0,0 +1,16 @@
+## Description
+
+
+
+## Test Plan
+
+
+
+## Checklist
+
+- [ ] I have run `make style` and fixed any issues
+- [ ] I have added tests for my changes (if applicable)
+- [ ] All existing tests pass (`make fast-test`)
+- [ ] My commits are signed off (`git commit -s`) per the [DCO](DCO)
+
+
diff --git a/.circleci/install-prerequisites.sh b/.github/scripts/install-prerequisites.sh
similarity index 89%
rename from .circleci/install-prerequisites.sh
rename to .github/scripts/install-prerequisites.sh
index 446221dba6..6ab602fc37 100755
--- a/.circleci/install-prerequisites.sh
+++ b/.github/scripts/install-prerequisites.sh
@@ -1,6 +1,6 @@
#!/bin/bash
-# This script is intended to be run by an Ubuntu build agent on CircleCI
+# This script is intended to be run by an Ubuntu CI build agent
# The goal is to install OS-level dependencies that are required before trying to install Python dependencies
set -e
@@ -25,7 +25,7 @@ elif [ "$ENGINE" == "fabric" ]; then
sudo dpkg -i packages-microsoft-prod.deb
rm packages-microsoft-prod.deb
- ENGINE_DEPENDENCIES="msodbcsql18"
+ ENGINE_DEPENDENCIES="msodbcsql18"
fi
ALL_DEPENDENCIES="$COMMON_DEPENDENCIES $ENGINE_DEPENDENCIES"
@@ -39,4 +39,4 @@ if [ "$ENGINE" == "spark" ]; then
java -version
fi
-echo "All done"
\ No newline at end of file
+echo "All done"
diff --git a/.circleci/manage-test-db.sh b/.github/scripts/manage-test-db.sh
similarity index 77%
rename from .circleci/manage-test-db.sh
rename to .github/scripts/manage-test-db.sh
index f90b567ce8..29d11afcc0 100755
--- a/.circleci/manage-test-db.sh
+++ b/.github/scripts/manage-test-db.sh
@@ -25,7 +25,7 @@ function_exists() {
# Snowflake
snowflake_init() {
echo "Installing Snowflake CLI"
- pip install "snowflake-cli-labs<3.8.0"
+ pip install "snowflake-cli"
}
snowflake_up() {
@@ -40,20 +40,6 @@ snowflake_down() {
databricks_init() {
echo "Installing Databricks CLI"
curl -fsSL https://raw.githubusercontent.com/databricks/setup-cli/main/install.sh | sudo sh || true
-
- echo "Writing out Databricks CLI config file"
- echo -e "[DEFAULT]\nhost = $DATABRICKS_SERVER_HOSTNAME\ntoken = $DATABRICKS_ACCESS_TOKEN" > ~/.databrickscfg
-
- # this takes a path like 'sql/protocolv1/o/2934659247569/0723-005339-foobar' and extracts '0723-005339-foobar' from it
- CLUSTER_ID=${DATABRICKS_HTTP_PATH##*/}
-
- echo "Extracted cluster id: $CLUSTER_ID from '$DATABRICKS_HTTP_PATH'"
-
- # Note: the cluster doesnt need to be running to create / drop catalogs, but it does need to be running to run the integration tests
- echo "Ensuring cluster is running"
- # the || true is to prevent the following error from causing an abort:
- # > Error: is in unexpected state Running.
- databricks clusters start $CLUSTER_ID || true
}
databricks_up() {
@@ -82,10 +68,10 @@ redshift_down() {
EXIT_CODE=1
ATTEMPTS=0
while [ $EXIT_CODE -ne 0 ] && [ $ATTEMPTS -lt 5 ]; do
- # note: sometimes this pg_terminate_backend() call can randomly fail with: ERROR: Insufficient privileges
+ # note: sometimes this pg_terminate_backend() call can randomly fail with: ERROR: Insufficient privileges
# if it does, let's proceed with the drop anyway rather than aborting and never attempting the drop
redshift_exec "select pg_terminate_backend(procpid) from pg_stat_activity where datname = '$1'" || true
-
+
# perform drop
redshift_exec "drop database $1;" && EXIT_CODE=$? || EXIT_CODE=$?
if [ $EXIT_CODE -ne 0 ]; then
@@ -117,14 +103,16 @@ clickhouse-cloud_init() {
# GCP Postgres
gcp-postgres_init() {
- # Download and start Cloud SQL Proxy
- curl -fsSL -o cloud-sql-proxy https://storage.googleapis.com/cloud-sql-connectors/cloud-sql-proxy/v2.18.0/cloud-sql-proxy.linux.amd64
- chmod +x cloud-sql-proxy
+ # Download Cloud SQL Proxy if not already present
+ if [ ! -f cloud-sql-proxy ]; then
+ curl -fsSL -o cloud-sql-proxy https://storage.googleapis.com/cloud-sql-connectors/cloud-sql-proxy/v2.18.0/cloud-sql-proxy.linux.amd64
+ chmod +x cloud-sql-proxy
+ fi
echo "$GCP_POSTGRES_KEYFILE_JSON" > /tmp/keyfile.json
- ./cloud-sql-proxy --credentials-file /tmp/keyfile.json $GCP_POSTGRES_INSTANCE_CONNECTION_STRING &
-
- # Wait for proxy to start
- sleep 5
+ if ! pgrep -x cloud-sql-proxy > /dev/null; then
+ ./cloud-sql-proxy --credentials-file /tmp/keyfile.json $GCP_POSTGRES_INSTANCE_CONNECTION_STRING &
+ sleep 5
+ fi
}
gcp-postgres_exec() {
@@ -140,13 +128,13 @@ gcp-postgres_down() {
}
# Fabric
-fabric_init() {
+fabric_init() {
python --version #note: as at 2025-08-20, ms-fabric-cli is pinned to Python >= 3.10, <3.13
pip install ms-fabric-cli
-
+
# to prevent the '[EncryptionFailed] An error occurred with the encrypted cache.' error
# ref: https://microsoft.github.io/fabric-cli/#switch-to-interactive-mode-optional
- fab config set encryption_fallback_enabled true
+ fab config set encryption_fallback_enabled true
echo "Logging in to Fabric"
fab auth login -u $FABRIC_CLIENT_ID -p $FABRIC_CLIENT_SECRET --tenant $FABRIC_TENANT_ID
diff --git a/.circleci/test_migration.sh b/.github/scripts/test_migration.sh
similarity index 82%
rename from .circleci/test_migration.sh
rename to .github/scripts/test_migration.sh
index 9b8fe89e6e..ec45772c73 100755
--- a/.circleci/test_migration.sh
+++ b/.github/scripts/test_migration.sh
@@ -24,17 +24,20 @@ TEST_DIR="$TMP_DIR/$EXAMPLE_NAME"
echo "Running migration test for '$EXAMPLE_NAME' in '$TEST_DIR' for example project '$EXAMPLE_DIR' using options '$SQLMESH_OPTS'"
+# Copy the example project from the *current* checkout so it's stable across old/new SQLMesh versions
+cp -r "$EXAMPLE_DIR" "$TEST_DIR"
+
git checkout $LAST_TAG
# Install dependencies from the previous release.
+uv venv .venv --clear
+source .venv/bin/activate
make install-dev
-cp -r $EXAMPLE_DIR $TEST_DIR
-
# this is only needed temporarily until the released tag for $LAST_TAG includes this config
if [ "$EXAMPLE_NAME" == "sushi_dbt" ]; then
echo 'migration_test_config = sqlmesh_config(Path(__file__).parent, dbt_target_name="duckdb")' >> $TEST_DIR/config.py
-fi
+fi
# Run initial plan
pushd $TEST_DIR
@@ -43,14 +46,16 @@ sqlmesh $SQLMESH_OPTS plan --no-prompts --auto-apply
rm -rf .cache
popd
-# Switch back to the starting state of the repository
+# Switch back to the starting state of the repository
git checkout -
# Install updated dependencies.
+uv venv .venv --clear
+source .venv/bin/activate
make install-dev
# Migrate and make sure the diff is empty
pushd $TEST_DIR
sqlmesh $SQLMESH_OPTS migrate
sqlmesh $SQLMESH_OPTS diff prod
-popd
\ No newline at end of file
+popd
diff --git a/.circleci/update-pypirc.sh b/.github/scripts/update-pypirc.sh
similarity index 100%
rename from .circleci/update-pypirc.sh
rename to .github/scripts/update-pypirc.sh
diff --git a/.circleci/wait-for-db.sh b/.github/scripts/wait-for-db.sh
similarity index 98%
rename from .circleci/wait-for-db.sh
rename to .github/scripts/wait-for-db.sh
index a313320279..07502e3898 100755
--- a/.circleci/wait-for-db.sh
+++ b/.github/scripts/wait-for-db.sh
@@ -80,4 +80,4 @@ while [ $EXIT_CODE -ne 0 ]; do
fi
done
-echo "$ENGINE is ready!"
\ No newline at end of file
+echo "$ENGINE is ready!"
diff --git a/.github/workflows/dco.yml b/.github/workflows/dco.yml
new file mode 100644
index 0000000000..a1c4e07300
--- /dev/null
+++ b/.github/workflows/dco.yml
@@ -0,0 +1,17 @@
+name: Sanity check
+on: [pull_request]
+
+jobs:
+ commits_check_job:
+ runs-on: ubuntu-latest
+ name: Commits Check
+ steps:
+ - name: Get PR Commits
+ id: 'get-pr-commits'
+ uses: tim-actions/get-pr-commits@master
+ with:
+ token: ${{ secrets.GITHUB_TOKEN }}
+ - name: DCO Check
+ uses: tim-actions/dco@master
+ with:
+ commits: ${{ steps.get-pr-commits.outputs.commits }}
diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml
index 08ac729206..4395c56313 100644
--- a/.github/workflows/pr.yaml
+++ b/.github/workflows/pr.yaml
@@ -6,11 +6,392 @@ on:
branches:
- main
concurrency:
- group: 'pr-${{ github.event.pull_request.number }}'
+ group: pr-${{ github.event.pull_request.number || github.sha }}
cancel-in-progress: true
permissions:
contents: read
jobs:
+ changes:
+ runs-on: ubuntu-latest
+ outputs:
+ python: ${{ steps.filter.outputs.python }}
+ client: ${{ steps.filter.outputs.client }}
+ ci: ${{ steps.filter.outputs.ci }}
+ steps:
+ - uses: actions/checkout@v5
+ - uses: dorny/paths-filter@v3
+ id: filter
+ with:
+ filters: |
+ python:
+ - 'sqlmesh/**'
+ - 'tests/**'
+ - 'examples/**'
+ - 'web/server/**'
+ - 'pytest.ini'
+ - 'setup.cfg'
+ - 'setup.py'
+ - 'pyproject.toml'
+ client:
+ - 'web/client/**'
+ ci:
+ - '.github/**'
+ - 'Makefile'
+ - '.pre-commit-config.yaml'
+
+ doc-tests:
+ needs: changes
+ if:
+ needs.changes.outputs.python == 'true' || needs.changes.outputs.ci ==
+ 'true' || github.ref == 'refs/heads/main'
+ runs-on: ubuntu-latest
+ env:
+ UV: '1'
+ steps:
+ - uses: actions/checkout@v5
+ - name: Set up Python
+ uses: actions/setup-python@v6
+ with:
+ python-version: '3.10'
+ - name: Install uv
+ uses: astral-sh/setup-uv@v7
+ - name: Install dependencies
+ run: |
+ uv venv .venv
+ source .venv/bin/activate
+ make install-dev install-doc
+ - name: Run doc tests
+ run: |
+ source .venv/bin/activate
+ make doc-test
+
+ style-and-cicd-tests:
+ needs: changes
+ if:
+ needs.changes.outputs.python == 'true' || needs.changes.outputs.ci ==
+ 'true' || github.ref == 'refs/heads/main'
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
+ env:
+ PYTEST_XDIST_AUTO_NUM_WORKERS: 2
+ UV: '1'
+ steps:
+ - uses: actions/checkout@v5
+ with:
+ fetch-depth: 0
+ - name: Set up Python
+ uses: actions/setup-python@v6
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install uv
+ uses: astral-sh/setup-uv@v7
+ - name: Install OpenJDK and ODBC
+ run:
+ sudo apt-get update && sudo apt-get install -y default-jdk
+ unixodbc-dev
+ - name: Install SQLMesh dev dependencies
+ run: |
+ uv venv .venv
+ source .venv/bin/activate
+ make install-dev
+ - name: Fix Git URL override
+ run:
+ git config --global --unset url."ssh://git@github.com".insteadOf ||
+ true
+ - name: Run linters and code style checks
+ run: |
+ source .venv/bin/activate
+ make py-style
+ - name: Exercise the benchmarks
+ if: matrix.python-version != '3.9'
+ run: |
+ source .venv/bin/activate
+ make benchmark-ci
+ - name: Run cicd tests
+ run: |
+ source .venv/bin/activate
+ make cicd-test
+ - name: Upload test results
+ uses: actions/upload-artifact@v5
+ if: ${{ !cancelled() }}
+ with:
+ name: test-results-style-cicd-${{ matrix.python-version }}
+ path: test-results/
+ retention-days: 7
+
+ cicd-tests-windows:
+ needs: changes
+ if:
+ needs.changes.outputs.python == 'true' || needs.changes.outputs.ci ==
+ 'true' || github.ref == 'refs/heads/main'
+ runs-on: windows-latest
+ steps:
+ - name: Enable symlinks in git config
+ run: git config --global core.symlinks true
+ - uses: actions/checkout@v5
+ - name: Install make
+ run: choco install make which -y
+ - name: Set up Python
+ uses: actions/setup-python@v6
+ with:
+ python-version: '3.12'
+ - name: Install SQLMesh dev dependencies
+ run: |
+ python -m venv venv
+ . ./venv/Scripts/activate
+ python.exe -m pip install --upgrade pip
+ make install-dev
+ - name: Run fast unit tests
+ run: |
+ . ./venv/Scripts/activate
+ which python
+ python --version
+ make fast-test
+ - name: Upload test results
+ uses: actions/upload-artifact@v5
+ if: ${{ !cancelled() }}
+ with:
+ name: test-results-windows
+ path: test-results/
+ retention-days: 7
+
+ migration-test:
+ needs: changes
+ if:
+ needs.changes.outputs.python == 'true' || needs.changes.outputs.ci ==
+ 'true' || github.ref == 'refs/heads/main'
+ runs-on: ubuntu-latest
+ env:
+ SQLMESH__DISABLE_ANONYMIZED_ANALYTICS: '1'
+ UV: '1'
+ steps:
+ - uses: actions/checkout@v5
+ with:
+ fetch-depth: 0
+ - name: Set up Python
+ uses: actions/setup-python@v6
+ with:
+ python-version: '3.10'
+ - name: Install uv
+ uses: astral-sh/setup-uv@v7
+ - name: Run migration test - sushi
+ run:
+ ./.github/scripts/test_migration.sh sushi "--gateway
+ duckdb_persistent"
+ - name: Run migration test - sushi_dbt
+ run:
+ ./.github/scripts/test_migration.sh sushi_dbt "--config
+ migration_test_config"
+
+ ui-style:
+ needs: [changes]
+ if:
+ needs.changes.outputs.client == 'true' || needs.changes.outputs.ci ==
+ 'true' || github.ref == 'refs/heads/main'
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v5
+ - uses: actions/setup-node@v6
+ with:
+ node-version: '20'
+ - uses: pnpm/action-setup@v4
+ with:
+ version: latest
+ - name: Get pnpm store directory
+ id: pnpm-cache
+ run: echo "store=$(pnpm store path)" >> $GITHUB_OUTPUT
+ - uses: actions/cache@v4
+ with:
+ path: ${{ steps.pnpm-cache.outputs.store }}
+ key: pnpm-store-${{ hashFiles('pnpm-lock.yaml') }}
+ restore-keys: pnpm-store-
+ - name: Install dependencies
+ run: pnpm install
+ - name: Run linters and code style checks
+ run: pnpm run lint
+
+ ui-test:
+ needs: changes
+ if:
+ needs.changes.outputs.client == 'true' || needs.changes.outputs.ci ==
+ 'true' || github.ref == 'refs/heads/main'
+ runs-on: ubuntu-latest
+ container:
+ image: mcr.microsoft.com/playwright:v1.54.1-jammy
+ steps:
+ - uses: actions/checkout@v5
+ - name: Install pnpm via corepack
+ run: |
+ npm install --global corepack@latest
+ corepack enable
+ corepack prepare pnpm@latest-10 --activate
+ pnpm config set store-dir .pnpm-store
+ - name: Install dependencies
+ run: pnpm install
+ - name: Build UI
+ run: npm --prefix web/client run build
+ - name: Run unit tests
+ run: npm --prefix web/client run test:unit
+ - name: Run e2e tests
+ run: npm --prefix web/client run test:e2e
+ env:
+ PLAYWRIGHT_SKIP_BUILD: '1'
+ HOME: /root
+
+ engine-tests-docker:
+ needs: changes
+ if:
+ needs.changes.outputs.python == 'true' || needs.changes.outputs.ci ==
+ 'true' || github.ref == 'refs/heads/main'
+ runs-on: ubuntu-latest
+ timeout-minutes: 25
+ strategy:
+ fail-fast: false
+ matrix:
+ engine:
+ [duckdb, postgres, mysql, mssql, trino, spark, clickhouse, risingwave]
+ env:
+ PYTEST_XDIST_AUTO_NUM_WORKERS: 2
+ SQLMESH__DISABLE_ANONYMIZED_ANALYTICS: '1'
+ UV: '1'
+ steps:
+ - uses: actions/checkout@v5
+ - name: Set up Python
+ uses: actions/setup-python@v6
+ with:
+ python-version: '3.12'
+ - name: Install uv
+ uses: astral-sh/setup-uv@v7
+ - name: Install SQLMesh dev dependencies
+ run: |
+ uv venv .venv
+ source .venv/bin/activate
+ make install-dev
+ - name: Install OS-level dependencies
+ run: ./.github/scripts/install-prerequisites.sh "${{ matrix.engine }}"
+ - name: Run tests
+ run: |
+ source .venv/bin/activate
+ make ${{ matrix.engine }}-test
+ - name: Upload test results
+ uses: actions/upload-artifact@v5
+ if: ${{ !cancelled() }}
+ with:
+ name: test-results-docker-${{ matrix.engine }}
+ path: test-results/
+ retention-days: 7
+
+ engine-tests-cloud:
+ needs: engine-tests-docker
+ if: github.ref == 'refs/heads/main'
+ runs-on: ubuntu-latest
+ timeout-minutes: 25
+ strategy:
+ fail-fast: false
+ matrix:
+ engine:
+ [
+ snowflake,
+ databricks,
+ redshift,
+ bigquery,
+ clickhouse-cloud,
+ athena,
+ fabric,
+ gcp-postgres,
+ ]
+ env:
+ PYTEST_XDIST_AUTO_NUM_WORKERS: 4
+ SQLMESH__DISABLE_ANONYMIZED_ANALYTICS: '1'
+ UV: '1'
+ SNOWFLAKE_ACCOUNT: ${{ secrets.SNOWFLAKE_ACCOUNT }}
+ SNOWFLAKE_USER: ${{ secrets.SNOWFLAKE_USER }}
+ SNOWFLAKE_WAREHOUSE: ${{ secrets.SNOWFLAKE_WAREHOUSE }}
+ SNOWFLAKE_AUTHENTICATOR: SNOWFLAKE_JWT
+ DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_SERVER_HOSTNAME }}
+ DATABRICKS_HOST: ${{ secrets.DATABRICKS_SERVER_HOSTNAME }}
+ DATABRICKS_HTTP_PATH: ${{ secrets.DATABRICKS_HTTP_PATH }}
+ DATABRICKS_CLIENT_ID: ${{ secrets.DATABRICKS_CLIENT_ID }}
+ DATABRICKS_CLIENT_SECRET: ${{ secrets.DATABRICKS_CLIENT_SECRET }}
+ DATABRICKS_CONNECT_VERSION: ${{ secrets.DATABRICKS_CONNECT_VERSION }}
+ REDSHIFT_HOST: ${{ secrets.REDSHIFT_HOST }}
+ REDSHIFT_PORT: ${{ secrets.REDSHIFT_PORT }}
+ REDSHIFT_USER: ${{ secrets.REDSHIFT_USER }}
+ REDSHIFT_PASSWORD: ${{ secrets.REDSHIFT_PASSWORD }}
+ BIGQUERY_KEYFILE: ${{ secrets.BIGQUERY_KEYFILE }}
+ BIGQUERY_KEYFILE_CONTENTS: ${{ secrets.BIGQUERY_KEYFILE_CONTENTS }}
+ CLICKHOUSE_CLOUD_HOST: ${{ secrets.CLICKHOUSE_CLOUD_HOST }}
+ CLICKHOUSE_CLOUD_USERNAME: ${{ secrets.CLICKHOUSE_CLOUD_USERNAME }}
+ CLICKHOUSE_CLOUD_PASSWORD: ${{ secrets.CLICKHOUSE_CLOUD_PASSWORD }}
+ GCP_POSTGRES_KEYFILE_JSON: ${{ secrets.GCP_POSTGRES_KEYFILE_JSON }}
+ GCP_POSTGRES_INSTANCE_CONNECTION_STRING:
+ ${{ secrets.GCP_POSTGRES_INSTANCE_CONNECTION_STRING }}
+ GCP_POSTGRES_USER: ${{ secrets.GCP_POSTGRES_USER }}
+ GCP_POSTGRES_PASSWORD: ${{ secrets.GCP_POSTGRES_PASSWORD }}
+ ATHENA_S3_WAREHOUSE_LOCATION: ${{ secrets.ATHENA_S3_WAREHOUSE_LOCATION }}
+ ATHENA_WORK_GROUP: ${{ secrets.ATHENA_WORK_GROUP }}
+ AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
+ AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
+ AWS_REGION: ${{ secrets.AWS_REGION }}
+ FABRIC_HOST: ${{ secrets.FABRIC_HOST }}
+ FABRIC_CLIENT_ID: ${{ secrets.FABRIC_CLIENT_ID }}
+ FABRIC_CLIENT_SECRET: ${{ secrets.FABRIC_CLIENT_SECRET }}
+ FABRIC_TENANT_ID: ${{ secrets.FABRIC_TENANT_ID }}
+ FABRIC_WORKSPACE_ID: ${{ secrets.FABRIC_WORKSPACE_ID }}
+ steps:
+ - uses: actions/checkout@v5
+ - name: Set up Python
+ uses: actions/setup-python@v6
+ with:
+ python-version: '3.12'
+ - name: Install uv
+ uses: astral-sh/setup-uv@v7
+ - name: Install OS-level dependencies
+ run: ./.github/scripts/install-prerequisites.sh "${{ matrix.engine }}"
+ - name: Install SQLMesh dev dependencies
+ run: |
+ uv venv .venv
+ source .venv/bin/activate
+ make install-dev
+ - name: Generate database name and setup credentials
+ run: |
+ UUID=$(cat /proc/sys/kernel/random/uuid)
+ TEST_DB_NAME="ci_${UUID:0:8}"
+ echo "TEST_DB_NAME=$TEST_DB_NAME" >> $GITHUB_ENV
+ echo "SNOWFLAKE_DATABASE=$TEST_DB_NAME" >> $GITHUB_ENV
+ echo "DATABRICKS_CATALOG=$TEST_DB_NAME" >> $GITHUB_ENV
+ echo "REDSHIFT_DATABASE=$TEST_DB_NAME" >> $GITHUB_ENV
+ echo "GCP_POSTGRES_DATABASE=$TEST_DB_NAME" >> $GITHUB_ENV
+ echo "FABRIC_DATABASE=$TEST_DB_NAME" >> $GITHUB_ENV
+
+ echo "$SNOWFLAKE_PRIVATE_KEY_RAW" | base64 -d > /tmp/snowflake-keyfile.p8
+ echo "SNOWFLAKE_PRIVATE_KEY_FILE=/tmp/snowflake-keyfile.p8" >> $GITHUB_ENV
+ env:
+ SNOWFLAKE_PRIVATE_KEY_RAW: ${{ secrets.SNOWFLAKE_PRIVATE_KEY_RAW }}
+ - name: Create test database
+ run:
+ ./.github/scripts/manage-test-db.sh "${{ matrix.engine }}"
+ "$TEST_DB_NAME" up
+ - name: Run tests
+ run: |
+ source .venv/bin/activate
+ make ${{ matrix.engine }}-test
+ - name: Tear down test database
+ if: always()
+ run:
+ ./.github/scripts/manage-test-db.sh "${{ matrix.engine }}"
+ "$TEST_DB_NAME" down
+ - name: Upload test results
+ uses: actions/upload-artifact@v5
+ if: ${{ !cancelled() }}
+ with:
+ name: test-results-cloud-${{ matrix.engine }}
+ path: test-results/
+ retention-days: 7
+
test-vscode:
env:
PLAYWRIGHT_SKIP_BROWSER_DOWNLOAD: 1
@@ -30,6 +411,8 @@ jobs:
test-vscode-e2e:
runs-on:
labels: [ubuntu-2204-8]
+ # As at 2026-01-12 this job flakes 100% of the time. It needs investigation
+ if: false
steps:
- uses: actions/checkout@v5
- uses: actions/setup-node@v6
@@ -98,30 +481,30 @@ jobs:
if [[ "${{ matrix.dbt-version }}" == "1.3" ]] || \
[[ "${{ matrix.dbt-version }}" == "1.4" ]] || \
[[ "${{ matrix.dbt-version }}" == "1.5" ]]; then
-
+
echo "DBT version is ${{ matrix.dbt-version }} (< 1.6.0), removing semantic_models and metrics sections..."
-
+
schema_file="tests/fixtures/dbt/sushi_test/models/schema.yml"
if [[ -f "$schema_file" ]]; then
echo "Modifying $schema_file..."
-
+
# Create a temporary file
temp_file=$(mktemp)
-
+
# Use awk to remove semantic_models and metrics sections
awk '
/^semantic_models:/ { in_semantic=1; next }
/^metrics:/ { in_metrics=1; next }
- /^[^ ]/ && (in_semantic || in_metrics) {
- in_semantic=0;
- in_metrics=0
+ /^[^ ]/ && (in_semantic || in_metrics) {
+ in_semantic=0;
+ in_metrics=0
}
!in_semantic && !in_metrics { print }
' "$schema_file" > "$temp_file"
-
+
# Move the temp file back
mv "$temp_file" "$schema_file"
-
+
echo "Successfully removed semantic_models and metrics sections"
else
echo "Schema file not found at $schema_file, skipping..."
diff --git a/.github/workflows/private-repo-test.yaml b/.github/workflows/private-repo-test.yaml
deleted file mode 100644
index 9b2365f48a..0000000000
--- a/.github/workflows/private-repo-test.yaml
+++ /dev/null
@@ -1,97 +0,0 @@
-name: Private Repo Testing
-
-on:
- pull_request_target:
- branches:
- - main
-
-concurrency:
- group: 'private-test-${{ github.event.pull_request.number }}'
- cancel-in-progress: true
-
-permissions:
- contents: read
-
-jobs:
- trigger-private-test:
- runs-on: ubuntu-latest
- steps:
- - name: Checkout code
- uses: actions/checkout@v5
- with:
- fetch-depth: 0
- ref: ${{ github.event.pull_request.head.sha || github.ref }}
- - name: Set up Python
- uses: actions/setup-python@v6
- with:
- python-version: '3.12'
- - name: Install uv
- uses: astral-sh/setup-uv@v7
- - name: Set up Node.js for UI build
- uses: actions/setup-node@v6
- with:
- node-version: '20'
- - name: Install pnpm
- uses: pnpm/action-setup@v4
- with:
- version: latest
- - name: Install UI dependencies
- run: pnpm install
- - name: Build UI
- run: pnpm --prefix web/client run build
- - name: Install Python dependencies
- run: |
- python -m venv .venv
- source .venv/bin/activate
- pip install build twine setuptools_scm
- - name: Generate development version
- id: version
- run: |
- source .venv/bin/activate
- # Generate a PEP 440 compliant unique version including run attempt
- BASE_VERSION=$(python .github/scripts/get_scm_version.py)
- COMMIT_SHA=$(git rev-parse --short HEAD)
- # Use PEP 440 compliant format: base.devN+pr.sha.attempt
- UNIQUE_VERSION="${BASE_VERSION}+pr${{ github.event.pull_request.number }}.${COMMIT_SHA}.run${{ github.run_attempt }}"
- echo "version=$UNIQUE_VERSION" >> $GITHUB_OUTPUT
- echo "Generated unique version with run attempt: $UNIQUE_VERSION"
- - name: Build package
- env:
- SETUPTOOLS_SCM_PRETEND_VERSION: ${{ steps.version.outputs.version }}
- run: |
- source .venv/bin/activate
- python -m build
- - name: Configure PyPI for private repository
- env:
- TOBIKO_PRIVATE_PYPI_URL: ${{ secrets.TOBIKO_PRIVATE_PYPI_URL }}
- TOBIKO_PRIVATE_PYPI_KEY: ${{ secrets.TOBIKO_PRIVATE_PYPI_KEY }}
- run: ./.circleci/update-pypirc.sh
- - name: Publish to private PyPI
- run: |
- source .venv/bin/activate
- python -m twine upload -r tobiko-private dist/*
- - name: Publish Python Tests package
- env:
- SETUPTOOLS_SCM_PRETEND_VERSION: ${{ steps.version.outputs.version }}
- run: |
- source .venv/bin/activate
- unset TWINE_USERNAME TWINE_PASSWORD && make publish-tests
- - name: Get GitHub App token
- id: get_token
- uses: actions/create-github-app-token@v2
- with:
- private-key: ${{ secrets.TOBIKO_RENOVATE_BOT_PRIVATE_KEY }}
- app-id: ${{ secrets.TOBIKO_RENOVATE_BOT_APP_ID }}
- owner: ${{ secrets.PRIVATE_REPO_OWNER }}
- - name: Trigger private repository workflow
- uses: convictional/trigger-workflow-and-wait@v1.6.5
- with:
- owner: ${{ secrets.PRIVATE_REPO_OWNER }}
- repo: ${{ secrets.PRIVATE_REPO_NAME }}
- github_token: ${{ steps.get_token.outputs.token }}
- workflow_file_name: ${{ secrets.PRIVATE_WORKFLOW_FILE }}
- client_payload: |
- {
- "package_version": "${{ steps.version.outputs.version }}",
- "pr_number": "${{ github.event.pull_request.number }}"
- }
diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
new file mode 100644
index 0000000000..75512ffd72
--- /dev/null
+++ b/.github/workflows/release.yaml
@@ -0,0 +1,71 @@
+name: Release
+on:
+ push:
+ tags:
+ - 'v*.*.*'
+permissions:
+ contents: write
+jobs:
+ ui-build:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v5
+ - uses: actions/setup-node@v6
+ with:
+ node-version: '20'
+ - uses: pnpm/action-setup@v4
+ with:
+ version: latest
+ - name: Install dependencies
+ run: pnpm install
+ - name: Build UI
+ run: pnpm --prefix web/client run build
+ - name: Upload UI build artifact
+ uses: actions/upload-artifact@v5
+ with:
+ name: ui-dist
+ path: web/client/dist/
+ retention-days: 1
+
+ publish:
+ needs: ui-build
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v5
+ - name: Download UI build artifact
+ uses: actions/download-artifact@v4
+ with:
+ name: ui-dist
+ path: web/client/dist/
+ - name: Set up Python
+ uses: actions/setup-python@v6
+ with:
+ python-version: '3.10'
+ - name: Install uv
+ uses: astral-sh/setup-uv@v7
+ - name: Install build dependencies
+ run: pip install build twine setuptools_scm
+ - name: Publish Python package
+ run: make publish
+ env:
+ TWINE_USERNAME: ${{ secrets.TWINE_USERNAME }}
+ TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }}
+ - name: Update pypirc for private repository
+ run: ./.github/scripts/update-pypirc.sh
+ env:
+ TOBIKO_PRIVATE_PYPI_URL: ${{ secrets.TOBIKO_PRIVATE_PYPI_URL }}
+ TOBIKO_PRIVATE_PYPI_KEY: ${{ secrets.TOBIKO_PRIVATE_PYPI_KEY }}
+ - name: Publish Python Tests package
+ run: unset TWINE_USERNAME TWINE_PASSWORD && make publish-tests
+
+ gh-release:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v5
+ with:
+ fetch-depth: 0
+ - name: Create release on GitHub
+ uses: softprops/action-gh-release@v2
+ with:
+ generate_release_notes: true
+ tag_name: ${{ github.ref_name }}
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000000..287a87dab5
--- /dev/null
+++ b/CODE_OF_CONDUCT.md
@@ -0,0 +1,5 @@
+# Code of Conduct
+
+SQLMesh follows the [LF Projects Code of Conduct](https://lfprojects.org/policies/code-of-conduct/). All participants in the project are expected to abide by it.
+
+If you believe someone is violating the code of conduct, please report it by following the instructions in the [LF Projects Code of Conduct](https://lfprojects.org/policies/code-of-conduct/).
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000000..0e1d8e1c6e
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,90 @@
+# Contributing to SQLMesh
+
+## Welcome
+
+SQLMesh is a project of the Linux Foundation. We welcome contributions from anyone β whether you're fixing a bug, improving documentation, or proposing a new feature.
+
+## Technical Steering Committee (TSC)
+
+The TSC is responsible for technical oversight of the SQLMesh project, including coordinating technical direction, approving contribution policies, and maintaining community norms.
+
+Initial TSC voting members are the project's Maintainers:
+
+| Name | GitHub Handle | Affiliation | Role |
+|---------------------|---------------|----------------|------------|
+| Alexander Butler | z3z1ma | Harness | TSC Member |
+| Alexander Filipchik | afilipchik | Cloud Kitchens | TSC Member |
+| Reid Hooper | rhooper9711 | Benzinga | TSC Member |
+| Yuki Kakegawa | StuffbyYuki | Jump.ai | TSC Member |
+| Toby Mao | tobymao | Fivetran | TSC Chair |
+| Alex Wilde | alexminerv | Minerva | TSC Member |
+
+
+## Roles
+
+**Contributors**: Anyone who contributes code, documentation, or other technical artifacts to the project.
+
+**Maintainers**: Contributors who have earned the ability to modify source code, documentation, or other technical artifacts. A Contributor may become a Maintainer by majority approval of the TSC. A Maintainer may be removed by majority approval of the TSC.
+
+## How to Contribute
+
+1. Fork the repository on GitHub
+2. Create a branch for your changes
+3. Make your changes and commit them with a sign-off (see DCO section below)
+4. Submit a pull request against the `main` branch
+
+File issues at [github.com/sqlmesh/sqlmesh/issues](https://github.com/sqlmesh/sqlmesh/issues).
+
+## Developer Certificate of Origin (DCO)
+
+All contributions must include a `Signed-off-by` line in the commit message per the [Developer Certificate of Origin](DCO). This certifies that you wrote the contribution or have the right to submit it under the project's open source license.
+
+Use `git commit -s` to add the sign-off automatically:
+
+```bash
+git commit -s -m "Your commit message"
+```
+
+To fix a commit that is missing the sign-off:
+
+```bash
+git commit --amend -s
+```
+
+To add a sign-off to multiple commits:
+
+```bash
+git rebase HEAD~N --signoff
+```
+
+## Development Setup
+
+See [docs/development.md](docs/development.md) for full setup instructions. Key commands:
+
+```bash
+python -m venv .venv
+source .venv/bin/activate
+make install-dev
+make style # Run before submitting
+make fast-test # Quick test suite
+```
+
+## Coding Standards
+
+- Run `make style` before submitting a pull request
+- Follow existing code patterns and conventions in the codebase
+- New files should include an SPDX license header:
+ ```python
+ # SPDX-License-Identifier: Apache-2.0
+ ```
+
+## Pull Request Process
+
+- Describe your changes clearly in the pull request description
+- Ensure all CI checks pass
+- Include a DCO sign-off on all commits (`git commit -s`)
+- Be responsive to review feedback from maintainers
+
+## Licensing
+
+Code contributions are licensed under the [Apache License 2.0](LICENSE). Documentation contributions are licensed under [Creative Commons Attribution 4.0 International (CC-BY-4.0)](https://creativecommons.org/licenses/by/4.0/). See the LICENSE file and the [technical charter](sqlmesh-technical-charter.pdf) for details.
diff --git a/DCO b/DCO
new file mode 100644
index 0000000000..49b8cb0549
--- /dev/null
+++ b/DCO
@@ -0,0 +1,34 @@
+Developer Certificate of Origin
+Version 1.1
+
+Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
+
+Everyone is permitted to copy and distribute verbatim copies of this
+license document, but changing it is not allowed.
+
+
+Developer's Certificate of Origin 1.1
+
+By making a contribution to this project, I certify that:
+
+(a) The contribution was created in whole or in part by me and I
+ have the right to submit it under the open source license
+ indicated in the file; or
+
+(b) The contribution is based upon previous work that, to the best
+ of my knowledge, is covered under an appropriate open source
+ license and I have the right under that license to submit that
+ work with modifications, whether created in whole or in part
+ by me, under the same open source license (unless I am
+ permitted to submit under a different license), as indicated
+ in the file; or
+
+(c) The contribution was provided directly to me by some other
+ person who certified (a), (b) or (c) and I have not modified
+ it.
+
+(d) I understand and agree that this project and the contribution
+ are public and that a record of the contribution (including all
+ personal information I submit with it, including my sign-off) is
+ maintained indefinitely and may be redistributed consistent with
+ this project or the open source license(s) involved.
diff --git a/GOVERNANCE.md b/GOVERNANCE.md
new file mode 100644
index 0000000000..44b6bc9947
--- /dev/null
+++ b/GOVERNANCE.md
@@ -0,0 +1,62 @@
+# SQLMesh Project Governance
+
+## Overview
+
+SQLMesh is a Series of LF Projects, LLC. The project is governed by its [Technical Charter](sqlmesh-technical-charter.pdf) and overseen by the Technical Steering Committee (TSC). SQLMesh is a project of the [Linux Foundation](https://www.linuxfoundation.org/).
+
+## Technical Steering Committee
+
+The TSC is responsible for all technical oversight of the project, including:
+
+- Coordinating the technical direction of the project
+- Approving project or system proposals
+- Organizing sub-projects and removing sub-projects
+- Creating sub-committees or working groups to focus on cross-project technical issues
+- Appointing representatives to work with other open source or open standards communities
+- Establishing community norms, workflows, issuing releases, and security vulnerability reports
+- Approving and implementing policies for contribution requirements
+- Coordinating any marketing, events, or communications regarding the project
+
+## TSC Composition
+
+TSC voting members are initially the project's Maintainers as listed in [CONTRIBUTING.md](CONTRIBUTING.md). The TSC may elect a Chair from among its voting members. The Chair presides over TSC meetings and serves as the primary point of contact with the Linux Foundation.
+
+## Decision Making
+
+The project operates as a consensus-based community. When a formal vote is required:
+
+- Each voting TSC member receives one vote
+- A quorum of 50% of voting members is required to conduct a vote
+- Decisions are made by a majority of those present when quorum is met
+- Electronic votes (e.g., via GitHub issues or mailing list) require a majority of all voting members to pass
+- Votes that do not meet quorum or remain unresolved may be referred to the Series Manager for resolution
+
+## Charter Amendments
+
+The technical charter may be amended by a two-thirds vote of the entire TSC, subject to approval by LF Projects, LLC.
+
+## Reference
+
+The full technical charter is available at [sqlmesh-technical-charter.pdf](sqlmesh-technical-charter.pdf).
+
+# TSC Meeting Minutes
+
+## 2026-03-10 β Initial TSC Meeting
+
+**Members present:** Toby Mao (tobymao)
+
+### Vote 1: Elect Toby Mao as TSC Chair
+- **Motion by:** Toby Mao
+- **Votes:** Toby Mao: Yes
+- **Result:** Approved (1-0-0, yes-no-abstain)
+
+### Vote 2: Elect TSC founding members
+- **Question:** Shall the following members be added to the TSC?
+ - Alexander Butler (z3z1ma)
+ - Alexander Filipchik (afilipchik)
+ - Reid Hooper (rhooper9711)
+ - Yuki Kakegawa (StuffbyYuki)
+ - Alex Wilde (alexminerv)
+- **Motion by:** Toby Mao
+- **Votes:** Toby Mao: Yes
+- **Result:** Approved (1-0-0, yes-no-abstain)
diff --git a/LICENSE b/LICENSE
index eabfad022a..7e95724816 100644
--- a/LICENSE
+++ b/LICENSE
@@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.
- Copyright 2024 Tobiko Data Inc.
+ Copyright Contributors to the SQLMesh project
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
diff --git a/Makefile b/Makefile
index 2b3e10cb1b..843beb0624 100644
--- a/Makefile
+++ b/Makefile
@@ -59,6 +59,10 @@ install-dev-dbt-%:
echo "Applying overrides for dbt 1.5.0"; \
$(PIP) install 'dbt-databricks==1.5.6' 'numpy<2' --reinstall; \
fi; \
+ if [ "$$version" = "1.3.0" ]; then \
+ echo "Applying overrides for dbt $$version - upgrading google-cloud-bigquery"; \
+ $(PIP) install 'google-cloud-bigquery>=3.0.0' --upgrade; \
+ fi; \
mv pyproject.toml.backup pyproject.toml; \
echo "Restored original pyproject.toml"
@@ -126,7 +130,7 @@ slow-test:
pytest -n auto -m "(fast or slow) and not cicdonly" && pytest -m "isolated" && pytest -m "registry_isolation" && pytest -m "dialect_isolated"
cicd-test:
- pytest -n auto -m "fast or slow" --junitxml=test-results/junit-cicd.xml && pytest -m "isolated" && pytest -m "registry_isolation" && pytest -m "dialect_isolated"
+ pytest -n auto -m "(fast or slow) and not pyspark" --junitxml=test-results/junit-cicd.xml && pytest -m "pyspark" && pytest -m "isolated" && pytest -m "registry_isolation" && pytest -m "dialect_isolated"
core-fast-test:
pytest -n auto -m "fast and not web and not github and not dbt and not jupyter"
@@ -162,7 +166,7 @@ web-test:
pytest -n auto -m "web"
guard-%:
- @ if [ "${${*}}" = "" ]; then \
+ @ if ! printenv ${*} > /dev/null 2>&1; then \
echo "Environment variable $* not set"; \
exit 1; \
fi
@@ -172,7 +176,7 @@ engine-%-install:
engine-docker-%-up:
docker compose -f ./tests/core/engine_adapter/integration/docker/compose.${*}.yaml up -d
- ./.circleci/wait-for-db.sh ${*}
+ ./.github/scripts/wait-for-db.sh ${*}
engine-%-up: engine-%-install engine-docker-%-up
@echo "Engine '${*}' is up and running"
@@ -212,14 +216,14 @@ risingwave-test: engine-risingwave-up
# Cloud Engines #
#################
-snowflake-test: guard-SNOWFLAKE_ACCOUNT guard-SNOWFLAKE_WAREHOUSE guard-SNOWFLAKE_DATABASE guard-SNOWFLAKE_USER guard-SNOWFLAKE_PASSWORD engine-snowflake-install
+snowflake-test: guard-SNOWFLAKE_ACCOUNT guard-SNOWFLAKE_WAREHOUSE guard-SNOWFLAKE_DATABASE guard-SNOWFLAKE_USER engine-snowflake-install
pytest -n auto -m "snowflake" --reruns 3 --junitxml=test-results/junit-snowflake.xml
bigquery-test: guard-BIGQUERY_KEYFILE engine-bigquery-install
$(PIP) install -e ".[bigframes]"
pytest -n auto -m "bigquery" --reruns 3 --junitxml=test-results/junit-bigquery.xml
-databricks-test: guard-DATABRICKS_CATALOG guard-DATABRICKS_SERVER_HOSTNAME guard-DATABRICKS_HTTP_PATH guard-DATABRICKS_ACCESS_TOKEN guard-DATABRICKS_CONNECT_VERSION engine-databricks-install
+databricks-test: guard-DATABRICKS_CATALOG guard-DATABRICKS_SERVER_HOSTNAME guard-DATABRICKS_HTTP_PATH guard-DATABRICKS_CONNECT_VERSION engine-databricks-install
$(PIP) install 'databricks-connect==${DATABRICKS_CONNECT_VERSION}'
pytest -n auto -m "databricks" --reruns 3 --junitxml=test-results/junit-databricks.xml
diff --git a/README.md b/README.md
index 3215f7cceb..41f78cc138 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,7 @@
+SQLMesh is a project of the Linux Foundation.
SQLMesh is a next-generation data transformation framework designed to ship data quickly, efficiently, and without error. Data teams can run and deploy data transformations written in SQL or Python with visibility and control at any size.
@@ -12,7 +13,7 @@ It is more than just a [dbt alternative](https://tobikodata.com/reduce_costs_wit
## Core Features
-
+
> Get instant SQL impact and context of your changes, both in the CLI and in the [SQLMesh VSCode Extension](https://sqlmesh.readthedocs.io/en/latest/guides/vscode/?h=vs+cod)
@@ -121,19 +122,19 @@ outputs:
* Never build a table [more than once](https://tobikodata.com/simplicity-or-efficiency-how-dbt-makes-you-choose.html)
* Track what dataβs been modified and run only the necessary transformations for [incremental models](https://tobikodata.com/correctly-loading-incremental-data-at-scale.html)
* Run [unit tests](https://tobikodata.com/we-need-even-greater-expectations.html) for free and configure automated audits
-* Run [table diffs](https://sqlmesh.readthedocs.io/en/stable/examples/sqlmesh_cli_crash_course/?h=crash#run-data-diff-against-prod) between prod and dev based on tables/views impacted by a change
+* Run [table diffs](https://sqlmesh.readthedocs.io/en/stable/examples/sqlmesh_cli_crash_course/?h=crash#run-data-diff-against-prod) between prod and dev based on tables/views impacted by a change
Level Up Your SQL
Write SQL in any dialect and SQLMesh will transpile it to your target SQL dialect on the fly before sending it to the warehouse.
-
+
* Debug transformation errors *before* you run them in your warehouse in [10+ different SQL dialects](https://sqlmesh.readthedocs.io/en/stable/integrations/overview/#execution-engines)
* Definitions using [simply SQL](https://sqlmesh.readthedocs.io/en/stable/concepts/models/sql_models/#sql-based-definition) (no need for redundant and confusing `Jinja` + `YAML`)
* See impact of changes before you run them in your warehouse with column-level lineage
-For more information, check out the [website](https://www.tobikodata.com/sqlmesh) and [documentation](https://sqlmesh.readthedocs.io/en/stable/).
+For more information, check out the [documentation](https://sqlmesh.readthedocs.io/en/stable/).
## Getting Started
Install SQLMesh through [pypi](https://pypi.org/project/sqlmesh/) by running:
@@ -169,21 +170,24 @@ sqlmesh init # follow the prompts to get started (choose DuckDB)
Follow the [quickstart guide](https://sqlmesh.readthedocs.io/en/stable/quickstart/cli/) to learn how to use SQLMesh. You already have a head start!
-Follow the [crash course](https://sqlmesh.readthedocs.io/en/stable/examples/sqlmesh_cli_crash_course/) to learn the core movesets and use the easy to reference cheat sheet.
+Follow the [crash course](https://sqlmesh.readthedocs.io/en/stable/examples/sqlmesh_cli_crash_course/) to learn the core movesets and use the easy to reference cheat sheet.
Follow this [example](https://sqlmesh.readthedocs.io/en/stable/examples/incremental_time_full_walkthrough/) to learn how to use SQLMesh in a full walkthrough.
## Join Our Community
-Together, we want to build data transformation without the waste. Connect with us in the following ways:
+Connect with us in the following ways:
* Join the [Tobiko Slack Community](https://tobikodata.com/slack) to ask questions, or just to say hi!
-* File an issue on our [GitHub](https://github.com/TobikoData/sqlmesh/issues/new)
+* File an issue on our [GitHub](https://github.com/SQLMesh/sqlmesh/issues/new)
* Send us an email at [hello@tobikodata.com](mailto:hello@tobikodata.com) with your questions or feedback
* Read our [blog](https://tobikodata.com/blog)
-## Contribution
-Contributions in the form of issues or pull requests (from fork) are greatly appreciated.
+## Contributing
+We welcome contributions! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute, including our DCO sign-off requirement.
-[Read more](https://sqlmesh.readthedocs.io/en/stable/development/) on how to contribute to SQLMesh open source.
+Please review our [Code of Conduct](CODE_OF_CONDUCT.md) and [Governance](GOVERNANCE.md) documents.
-[Watch this video walkthrough](https://www.loom.com/share/2abd0d661c12459693fa155490633126?sid=b65c1c0f-8ef7-4036-ad19-3f85a3b87ff2) to see how our team contributes a feature to SQLMesh.
+[Read more](https://sqlmesh.readthedocs.io/en/stable/development/) on how to set up your development environment.
+
+## License
+This project is licensed under the [Apache License 2.0](LICENSE). Documentation is licensed under [CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/).
diff --git a/SECURITY.md b/SECURITY.md
new file mode 100644
index 0000000000..2ffffacea3
--- /dev/null
+++ b/SECURITY.md
@@ -0,0 +1,17 @@
+# Security Policy
+
+## Reporting a Vulnerability
+
+If you discover a security vulnerability in SQLMesh, please report it through [GitHub Security Advisories](https://github.com/sqlmesh/sqlmesh/security/advisories/new). Do not file a public issue for security vulnerabilities.
+
+## Response
+
+We will acknowledge receipt of your report within 72 hours and aim to provide an initial assessment within one week.
+
+## Disclosure
+
+We follow a coordinated disclosure process. We will work with you to understand and address the issue before any public disclosure.
+
+## Supported Versions
+
+Security fixes are generally applied to the latest release. Critical vulnerabilities may be backported to recent prior releases at the discretion of the maintainers.
diff --git a/docs/HOWTO.md b/docs/HOWTO.md
index 9ccefff077..edd7c9833f 100644
--- a/docs/HOWTO.md
+++ b/docs/HOWTO.md
@@ -92,7 +92,7 @@ You will work on the docs in a local copy of the sqlmesh git repository.
If you don't have a copy of the repo on your machine, open a terminal and clone it into a `sqlmesh` directory by executing:
``` bash
-git clone https://github.com/TobikoData/sqlmesh.git
+git clone https://github.com/SQLMesh/sqlmesh.git
```
And navigate to the directory:
diff --git a/docs/cloud/features/scheduler/hybrid_executors_docker_compose.md b/docs/cloud/features/scheduler/hybrid_executors_docker_compose.md
index e3bd072752..8f8f323139 100644
--- a/docs/cloud/features/scheduler/hybrid_executors_docker_compose.md
+++ b/docs/cloud/features/scheduler/hybrid_executors_docker_compose.md
@@ -25,7 +25,7 @@ Both executors must be properly configured with environment variables to connect
1. **Get docker-compose file**:
- Download the [docker-compose.yml](https://raw.githubusercontent.com/TobikoData/sqlmesh/refs/heads/main/docs/cloud/features/scheduler/scheduler/docker-compose.yml) and [.env.example](https://raw.githubusercontent.com/TobikoData/sqlmesh/refs/heads/main/docs/cloud/features/scheduler/scheduler/.env.example) files to a local directory.
+ Download the [docker-compose.yml](https://raw.githubusercontent.com/SQLMesh/sqlmesh/refs/heads/main/docs/cloud/features/scheduler/scheduler/docker-compose.yml) and [.env.example](https://raw.githubusercontent.com/SQLMesh/sqlmesh/refs/heads/main/docs/cloud/features/scheduler/scheduler/.env.example) files to a local directory.
2. **Create your environment file**:
diff --git a/docs/concepts/macros/sqlmesh_macros.md b/docs/concepts/macros/sqlmesh_macros.md
index f28e77e203..c7d967b12c 100644
--- a/docs/concepts/macros/sqlmesh_macros.md
+++ b/docs/concepts/macros/sqlmesh_macros.md
@@ -2111,7 +2111,7 @@ FROM some_table;
Generics can be nested and are resolved recursively allowing for fairly robust type hinting.
-See examples of the coercion function in action in the test suite [here](https://github.com/TobikoData/sqlmesh/blob/main/tests/core/test_macros.py).
+See examples of the coercion function in action in the test suite [here](https://github.com/SQLMesh/sqlmesh/blob/main/tests/core/test_macros.py).
#### Conclusion
diff --git a/docs/concepts/models/sql_models.md b/docs/concepts/models/sql_models.md
index 28bf0fbe78..217cd7a6a2 100644
--- a/docs/concepts/models/sql_models.md
+++ b/docs/concepts/models/sql_models.md
@@ -149,7 +149,8 @@ MODEL (
SELECT
@field_a,
- @{field_b} AS field_b
+ @{field_b} AS field_b,
+ @'prefix_@{field_a}_suffix' AS literal_example
FROM @customer.some_source
```
@@ -163,8 +164,9 @@ MODEL (
);
SELECT
- 'x',
- y AS field_b
+ x,
+ y AS field_b,
+ 'prefix_x_suffix' AS literal_example
FROM customer1.some_source
-- This uses the second variable mapping
@@ -174,14 +176,13 @@ MODEL (
);
SELECT
- 'z',
- w AS field_b
+ z,
+ w AS field_b,
+ 'prefix_z_suffix' AS literal_example
FROM customer2.some_source
```
-Note the use of curly brace syntax `@{field_b} AS field_b` in the model query above. It is used to tell SQLMesh that the rendered variable value should be treated as a SQL identifier instead of a string literal.
-
-You can see the different behavior in the first rendered model. `@field_a` is resolved to the string literal `'x'` (with single quotes) and `@{field_b}` is resolved to the identifier `y` (without quotes). Learn more about the curly brace syntax [here](../../concepts/macros/sqlmesh_macros.md#embedding-variables-in-strings).
+Both `@field_a` and `@{field_b}` resolve blueprint variable values as SQL identifiers. The curly brace syntax is useful when embedding a variable within a larger string where the variable boundary would otherwise be ambiguous (e.g. `@{customer}_suffix`). To produce a string literal with interpolated variables, use the `@'...@{var}...'` syntax as shown with `literal_example` above. Learn more about the curly brace syntax [here](../../concepts/macros/sqlmesh_macros.md#embedding-variables-in-strings).
Blueprint variable mappings can also be constructed dynamically, e.g., by using a macro: `blueprints @gen_blueprints()`. This is useful in cases where the `blueprints` list needs to be sourced from external sources, such as CSV files.
diff --git a/docs/development.md b/docs/development.md
index 662ad17d6c..ff8b250d87 100644
--- a/docs/development.md
+++ b/docs/development.md
@@ -1,6 +1,6 @@
# Contribute to development
-SQLMesh is licensed under [Apache 2.0](https://github.com/TobikoData/sqlmesh/blob/main/LICENSE). We encourage community contribution and would love for you to get involved. The following document outlines the process to contribute to SQLMesh.
+SQLMesh is licensed under [Apache 2.0](https://github.com/SQLMesh/sqlmesh/blob/main/LICENSE). We encourage community contribution and would love for you to get involved. The following document outlines the process to contribute to SQLMesh.
## Prerequisites
diff --git a/docs/examples/incremental_time_full_walkthrough.md b/docs/examples/incremental_time_full_walkthrough.md
index 4e1d577d2c..ffa9def911 100644
--- a/docs/examples/incremental_time_full_walkthrough.md
+++ b/docs/examples/incremental_time_full_walkthrough.md
@@ -689,7 +689,7 @@ In the terminal output, I can see the change displayed like before, but I see so
I leave the [effective date](../concepts/plans.md#effective-date) prompt blank because I do not want to reprocess historical data in `prod` - I only want to apply this new business logic going forward.
-However, I do want to preview the new business logic in my `dev` environment before pushing to `prod`. Because I have [configured SQLMesh to create previews](https://github.com/TobikoData/sqlmesh-demos/blob/e0e3899e173cf7b8447ae707402a9df59911d1c0/config.yaml#L42) for forward-only models in my `config.yaml` file, SQLMesh has created a temporary copy of the `prod` table in my `dev` environment, so I can test the new logic on historical data.
+However, I do want to preview the new business logic in my `dev` environment before pushing to `prod`. Because I have [configured SQLMesh to create previews](https://github.com/SQLMesh/sqlmesh-demos/blob/e0e3899e173cf7b8447ae707402a9df59911d1c0/config.yaml#L42) for forward-only models in my `config.yaml` file, SQLMesh has created a temporary copy of the `prod` table in my `dev` environment, so I can test the new logic on historical data.
I specify the beginning of the preview's historical data window as `2024-10-27` in the preview start date prompt, and I specify the end of the window as now by leaving the preview end date prompt blank.
diff --git a/docs/examples/overview.md b/docs/examples/overview.md
index a252b3f9c2..e7dbc1916d 100644
--- a/docs/examples/overview.md
+++ b/docs/examples/overview.md
@@ -27,16 +27,16 @@ Walkthroughs are easy to follow and provide lots of information in a self-contai
## Projects
-SQLMesh example projects are stored in the [sqlmesh-examples Github repository](https://github.com/TobikoData/sqlmesh-examples). The repository's front page includes additional information about how to download the files and set up the projects.
+SQLMesh example projects are stored in the [sqlmesh-examples Github repository](https://github.com/SQLMesh/sqlmesh-examples). The repository's front page includes additional information about how to download the files and set up the projects.
The two most comprehensive example projects use the SQLMesh `sushi` data, based on a fictional sushi restaurant. ("Tobiko" is the Japanese word for flying fish roe, commonly used in sushi.)
-The `sushi` data is described in an [overview notebook](https://github.com/TobikoData/sqlmesh-examples/blob/main/001_sushi/sushi-overview.ipynb) in the repository.
+The `sushi` data is described in an [overview notebook](https://github.com/SQLMesh/sqlmesh-examples/blob/main/001_sushi/sushi-overview.ipynb) in the repository.
The example repository include two versions of the `sushi` project, at different levels of complexity:
-- The [`simple` project](https://github.com/TobikoData/sqlmesh-examples/tree/main/001_sushi/1_simple) contains four `VIEW` and one `SEED` model
+- The [`simple` project](https://github.com/SQLMesh/sqlmesh-examples/tree/main/001_sushi/1_simple) contains four `VIEW` and one `SEED` model
- The `VIEW` model kind refreshes every run, making it easy to reason about SQLMesh's behavior
-- The [`moderate` project](https://github.com/TobikoData/sqlmesh-examples/tree/main/001_sushi/2_moderate) contains five `INCREMENTAL_BY_TIME_RANGE`, one `FULL`, one `VIEW`, and one `SEED` model
+- The [`moderate` project](https://github.com/SQLMesh/sqlmesh-examples/tree/main/001_sushi/2_moderate) contains five `INCREMENTAL_BY_TIME_RANGE`, one `FULL`, one `VIEW`, and one `SEED` model
- The incremental models allow you to observe how and when new data is transformed by SQLMesh
- Some models, like `customer_revenue_lifetime`, demonstrate more advanced incremental queries like customer lifetime value calculation
diff --git a/docs/guides/custom_materializations.md b/docs/guides/custom_materializations.md
index 58eb64026d..905a3d017e 100644
--- a/docs/guides/custom_materializations.md
+++ b/docs/guides/custom_materializations.md
@@ -24,13 +24,13 @@ A custom materialization must:
- Be written in Python code
- Be a Python class that inherits the SQLMesh `CustomMaterialization` base class
-- Use or override the `insert` method from the SQLMesh [`MaterializableStrategy`](https://github.com/TobikoData/sqlmesh/blob/034476e7f64d261860fd630c3ac56d8a9c9f3e3a/sqlmesh/core/snapshot/evaluator.py#L1146) class/subclasses
+- Use or override the `insert` method from the SQLMesh [`MaterializableStrategy`](https://github.com/SQLMesh/sqlmesh/blob/034476e7f64d261860fd630c3ac56d8a9c9f3e3a/sqlmesh/core/snapshot/evaluator.py#L1146) class/subclasses
- Be loaded or imported by SQLMesh at runtime
A custom materialization may:
-- Use or override methods from the SQLMesh [`MaterializableStrategy`](https://github.com/TobikoData/sqlmesh/blob/034476e7f64d261860fd630c3ac56d8a9c9f3e3a/sqlmesh/core/snapshot/evaluator.py#L1146) class/subclasses
-- Use or override methods from the SQLMesh [`EngineAdapter`](https://github.com/TobikoData/sqlmesh/blob/034476e7f64d261860fd630c3ac56d8a9c9f3e3a/sqlmesh/core/engine_adapter/base.py#L67) class/subclasses
+- Use or override methods from the SQLMesh [`MaterializableStrategy`](https://github.com/SQLMesh/sqlmesh/blob/034476e7f64d261860fd630c3ac56d8a9c9f3e3a/sqlmesh/core/snapshot/evaluator.py#L1146) class/subclasses
+- Use or override methods from the SQLMesh [`EngineAdapter`](https://github.com/SQLMesh/sqlmesh/blob/034476e7f64d261860fd630c3ac56d8a9c9f3e3a/sqlmesh/core/engine_adapter/base.py#L67) class/subclasses
- Execute arbitrary SQL code and fetch results with the engine adapter `execute` and related methods
A custom materialization may perform arbitrary Python processing with Pandas or other libraries, but in most cases that logic should reside in a [Python model](../concepts/models/python_models.md) instead of the materialization.
@@ -157,7 +157,7 @@ class CustomFullMaterialization(CustomMaterialization):
) -> None:
config_value = model.custom_materialization_properties["config_key"]
# Proceed with implementing the insertion logic.
- # Example existing materialization for look and feel: https://github.com/TobikoData/sqlmesh/blob/main/sqlmesh/core/snapshot/evaluator.py
+ # Example existing materialization for look and feel: https://github.com/SQLMesh/sqlmesh/blob/main/sqlmesh/core/snapshot/evaluator.py
```
## Extending `CustomKind`
@@ -292,4 +292,4 @@ setup(
)
```
-Refer to the SQLMesh Github [custom_materializations](https://github.com/TobikoData/sqlmesh/tree/main/examples/custom_materializations) example for more details on Python packaging.
+Refer to the SQLMesh Github [custom_materializations](https://github.com/SQLMesh/sqlmesh/tree/main/examples/custom_materializations) example for more details on Python packaging.
diff --git a/docs/guides/linter.md b/docs/guides/linter.md
index 22cc5077b8..6cdac167ec 100644
--- a/docs/guides/linter.md
+++ b/docs/guides/linter.md
@@ -16,7 +16,7 @@ Some rules validate that a pattern is *not* present, such as not allowing `SELEC
Rules are defined in Python. Each rule is an individual Python class that inherits from SQLMesh's `Rule` base class and defines the logic for validating a pattern.
-We display a portion of the `Rule` base class's code below ([full source code](https://github.com/TobikoData/sqlmesh/blob/main/sqlmesh/core/linter/rule.py)). Its methods and properties illustrate the most important components of the subclassed rules you define.
+We display a portion of the `Rule` base class's code below ([full source code](https://github.com/SQLMesh/sqlmesh/blob/main/sqlmesh/core/linter/rule.py)). Its methods and properties illustrate the most important components of the subclassed rules you define.
Each rule class you create has four vital components:
diff --git a/docs/guides/model_selection.md b/docs/guides/model_selection.md
index 9cc0a4358a..79fd17a18c 100644
--- a/docs/guides/model_selection.md
+++ b/docs/guides/model_selection.md
@@ -78,7 +78,7 @@ NOTE: the `--backfill-model` argument can only be used in development environmen
## Examples
-We now demonstrate the use of `--select-model` and `--backfill-model` with the SQLMesh `sushi` example project, available in the `examples/sushi` directory of the [SQLMesh Github repository](https://github.com/TobikoData/sqlmesh).
+We now demonstrate the use of `--select-model` and `--backfill-model` with the SQLMesh `sushi` example project, available in the `examples/sushi` directory of the [SQLMesh Github repository](https://github.com/SQLMesh/sqlmesh).
### sushi
@@ -242,8 +242,9 @@ Models:
#### Select with git changes
The git-based selector allows you to select models whose files have changed compared to a target branch (default: main). This includes:
+
- Untracked files (new files not in git)
-- Uncommitted changes in working directory
+- Uncommitted changes in working directory (both staged and unstaged)
- Committed changes different from the target branch
For example:
diff --git a/docs/guides/multi_repo.md b/docs/guides/multi_repo.md
index bf34c7d21a..4dae4de57e 100644
--- a/docs/guides/multi_repo.md
+++ b/docs/guides/multi_repo.md
@@ -5,7 +5,7 @@ SQLMesh provides native support for multiple repos and makes it easy to maintain
If you are wanting to separate your systems/data and provide isolation, checkout the [isolated systems guide](https://sqlmesh.readthedocs.io/en/stable/guides/isolated_systems/?h=isolated).
## Bootstrapping multiple projects
-Setting up SQLMesh with multiple repos is quite simple. Copy the contents of this example [multi-repo project](https://github.com/TobikoData/sqlmesh/tree/main/examples/multi).
+Setting up SQLMesh with multiple repos is quite simple. Copy the contents of this example [multi-repo project](https://github.com/SQLMesh/sqlmesh/tree/main/examples/multi).
To bootstrap the project, you can point SQLMesh at both projects.
@@ -196,7 +196,7 @@ $ sqlmesh -p examples/multi/repo_1 migrate
SQLMesh also supports multiple repos for dbt projects, allowing it to correctly detect changes and orchestrate backfills even when changes span multiple dbt projects.
-You can watch a [quick demo](https://www.loom.com/share/69c083428bb348da8911beb2cd4d30b2) of this setup or experiment with the [multi-repo dbt example](https://github.com/TobikoData/sqlmesh/tree/main/examples/multi_dbt) yourself.
+You can watch a [quick demo](https://www.loom.com/share/69c083428bb348da8911beb2cd4d30b2) of this setup or experiment with the [multi-repo dbt example](https://github.com/SQLMesh/sqlmesh/tree/main/examples/multi_dbt) yourself.
## Multi-repo mixed projects
@@ -212,4 +212,4 @@ $ sqlmesh -p examples/multi_hybrid/dbt_repo -p examples/multi_hybrid/sqlmesh_rep
SQLMesh will automatically detect dependencies and lineage across both SQLMesh and dbt projects, even when models are sourcing from different project types.
-For an example of this setup, refer to the [mixed SQLMesh and dbt example](https://github.com/TobikoData/sqlmesh/tree/main/examples/multi_hybrid).
+For an example of this setup, refer to the [mixed SQLMesh and dbt example](https://github.com/SQLMesh/sqlmesh/tree/main/examples/multi_hybrid).
diff --git a/docs/guides/notifications.md b/docs/guides/notifications.md
index 03405b8252..749a71c842 100644
--- a/docs/guides/notifications.md
+++ b/docs/guides/notifications.md
@@ -256,7 +256,7 @@ This example shows an email notification target, where `sushi@example.com` email
In Python configuration files, new notification targets can be configured to send custom messages.
-To customize a notification, create a new notification target class as a subclass of one of the three target classes described above (`SlackWebhookNotificationTarget`, `SlackApiNotificationTarget`, or `BasicSMTPNotificationTarget`). See the definitions of these classes on Github [here](https://github.com/TobikoData/sqlmesh/blob/main/sqlmesh/core/notification_target.py).
+To customize a notification, create a new notification target class as a subclass of one of the three target classes described above (`SlackWebhookNotificationTarget`, `SlackApiNotificationTarget`, or `BasicSMTPNotificationTarget`). See the definitions of these classes on Github [here](https://github.com/SQLMesh/sqlmesh/blob/main/sqlmesh/core/notification_target.py).
Each of those notification target classes is a subclass of `BaseNotificationTarget`, which contains a `notify` function corresponding to each event type. This table lists the notification functions, along with the contextual information available to them at calling time (e.g., the environment name for start/end events):
diff --git a/docs/guides/vscode.md b/docs/guides/vscode.md
index 151e630f27..5ef3cd71ce 100644
--- a/docs/guides/vscode.md
+++ b/docs/guides/vscode.md
@@ -6,7 +6,7 @@
The SQLMesh Visual Studio Code extension is in preview and undergoing active development. You may encounter bugs or API incompatibilities with the SQLMesh version you are running.
- We encourage you to try the extension and [create Github issues](https://github.com/tobikodata/sqlmesh/issues) for any problems you encounter.
+ We encourage you to try the extension and [create Github issues](https://github.com/SQLMesh/sqlmesh/issues) for any problems you encounter.
In this guide, you'll set up the SQLMesh extension in the Visual Studio Code IDE software (which we refer to as "VSCode").
@@ -187,7 +187,7 @@ The most common problem is the extension not using the correct Python interprete
Follow the [setup process described above](#vscode-python-interpreter) to ensure that the extension is using the correct Python interpreter.
-If you have checked the VSCode `sqlmesh` output channel and the extension is still not using the correct Python interpreter, please raise an issue [here](https://github.com/tobikodata/sqlmesh/issues).
+If you have checked the VSCode `sqlmesh` output channel and the extension is still not using the correct Python interpreter, please raise an issue [here](https://github.com/SQLMesh/sqlmesh/issues).
### Missing Python dependencies
@@ -205,4 +205,4 @@ If you are using Tobiko Cloud, make sure `lsp` is included in the list of extras
While the SQLMesh VSCode extension is in preview and the APIs to the underlying SQLMesh version are not stable, we do not guarantee compatibility between the extension and the SQLMesh version you are using.
-If you encounter a problem, please raise an issue [here](https://github.com/tobikodata/sqlmesh/issues).
\ No newline at end of file
+If you encounter a problem, please raise an issue [here](https://github.com/SQLMesh/sqlmesh/issues).
\ No newline at end of file
diff --git a/docs/index.md b/docs/index.md
index 3e9330f83f..83c1b0a431 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -1,7 +1,7 @@
#
-
+
SQLMesh is a next-generation data transformation framework designed to ship data quickly, efficiently, and without error. Data teams can efficiently run and deploy data transformations written in SQL or Python with visibility and control at any size.
@@ -9,11 +9,11 @@ SQLMesh is a next-generation data transformation framework designed to ship data
It is more than just a [dbt alternative](https://tobikodata.com/reduce_costs_with_cron_and_partitions.html).
-
+
## Core Features
-
+
> Get instant SQL impact analysis of your changes, whether in the CLI or in [SQLMesh Plan Mode](https://sqlmesh.readthedocs.io/en/stable/guides/ui/?h=modes#working-with-an-ide)
@@ -121,7 +121,7 @@ It is more than just a [dbt alternative](https://tobikodata.com/reduce_costs_wit
??? tip "Level Up Your SQL"
Write SQL in any dialect and SQLMesh will transpile it to your target SQL dialect on the fly before sending it to the warehouse.
-
+
* Debug transformation errors *before* you run them in your warehouse in [10+ different SQL dialects](https://sqlmesh.readthedocs.io/en/stable/integrations/overview/#execution-engines)
* Definitions using [simply SQL](https://sqlmesh.readthedocs.io/en/stable/concepts/models/sql_models/#sql-based-definition) (no need for redundant and confusing `Jinja` + `YAML`)
@@ -153,7 +153,7 @@ Follow this [example](https://sqlmesh.readthedocs.io/en/stable/examples/incremen
Together, we want to build data transformation without the waste. Connect with us in the following ways:
* Join the [Tobiko Slack Community](https://tobikodata.com/slack) to ask questions, or just to say hi!
-* File an issue on our [GitHub](https://github.com/TobikoData/sqlmesh/issues/new)
+* File an issue on our [GitHub](https://github.com/SQLMesh/sqlmesh/issues/new)
* Send us an email at [hello@tobikodata.com](mailto:hello@tobikodata.com) with your questions or feedback
* Read our [blog](https://tobikodata.com/blog)
diff --git a/docs/integrations/dbt.md b/docs/integrations/dbt.md
index 7cbef5b8fa..5854236aa2 100644
--- a/docs/integrations/dbt.md
+++ b/docs/integrations/dbt.md
@@ -358,4 +358,4 @@ The dbt jinja methods that are not currently supported are:
## Missing something you need?
-Submit an [issue](https://github.com/TobikoData/sqlmesh/issues), and we'll look into it!
+Submit an [issue](https://github.com/SQLMesh/sqlmesh/issues), and we'll look into it!
diff --git a/docs/integrations/dlt.md b/docs/integrations/dlt.md
index a53dc184ea..7125510de9 100644
--- a/docs/integrations/dlt.md
+++ b/docs/integrations/dlt.md
@@ -70,7 +70,7 @@ SQLMesh will retrieve the data warehouse connection credentials from your dlt pr
### Example
-Generating a SQLMesh project dlt is quite simple. In this example, we'll use the example `sushi_pipeline.py` from the [sushi-dlt project](https://github.com/TobikoData/sqlmesh/tree/main/examples/sushi_dlt).
+Generating a SQLMesh project dlt is quite simple. In this example, we'll use the example `sushi_pipeline.py` from the [sushi-dlt project](https://github.com/SQLMesh/sqlmesh/tree/main/examples/sushi_dlt).
First, run the pipeline within the project directory:
diff --git a/docs/integrations/engines/bigquery.md b/docs/integrations/engines/bigquery.md
index a454996ecd..b93d6837ed 100644
--- a/docs/integrations/engines/bigquery.md
+++ b/docs/integrations/engines/bigquery.md
@@ -193,6 +193,23 @@ If the `impersonated_service_account` argument is set, SQLMesh will:
The user account must have [sufficient permissions to impersonate the service account](https://cloud.google.com/docs/authentication/use-service-account-impersonation).
+## Query Label
+
+BigQuery supports a `query_label` session variable which is attached to query jobs and can be used for auditing / attribution.
+
+SQLMesh supports setting it via `session_properties.query_label` on a model, as an array (or tuple) of key/value tuples.
+
+Example:
+```sql
+MODEL (
+ name my_project.my_dataset.my_model,
+ dialect 'bigquery',
+ session_properties (
+ query_label = [('team', 'data_platform'), ('env', 'prod')]
+ )
+);
+```
+
## Permissions Required
With any of the above connection methods, ensure these BigQuery permissions are enabled to allow SQLMesh to work correctly.
diff --git a/docs/integrations/engines/trino.md b/docs/integrations/engines/trino.md
index ec1139e20d..db732f0cc1 100644
--- a/docs/integrations/engines/trino.md
+++ b/docs/integrations/engines/trino.md
@@ -90,6 +90,7 @@ hive.metastore.glue.default-warehouse-dir=s3://my-bucket/
| `http_scheme` | The HTTP scheme to use when connecting to your cluster. By default, it's `https` and can only be `http` for no-auth or basic auth. | string | N |
| `port` | The port to connect to your cluster. By default, it's `443` for `https` scheme and `80` for `http` | int | N |
| `roles` | Mapping of catalog name to a role | dict | N |
+| `source` | Value to send as Trino's `source` field for query attribution / auditing. Default: `sqlmesh`. | string | N |
| `http_headers` | Additional HTTP headers to send with each request. | dict | N |
| `session_properties` | Trino session properties. Run `SHOW SESSION` to see all options. | dict | N |
| `retries` | Number of retries to attempt when a request fails. Default: `3` | int | N |
diff --git a/docs/integrations/github.md b/docs/integrations/github.md
index a11d90d044..07903fce56 100644
--- a/docs/integrations/github.md
+++ b/docs/integrations/github.md
@@ -286,21 +286,22 @@ Below is an example of how to define the default config for the bot in either YA
### Configuration Properties
-| Option | Description | Type | Required |
-|---------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:|
-| `invalidate_environment_after_deploy` | Indicates if the PR environment created should be automatically invalidated after changes are deployed. Invalidated environments are cleaned up automatically by the Janitor. Default: `True` | bool | N |
-| `merge_method` | The merge method to use when automatically merging a PR after deploying to prod. Defaults to `None` meaning automatic merge is not done. Options: `merge`, `squash`, `rebase` | string | N |
-| `enable_deploy_command` | Indicates if the `/deploy` command should be enabled in order to allowed synchronized deploys to production. Default: `False` | bool | N |
-| `command_namespace` | The namespace to use for SQLMesh commands. For example if you provide `#SQLMesh` as a value then commands will be expected in the format of `#SQLMesh/`. Default: `None` meaning no namespace is used. | string | N |
-| `auto_categorize_changes` | Auto categorization behavior to use for the bot. If not provided then the project-wide categorization behavior is used. See [Auto-categorize model changes](https://sqlmesh.readthedocs.io/en/stable/guides/configuration/#auto-categorize-model-changes) for details. | dict | N |
-| `default_pr_start` | Default start when creating PR environment plans. If running in a mode where the bot automatically backfills models (based on `auto_categorize_changes` behavior) then this can be used to limit the amount of data backfilled. Defaults to `None` meaning the start date is set to the earliest model's start or to 1 day ago if [data previews](../concepts/plans.md#data-preview) need to be computed.| str | N |
-| `pr_min_intervals` | Intended for use when `default_pr_start` is set to a relative time, eg `1 week ago`. This ensures that at least this many intervals across every model are included for backfill in the PR environment. Without this, models with an interval unit wider than `default_pr_start` (such as `@monthly` models if `default_pr_start` was set to `1 week ago`) will be excluded from backfill entirely. | int | N |
-| `skip_pr_backfill` | Indicates if the bot should skip backfilling models in the PR environment. Default: `True` | bool | N |
-| `pr_include_unmodified` | Indicates whether to include unmodified models in the PR environment. Default to the project's config value (which defaults to `False`) | bool | N |
-| `run_on_deploy_to_prod` | Indicates whether to run latest intervals when deploying to prod. If set to false, the deployment will backfill only the changed models up to the existing latest interval in production, ignoring any missing intervals beyond this point. Default: `False` | bool | N |
-| `pr_environment_name` | The name of the PR environment to create for which a PR number will be appended to. Defaults to the repo name if not provided. Note: The name will be normalized to alphanumeric + underscore and lowercase. | str | N |
-| `prod_branch_name` | The name of the git branch associated with production. Ex: `prod`. Default: `main` or `master` is considered prod | str | N |
-| `forward_only_branch_suffix` | If the git branch has this suffix, trigger a [forward-only](../concepts/plans.md#forward-only-plans) plan instead of a normal plan. Default: `-forward-only` | str | N |
+| Option | Description | Type | Required |
+|---------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:|
+| `invalidate_environment_after_deploy` | Indicates if the PR environment created should be automatically invalidated after changes are deployed. Invalidated environments are cleaned up automatically by the Janitor. Default: `True` | bool | N |
+| `merge_method` | The merge method to use when automatically merging a PR after deploying to prod. Defaults to `None` meaning automatic merge is not done. Options: `merge`, `squash`, `rebase` | string | N |
+| `enable_deploy_command` | Indicates if the `/deploy` command should be enabled in order to allowed synchronized deploys to production. Default: `False` | bool | N |
+| `command_namespace` | The namespace to use for SQLMesh commands. For example if you provide `#SQLMesh` as a value then commands will be expected in the format of `#SQLMesh/`. Default: `None` meaning no namespace is used. | string | N |
+| `auto_categorize_changes` | Auto categorization behavior to use for the bot. If not provided then the project-wide categorization behavior is used. See [Auto-categorize model changes](https://sqlmesh.readthedocs.io/en/stable/guides/configuration/#auto-categorize-model-changes) for details. | dict | N |
+| `default_pr_start` | Default start when creating PR environment plans. If running in a mode where the bot automatically backfills models (based on `auto_categorize_changes` behavior) then this can be used to limit the amount of data backfilled. Defaults to `None` meaning the start date is set to the earliest model's start or to 1 day ago if [data previews](../concepts/plans.md#data-preview) need to be computed. | str | N |
+| `pr_min_intervals` | Intended for use when `default_pr_start` is set to a relative time, eg `1 week ago`. This ensures that at least this many intervals across every model are included for backfill in the PR environment. Without this, models with an interval unit wider than `default_pr_start` (such as `@monthly` models if `default_pr_start` was set to `1 week ago`) will be excluded from backfill entirely. | int | N |
+| `skip_pr_backfill` | Indicates if the bot should skip backfilling models in the PR environment. Default: `True` | bool | N |
+| `pr_include_unmodified` | Indicates whether to include unmodified models in the PR environment. Default to the project's config value (which defaults to `False`) | bool | N |
+| `run_on_deploy_to_prod` | Indicates whether to run latest intervals when deploying to prod. If set to false, the deployment will backfill only the changed models up to the existing latest interval in production, ignoring any missing intervals beyond this point. Default: `False` | bool | N |
+| `pr_environment_name` | The name of the PR environment to create for which a PR number will be appended to. Defaults to the repo name if not provided. Note: The name will be normalized to alphanumeric + underscore and lowercase. | str | N |
+| `prod_branch_name` | The name of the git branch associated with production. Ex: `prod`. Default: `main` or `master` is considered prod | str | N |
+| `forward_only_branch_suffix` | If the git branch has this suffix, trigger a [forward-only](../concepts/plans.md#forward-only-plans) plan instead of a normal plan. Default: `-forward-only` | str | N |
+| `check_if_blocked_on_deploy_to_prod` | The bot normally checks if a PR is blocked from merging before deploying to production. Setting this to `False` will skip that check. Default: `True` | bool | N |
Example with all properties defined:
@@ -363,7 +364,7 @@ These are the possible outputs (based on how the bot is configured) that are cre
* `prod_plan_preview`
* `prod_environment_synced`
-[There are many possible conclusions](https://github.com/TobikoData/sqlmesh/blob/main/sqlmesh/integrations/github/cicd/controller.py#L96-L102) so the best use case for this is likely to check for `success` conclusion in order to potentially run follow up steps.
+[There are many possible conclusions](https://github.com/SQLMesh/sqlmesh/blob/main/sqlmesh/integrations/github/cicd/controller.py#L96-L102) so the best use case for this is likely to check for `success` conclusion in order to potentially run follow up steps.
Note that in error cases conclusions may not be set and therefore you will get an empty string.
Example of running a step after pr environment has been synced:
diff --git a/docs/quickstart/cli.md b/docs/quickstart/cli.md
index 7b77b2af1e..a592847470 100644
--- a/docs/quickstart/cli.md
+++ b/docs/quickstart/cli.md
@@ -160,7 +160,7 @@ https://sqlmesh.readthedocs.io/en/stable/quickstart/cli/
Need help?
- Docs: https://sqlmesh.readthedocs.io
- Slack: https://www.tobikodata.com/slack
-- GitHub: https://github.com/TobikoData/sqlmesh/issues
+- GitHub: https://github.com/SQLMesh/sqlmesh/issues
```
??? info "Learn more about the project's configuration: `config.yaml`"
diff --git a/docs/reference/python.md b/docs/reference/python.md
index 14e0da84c8..1c4c9191ff 100644
--- a/docs/reference/python.md
+++ b/docs/reference/python.md
@@ -4,6 +4,6 @@ SQLMesh is built in Python, and its complete Python API reference is located [he
The Python API reference is comprehensive and includes the internal components of SQLMesh. Those components are likely only of interest if you want to modify SQLMesh itself.
-If you want to use SQLMesh via its Python API, the best approach is to study how the SQLMesh [CLI](./cli.md) calls it behind the scenes. The CLI implementation code shows exactly which Python methods are called for each CLI command and can be viewed [on Github](https://github.com/TobikoData/sqlmesh/blob/main/sqlmesh/cli/main.py). For example, the Python code executed by the `plan` command is located [here](https://github.com/TobikoData/sqlmesh/blob/15c8788100fa1cfb8b0cc1879ccd1ad21dc3e679/sqlmesh/cli/main.py#L302).
+If you want to use SQLMesh via its Python API, the best approach is to study how the SQLMesh [CLI](./cli.md) calls it behind the scenes. The CLI implementation code shows exactly which Python methods are called for each CLI command and can be viewed [on Github](https://github.com/SQLMesh/sqlmesh/blob/main/sqlmesh/cli/main.py). For example, the Python code executed by the `plan` command is located [here](https://github.com/SQLMesh/sqlmesh/blob/15c8788100fa1cfb8b0cc1879ccd1ad21dc3e679/sqlmesh/cli/main.py#L302).
Almost all the relevant Python methods are in the [SQLMesh `Context` class](https://sqlmesh.readthedocs.io/en/stable/_readthedocs/html/sqlmesh/core/context.html#Context).
diff --git a/examples/sushi/models/customers.sql b/examples/sushi/models/customers.sql
index f91f1166e8..d2bda09ed3 100644
--- a/examples/sushi/models/customers.sql
+++ b/examples/sushi/models/customers.sql
@@ -42,4 +42,4 @@ LEFT JOIN (
ON o.customer_id = m.customer_id
LEFT JOIN raw.demographics AS d
ON o.customer_id = d.customer_id
-WHERE sushi.orders.customer_id > 0
\ No newline at end of file
+WHERE o.customer_id > 0
\ No newline at end of file
diff --git a/mkdocs.yml b/mkdocs.yml
index 47ddca54e9..86761de9d7 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -1,6 +1,6 @@
site_name: SQLMesh
-repo_url: https://github.com/TobikoData/sqlmesh
-repo_name: TobikoData/sqlmesh
+repo_url: https://github.com/SQLMesh/sqlmesh
+repo_name: SQLMesh/sqlmesh
nav:
- "Overview": index.md
- Get started:
@@ -202,7 +202,7 @@ extra:
- icon: fontawesome/solid/paper-plane
link: mailto:hello@tobikodata.com
- icon: fontawesome/brands/github
- link: https://github.com/TobikoData/sqlmesh/issues/new
+ link: https://github.com/SQLMesh/sqlmesh/issues/new
analytics:
provider: google
property: G-JXQ1R227VS
diff --git a/pdoc/cli.py b/pdoc/cli.py
index 5833c59207..9301ae0444 100755
--- a/pdoc/cli.py
+++ b/pdoc/cli.py
@@ -29,7 +29,7 @@ def mocked_import(*args, **kwargs):
opts.logo_link = "https://tobikodata.com"
opts.footer_text = "Copyright Tobiko Data Inc. 2022"
opts.template_directory = Path(__file__).parent.joinpath("templates").absolute()
- opts.edit_url = ["sqlmesh=https://github.com/TobikoData/sqlmesh/tree/main/sqlmesh/"]
+ opts.edit_url = ["sqlmesh=https://github.com/SQLMesh/sqlmesh/tree/main/sqlmesh/"]
with mock.patch("pdoc.__main__.parser", **{"parse_args.return_value": opts}):
cli()
diff --git a/posts/virtual_data_environments.md b/posts/virtual_data_environments.md
index dc3b2cb46e..5cde9dba51 100644
--- a/posts/virtual_data_environments.md
+++ b/posts/virtual_data_environments.md
@@ -8,7 +8,7 @@ In this post, I'm going to explain why existing approaches to managing developme
I'll introduce [Virtual Data Environments](#virtual-data-environments-1) - a novel approach that provides low-cost, efficient, scalable, and safe data environments that are easy to use and manage. They significantly boost the productivity of anyone who has to create or maintain data pipelines.
-Finally, Iβm going to explain how **Virtual Data Environments** are implemented in [SQLMesh](https://github.com/TobikoData/sqlmesh) and share details on each core component involved:
+Finally, Iβm going to explain how **Virtual Data Environments** are implemented in [SQLMesh](https://github.com/SQLMesh/sqlmesh) and share details on each core component involved:
- Data [fingerprinting](#fingerprinting)
- [Automatic change categorization](#automatic-change-categorization)
- Decoupling of [physical](#physical-layer) and [virtual](#virtual-layer) layers
@@ -156,6 +156,6 @@ With **Virtual Data Environments**, SQLMesh is able to provide fully **isolated*
- Rolling back a change happens almost instantaneously since no data movement is involved and only views that are part of the **virtual layer** get updated.
- Deploying changes to production is a **virtual layer** operation, which ensures that results observed during development are exactly the same in production and that data and code are always in sync.
-To streamline deploying changes to production, our team is about to release the SQLMesh [CI/CD bot](https://github.com/TobikoData/sqlmesh/blob/main/docs/integrations/github.md), which will help automate this process.
+To streamline deploying changes to production, our team is about to release the SQLMesh [CI/CD bot](https://github.com/SQLMesh/sqlmesh/blob/main/docs/integrations/github.md), which will help automate this process.
Don't miss out - join our [Slack channel](https://tobikodata.com/slack) and stay tuned!
diff --git a/pyproject.toml b/pyproject.toml
index 97c190a290..56d66ecff5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,7 +3,7 @@ name = "sqlmesh"
dynamic = ["version"]
description = "Next-generation data transformation framework"
readme = "README.md"
-authors = [{ name = "TobikoData Inc.", email = "engineering@tobikodata.com" }]
+authors = [{ name = "SQLMesh Contributors" }]
license = { file = "LICENSE" }
requires-python = ">= 3.9"
dependencies = [
@@ -18,13 +18,13 @@ dependencies = [
"ipywidgets",
"jinja2",
"packaging",
- "pandas",
+ "pandas<3.0.0",
"pydantic>=2.0.0",
"python-dotenv",
"requests",
"rich[jupyter]",
"ruamel.yaml",
- "sqlglot[rs]~=27.28.0",
+ "sqlglot~=30.0.1",
"tenacity",
"time-machine",
"json-stream"
@@ -56,7 +56,7 @@ dev = [
"agate",
"beautifulsoup4",
"clickhouse-connect",
- "cryptography<46.0.0",
+ "cryptography",
"databricks-sql-connector",
"dbt-bigquery",
"dbt-core",
@@ -111,7 +111,7 @@ duckdb = []
fabric = ["pyodbc>=5.0.0"]
gcppostgres = ["cloud-sql-python-connector[pg8000]>=1.8.0"]
github = ["PyGithub>=2.6.0"]
-motherduck = ["duckdb>=1.2.0"]
+motherduck = ["duckdb>=1.3.2"]
mssql = ["pymssql"]
mssql-odbc = ["pyodbc>=5.0.0"]
mysql = ["pymysql"]
@@ -120,13 +120,13 @@ postgres = ["psycopg2"]
redshift = ["redshift_connector"]
slack = ["slack_sdk"]
snowflake = [
- "cryptography<46.0.0",
+ "cryptography",
"snowflake-connector-python[pandas,secure-local-storage]",
"snowflake-snowpark-python",
]
trino = ["trino"]
web = [
- "fastapi==0.115.5",
+ "fastapi==0.120.1",
"watchfiles>=0.19.0",
"uvicorn[standard]==0.22.0",
"sse-starlette>=0.2.2",
@@ -134,7 +134,7 @@ web = [
]
lsp = [
# Duplicate of web
- "fastapi==0.115.5",
+ "fastapi==0.120.1",
"watchfiles>=0.19.0",
# "uvicorn[standard]==0.22.0",
"sse-starlette>=0.2.2",
@@ -154,8 +154,8 @@ sqlmesh_lsp = "sqlmesh.lsp.main:main"
[project.urls]
Homepage = "https://sqlmesh.com/"
Documentation = "https://sqlmesh.readthedocs.io/en/stable/"
-Repository = "https://github.com/TobikoData/sqlmesh"
-Issues = "https://github.com/TobikoData/sqlmesh/issues"
+Repository = "https://github.com/SQLMesh/sqlmesh"
+Issues = "https://github.com/SQLMesh/sqlmesh/issues"
[build-system]
requires = ["setuptools >= 61.0", "setuptools_scm"]
diff --git a/sqlmesh-technical-charter.pdf b/sqlmesh-technical-charter.pdf
new file mode 100644
index 0000000000..107f015050
Binary files /dev/null and b/sqlmesh-technical-charter.pdf differ
diff --git a/sqlmesh/cli/main.py b/sqlmesh/cli/main.py
index 2f18c0a4b7..ec5acbea59 100644
--- a/sqlmesh/cli/main.py
+++ b/sqlmesh/cli/main.py
@@ -246,7 +246,7 @@ def init(
Need help?
β’ Docs: https://sqlmesh.readthedocs.io
β’ Slack: https://www.tobikodata.com/slack
-β’ GitHub: https://github.com/TobikoData/sqlmesh/issues
+β’ GitHub: https://github.com/SQLMesh/sqlmesh/issues
""")
@@ -535,7 +535,7 @@ def diff(ctx: click.Context, environment: t.Optional[str] = None) -> None:
)
@click.option(
"--min-intervals",
- default=0,
+ default=None,
help="For every model, ensure at least this many intervals are covered by a missing intervals check regardless of the plan start date",
)
@opt.verbose
diff --git a/sqlmesh/core/_typing.py b/sqlmesh/core/_typing.py
index 8e28312c1a..2bc69e901b 100644
--- a/sqlmesh/core/_typing.py
+++ b/sqlmesh/core/_typing.py
@@ -8,8 +8,8 @@
if t.TYPE_CHECKING:
TableName = t.Union[str, exp.Table]
SchemaName = t.Union[str, exp.Table]
- SessionProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]]
- CustomMaterializationProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]]
+ SessionProperties = t.Dict[str, t.Union[exp.Expr, str, int, float, bool]]
+ CustomMaterializationProperties = t.Dict[str, t.Union[exp.Expr, str, int, float, bool]]
if sys.version_info >= (3, 11):
diff --git a/sqlmesh/core/audit/definition.py b/sqlmesh/core/audit/definition.py
index 9f470872fe..4c90151ee4 100644
--- a/sqlmesh/core/audit/definition.py
+++ b/sqlmesh/core/audit/definition.py
@@ -67,7 +67,7 @@ class AuditMixin(AuditCommonMetaMixin):
"""
query_: ParsableSql
- defaults: t.Dict[str, exp.Expression]
+ defaults: t.Dict[str, exp.Expr]
expressions_: t.Optional[t.List[ParsableSql]]
jinja_macros: JinjaMacroRegistry
formatting: t.Optional[bool]
@@ -77,10 +77,10 @@ def query(self) -> t.Union[exp.Query, d.JinjaQuery]:
return t.cast(t.Union[exp.Query, d.JinjaQuery], self.query_.parse(self.dialect))
@property
- def expressions(self) -> t.List[exp.Expression]:
+ def expressions(self) -> t.List[exp.Expr]:
if not self.expressions_:
return []
- result = []
+ result: t.List[exp.Expr] = []
for e in self.expressions_:
parsed = e.parse(self.dialect)
if not isinstance(parsed, exp.Semicolon):
@@ -95,7 +95,7 @@ def macro_definitions(self) -> t.List[d.MacroDef]:
@field_validator("name", "dialect", mode="before", check_fields=False)
def audit_string_validator(cls: t.Type, v: t.Any) -> t.Optional[str]:
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
return v.name.lower()
return str(v).lower() if v is not None else None
@@ -111,9 +111,7 @@ def audit_map_validator(cls: t.Type, v: t.Any, values: t.Any) -> t.Dict[str, t.A
if isinstance(v, dict):
dialect = get_dialect(values)
return {
- key: value
- if isinstance(value, exp.Expression)
- else d.parse_one(str(value), dialect=dialect)
+ key: value if isinstance(value, exp.Expr) else d.parse_one(str(value), dialect=dialect)
for key, value in v.items()
}
raise_config_error("Defaults must be a tuple of exp.EQ or a dict", error_type=AuditConfigError)
@@ -133,7 +131,7 @@ class ModelAudit(PydanticModel, AuditMixin, DbtInfoMixin, frozen=True):
blocking: bool = True
standalone: t.Literal[False] = False
query_: ParsableSql = Field(alias="query")
- defaults: t.Dict[str, exp.Expression] = {}
+ defaults: t.Dict[str, exp.Expr] = {}
expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions")
jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry()
formatting: t.Optional[bool] = Field(default=None, exclude=True)
@@ -169,7 +167,7 @@ class StandaloneAudit(_Node, AuditMixin):
blocking: bool = False
standalone: t.Literal[True] = True
query_: ParsableSql = Field(alias="query")
- defaults: t.Dict[str, exp.Expression] = {}
+ defaults: t.Dict[str, exp.Expr] = {}
expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions")
jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry()
default_catalog: t.Optional[str] = None
@@ -323,13 +321,13 @@ def render_definition(
include_python: bool = True,
include_defaults: bool = False,
render_query: bool = False,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
"""Returns the original list of sql expressions comprising the model definition.
Args:
include_python: Whether or not to include Python code in the rendered definition.
"""
- expressions: t.List[exp.Expression] = []
+ expressions: t.List[exp.Expr] = []
comment = None
for field_name in sorted(self.meta_fields):
field_value = getattr(self, field_name)
@@ -381,7 +379,7 @@ def meta_fields(self) -> t.Iterable[str]:
return set(AuditCommonMetaMixin.__annotations__) | set(_Node.all_field_infos())
@property
- def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expression]]]:
+ def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expr]]]:
return [(self, {})]
@@ -389,7 +387,7 @@ def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expression]]
def load_audit(
- expressions: t.List[exp.Expression],
+ expressions: t.List[exp.Expr],
*,
path: Path = Path(),
module_path: Path = Path(),
@@ -499,7 +497,7 @@ def load_audit(
def load_multiple_audits(
- expressions: t.List[exp.Expression],
+ expressions: t.List[exp.Expr],
*,
path: Path = Path(),
module_path: Path = Path(),
@@ -510,7 +508,7 @@ def load_multiple_audits(
variables: t.Optional[t.Dict[str, t.Any]] = None,
project: t.Optional[str] = None,
) -> t.Generator[Audit, None, None]:
- audit_block: t.List[exp.Expression] = []
+ audit_block: t.List[exp.Expr] = []
for expression in expressions:
if isinstance(expression, d.Audit):
if audit_block:
@@ -543,7 +541,7 @@ def _raise_config_error(msg: str, path: pathlib.Path) -> None:
# mypy doesn't realize raise_config_error raises an exception
@t.no_type_check
-def _maybe_parse_arg_pair(e: exp.Expression) -> t.Tuple[str, exp.Expression]:
+def _maybe_parse_arg_pair(e: exp.Expr) -> t.Tuple[str, exp.Expr]:
if isinstance(e, exp.EQ):
return e.left.name, e.right
diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py
index d89d896897..7a002faebb 100644
--- a/sqlmesh/core/config/connection.py
+++ b/sqlmesh/core/config/connection.py
@@ -17,6 +17,7 @@
from packaging import version
from sqlglot import exp
from sqlglot.helper import subclasses
+from sqlglot.errors import ParseError
from sqlmesh.core import engine_adapter
from sqlmesh.core.config.base import BaseConfig
@@ -1061,6 +1062,7 @@ class BigQueryConnectionConfig(ConnectionConfig):
job_retry_deadline_seconds: t.Optional[int] = None
priority: t.Optional[BigQueryPriority] = None
maximum_bytes_billed: t.Optional[int] = None
+ reservation: t.Optional[str] = None
concurrent_tasks: int = 1
register_comments: bool = True
@@ -1170,6 +1172,7 @@ def _extra_engine_config(self) -> t.Dict[str, t.Any]:
"job_retry_deadline_seconds",
"priority",
"maximum_bytes_billed",
+ "reservation",
}
}
@@ -1887,9 +1890,11 @@ class TrinoConnectionConfig(ConnectionConfig):
client_certificate: t.Optional[str] = None
client_private_key: t.Optional[str] = None
cert: t.Optional[str] = None
+ source: str = "sqlmesh"
# SQLMesh options
schema_location_mapping: t.Optional[dict[re.Pattern, str]] = None
+ timestamp_mapping: t.Optional[dict[exp.DataType, exp.DataType]] = None
concurrent_tasks: int = 4
register_comments: bool = True
pre_ping: t.Literal[False] = False
@@ -1914,6 +1919,34 @@ def _validate_regex_keys(
)
return compiled
+ @field_validator("timestamp_mapping", mode="before")
+ @classmethod
+ def _validate_timestamp_mapping(
+ cls, value: t.Optional[dict[str, str]]
+ ) -> t.Optional[dict[exp.DataType, exp.DataType]]:
+ if value is None:
+ return value
+
+ result: dict[exp.DataType, exp.DataType] = {}
+ for source_type, target_type in value.items():
+ try:
+ source_datatype = exp.DataType.build(source_type)
+ except ParseError:
+ raise ConfigError(
+ f"Invalid SQL type string in timestamp_mapping: "
+ f"'{source_type}' is not a valid SQL data type."
+ )
+ try:
+ target_datatype = exp.DataType.build(target_type)
+ except ParseError:
+ raise ConfigError(
+ f"Invalid SQL type string in timestamp_mapping: "
+ f"'{target_type}' is not a valid SQL data type."
+ )
+ result[source_datatype] = target_datatype
+
+ return result
+
@model_validator(mode="after")
def _root_validator(self) -> Self:
port = self.port
@@ -1954,6 +1987,7 @@ def _connection_kwargs_keys(self) -> t.Set[str]:
"port",
"catalog",
"roles",
+ "source",
"http_scheme",
"http_headers",
"session_properties",
@@ -1981,7 +2015,17 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
OAuth2Authentication,
)
+ auth: t.Optional[
+ t.Union[
+ BasicAuthentication,
+ KerberosAuthentication,
+ OAuth2Authentication,
+ JWTAuthentication,
+ CertificateAuthentication,
+ ]
+ ] = None
if self.method.is_basic or self.method.is_ldap:
+ assert self.password is not None # for mypy since validator already checks this
auth = BasicAuthentication(self.user, self.password)
elif self.method.is_kerberos:
if self.keytab:
@@ -2000,23 +2044,27 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
elif self.method.is_oauth:
auth = OAuth2Authentication()
elif self.method.is_jwt:
+ assert self.jwt_token is not None
auth = JWTAuthentication(self.jwt_token)
elif self.method.is_certificate:
+ assert self.client_certificate is not None
+ assert self.client_private_key is not None
auth = CertificateAuthentication(self.client_certificate, self.client_private_key)
- else:
- auth = None
return {
"auth": auth,
"user": self.impersonation_user or self.user,
"max_attempts": self.retries,
"verify": self.cert if self.cert is not None else self.verify,
- "source": "sqlmesh",
+ "source": self.source,
}
@property
def _extra_engine_config(self) -> t.Dict[str, t.Any]:
- return {"schema_location_mapping": self.schema_location_mapping}
+ return {
+ "schema_location_mapping": self.schema_location_mapping,
+ "timestamp_mapping": self.timestamp_mapping,
+ }
class ClickhouseConnectionConfig(ConnectionConfig):
@@ -2299,7 +2347,7 @@ def init(cursor: t.Any) -> None:
for tpe in subclasses(
__name__,
ConnectionConfig,
- exclude=(ConnectionConfig, BaseDuckDBConnectionConfig),
+ exclude={ConnectionConfig, BaseDuckDBConnectionConfig},
)
}
@@ -2308,7 +2356,7 @@ def init(cursor: t.Any) -> None:
for tpe in subclasses(
__name__,
ConnectionConfig,
- exclude=(ConnectionConfig, BaseDuckDBConnectionConfig),
+ exclude={ConnectionConfig, BaseDuckDBConnectionConfig},
)
}
@@ -2320,7 +2368,7 @@ def init(cursor: t.Any) -> None:
for tpe in subclasses(
__name__,
ConnectionConfig,
- exclude=(ConnectionConfig, BaseDuckDBConnectionConfig),
+ exclude={ConnectionConfig, BaseDuckDBConnectionConfig},
)
}
diff --git a/sqlmesh/core/config/linter.py b/sqlmesh/core/config/linter.py
index c2a40e09aa..11d700c540 100644
--- a/sqlmesh/core/config/linter.py
+++ b/sqlmesh/core/config/linter.py
@@ -34,7 +34,7 @@ def _validate_rules(cls, v: t.Any) -> t.Set[str]:
v = v.unnest().name
elif isinstance(v, (exp.Tuple, exp.Array)):
v = [e.name for e in v.expressions]
- elif isinstance(v, exp.Expression):
+ elif isinstance(v, exp.Expr):
v = v.name
return {name.lower() for name in ensure_collection(v)}
diff --git a/sqlmesh/core/config/model.py b/sqlmesh/core/config/model.py
index aeefdf2557..ac41d75fe3 100644
--- a/sqlmesh/core/config/model.py
+++ b/sqlmesh/core/config/model.py
@@ -71,9 +71,9 @@ class ModelDefaultsConfig(BaseConfig):
enabled: t.Optional[t.Union[str, bool]] = None
formatting: t.Optional[t.Union[str, bool]] = None
batch_concurrency: t.Optional[int] = None
- pre_statements: t.Optional[t.List[t.Union[str, exp.Expression]]] = None
- post_statements: t.Optional[t.List[t.Union[str, exp.Expression]]] = None
- on_virtual_update: t.Optional[t.List[t.Union[str, exp.Expression]]] = None
+ pre_statements: t.Optional[t.List[t.Union[str, exp.Expr]]] = None
+ post_statements: t.Optional[t.List[t.Union[str, exp.Expr]]] = None
+ on_virtual_update: t.Optional[t.List[t.Union[str, exp.Expr]]] = None
_model_kind_validator = model_kind_validator
_on_destructive_change_validator = on_destructive_change_validator
diff --git a/sqlmesh/core/config/scheduler.py b/sqlmesh/core/config/scheduler.py
index 69adcafe70..970defee62 100644
--- a/sqlmesh/core/config/scheduler.py
+++ b/sqlmesh/core/config/scheduler.py
@@ -146,7 +146,7 @@ def get_default_catalog_per_gateway(self, context: GenericContext) -> t.Dict[str
SCHEDULER_CONFIG_TO_TYPE = {
tpe.all_field_infos()["type_"].default: tpe
- for tpe in subclasses(__name__, BaseConfig, exclude=(BaseConfig,))
+ for tpe in subclasses(__name__, BaseConfig, exclude={BaseConfig})
}
diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py
index 5d28ef9551..dc51aad2a7 100644
--- a/sqlmesh/core/context.py
+++ b/sqlmesh/core/context.py
@@ -234,7 +234,7 @@ def resolve_table(self, model_name: str) -> str:
)
def fetchdf(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> pd.DataFrame:
"""Fetches a dataframe given a sql string or sqlglot expression.
@@ -248,7 +248,7 @@ def fetchdf(
return self.engine_adapter.fetchdf(query, quote_identifiers=quote_identifiers)
def fetch_pyspark_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> PySparkDataFrame:
"""Fetches a PySpark dataframe given a sql string or sqlglot expression.
@@ -692,8 +692,11 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
if snapshot.node.project in self._projects:
uncached.add(snapshot.name)
else:
- store = self._standalone_audits if snapshot.is_audit else self._models
- store[snapshot.name] = snapshot.node # type: ignore
+ local_store = self._standalone_audits if snapshot.is_audit else self._models
+ if snapshot.name in local_store:
+ uncached.add(snapshot.name)
+ else:
+ local_store[snapshot.name] = snapshot.node # type: ignore
for model in self._models.values():
self.dag.add(model.fqn, model.depends_on)
@@ -1102,7 +1105,7 @@ def render(
execution_time: t.Optional[TimeLike] = None,
expand: t.Union[bool, t.Iterable[str]] = False,
**kwargs: t.Any,
- ) -> exp.Expression:
+ ) -> exp.Expr:
"""Renders a model's query, expanding macros with provided kwargs, and optionally expanding referenced models.
Args:
@@ -1556,6 +1559,7 @@ def plan_builder(
run = run or False
diff_rendered = diff_rendered or False
skip_linter = skip_linter or False
+ min_intervals = min_intervals or 0
environment = environment or self.config.default_target_environment
environment = Environment.sanitize_name(environment)
@@ -1856,10 +1860,10 @@ def table_diff(
self,
source: str,
target: str,
- on: t.Optional[t.List[str] | exp.Condition] = None,
+ on: t.Optional[t.List[str] | exp.Expr] = None,
skip_columns: t.Optional[t.List[str]] = None,
select_models: t.Optional[t.Collection[str]] = None,
- where: t.Optional[str | exp.Condition] = None,
+ where: t.Optional[str | exp.Expr] = None,
limit: int = 20,
show: bool = True,
show_sample: bool = True,
@@ -1918,7 +1922,7 @@ def table_diff(
raise SQLMeshError(e)
models_to_diff: t.List[
- t.Tuple[Model, EngineAdapter, str, str, t.Optional[t.List[str] | exp.Condition]]
+ t.Tuple[Model, EngineAdapter, str, str, t.Optional[t.List[str] | exp.Expr]]
] = []
models_without_grain: t.List[Model] = []
source_snapshots_to_name = {
@@ -2037,9 +2041,9 @@ def _model_diff(
target_alias: str,
limit: int,
decimals: int,
- on: t.Optional[t.List[str] | exp.Condition] = None,
+ on: t.Optional[t.List[str] | exp.Expr] = None,
skip_columns: t.Optional[t.List[str]] = None,
- where: t.Optional[str | exp.Condition] = None,
+ where: t.Optional[str | exp.Expr] = None,
show: bool = True,
temp_schema: t.Optional[str] = None,
skip_grain_check: bool = False,
@@ -2079,10 +2083,10 @@ def _table_diff(
limit: int,
decimals: int,
adapter: EngineAdapter,
- on: t.Optional[t.List[str] | exp.Condition] = None,
+ on: t.Optional[t.List[str] | exp.Expr] = None,
model: t.Optional[Model] = None,
skip_columns: t.Optional[t.List[str]] = None,
- where: t.Optional[str | exp.Condition] = None,
+ where: t.Optional[str | exp.Expr] = None,
schema_diff_ignore_case: bool = False,
) -> TableDiff:
if not on:
@@ -2340,7 +2344,7 @@ def audit(
return not errors
@python_api_analytics
- def rewrite(self, sql: str, dialect: str = "") -> exp.Expression:
+ def rewrite(self, sql: str, dialect: str = "") -> exp.Expr:
"""Rewrite a sql expression with semantic references into an executable query.
https://sqlmesh.readthedocs.io/en/latest/concepts/metrics/overview/
@@ -3042,10 +3046,17 @@ def _get_plan_default_start_end(
modified_model_names: t.Set[str],
execution_time: t.Optional[TimeLike] = None,
) -> t.Tuple[t.Optional[int], t.Optional[int]]:
- if not max_interval_end_per_model:
+ # exclude seeds so their stale interval ends does not become the default plan end date
+ # when they're the only ones that contain intervals in this plan
+ non_seed_interval_ends = {
+ model_fqn: end
+ for model_fqn, end in max_interval_end_per_model.items()
+ if model_fqn not in snapshots or not snapshots[model_fqn].is_seed
+ }
+ if not non_seed_interval_ends:
return None, None
- default_end = to_timestamp(max(max_interval_end_per_model.values()))
+ default_end = to_timestamp(max(non_seed_interval_ends.values()))
default_start: t.Optional[int] = None
# Infer the default start by finding the smallest interval start that corresponds to the default end.
for model_name in backfill_models or modified_model_names or max_interval_end_per_model:
diff --git a/sqlmesh/core/context_diff.py b/sqlmesh/core/context_diff.py
index 07d13b1c2f..047e58609a 100644
--- a/sqlmesh/core/context_diff.py
+++ b/sqlmesh/core/context_diff.py
@@ -36,7 +36,7 @@
from sqlmesh.utils.metaprogramming import Executable # noqa
from sqlmesh.core.environment import EnvironmentStatements
-IGNORED_PACKAGES = {"sqlmesh", "sqlglot"}
+IGNORED_PACKAGES = {"sqlmesh", "sqlglot", "sqlglotc"}
class ContextDiff(PydanticModel):
diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py
index 332550d57c..3e8f4fe9a7 100644
--- a/sqlmesh/core/dialect.py
+++ b/sqlmesh/core/dialect.py
@@ -12,8 +12,9 @@
from sqlglot import Dialect, Generator, ParseError, Parser, Tokenizer, TokenType, exp
from sqlglot.dialects.dialect import DialectType
-from sqlglot.dialects import DuckDB, Snowflake
+from sqlglot.dialects import DuckDB, Snowflake, TSQL
import sqlglot.dialects.athena as athena
+from sqlglot.parsers.athena import AthenaTrinoParser
from sqlglot.helper import seq_get
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.qualify_columns import quote_identifiers
@@ -52,7 +53,7 @@ class Metric(exp.Expression):
arg_types = {"expressions": True}
-class Jinja(exp.Func):
+class Jinja(exp.Expression, exp.Func):
arg_types = {"this": True}
@@ -76,7 +77,7 @@ class MacroVar(exp.Var):
pass
-class MacroFunc(exp.Func):
+class MacroFunc(exp.Expression, exp.Func):
@property
def name(self) -> str:
return self.this.name
@@ -102,7 +103,7 @@ class DColonCast(exp.Cast):
pass
-class MetricAgg(exp.AggFunc):
+class MetricAgg(exp.Expression, exp.AggFunc):
"""Used for computing metrics."""
arg_types = {"this": True}
@@ -118,7 +119,7 @@ class StagedFilePath(exp.Expression):
arg_types = exp.Table.arg_types.copy()
-def _parse_statement(self: Parser) -> t.Optional[exp.Expression]:
+def _parse_statement(self: Parser) -> t.Optional[exp.Expr]:
if self._curr is None:
return None
@@ -152,7 +153,7 @@ def _parse_statement(self: Parser) -> t.Optional[exp.Expression]:
raise
-def _parse_lambda(self: Parser, alias: bool = False) -> t.Optional[exp.Expression]:
+def _parse_lambda(self: Parser, alias: bool = False) -> t.Optional[exp.Expr]:
node = self.__parse_lambda(alias=alias) # type: ignore
if isinstance(node, exp.Lambda):
node.set("this", self._parse_alias(node.this))
@@ -163,7 +164,7 @@ def _parse_id_var(
self: Parser,
any_token: bool = True,
tokens: t.Optional[t.Collection[TokenType]] = None,
-) -> t.Optional[exp.Expression]:
+) -> t.Optional[exp.Expr]:
if self._prev and self._prev.text == SQLMESH_MACRO_PREFIX and self._match(TokenType.L_BRACE):
identifier = self.__parse_id_var(any_token=any_token, tokens=tokens) # type: ignore
if not self._match(TokenType.R_BRACE):
@@ -207,12 +208,12 @@ def _parse_id_var(
else:
self.raise_error("Expecting }")
- identifier = self.expression(exp.Identifier, this=this, quoted=identifier.quoted)
+ identifier = self.expression(exp.Identifier(this=this, quoted=identifier.quoted))
return identifier
-def _parse_macro(self: Parser, keyword_macro: str = "") -> t.Optional[exp.Expression]:
+def _parse_macro(self: Parser, keyword_macro: str = "") -> t.Optional[exp.Expr]:
if self._prev.text != SQLMESH_MACRO_PREFIX:
return self._parse_parameter()
@@ -220,7 +221,7 @@ def _parse_macro(self: Parser, keyword_macro: str = "") -> t.Optional[exp.Expres
index = self._index
field = self._parse_primary() or self._parse_function(functions={}) or self._parse_id_var()
- def _build_macro(field: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ def _build_macro(field: t.Optional[exp.Expr]) -> t.Optional[exp.Expr]:
if isinstance(field, exp.Func):
macro_name = field.name.upper()
if macro_name != keyword_macro and macro_name in KEYWORD_MACROS:
@@ -230,37 +231,39 @@ def _build_macro(field: t.Optional[exp.Expression]) -> t.Optional[exp.Expression
if isinstance(field, exp.Anonymous):
if macro_name == "DEF":
return self.expression(
- MacroDef,
- this=field.expressions[0],
- expression=field.expressions[1],
+ MacroDef(
+ this=field.expressions[0],
+ expression=field.expressions[1],
+ ),
comments=comments,
)
if macro_name == "SQL":
into = field.expressions[1].this.lower() if len(field.expressions) > 1 else None
return self.expression(
- MacroSQL, this=field.expressions[0], into=into, comments=comments
+ MacroSQL(this=field.expressions[0], into=into), comments=comments
)
else:
field = self.expression(
- exp.Anonymous,
- this=field.sql_name(),
- expressions=list(field.args.values()),
+ exp.Anonymous(
+ this=field.sql_name(),
+ expressions=list(field.args.values()),
+ ),
comments=comments,
)
- return self.expression(MacroFunc, this=field, comments=comments)
+ return self.expression(MacroFunc(this=field), comments=comments)
if field is None:
return None
if field.is_string or (isinstance(field, exp.Identifier) and field.quoted):
return self.expression(
- MacroStrReplace, this=exp.Literal.string(field.this), comments=comments
+ MacroStrReplace(this=exp.Literal.string(field.this)), comments=comments
)
if "@" in field.this:
- return field
- return self.expression(MacroVar, this=field.this, comments=comments)
+ return field # type: ignore[return-value]
+ return self.expression(MacroVar(this=field.this), comments=comments)
if isinstance(field, (exp.Window, exp.IgnoreNulls, exp.RespectNulls)):
field.set("this", _build_macro(field.this))
@@ -273,7 +276,7 @@ def _build_macro(field: t.Optional[exp.Expression]) -> t.Optional[exp.Expression
KEYWORD_MACROS = {"WITH", "JOIN", "WHERE", "GROUP_BY", "HAVING", "ORDER_BY", "LIMIT"}
-def _parse_matching_macro(self: Parser, name: str) -> t.Optional[exp.Expression]:
+def _parse_matching_macro(self: Parser, name: str) -> t.Optional[exp.Expr]:
if not self._match_pair(TokenType.PARAMETER, TokenType.VAR, advance=False) or (
self._next and self._next.text.upper() != name.upper()
):
@@ -283,7 +286,7 @@ def _parse_matching_macro(self: Parser, name: str) -> t.Optional[exp.Expression]
return _parse_macro(self, keyword_macro=name)
-def _parse_body_macro(self: Parser) -> t.Tuple[str, t.Optional[exp.Expression]]:
+def _parse_body_macro(self: Parser) -> t.Tuple[str, t.Optional[exp.Expr]]:
name = self._next and self._next.text.upper()
if name == "JOIN":
@@ -301,7 +304,7 @@ def _parse_body_macro(self: Parser) -> t.Tuple[str, t.Optional[exp.Expression]]:
return ("", None)
-def _parse_with(self: Parser, skip_with_token: bool = False) -> t.Optional[exp.Expression]:
+def _parse_with(self: Parser, skip_with_token: bool = False) -> t.Optional[exp.Expr]:
macro = _parse_matching_macro(self, "WITH")
if not macro:
return self.__parse_with(skip_with_token=skip_with_token) # type: ignore
@@ -312,7 +315,7 @@ def _parse_with(self: Parser, skip_with_token: bool = False) -> t.Optional[exp.E
def _parse_join(
self: Parser, skip_join_token: bool = False, parse_bracket: bool = False
-) -> t.Optional[exp.Expression]:
+) -> t.Optional[exp.Expr]:
index = self._index
method, side, kind = self._parse_join_parts()
macro = _parse_matching_macro(self, "JOIN")
@@ -351,7 +354,7 @@ def _parse_select(
parse_set_operation: bool = True,
consume_pipe: bool = True,
from_: t.Optional[exp.From] = None,
-) -> t.Optional[exp.Expression]:
+) -> t.Optional[exp.Expr]:
select = self.__parse_select( # type: ignore
nested=nested,
table=table,
@@ -372,7 +375,7 @@ def _parse_select(
return select
-def _parse_where(self: Parser, skip_where_token: bool = False) -> t.Optional[exp.Expression]:
+def _parse_where(self: Parser, skip_where_token: bool = False) -> t.Optional[exp.Expr]:
macro = _parse_matching_macro(self, "WHERE")
if not macro:
return self.__parse_where(skip_where_token=skip_where_token) # type: ignore
@@ -381,7 +384,7 @@ def _parse_where(self: Parser, skip_where_token: bool = False) -> t.Optional[exp
return macro
-def _parse_group(self: Parser, skip_group_by_token: bool = False) -> t.Optional[exp.Expression]:
+def _parse_group(self: Parser, skip_group_by_token: bool = False) -> t.Optional[exp.Expr]:
macro = _parse_matching_macro(self, "GROUP_BY")
if not macro:
return self.__parse_group(skip_group_by_token=skip_group_by_token) # type: ignore
@@ -390,7 +393,7 @@ def _parse_group(self: Parser, skip_group_by_token: bool = False) -> t.Optional[
return macro
-def _parse_having(self: Parser, skip_having_token: bool = False) -> t.Optional[exp.Expression]:
+def _parse_having(self: Parser, skip_having_token: bool = False) -> t.Optional[exp.Expr]:
macro = _parse_matching_macro(self, "HAVING")
if not macro:
return self.__parse_having(skip_having_token=skip_having_token) # type: ignore
@@ -400,8 +403,8 @@ def _parse_having(self: Parser, skip_having_token: bool = False) -> t.Optional[e
def _parse_order(
- self: Parser, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False
-) -> t.Optional[exp.Expression]:
+ self: Parser, this: t.Optional[exp.Expr] = None, skip_order_token: bool = False
+) -> t.Optional[exp.Expr]:
macro = _parse_matching_macro(self, "ORDER_BY")
if not macro:
return self.__parse_order(this, skip_order_token=skip_order_token) # type: ignore
@@ -412,10 +415,10 @@ def _parse_order(
def _parse_limit(
self: Parser,
- this: t.Optional[exp.Expression] = None,
+ this: t.Optional[exp.Expr] = None,
top: bool = False,
skip_limit_token: bool = False,
-) -> t.Optional[exp.Expression]:
+) -> t.Optional[exp.Expr]:
macro = _parse_matching_macro(self, "TOP" if top else "LIMIT")
if not macro:
return self.__parse_limit(this, top=top, skip_limit_token=skip_limit_token) # type: ignore
@@ -424,7 +427,7 @@ def _parse_limit(
return macro
-def _parse_value(self: Parser, values: bool = True) -> t.Optional[exp.Expression]:
+def _parse_value(self: Parser, values: bool = True) -> t.Optional[exp.Expr]:
wrapped = self._match(TokenType.L_PAREN, advance=False)
# The base _parse_value method always constructs a Tuple instance. This is problematic when
@@ -438,11 +441,11 @@ def _parse_value(self: Parser, values: bool = True) -> t.Optional[exp.Expression
return expr
-def _parse_macro_or_clause(self: Parser, parser: t.Callable) -> t.Optional[exp.Expression]:
+def _parse_macro_or_clause(self: Parser, parser: t.Callable) -> t.Optional[exp.Expr]:
return _parse_macro(self) if self._match(TokenType.PARAMETER) else parser()
-def _parse_props(self: Parser) -> t.Optional[exp.Expression]:
+def _parse_props(self: Parser) -> t.Optional[exp.Expr]:
key = self._parse_id_var(any_token=True)
if not key:
return None
@@ -460,7 +463,7 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]:
elif name == "merge_filter":
value = self._parse_conjunction()
elif self._match(TokenType.L_PAREN):
- value = self.expression(exp.Tuple, expressions=self._parse_csv(self._parse_equality))
+ value = self.expression(exp.Tuple(expressions=self._parse_csv(self._parse_equality)))
self._match_r_paren()
else:
value = self._parse_bracket(self._parse_field(any_token=True))
@@ -469,7 +472,7 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]:
# Make sure if we get a windows path that it is converted to posix
value = exp.Literal.string(value.this.replace("\\", "/")) # type: ignore
- return self.expression(exp.Property, this=name, value=value)
+ return self.expression(exp.Property(this=name, value=value))
def _parse_types(
@@ -477,7 +480,7 @@ def _parse_types(
check_func: bool = False,
schema: bool = False,
allow_identifiers: bool = True,
-) -> t.Optional[exp.Expression]:
+) -> t.Optional[exp.Expr]:
start = self._curr
parsed_type = self.__parse_types( # type: ignore
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
@@ -496,13 +499,20 @@ def _parse_types(
#
# See: https://docs.snowflake.com/en/user-guide/querying-stage
def _parse_table_parts(
- self: Parser, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False
+ self: Parser,
+ schema: bool = False,
+ is_db_reference: bool = False,
+ wildcard: bool = False,
+ fast: bool = False,
) -> exp.Table | StagedFilePath:
index = self._index
table = self.__parse_table_parts( # type: ignore
- schema=schema, is_db_reference=is_db_reference, wildcard=wildcard
+ schema=schema, is_db_reference=is_db_reference, wildcard=wildcard, fast=fast
)
+ if table is None:
+ return table # type: ignore[return-value]
+
table_arg = table.this
name = table_arg.name if isinstance(table_arg, exp.Var) else ""
@@ -526,7 +536,9 @@ def _parse_table_parts(
)
):
self._retreat(index)
- return Parser._parse_table_parts(self, schema=schema, is_db_reference=is_db_reference)
+ return Parser._parse_table_parts(
+ self, schema=schema, is_db_reference=is_db_reference, fast=fast
+ ) # type: ignore[return-value]
table_arg.replace(MacroVar(this=name[1:]))
return StagedFilePath(**table.args)
@@ -534,7 +546,7 @@ def _parse_table_parts(
return table
-def _parse_if(self: Parser) -> t.Optional[exp.Expression]:
+def _parse_if(self: Parser) -> t.Optional[exp.Expr]:
# If we fail to parse an IF function with expressions as arguments, we then try
# to parse a statement / command to support the macro @IF(condition, statement)
index = self._index
@@ -554,6 +566,10 @@ def _parse_if(self: Parser) -> t.Optional[exp.Expression]:
if last_token.token_type == TokenType.R_PAREN:
self._tokens[-2].comments.extend(last_token.comments)
self._tokens.pop()
+ if hasattr(self, "_tokens_size"):
+ # keep _tokens_size in sync sqlglot 30.0.3 caches len(_tokens)
+ # _advance() tries to read tokens[index + 1] past the new end
+ self._tokens_size -= 1
else:
self.raise_error("Expecting )")
@@ -566,11 +582,11 @@ def _parse_if(self: Parser) -> t.Optional[exp.Expression]:
return exp.Anonymous(this="IF", expressions=[cond, stmt])
-def _create_parser(expression_type: t.Type[exp.Expression], table_keys: t.List[str]) -> t.Callable:
- def parse(self: Parser) -> t.Optional[exp.Expression]:
+def _create_parser(expression_type: t.Type[exp.Expr], table_keys: t.List[str]) -> t.Callable:
+ def parse(self: Parser) -> t.Optional[exp.Expr]:
from sqlmesh.core.model.kind import ModelKindName
- expressions: t.List[exp.Expression] = []
+ expressions: t.List[exp.Expr] = []
while True:
prev_property = seq_get(expressions, -1)
@@ -589,7 +605,7 @@ def parse(self: Parser) -> t.Optional[exp.Expression]:
key = key_expression.name.lower()
start = self._curr
- value: t.Optional[exp.Expression | str]
+ value: t.Optional[exp.Expr | str]
if key in table_keys:
value = self._parse_table_parts()
@@ -629,7 +645,7 @@ def parse(self: Parser) -> t.Optional[exp.Expression]:
else:
props = None
- value = self.expression(ModelKind, this=kind.value, expressions=props)
+ value = self.expression(ModelKind(this=kind.value, expressions=props))
elif key == "expression":
value = self._parse_conjunction()
elif key == "partitioned_by":
@@ -641,12 +657,12 @@ def parse(self: Parser) -> t.Optional[exp.Expression]:
else:
value = self._parse_bracket(self._parse_field(any_token=True))
- if isinstance(value, exp.Expression):
+ if isinstance(value, exp.Expr):
value.meta["sql"] = self._find_sql(start, self._prev)
- expressions.append(self.expression(exp.Property, this=key, value=value))
+ expressions.append(self.expression(exp.Property(this=key, value=value)))
- return self.expression(expression_type, expressions=expressions)
+ return self.expression(expression_type(expressions=expressions))
return parse
@@ -658,7 +674,7 @@ def parse(self: Parser) -> t.Optional[exp.Expression]:
}
-def _props_sql(self: Generator, expressions: t.List[exp.Expression]) -> str:
+def _props_sql(self: Generator, expressions: t.List[exp.Expr]) -> str:
props = []
size = len(expressions)
@@ -676,7 +692,7 @@ def _props_sql(self: Generator, expressions: t.List[exp.Expression]) -> str:
return "\n".join(props)
-def _on_virtual_update_sql(self: Generator, expressions: t.List[exp.Expression]) -> str:
+def _on_virtual_update_sql(self: Generator, expressions: t.List[exp.Expr]) -> str:
statements = "\n".join(
self.sql(expression)
if isinstance(expression, JinjaStatement)
@@ -697,7 +713,7 @@ def _model_kind_sql(self: Generator, expression: ModelKind) -> str:
return expression.name.upper()
-def _macro_keyword_func_sql(self: Generator, expression: exp.Expression) -> str:
+def _macro_keyword_func_sql(self: Generator, expression: exp.Expr) -> str:
name = expression.name
keyword = name.replace("_", " ")
*args, clause = expression.expressions
@@ -731,7 +747,7 @@ def _override(klass: t.Type[Tokenizer | Parser], func: t.Callable) -> None:
def format_model_expressions(
- expressions: t.List[exp.Expression],
+ expressions: t.List[exp.Expr],
dialect: t.Optional[str] = None,
rewrite_casts: bool = True,
**kwargs: t.Any,
@@ -752,7 +768,7 @@ def format_model_expressions(
if rewrite_casts:
- def cast_to_colon(node: exp.Expression) -> exp.Expression:
+ def cast_to_colon(node: exp.Expr) -> exp.Expr:
if isinstance(node, exp.Cast) and not any(
# Only convert CAST into :: if it doesn't have additional args set, otherwise this
# conversion could alter the semantics (eg. changing SAFE_CAST in BigQuery to CAST)
@@ -784,8 +800,8 @@ def cast_to_colon(node: exp.Expression) -> exp.Expression:
def text_diff(
- a: t.List[exp.Expression],
- b: t.List[exp.Expression],
+ a: t.List[exp.Expr],
+ b: t.List[exp.Expr],
a_dialect: t.Optional[str] = None,
b_dialect: t.Optional[str] = None,
) -> str:
@@ -803,8 +819,15 @@ def text_diff(
return "\n".join(unified_diff(a_sql, b_sql))
+WS_OR_COMMENT = r"(?:\s|--[^\n]*\n|/\*.*?\*/)"
+HEADER = r"\b(?:model|audit)\b(?=\s*\()"
+KEY_BOUNDARY = r"(?:\(|,)" # key is preceded by either '(' or ','
+DIALECT_VALUE = r"['\"]?(?P[a-z][a-z0-9]*)['\"]?"
+VALUE_BOUNDARY = r"(?=,|\))" # value is followed by comma or closing paren
+
DIALECT_PATTERN = re.compile(
- r"(model|audit).*?\(.*?dialect\s+'?([a-z]*)", re.IGNORECASE | re.DOTALL
+ rf"{HEADER}.*?{KEY_BOUNDARY}{WS_OR_COMMENT}*dialect{WS_OR_COMMENT}+{DIALECT_VALUE}{WS_OR_COMMENT}*{VALUE_BOUNDARY}",
+ re.IGNORECASE | re.DOTALL,
)
@@ -853,7 +876,7 @@ def _is_virtual_statement_end(tokens: t.List[Token], pos: int) -> bool:
return _is_command_statement(ON_VIRTUAL_UPDATE_END, tokens, pos)
-def virtual_statement(statements: t.List[exp.Expression]) -> VirtualUpdateStatement:
+def virtual_statement(statements: t.List[exp.Expr]) -> VirtualUpdateStatement:
return VirtualUpdateStatement(expressions=statements)
@@ -867,7 +890,7 @@ class ChunkType(Enum):
def parse_one(
sql: str, dialect: t.Optional[str] = None, into: t.Optional[exp.IntoType] = None
-) -> exp.Expression:
+) -> exp.Expr:
expressions = parse(sql, default_dialect=dialect, match_dialect=False, into=into)
if not expressions:
raise SQLMeshError(f"No expressions found in '{sql}'")
@@ -881,7 +904,7 @@ def parse(
default_dialect: t.Optional[str] = None,
match_dialect: bool = True,
into: t.Optional[exp.IntoType] = None,
-) -> t.List[exp.Expression]:
+) -> t.List[exp.Expr]:
"""Parse a sql string.
Supports parsing model definition.
@@ -895,7 +918,8 @@ def parse(
A list of the parsed expressions: [Model, *Statements, Query, *Statements]
"""
match = match_dialect and DIALECT_PATTERN.search(sql[:MAX_MODEL_DEFINITION_SIZE])
- dialect = Dialect.get_or_raise(match.group(2) if match else default_dialect)
+ dialect_str = match.group("dialect") if match else None
+ dialect = Dialect.get_or_raise(dialect_str or default_dialect)
tokens = dialect.tokenize(sql)
chunks: t.List[t.Tuple[t.List[Token], ChunkType]] = [([], ChunkType.SQL)]
@@ -944,10 +968,10 @@ def parse(
pos += 1
parser = dialect.parser()
- expressions: t.List[exp.Expression] = []
+ expressions: t.List[exp.Expr] = []
- def parse_sql_chunk(chunk: t.List[Token], meta_sql: bool = True) -> t.List[exp.Expression]:
- parsed_expressions: t.List[t.Optional[exp.Expression]] = (
+ def parse_sql_chunk(chunk: t.List[Token], meta_sql: bool = True) -> t.List[exp.Expr]:
+ parsed_expressions: t.List[t.Optional[exp.Expr]] = (
parser.parse(chunk, sql) if into is None else parser.parse_into(into, chunk, sql)
)
expressions = []
@@ -958,7 +982,7 @@ def parse_sql_chunk(chunk: t.List[Token], meta_sql: bool = True) -> t.List[exp.E
expressions.append(expression)
return expressions
- def parse_jinja_chunk(chunk: t.List[Token], meta_sql: bool = True) -> exp.Expression:
+ def parse_jinja_chunk(chunk: t.List[Token], meta_sql: bool = True) -> exp.Expr:
start, *_, end = chunk
segment = sql[start.end + 2 : end.start - 1]
factory = jinja_query if chunk_type == ChunkType.JINJA_QUERY else jinja_statement
@@ -969,9 +993,9 @@ def parse_jinja_chunk(chunk: t.List[Token], meta_sql: bool = True) -> exp.Expres
def parse_virtual_statement(
chunks: t.List[t.Tuple[t.List[Token], ChunkType]], pos: int
- ) -> t.Tuple[t.List[exp.Expression], int]:
+ ) -> t.Tuple[t.List[exp.Expr], int]:
# For virtual statements we need to handle both SQL and Jinja nested blocks within the chunk
- virtual_update_statements = []
+ virtual_update_statements: t.List[exp.Expr] = []
start = chunks[pos][0][0].start
while (
@@ -1023,7 +1047,7 @@ def extend_sqlglot() -> None:
# so this ensures that the extra ones it defines are also extended
if dialect == athena.Athena:
tokenizers.add(athena._TrinoTokenizer)
- parsers.add(athena._TrinoParser)
+ parsers.add(AthenaTrinoParser)
generators.add(athena._TrinoGenerator)
generators.add(athena._HiveGenerator)
@@ -1093,6 +1117,7 @@ def extend_sqlglot() -> None:
_override(Parser, _parse_value)
_override(Parser, _parse_lambda)
_override(Parser, _parse_types)
+ _override(TSQL.Parser, Parser._parse_if)
_override(Parser, _parse_if)
_override(Parser, _parse_id_var)
_override(Parser, _warn_unsupported)
@@ -1242,7 +1267,7 @@ def normalize_model_name(
def find_tables(
- expression: exp.Expression, default_catalog: t.Optional[str], dialect: DialectType = None
+ expression: exp.Expr, default_catalog: t.Optional[str], dialect: DialectType = None
) -> t.Set[str]:
"""Find all tables referenced in a query.
@@ -1265,10 +1290,10 @@ def find_tables(
return expression.meta[TABLES_META]
-def add_table(node: exp.Expression, table: str) -> exp.Expression:
+def add_table(node: exp.Expr, table: str) -> exp.Expr:
"""Add a table to all columns in an expression."""
- def _transform(node: exp.Expression) -> exp.Expression:
+ def _transform(node: exp.Expr) -> exp.Expr:
if isinstance(node, exp.Column) and not node.table:
return exp.column(node.this, table=table)
if isinstance(node, exp.Identifier):
@@ -1378,7 +1403,7 @@ def normalize_and_quote(
quote_identifiers(query, dialect=dialect)
-def interpret_expression(e: exp.Expression) -> exp.Expression | str | int | float | bool:
+def interpret_expression(e: exp.Expr) -> exp.Expr | str | int | float | bool:
if e.is_int:
return int(e.this)
if e.is_number:
@@ -1390,13 +1415,13 @@ def interpret_expression(e: exp.Expression) -> exp.Expression | str | int | floa
def interpret_key_value_pairs(
e: exp.Tuple,
-) -> t.Dict[str, exp.Expression | str | int | float | bool]:
+) -> t.Dict[str, exp.Expr | str | int | float | bool]:
return {i.this.name: interpret_expression(i.expression) for i in e.expressions}
def extract_func_call(
- v: exp.Expression, allow_tuples: bool = False
-) -> t.Tuple[str, t.Dict[str, exp.Expression]]:
+ v: exp.Expr, allow_tuples: bool = False
+) -> t.Tuple[str, t.Dict[str, exp.Expr]]:
kwargs = {}
if isinstance(v, exp.Anonymous):
@@ -1433,7 +1458,7 @@ def extract_function_calls(func_calls: t.Any, allow_tuples: bool = False) -> t.A
return [extract_func_call(i, allow_tuples=allow_tuples) for i in func_calls.expressions]
if isinstance(func_calls, exp.Paren):
return [extract_func_call(func_calls.this, allow_tuples=allow_tuples)]
- if isinstance(func_calls, exp.Expression):
+ if isinstance(func_calls, exp.Expr):
return [extract_func_call(func_calls, allow_tuples=allow_tuples)]
if isinstance(func_calls, list):
function_calls = []
@@ -1465,9 +1490,7 @@ def is_meta_expression(v: t.Any) -> bool:
return isinstance(v, (Audit, Metric, Model))
-def replace_merge_table_aliases(
- expression: exp.Expression, dialect: t.Optional[str] = None
-) -> exp.Expression:
+def replace_merge_table_aliases(expression: exp.Expr, dialect: t.Optional[str] = None) -> exp.Expr:
"""
Resolves references from the "source" and "target" tables (or their DBT equivalents)
with the corresponding SQLMesh merge aliases (MERGE_SOURCE_ALIAS and MERGE_TARGET_ALIAS)
diff --git a/sqlmesh/core/engine_adapter/athena.py b/sqlmesh/core/engine_adapter/athena.py
index bd84ba5276..338381549b 100644
--- a/sqlmesh/core/engine_adapter/athena.py
+++ b/sqlmesh/core/engine_adapter/athena.py
@@ -158,7 +158,7 @@ def _create_schema(
schema_name: SchemaName,
ignore_if_exists: bool,
warn_on_error: bool,
- properties: t.List[exp.Expression],
+ properties: t.List[exp.Expr],
kind: str,
) -> None:
if location := self._table_location(table_properties=None, table=exp.to_table(schema_name)):
@@ -177,14 +177,14 @@ def _create_schema(
def _build_create_table_exp(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
table_kind: t.Optional[str] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
**kwargs: t.Any,
) -> exp.Create:
exists = False if replace else exists
@@ -235,18 +235,18 @@ def _build_table_properties_exp(
catalog_name: t.Optional[str] = None,
table_format: t.Optional[str] = None,
storage_format: t.Optional[str] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
partition_interval_unit: t.Optional[IntervalUnit] = None,
- clustered_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ clustered_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
table_kind: t.Optional[str] = None,
table: t.Optional[exp.Table] = None,
- expression: t.Optional[exp.Expression] = None,
+ expression: t.Optional[exp.Expr] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
table_properties = table_properties or {}
is_hive = self._table_type(table_format) == "hive"
@@ -266,7 +266,7 @@ def _build_table_properties_exp(
properties.append(exp.SchemaCommentProperty(this=exp.Literal.string(table_description)))
if partitioned_by:
- schema_expressions: t.List[exp.Expression] = []
+ schema_expressions: t.List[exp.Expr] = []
if is_hive and target_columns_to_types:
# For Hive-style tables, you cannot include the partitioned by columns in the main set of columns
# In the PARTITIONED BY expression, you also cant just include the column names, you need to include the data type as well
@@ -381,7 +381,7 @@ def _is_hive_partitioned_table(self, table: exp.Table) -> bool:
raise e
def _table_location_or_raise(
- self, table_properties: t.Optional[t.Dict[str, exp.Expression]], table: exp.Table
+ self, table_properties: t.Optional[t.Dict[str, exp.Expr]], table: exp.Table
) -> exp.LocationProperty:
location = self._table_location(table_properties, table)
if not location:
@@ -392,7 +392,7 @@ def _table_location_or_raise(
def _table_location(
self,
- table_properties: t.Optional[t.Dict[str, exp.Expression]],
+ table_properties: t.Optional[t.Dict[str, exp.Expr]],
table: exp.Table,
) -> t.Optional[exp.LocationProperty]:
base_uri: str
@@ -402,7 +402,7 @@ def _table_location(
s3_base_location_property = table_properties.pop(
"s3_base_location"
) # pop because it's handled differently and we dont want it to end up in the TBLPROPERTIES clause
- if isinstance(s3_base_location_property, exp.Expression):
+ if isinstance(s3_base_location_property, exp.Expr):
base_uri = s3_base_location_property.name
else:
base_uri = s3_base_location_property
@@ -419,7 +419,7 @@ def _table_location(
return exp.LocationProperty(this=exp.Literal.string(full_uri))
def _find_matching_columns(
- self, partitioned_by: t.List[exp.Expression], columns_to_types: t.Dict[str, exp.DataType]
+ self, partitioned_by: t.List[exp.Expr], columns_to_types: t.Dict[str, exp.DataType]
) -> t.List[t.Tuple[str, exp.DataType]]:
matches = []
for col in partitioned_by:
@@ -557,7 +557,7 @@ def _chunks() -> t.Iterable[t.List[t.List[str]]]:
PartitionsToDelete=[{"Values": v} for v in batch],
)
- def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None:
+ def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None:
table = exp.to_table(table_name)
table_type = self._query_table_type(table)
diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py
index a7a8e2f707..8de7b79398 100644
--- a/sqlmesh/core/engine_adapter/base.py
+++ b/sqlmesh/core/engine_adapter/base.py
@@ -236,7 +236,7 @@ def _casted_columns(
cls,
target_columns_to_types: t.Dict[str, exp.DataType],
source_columns: t.Optional[t.List[str]] = None,
- ) -> t.List[exp.Alias]:
+ ) -> t.List[exp.Expr]:
source_columns_lookup = set(source_columns or target_columns_to_types)
return [
exp.alias_(
@@ -591,7 +591,7 @@ def create_index(
def _pop_creatable_type_from_properties(
self,
- properties: t.Dict[str, exp.Expression],
+ properties: t.Dict[str, exp.Expr],
) -> t.Optional[exp.Property]:
"""Pop out the creatable_type from the properties dictionary (if exists (return it/remove it) else return none).
It also checks that none of the expressions are MATERIALIZE as that conflicts with the `materialize` parameter.
@@ -652,9 +652,9 @@ def create_managed_table(
table_name: TableName,
query: Query,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
- clustered_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
+ clustered_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
source_columns: t.Optional[t.List[str]] = None,
@@ -964,7 +964,7 @@ def _create_table_from_source_queries(
def _create_table(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -1002,7 +1002,7 @@ def _create_table(
def _build_create_table_exp(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -1203,7 +1203,7 @@ def create_view(
materialized_properties: t.Optional[t.Dict[str, t.Any]] = None,
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
- view_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
source_columns: t.Optional[t.List[str]] = None,
**create_kwargs: t.Any,
) -> None:
@@ -1382,7 +1382,7 @@ def create_schema(
schema_name: SchemaName,
ignore_if_exists: bool = True,
warn_on_error: bool = True,
- properties: t.Optional[t.List[exp.Expression]] = None,
+ properties: t.Optional[t.List[exp.Expr]] = None,
) -> None:
properties = properties or []
return self._create_schema(
@@ -1398,7 +1398,7 @@ def _create_schema(
schema_name: SchemaName,
ignore_if_exists: bool,
warn_on_error: bool,
- properties: t.List[exp.Expression],
+ properties: t.List[exp.Expr],
kind: str,
) -> None:
"""Create a schema from a name or qualified table name."""
@@ -1423,7 +1423,7 @@ def drop_schema(
schema_name: SchemaName,
ignore_if_not_exists: bool = True,
cascade: bool = False,
- **drop_args: t.Dict[str, exp.Expression],
+ **drop_args: t.Dict[str, exp.Expr],
) -> None:
return self._drop_object(
name=schema_name,
@@ -1494,7 +1494,7 @@ def table_exists(self, table_name: TableName) -> bool:
except Exception:
return False
- def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None:
+ def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None:
self.execute(exp.delete(table_name, where))
def insert_append(
@@ -1552,7 +1552,7 @@ def insert_overwrite_by_partition(
self,
table_name: TableName,
query_or_df: QueryOrDF,
- partitioned_by: t.List[exp.Expression],
+ partitioned_by: t.List[exp.Expr],
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
source_columns: t.Optional[t.List[str]] = None,
) -> None:
@@ -1583,10 +1583,8 @@ def insert_overwrite_by_time_partition(
query_or_df: QueryOrDF,
start: TimeLike,
end: TimeLike,
- time_formatter: t.Callable[
- [TimeLike, t.Optional[t.Dict[str, exp.DataType]]], exp.Expression
- ],
- time_column: TimeColumn | exp.Expression | str,
+ time_formatter: t.Callable[[TimeLike, t.Optional[t.Dict[str, exp.DataType]]], exp.Expr],
+ time_column: TimeColumn | exp.Expr | str,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
source_columns: t.Optional[t.List[str]] = None,
**kwargs: t.Any,
@@ -1726,7 +1724,7 @@ def _merge(
self,
target_table: TableName,
query: Query,
- on: exp.Expression,
+ on: exp.Expr,
whens: exp.Whens,
) -> None:
this = exp.alias_(exp.to_table(target_table), alias=MERGE_TARGET_ALIAS, table=True)
@@ -1741,7 +1739,7 @@ def scd_type_2_by_time(
self,
target_table: TableName,
source_table: QueryOrDF,
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
valid_from_col: exp.Column,
valid_to_col: exp.Column,
execution_time: t.Union[TimeLike, exp.Column],
@@ -1777,11 +1775,11 @@ def scd_type_2_by_column(
self,
target_table: TableName,
source_table: QueryOrDF,
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
valid_from_col: exp.Column,
valid_to_col: exp.Column,
execution_time: t.Union[TimeLike, exp.Column],
- check_columns: t.Union[exp.Star, t.Sequence[exp.Expression]],
+ check_columns: t.Union[exp.Star, t.Sequence[exp.Expr]],
invalidate_hard_deletes: bool = True,
execution_time_as_valid_from: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -1813,13 +1811,13 @@ def _scd_type_2(
self,
target_table: TableName,
source_table: QueryOrDF,
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
valid_from_col: exp.Column,
valid_to_col: exp.Column,
execution_time: t.Union[TimeLike, exp.Column],
invalidate_hard_deletes: bool = True,
updated_at_col: t.Optional[exp.Column] = None,
- check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expression]]] = None,
+ check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expr]]] = None,
updated_at_as_valid_from: bool = False,
execution_time_as_valid_from: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -1908,7 +1906,7 @@ def remove_managed_columns(
raise SQLMeshError(
"Cannot use `updated_at_as_valid_from` without `updated_at_name` for SCD Type 2"
)
- update_valid_from_start: t.Union[str, exp.Expression] = updated_at_col
+ update_valid_from_start: t.Union[str, exp.Expr] = updated_at_col
# If using check_columns and the user doesn't always want execution_time for valid from
# then we only use epoch 0 if we are truncating the table and loading rows for the first time.
# All future new rows should have execution time.
@@ -2207,9 +2205,9 @@ def merge(
target_table: TableName,
source_table: QueryOrDF,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
when_matched: t.Optional[exp.Whens] = None,
- merge_filter: t.Optional[exp.Expression] = None,
+ merge_filter: t.Optional[exp.Expr] = None,
source_columns: t.Optional[t.List[str]] = None,
**kwargs: t.Any,
) -> None:
@@ -2382,7 +2380,7 @@ def get_data_objects(
def fetchone(
self,
- query: t.Union[exp.Expression, str],
+ query: t.Union[exp.Expr, str],
ignore_unsupported_errors: bool = False,
quote_identifiers: bool = False,
) -> t.Optional[t.Tuple]:
@@ -2396,7 +2394,7 @@ def fetchone(
def fetchall(
self,
- query: t.Union[exp.Expression, str],
+ query: t.Union[exp.Expr, str],
ignore_unsupported_errors: bool = False,
quote_identifiers: bool = False,
) -> t.List[t.Tuple]:
@@ -2409,7 +2407,7 @@ def fetchall(
return self.cursor.fetchall()
def _fetch_native_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> DF:
"""Fetches a DataFrame that can be either Pandas or PySpark from the cursor"""
with self.transaction():
@@ -2432,7 +2430,7 @@ def _native_df_to_pandas_df(
raise NotImplementedError(f"Unable to convert {type(query_or_df)} to Pandas")
def fetchdf(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> pd.DataFrame:
"""Fetches a Pandas DataFrame from the cursor"""
import pandas as pd
@@ -2445,7 +2443,7 @@ def fetchdf(
return df
def fetch_pyspark_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> PySparkDataFrame:
"""Fetches a PySpark DataFrame from the cursor"""
raise NotImplementedError(f"Engine does not support PySpark DataFrames: {type(self)}")
@@ -2575,7 +2573,7 @@ def _is_session_active(self) -> bool:
def execute(
self,
- expressions: t.Union[str, exp.Expression, t.Sequence[exp.Expression]],
+ expressions: t.Union[str, exp.Expr, t.Sequence[exp.Expr]],
ignore_unsupported_errors: bool = False,
quote_identifiers: bool = True,
track_rows_processed: bool = False,
@@ -2587,7 +2585,7 @@ def execute(
)
with self.transaction():
for e in ensure_list(expressions):
- if isinstance(e, exp.Expression):
+ if isinstance(e, exp.Expr):
self._check_identifier_length(e)
sql = self._to_sql(e, quote=quote_identifiers, **to_sql_kwargs)
else:
@@ -2597,7 +2595,7 @@ def execute(
self._log_sql(
sql,
- expression=e if isinstance(e, exp.Expression) else None,
+ expression=e if isinstance(e, exp.Expr) else None,
quote_identifiers=quote_identifiers,
)
self._execute(sql, track_rows_processed, **kwargs)
@@ -2610,7 +2608,7 @@ def _attach_correlation_id(self, sql: str) -> str:
def _log_sql(
self,
sql: str,
- expression: t.Optional[exp.Expression] = None,
+ expression: t.Optional[exp.Expr] = None,
quote_identifiers: bool = True,
) -> None:
if not logger.isEnabledFor(self._execute_log_level):
@@ -2702,7 +2700,7 @@ def temp_table(
self.drop_table(table)
def _table_or_view_properties_to_expressions(
- self, table_or_view_properties: t.Optional[t.Dict[str, exp.Expression]] = None
+ self, table_or_view_properties: t.Optional[t.Dict[str, exp.Expr]] = None
) -> t.List[exp.Property]:
"""Converts model properties (either physical or virtual) to a list of property expressions."""
if not table_or_view_properties:
@@ -2714,7 +2712,7 @@ def _table_or_view_properties_to_expressions(
def _build_partitioned_by_exp(
self,
- partitioned_by: t.List[exp.Expression],
+ partitioned_by: t.List[exp.Expr],
*,
partition_interval_unit: t.Optional[IntervalUnit] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -2725,7 +2723,7 @@ def _build_partitioned_by_exp(
def _build_clustered_by_exp(
self,
- clustered_by: t.List[exp.Expression],
+ clustered_by: t.List[exp.Expr],
**kwargs: t.Any,
) -> t.Optional[exp.Cluster]:
return None
@@ -2735,17 +2733,17 @@ def _build_table_properties_exp(
catalog_name: t.Optional[str] = None,
table_format: t.Optional[str] = None,
storage_format: t.Optional[str] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
partition_interval_unit: t.Optional[IntervalUnit] = None,
- clustered_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ clustered_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
table_kind: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
"""Creates a SQLGlot table properties expression for ddl."""
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
if table_description:
properties.append(
@@ -2764,12 +2762,12 @@ def _build_table_properties_exp(
def _build_view_properties_exp(
self,
- view_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
table_description: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
"""Creates a SQLGlot table properties expression for view"""
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
if table_description:
properties.append(
@@ -2791,7 +2789,7 @@ def _truncate_table_comment(self, comment: str) -> str:
def _truncate_column_comment(self, comment: str) -> str:
return self._truncate_comment(comment, self.MAX_COLUMN_COMMENT_LENGTH)
- def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.Any) -> str:
+ def _to_sql(self, expression: exp.Expr, quote: bool = True, **kwargs: t.Any) -> str:
"""
Converts an expression to a SQL string. Has a set of default kwargs to apply, and then default
kwargs defined for the given dialect, and then kwargs provided by the user when defining the engine
@@ -2852,7 +2850,7 @@ def _order_projections_and_filter(
self,
query: Query,
target_columns_to_types: t.Dict[str, exp.DataType],
- where: t.Optional[exp.Expression] = None,
+ where: t.Optional[exp.Expr] = None,
coerce_types: bool = False,
) -> Query:
if not isinstance(query, exp.Query) or (
@@ -2861,9 +2859,9 @@ def _order_projections_and_filter(
return query
query = t.cast(exp.Query, query.copy())
- with_ = query.args.pop("with", None)
+ with_ = query.args.pop("with_", None)
- select_exprs: t.List[exp.Expression] = [
+ select_exprs: t.List[exp.Expr] = [
exp.column(c, quoted=True) for c in target_columns_to_types
]
if coerce_types and columns_to_types_all_known(target_columns_to_types):
@@ -2877,7 +2875,7 @@ def _order_projections_and_filter(
query = query.where(where, copy=False)
if with_:
- query.set("with", with_)
+ query.set("with_", with_)
return query
@@ -2914,7 +2912,7 @@ def _replace_by_key(
target_table: TableName,
source_table: QueryOrDF,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
- key: t.Sequence[exp.Expression],
+ key: t.Sequence[exp.Expr],
is_unique_key: bool,
source_columns: t.Optional[t.List[str]] = None,
) -> None:
@@ -3055,7 +3053,7 @@ def _select_columns(
)
)
- def _check_identifier_length(self, expression: exp.Expression) -> None:
+ def _check_identifier_length(self, expression: exp.Expr) -> None:
if self.MAX_IDENTIFIER_LENGTH is None or not isinstance(expression, exp.DDL):
return
@@ -3147,7 +3145,7 @@ def _apply_grants_config_expr(
table: exp.Table,
grants_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
"""Returns SQLGlot Grant expressions to apply grants to a table.
Args:
@@ -3170,7 +3168,7 @@ def _revoke_grants_config_expr(
table: exp.Table,
grants_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
"""Returns SQLGlot expressions to revoke grants from a table.
Args:
diff --git a/sqlmesh/core/engine_adapter/base_postgres.py b/sqlmesh/core/engine_adapter/base_postgres.py
index 11f56da133..e2347b1263 100644
--- a/sqlmesh/core/engine_adapter/base_postgres.py
+++ b/sqlmesh/core/engine_adapter/base_postgres.py
@@ -110,7 +110,7 @@ def create_view(
materialized_properties: t.Optional[t.Dict[str, t.Any]] = None,
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
- view_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
source_columns: t.Optional[t.List[str]] = None,
**create_kwargs: t.Any,
) -> None:
diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py
index 59a56b6ace..d136445114 100644
--- a/sqlmesh/core/engine_adapter/bigquery.py
+++ b/sqlmesh/core/engine_adapter/bigquery.py
@@ -67,7 +67,7 @@ class BigQueryEngineAdapter(ClusteredByMixin, RowDiffMixin, GrantsFromInfoSchema
SUPPORTS_MATERIALIZED_VIEWS = True
SUPPORTS_CLONING = True
SUPPORTS_GRANTS = True
- CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("session_user")
+ CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expr = exp.func("session_user")
SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True
USE_CATALOG_IN_GRANTS = True
GRANT_INFORMATION_SCHEMA_TABLE_NAME = "OBJECT_PRIVILEGES"
@@ -140,8 +140,10 @@ def _job_params(self) -> t.Dict[str, t.Any]:
"priority", BigQueryPriority.INTERACTIVE.bigquery_constant
),
}
- if self._extra_config.get("maximum_bytes_billed"):
+ if self._extra_config.get("maximum_bytes_billed") is not None:
params["maximum_bytes_billed"] = self._extra_config.get("maximum_bytes_billed")
+ if self._extra_config.get("reservation") is not None:
+ params["reservation"] = self._extra_config.get("reservation")
if self.correlation_id:
# BigQuery label keys must be lowercase
key = self.correlation_id.job_type.value.lower()
@@ -288,7 +290,7 @@ def create_schema(
schema_name: SchemaName,
ignore_if_exists: bool = True,
warn_on_error: bool = True,
- properties: t.List[exp.Expression] = [],
+ properties: t.List[exp.Expr] = [],
) -> None:
"""Create a schema from a name or qualified table name."""
from google.api_core.exceptions import Conflict
@@ -433,7 +435,7 @@ def alter_table(
def fetchone(
self,
- query: t.Union[exp.Expression, str],
+ query: t.Union[exp.Expr, str],
ignore_unsupported_errors: bool = False,
quote_identifiers: bool = False,
) -> t.Optional[t.Tuple]:
@@ -453,7 +455,7 @@ def fetchone(
def fetchall(
self,
- query: t.Union[exp.Expression, str],
+ query: t.Union[exp.Expr, str],
ignore_unsupported_errors: bool = False,
quote_identifiers: bool = False,
) -> t.List[t.Tuple]:
@@ -689,7 +691,7 @@ def insert_overwrite_by_partition(
self,
table_name: TableName,
query_or_df: QueryOrDF,
- partitioned_by: t.List[exp.Expression],
+ partitioned_by: t.List[exp.Expr],
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
source_columns: t.Optional[t.List[str]] = None,
) -> None:
@@ -803,7 +805,7 @@ def _table_name(self, table_name: TableName) -> str:
return ".".join(part.name for part in exp.to_table(table_name).parts)
def _fetch_native_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> DF:
self.execute(query, quote_identifiers=quote_identifiers)
query_job = self._query_job
@@ -863,7 +865,7 @@ def _build_description_property_exp(
def _build_partitioned_by_exp(
self,
- partitioned_by: t.List[exp.Expression],
+ partitioned_by: t.List[exp.Expr],
*,
partition_interval_unit: t.Optional[IntervalUnit] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -909,16 +911,16 @@ def _build_table_properties_exp(
catalog_name: t.Optional[str] = None,
table_format: t.Optional[str] = None,
storage_format: t.Optional[str] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
partition_interval_unit: t.Optional[IntervalUnit] = None,
- clustered_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ clustered_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
table_kind: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
if partitioned_by and (
partitioned_by_prop := self._build_partitioned_by_exp(
@@ -1025,12 +1027,12 @@ def _build_col_comment_exp(
def _build_view_properties_exp(
self,
- view_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
table_description: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
"""Creates a SQLGlot table properties expression for view"""
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
if table_description:
properties.append(
@@ -1106,7 +1108,9 @@ def _execute(
else []
)
+ # Create job config
job_config = QueryJobConfig(**self._job_params, connection_properties=connection_properties)
+
self._query_job = self._db_call(
self.client.query,
query=sql,
@@ -1257,10 +1261,10 @@ def _update_clustering_key(self, operation: TableAlterClusterByOperation) -> Non
)
)
- def _normalize_decimal_value(self, col: exp.Expression, precision: int) -> exp.Expression:
+ def _normalize_decimal_value(self, col: exp.Expr, precision: int) -> exp.Expr:
return exp.func("FORMAT", exp.Literal.string(f"%.{precision}f"), col)
- def _normalize_nested_value(self, col: exp.Expression) -> exp.Expression:
+ def _normalize_nested_value(self, col: exp.Expr) -> exp.Expr:
return exp.func("TO_JSON_STRING", col, dialect=self.dialect)
@t.overload
@@ -1338,7 +1342,7 @@ def _get_current_schema(self) -> str:
def _get_bq_dataset_location(self, project: str, dataset: str) -> str:
return self._db_call(self.client.get_dataset, dataset_ref=f"{project}.{dataset}").location
- def _get_grant_expression(self, table: exp.Table) -> exp.Expression:
+ def _get_grant_expression(self, table: exp.Table) -> exp.Expr:
if not table.db:
raise ValueError(
f"Table {table.sql(dialect=self.dialect)} does not have a schema (dataset)"
@@ -1392,8 +1396,8 @@ def _dcl_grants_config_expr(
table: exp.Table,
grants_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
- ) -> t.List[exp.Expression]:
- expressions: t.List[exp.Expression] = []
+ ) -> t.List[exp.Expr]:
+ expressions: t.List[exp.Expr] = []
if not grants_config:
return expressions
diff --git a/sqlmesh/core/engine_adapter/clickhouse.py b/sqlmesh/core/engine_adapter/clickhouse.py
index 45c22a6e55..71a834ecfc 100644
--- a/sqlmesh/core/engine_adapter/clickhouse.py
+++ b/sqlmesh/core/engine_adapter/clickhouse.py
@@ -64,7 +64,7 @@ def cluster(self) -> t.Optional[str]:
# doesn't use the row index at all
def fetchone(
self,
- query: t.Union[exp.Expression, str],
+ query: t.Union[exp.Expr, str],
ignore_unsupported_errors: bool = False,
quote_identifiers: bool = False,
) -> t.Tuple:
@@ -77,13 +77,11 @@ def fetchone(
return self.cursor.fetchall()[0]
def _fetch_native_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> pd.DataFrame:
"""Fetches a Pandas DataFrame from the cursor"""
return self.cursor.client.query_df(
- self._to_sql(query, quote=quote_identifiers)
- if isinstance(query, exp.Expression)
- else query,
+ self._to_sql(query, quote=quote_identifiers) if isinstance(query, exp.Expr) else query,
use_extended_dtypes=True,
)
@@ -168,7 +166,7 @@ def create_schema(
schema_name: SchemaName,
ignore_if_exists: bool = True,
warn_on_error: bool = True,
- properties: t.List[exp.Expression] = [],
+ properties: t.List[exp.Expr] = [],
) -> None:
"""Create a Clickhouse database from a name or qualified table name.
@@ -229,7 +227,7 @@ def _insert_overwrite_by_condition(
# REPLACE BY KEY: extract kwargs if present
dynamic_key = kwargs.get("dynamic_key")
if dynamic_key:
- dynamic_key_exp = t.cast(exp.Expression, kwargs.get("dynamic_key_exp"))
+ dynamic_key_exp = t.cast(exp.Expr, kwargs.get("dynamic_key_exp"))
dynamic_key_unique = t.cast(bool, kwargs.get("dynamic_key_unique"))
try:
@@ -414,7 +412,7 @@ def _replace_by_key(
target_table: TableName,
source_table: QueryOrDF,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
- key: t.Sequence[exp.Expression],
+ key: t.Sequence[exp.Expr],
is_unique_key: bool,
source_columns: t.Optional[t.List[str]] = None,
) -> None:
@@ -440,7 +438,7 @@ def insert_overwrite_by_partition(
self,
table_name: TableName,
query_or_df: QueryOrDF,
- partitioned_by: t.List[exp.Expression],
+ partitioned_by: t.List[exp.Expr],
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
source_columns: t.Optional[t.List[str]] = None,
) -> None:
@@ -487,7 +485,7 @@ def _get_partition_ids(
def _create_table(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -595,7 +593,7 @@ def _rename_table(
self.execute(f"RENAME TABLE {old_table_sql} TO {new_table_sql}{self._on_cluster_sql()}")
- def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None:
+ def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None:
delete_expr = exp.delete(table_name, where)
if self.engine_run_mode.is_cluster:
delete_expr.set("cluster", exp.OnCluster(this=exp.to_identifier(self.cluster)))
@@ -649,7 +647,7 @@ def _drop_object(
def _build_partitioned_by_exp(
self,
- partitioned_by: t.List[exp.Expression],
+ partitioned_by: t.List[exp.Expr],
**kwargs: t.Any,
) -> t.Optional[t.Union[exp.PartitionedByProperty, exp.Property]]:
return exp.PartitionedByProperty(
@@ -714,14 +712,14 @@ def use_server_nulls_for_unmatched_after_join(
return query
def _build_settings_property(
- self, key: str, value: exp.Expression | str | int | float
+ self, key: str, value: exp.Expr | str | int | float
) -> exp.SettingsProperty:
return exp.SettingsProperty(
expressions=[
exp.EQ(
this=exp.var(key.lower()),
expression=value
- if isinstance(value, exp.Expression)
+ if isinstance(value, exp.Expr)
else exp.Literal(this=value, is_string=isinstance(value, str)),
)
]
@@ -732,17 +730,17 @@ def _build_table_properties_exp(
catalog_name: t.Optional[str] = None,
table_format: t.Optional[str] = None,
storage_format: t.Optional[str] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
partition_interval_unit: t.Optional[IntervalUnit] = None,
- clustered_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ clustered_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
table_kind: t.Optional[str] = None,
empty_ctas: bool = False,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
table_engine = self.DEFAULT_TABLE_ENGINE
if storage_format:
@@ -809,9 +807,7 @@ def _build_table_properties_exp(
ttl = table_properties_copy.pop("TTL", None)
if ttl:
properties.append(
- exp.MergeTreeTTL(
- expressions=[ttl if isinstance(ttl, exp.Expression) else exp.var(ttl)]
- )
+ exp.MergeTreeTTL(expressions=[ttl if isinstance(ttl, exp.Expr) else exp.var(ttl)])
)
if (
@@ -845,12 +841,12 @@ def _build_table_properties_exp(
def _build_view_properties_exp(
self,
- view_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
table_description: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
"""Creates a SQLGlot table properties expression for view"""
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
view_properties_copy = view_properties.copy() if view_properties else {}
diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py
index 97190492f2..e3d029a17d 100644
--- a/sqlmesh/core/engine_adapter/databricks.py
+++ b/sqlmesh/core/engine_adapter/databricks.py
@@ -78,21 +78,21 @@ def can_access_databricks_connect(cls, disable_databricks_connect: bool) -> bool
def _use_spark_session(self) -> bool:
if self.can_access_spark_session(bool(self._extra_config.get("disable_spark_session"))):
return True
- return (
- self.can_access_databricks_connect(
- bool(self._extra_config.get("disable_databricks_connect"))
- )
- and (
- {
- "databricks_connect_server_hostname",
- "databricks_connect_access_token",
- }.issubset(self._extra_config)
- )
- and (
- "databricks_connect_cluster_id" in self._extra_config
- or "databricks_connect_use_serverless" in self._extra_config
- )
- )
+
+ if self.can_access_databricks_connect(
+ bool(self._extra_config.get("disable_databricks_connect"))
+ ):
+ if self._extra_config.get("databricks_connect_use_serverless"):
+ return True
+
+ if {
+ "databricks_connect_cluster_id",
+ "databricks_connect_server_hostname",
+ "databricks_connect_access_token",
+ }.issubset(self._extra_config):
+ return True
+
+ return False
@property
def is_spark_session_connection(self) -> bool:
@@ -108,7 +108,7 @@ def _set_spark_engine_adapter_if_needed(self) -> None:
connect_kwargs = dict(
host=self._extra_config["databricks_connect_server_hostname"],
- token=self._extra_config["databricks_connect_access_token"],
+ token=self._extra_config.get("databricks_connect_access_token"),
)
if "databricks_connect_use_serverless" in self._extra_config:
connect_kwargs["serverless"] = True
@@ -163,7 +163,7 @@ def _grant_object_kind(table_type: DataObjectType) -> str:
return "MATERIALIZED VIEW"
return "TABLE"
- def _get_grant_expression(self, table: exp.Table) -> exp.Expression:
+ def _get_grant_expression(self, table: exp.Table) -> exp.Expr:
# We only care about explicitly granted privileges and not inherited ones
# if this is removed you would see grants inherited from the catalog get returned
expression = super()._get_grant_expression(table)
@@ -210,7 +210,7 @@ def query_factory() -> Query:
return [SourceQuery(query_factory=query_factory)]
def _fetch_native_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> DF:
"""Fetches a DataFrame that can be either Pandas or PySpark from the cursor"""
if self.is_spark_session_connection:
@@ -223,7 +223,7 @@ def _fetch_native_df(
return self.cursor.fetchall_arrow().to_pandas()
def fetchdf(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> pd.DataFrame:
"""
Returns a Pandas DataFrame from a query or expression.
@@ -364,10 +364,10 @@ def _build_table_properties_exp(
catalog_name: t.Optional[str] = None,
table_format: t.Optional[str] = None,
storage_format: t.Optional[str] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
partition_interval_unit: t.Optional[IntervalUnit] = None,
- clustered_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ clustered_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
table_kind: t.Optional[str] = None,
diff --git a/sqlmesh/core/engine_adapter/duckdb.py b/sqlmesh/core/engine_adapter/duckdb.py
index 3b057219e0..ebfcaa7901 100644
--- a/sqlmesh/core/engine_adapter/duckdb.py
+++ b/sqlmesh/core/engine_adapter/duckdb.py
@@ -145,7 +145,7 @@ def _get_data_objects(
for row in df.itertuples()
]
- def _normalize_decimal_value(self, col: exp.Expression, precision: int) -> exp.Expression:
+ def _normalize_decimal_value(self, col: exp.Expr, precision: int) -> exp.Expr:
"""
duckdb truncates instead of rounding when casting to decimal.
@@ -163,7 +163,7 @@ def _normalize_decimal_value(self, col: exp.Expression, precision: int) -> exp.E
def _create_table(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
diff --git a/sqlmesh/core/engine_adapter/mixins.py b/sqlmesh/core/engine_adapter/mixins.py
index c8ef32b9da..bf4bb970a2 100644
--- a/sqlmesh/core/engine_adapter/mixins.py
+++ b/sqlmesh/core/engine_adapter/mixins.py
@@ -38,9 +38,9 @@ def merge(
target_table: TableName,
source_table: QueryOrDF,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
when_matched: t.Optional[exp.Whens] = None,
- merge_filter: t.Optional[exp.Expression] = None,
+ merge_filter: t.Optional[exp.Expr] = None,
source_columns: t.Optional[t.List[str]] = None,
**kwargs: t.Any,
) -> None:
@@ -58,18 +58,14 @@ def merge(
class PandasNativeFetchDFSupportMixin(EngineAdapter):
def _fetch_native_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> DF:
"""Fetches a Pandas DataFrame from a SQL query."""
from warnings import catch_warnings, filterwarnings
from pandas.io.sql import read_sql_query
- sql = (
- self._to_sql(query, quote=quote_identifiers)
- if isinstance(query, exp.Expression)
- else query
- )
+ sql = self._to_sql(query, quote=quote_identifiers) if isinstance(query, exp.Expr) else query
logger.debug(f"Executing SQL:\n{sql}")
with catch_warnings(), self.transaction():
filterwarnings(
@@ -87,7 +83,7 @@ class HiveMetastoreTablePropertiesMixin(EngineAdapter):
def _build_partitioned_by_exp(
self,
- partitioned_by: t.List[exp.Expression],
+ partitioned_by: t.List[exp.Expr],
*,
catalog_name: t.Optional[str] = None,
**kwargs: t.Any,
@@ -120,16 +116,16 @@ def _build_table_properties_exp(
catalog_name: t.Optional[str] = None,
table_format: t.Optional[str] = None,
storage_format: t.Optional[str] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
partition_interval_unit: t.Optional[IntervalUnit] = None,
- clustered_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ clustered_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
table_kind: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
if table_format and self.dialect == "spark":
properties.append(exp.FileFormatProperty(this=exp.Var(this=table_format)))
@@ -166,12 +162,12 @@ def _build_table_properties_exp(
def _build_view_properties_exp(
self,
- view_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
table_description: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
"""Creates a SQLGlot table properties expression for view"""
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
if table_description:
properties.append(
@@ -194,7 +190,7 @@ def _truncate_comment(self, comment: str, length: t.Optional[int]) -> str:
class GetCurrentCatalogFromFunctionMixin(EngineAdapter):
- CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.func("current_catalog")
+ CURRENT_CATALOG_EXPRESSION: exp.Expr = exp.func("current_catalog")
def get_current_catalog(self) -> t.Optional[str]:
"""Returns the catalog name of the current connection."""
@@ -240,7 +236,7 @@ def _default_precision_to_max(
def _build_create_table_exp(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -322,11 +318,11 @@ def is_destructive(self) -> bool:
return False
@property
- def _alter_actions(self) -> t.List[exp.Expression]:
+ def _alter_actions(self) -> t.List[exp.Expr]:
return [exp.Cluster(expressions=self.cluster_key_expressions)]
@property
- def cluster_key_expressions(self) -> t.List[exp.Expression]:
+ def cluster_key_expressions(self) -> t.List[exp.Expr]:
# Note: Assumes `clustering_key` as a string like:
# - "(col_a)"
# - "(col_a, col_b)"
@@ -346,14 +342,14 @@ def is_destructive(self) -> bool:
return False
@property
- def _alter_actions(self) -> t.List[exp.Expression]:
+ def _alter_actions(self) -> t.List[exp.Expr]:
return [exp.Command(this="DROP", expression="CLUSTERING KEY")]
class ClusteredByMixin(EngineAdapter):
def _build_clustered_by_exp(
self,
- clustered_by: t.List[exp.Expression],
+ clustered_by: t.List[exp.Expr],
**kwargs: t.Any,
) -> t.Optional[exp.Cluster]:
return exp.Cluster(expressions=[c.copy() for c in clustered_by])
@@ -410,9 +406,9 @@ def logical_merge(
target_table: TableName,
source_table: QueryOrDF,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
when_matched: t.Optional[exp.Whens] = None,
- merge_filter: t.Optional[exp.Expression] = None,
+ merge_filter: t.Optional[exp.Expr] = None,
source_columns: t.Optional[t.List[str]] = None,
) -> None:
"""
@@ -452,12 +448,12 @@ def concat_columns(
decimal_precision: int = 3,
timestamp_precision: int = MAX_TIMESTAMP_PRECISION,
delimiter: str = ",",
- ) -> exp.Expression:
+ ) -> exp.Expr:
"""
Produce an expression that generates a string version of a record, that is:
- Every column converted to a string representation, joined together into a single string using the specified :delimiter
"""
- expressions_to_concat: t.List[exp.Expression] = []
+ expressions_to_concat: t.List[exp.Expr] = []
for idx, (column, type) in enumerate(columns_to_types.items()):
expressions_to_concat.append(
exp.func(
@@ -475,11 +471,11 @@ def concat_columns(
def normalize_value(
self,
- expr: exp.Expression,
+ expr: exp.Expr,
type: exp.DataType,
decimal_precision: int = 3,
timestamp_precision: int = MAX_TIMESTAMP_PRECISION,
- ) -> exp.Expression:
+ ) -> exp.Expr:
"""
Return an expression that converts the values inside the column `col` to a normalized string
@@ -490,6 +486,7 @@ def normalize_value(
- `boolean` columns -> '1' or '0'
- NULLS -> "" (empty string)
"""
+ value: exp.Expr
if type.is_type(exp.DataType.Type.BOOLEAN):
value = self._normalize_boolean_value(expr)
elif type.is_type(*exp.DataType.INTEGER_TYPES):
@@ -512,12 +509,12 @@ def normalize_value(
return exp.cast(value, to=exp.DataType.build("VARCHAR"))
- def _normalize_nested_value(self, expr: exp.Expression) -> exp.Expression:
+ def _normalize_nested_value(self, expr: exp.Expr) -> exp.Expr:
return expr
def _normalize_timestamp_value(
- self, expr: exp.Expression, type: exp.DataType, precision: int
- ) -> exp.Expression:
+ self, expr: exp.Expr, type: exp.DataType, precision: int
+ ) -> exp.Expr:
if precision > self.MAX_TIMESTAMP_PRECISION:
raise ValueError(
f"Requested timestamp precision '{precision}' exceeds maximum supported precision: {self.MAX_TIMESTAMP_PRECISION}"
@@ -547,18 +544,18 @@ def _normalize_timestamp_value(
return expr
- def _normalize_integer_value(self, expr: exp.Expression) -> exp.Expression:
+ def _normalize_integer_value(self, expr: exp.Expr) -> exp.Expr:
return exp.cast(expr, "BIGINT")
- def _normalize_decimal_value(self, expr: exp.Expression, precision: int) -> exp.Expression:
+ def _normalize_decimal_value(self, expr: exp.Expr, precision: int) -> exp.Expr:
return exp.cast(expr, f"DECIMAL(38,{precision})")
- def _normalize_boolean_value(self, expr: exp.Expression) -> exp.Expression:
+ def _normalize_boolean_value(self, expr: exp.Expr) -> exp.Expr:
return exp.cast(expr, "INT")
class GrantsFromInfoSchemaMixin(EngineAdapter):
- CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("current_user")
+ CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expr = exp.func("current_user")
SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = False
USE_CATALOG_IN_GRANTS = False
GRANT_INFORMATION_SCHEMA_TABLE_NAME = "table_privileges"
@@ -578,8 +575,8 @@ def _dcl_grants_config_expr(
table: exp.Table,
grants_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
- ) -> t.List[exp.Expression]:
- expressions: t.List[exp.Expression] = []
+ ) -> t.List[exp.Expr]:
+ expressions: t.List[exp.Expr] = []
if not grants_config:
return expressions
@@ -617,7 +614,7 @@ def _apply_grants_config_expr(
table: exp.Table,
grants_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
return self._dcl_grants_config_expr(exp.Grant, table, grants_config, table_type)
def _revoke_grants_config_expr(
@@ -625,10 +622,10 @@ def _revoke_grants_config_expr(
table: exp.Table,
grants_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
return self._dcl_grants_config_expr(exp.Revoke, table, grants_config, table_type)
- def _get_grant_expression(self, table: exp.Table) -> exp.Expression:
+ def _get_grant_expression(self, table: exp.Table) -> exp.Expr:
schema_identifier = table.args.get("db") or normalize_identifiers(
exp.to_identifier(self._get_current_schema(), quoted=True), dialect=self.dialect
)
diff --git a/sqlmesh/core/engine_adapter/mssql.py b/sqlmesh/core/engine_adapter/mssql.py
index 359d1f0818..e381c0a198 100644
--- a/sqlmesh/core/engine_adapter/mssql.py
+++ b/sqlmesh/core/engine_adapter/mssql.py
@@ -176,7 +176,7 @@ def drop_schema(
schema_name: SchemaName,
ignore_if_not_exists: bool = True,
cascade: bool = False,
- **drop_args: t.Dict[str, exp.Expression],
+ **drop_args: t.Dict[str, exp.Expr],
) -> None:
"""
MsSql doesn't support CASCADE clause and drops schemas unconditionally.
@@ -205,9 +205,9 @@ def merge(
target_table: TableName,
source_table: QueryOrDF,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
when_matched: t.Optional[exp.Whens] = None,
- merge_filter: t.Optional[exp.Expression] = None,
+ merge_filter: t.Optional[exp.Expr] = None,
source_columns: t.Optional[t.List[str]] = None,
**kwargs: t.Any,
) -> None:
@@ -401,7 +401,7 @@ def _get_data_objects(
for row in dataframe.itertuples()
]
- def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.Any) -> str:
+ def _to_sql(self, expression: exp.Expr, quote: bool = True, **kwargs: t.Any) -> str:
sql = super()._to_sql(expression, quote=quote, **kwargs)
return f"{sql};"
@@ -448,7 +448,7 @@ def _insert_overwrite_by_condition(
**kwargs,
)
- def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None:
+ def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None:
if where == exp.true():
# "A TRUNCATE TABLE operation can be rolled back within a transaction."
# ref: https://learn.microsoft.com/en-us/sql/t-sql/statements/truncate-table-transact-sql?view=sql-server-ver15#remarks
diff --git a/sqlmesh/core/engine_adapter/mysql.py b/sqlmesh/core/engine_adapter/mysql.py
index 31773d6c63..66759dc440 100644
--- a/sqlmesh/core/engine_adapter/mysql.py
+++ b/sqlmesh/core/engine_adapter/mysql.py
@@ -73,7 +73,7 @@ def drop_schema(
schema_name: SchemaName,
ignore_if_not_exists: bool = True,
cascade: bool = False,
- **drop_args: t.Dict[str, exp.Expression],
+ **drop_args: t.Dict[str, exp.Expr],
) -> None:
# MySQL doesn't support CASCADE clause and drops schemas unconditionally.
super().drop_schema(schema_name, ignore_if_not_exists=ignore_if_not_exists, cascade=False)
diff --git a/sqlmesh/core/engine_adapter/postgres.py b/sqlmesh/core/engine_adapter/postgres.py
index 3dd108cf91..6794169322 100644
--- a/sqlmesh/core/engine_adapter/postgres.py
+++ b/sqlmesh/core/engine_adapter/postgres.py
@@ -40,7 +40,7 @@ class PostgresEngineAdapter(
MAX_IDENTIFIER_LENGTH: t.Optional[int] = 63
SUPPORTS_QUERY_EXECUTION_TRACKING = True
GRANT_INFORMATION_SCHEMA_TABLE_NAME = "role_table_grants"
- CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.column("current_role")
+ CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expr = exp.column("current_role")
SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True
SCHEMA_DIFFER_KWARGS = {
"parameterized_type_defaults": {
@@ -73,7 +73,7 @@ class PostgresEngineAdapter(
}
def _fetch_native_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> DF:
"""
`read_sql_query` when using psycopg will result on a hanging transaction that must be committed
@@ -113,9 +113,9 @@ def merge(
target_table: TableName,
source_table: QueryOrDF,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
when_matched: t.Optional[exp.Whens] = None,
- merge_filter: t.Optional[exp.Expression] = None,
+ merge_filter: t.Optional[exp.Expr] = None,
source_columns: t.Optional[t.List[str]] = None,
**kwargs: t.Any,
) -> None:
diff --git a/sqlmesh/core/engine_adapter/redshift.py b/sqlmesh/core/engine_adapter/redshift.py
index 03dc89053e..c2a27954cd 100644
--- a/sqlmesh/core/engine_adapter/redshift.py
+++ b/sqlmesh/core/engine_adapter/redshift.py
@@ -143,7 +143,7 @@ def cursor(self) -> t.Any:
return cursor
def _fetch_native_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> pd.DataFrame:
"""Fetches a Pandas DataFrame from the cursor"""
import pandas as pd
@@ -217,7 +217,7 @@ def create_view(
materialized_properties: t.Optional[t.Dict[str, t.Any]] = None,
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
- view_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
source_columns: t.Optional[t.List[str]] = None,
**create_kwargs: t.Any,
) -> None:
@@ -227,7 +227,7 @@ def create_view(
swap tables out from under views. Therefore, we create the view as non-binding.
"""
no_schema_binding = True
- if isinstance(query_or_df, exp.Expression):
+ if isinstance(query_or_df, exp.Expr):
# We can't include NO SCHEMA BINDING if the query has a recursive CTE
has_recursive_cte = any(
w.args.get("recursive", False) for w in query_or_df.find_all(exp.With)
@@ -367,9 +367,9 @@ def merge(
target_table: TableName,
source_table: QueryOrDF,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
when_matched: t.Optional[exp.Whens] = None,
- merge_filter: t.Optional[exp.Expression] = None,
+ merge_filter: t.Optional[exp.Expr] = None,
source_columns: t.Optional[t.List[str]] = None,
**kwargs: t.Any,
) -> None:
@@ -400,12 +400,12 @@ def _merge(
self,
target_table: TableName,
query: Query,
- on: exp.Expression,
+ on: exp.Expr,
whens: exp.Whens,
) -> None:
# Redshift does not support table aliases in the target table of a MERGE statement.
# So we must use the actual table name instead of an alias, as we do with the source table.
- def resolve_target_table(expression: exp.Expression) -> exp.Expression:
+ def resolve_target_table(expression: exp.Expr) -> exp.Expr:
if (
isinstance(expression, exp.Column)
and expression.table.upper() == MERGE_TARGET_ALIAS
@@ -436,7 +436,7 @@ def resolve_target_table(expression: exp.Expression) -> exp.Expression:
track_rows_processed=True,
)
- def _normalize_decimal_value(self, expr: exp.Expression, precision: int) -> exp.Expression:
+ def _normalize_decimal_value(self, expr: exp.Expr, precision: int) -> exp.Expr:
# Redshift is finicky. It truncates when the data is already in a table, but rounds when the data is generated as part of a SELECT.
#
# The following works:
diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py
index a8eabe070d..09c530b8f3 100644
--- a/sqlmesh/core/engine_adapter/snowflake.py
+++ b/sqlmesh/core/engine_adapter/snowflake.py
@@ -83,7 +83,7 @@ class SnowflakeEngineAdapter(
SNOWPARK = "snowpark"
SUPPORTS_QUERY_EXECUTION_TRACKING = True
SUPPORTS_GRANTS = True
- CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("CURRENT_ROLE")
+ CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expr = exp.func("CURRENT_ROLE")
USE_CATALOG_IN_GRANTS = True
@contextlib.contextmanager
@@ -95,7 +95,7 @@ def session(self, properties: SessionProperties) -> t.Iterator[None]:
if isinstance(warehouse, str):
warehouse = exp.to_identifier(warehouse)
- if not isinstance(warehouse, exp.Expression):
+ if not isinstance(warehouse, exp.Expr):
raise SQLMeshError(f"Invalid warehouse: '{warehouse}'")
warehouse_exp = quote_identifiers(
@@ -189,7 +189,7 @@ def _drop_catalog(self, catalog_name: exp.Identifier) -> None:
def _create_table(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -225,9 +225,9 @@ def create_managed_table(
table_name: TableName,
query: Query,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
- clustered_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
+ clustered_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
source_columns: t.Optional[t.List[str]] = None,
@@ -278,7 +278,7 @@ def create_view(
materialized_properties: t.Optional[t.Dict[str, t.Any]] = None,
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
- view_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
source_columns: t.Optional[t.List[str]] = None,
**create_kwargs: t.Any,
) -> None:
@@ -311,16 +311,16 @@ def _build_table_properties_exp(
catalog_name: t.Optional[str] = None,
table_format: t.Optional[str] = None,
storage_format: t.Optional[str] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
partition_interval_unit: t.Optional[IntervalUnit] = None,
- clustered_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ clustered_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
table_kind: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
# TODO: there is some overlap with the base class and other engine adapters
# we need a better way of filtering table properties relevent to the current engine
@@ -471,7 +471,7 @@ def cleanup() -> None:
return [SourceQuery(query_factory=query_factory, cleanup_func=cleanup)]
def _fetch_native_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> DF:
import pandas as pd
from snowflake.connector.errors import NotSupportedError
@@ -561,7 +561,7 @@ def _get_data_objects(
for row in df.rename(columns={col: col.lower() for col in df.columns}).itertuples()
]
- def _get_grant_expression(self, table: exp.Table) -> exp.Expression:
+ def _get_grant_expression(self, table: exp.Table) -> exp.Expr:
# Upon execute the catalog in table expressions are properly normalized to handle the case where a user provides
# the default catalog in their connection config. This doesn't though update catalogs in strings like when querying
# the information schema. So we need to manually replace those here.
@@ -586,7 +586,7 @@ def set_current_catalog(self, catalog: str) -> None:
def set_current_schema(self, schema: str) -> None:
self.execute(exp.Use(kind="SCHEMA", this=to_schema(schema)))
- def _normalize_catalog(self, expression: exp.Expression) -> exp.Expression:
+ def _normalize_catalog(self, expression: exp.Expr) -> exp.Expr:
# note: important to use self._default_catalog instead of the self.default_catalog property
# otherwise we get RecursionError: maximum recursion depth exceeded
# because it calls get_current_catalog(), which executes a query, which needs the default catalog, which calls get_current_catalog()... etc
@@ -604,7 +604,7 @@ def unquote_and_lower(identifier: str) -> str:
self._default_catalog, dialect=self.dialect
)
- def catalog_rewriter(node: exp.Expression) -> exp.Expression:
+ def catalog_rewriter(node: exp.Expr) -> exp.Expr:
if isinstance(node, exp.Table):
if node.catalog:
# only replace the catalog on the model with the target catalog if the two are functionally equivalent
@@ -621,7 +621,7 @@ def catalog_rewriter(node: exp.Expression) -> exp.Expression:
expression = expression.transform(catalog_rewriter)
return expression
- def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.Any) -> str:
+ def _to_sql(self, expression: exp.Expr, quote: bool = True, **kwargs: t.Any) -> str:
return super()._to_sql(
expression=self._normalize_catalog(expression), quote=quote, **kwargs
)
diff --git a/sqlmesh/core/engine_adapter/spark.py b/sqlmesh/core/engine_adapter/spark.py
index 5216b0a329..9199aa3bcd 100644
--- a/sqlmesh/core/engine_adapter/spark.py
+++ b/sqlmesh/core/engine_adapter/spark.py
@@ -340,12 +340,12 @@ def _get_temp_table(
return table
def fetchdf(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> pd.DataFrame:
return self.fetch_pyspark_df(query, quote_identifiers=quote_identifiers).toPandas()
def fetch_pyspark_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> PySparkDataFrame:
return self._ensure_pyspark_df(
self._fetch_native_df(query, quote_identifiers=quote_identifiers)
@@ -437,7 +437,7 @@ def _native_df_to_pandas_df(
def _create_table(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py
index 74df3667ff..00acddb26c 100644
--- a/sqlmesh/core/engine_adapter/trino.py
+++ b/sqlmesh/core/engine_adapter/trino.py
@@ -74,6 +74,32 @@ class TrinoEngineAdapter(
def schema_location_mapping(self) -> t.Optional[t.Dict[re.Pattern, str]]:
return self._extra_config.get("schema_location_mapping")
+ @property
+ def timestamp_mapping(self) -> t.Optional[t.Dict[exp.DataType, exp.DataType]]:
+ return self._extra_config.get("timestamp_mapping")
+
+ def _apply_timestamp_mapping(
+ self, columns_to_types: t.Dict[str, exp.DataType]
+ ) -> t.Tuple[t.Dict[str, exp.DataType], t.Set[str]]:
+ """Apply custom timestamp mapping to column types.
+
+ Returns:
+ A tuple of (mapped_columns_to_types, mapped_column_names) where mapped_column_names
+ contains the names of columns that were found in the mapping.
+ """
+ if not self.timestamp_mapping:
+ return columns_to_types, set()
+
+ result = {}
+ mapped_columns: t.Set[str] = set()
+ for column, column_type in columns_to_types.items():
+ if column_type in self.timestamp_mapping:
+ result[column] = self.timestamp_mapping[column_type]
+ mapped_columns.add(column)
+ else:
+ result[column] = column_type
+ return result, mapped_columns
+
@property
def catalog_support(self) -> CatalogSupport:
return CatalogSupport.FULL_SUPPORT
@@ -103,7 +129,7 @@ def session(self, properties: SessionProperties) -> t.Iterator[None]:
yield
return
- if not isinstance(authorization, exp.Expression):
+ if not isinstance(authorization, exp.Expr):
authorization = exp.Literal.string(authorization)
if not authorization.is_string:
@@ -117,7 +143,7 @@ def session(self, properties: SessionProperties) -> t.Iterator[None]:
try:
yield
finally:
- self.execute(f"RESET SESSION AUTHORIZATION")
+ self.execute("RESET SESSION AUTHORIZATION")
def replace_query(
self,
@@ -286,8 +312,11 @@ def _build_schema_exp(
is_view: bool = False,
materialized: bool = False,
) -> exp.Schema:
+ target_columns_to_types, mapped_columns = self._apply_timestamp_mapping(
+ target_columns_to_types
+ )
if "delta_lake" in self.get_catalog_type_from_table(table):
- target_columns_to_types = self._to_delta_ts(target_columns_to_types)
+ target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns)
return super()._build_schema_exp(
table, target_columns_to_types, column_descriptions, expressions, is_view
@@ -297,13 +326,13 @@ def _scd_type_2(
self,
target_table: TableName,
source_table: QueryOrDF,
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
valid_from_col: exp.Column,
valid_to_col: exp.Column,
execution_time: t.Union[TimeLike, exp.Column],
invalidate_hard_deletes: bool = True,
updated_at_col: t.Optional[exp.Column] = None,
- check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expression]]] = None,
+ check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expr]]] = None,
updated_at_as_valid_from: bool = False,
execution_time_as_valid_from: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -313,10 +342,15 @@ def _scd_type_2(
source_columns: t.Optional[t.List[str]] = None,
**kwargs: t.Any,
) -> None:
+ mapped_columns: t.Set[str] = set()
+ if target_columns_to_types:
+ target_columns_to_types, mapped_columns = self._apply_timestamp_mapping(
+ target_columns_to_types
+ )
if target_columns_to_types and "delta_lake" in self.get_catalog_type_from_table(
target_table
):
- target_columns_to_types = self._to_delta_ts(target_columns_to_types)
+ target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns)
return super()._scd_type_2(
target_table,
@@ -346,18 +380,21 @@ def _scd_type_2(
# - `timestamp(3) with time zone` for timezone-aware
# https://trino.io/docs/current/connector/delta-lake.html#delta-lake-to-trino-type-mapping
def _to_delta_ts(
- self, columns_to_types: t.Dict[str, exp.DataType]
+ self,
+ columns_to_types: t.Dict[str, exp.DataType],
+ skip_columns: t.Optional[t.Set[str]] = None,
) -> t.Dict[str, exp.DataType]:
ts6 = exp.DataType.build("timestamp(6)")
ts3_tz = exp.DataType.build("timestamp(3) with time zone")
+ skip = skip_columns or set()
delta_columns_to_types = {
- k: ts6 if v.is_type(exp.DataType.Type.TIMESTAMP) else v
+ k: ts6 if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMP) else v
for k, v in columns_to_types.items()
}
delta_columns_to_types = {
- k: ts3_tz if v.is_type(exp.DataType.Type.TIMESTAMPTZ) else v
+ k: ts3_tz if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMPTZ) else v
for k, v in delta_columns_to_types.items()
}
@@ -372,7 +409,7 @@ def _create_schema(
schema_name: SchemaName,
ignore_if_exists: bool,
warn_on_error: bool,
- properties: t.List[exp.Expression],
+ properties: t.List[exp.Expr],
kind: str,
) -> None:
if mapped_location := self._schema_location(schema_name):
@@ -389,7 +426,7 @@ def _create_schema(
def _create_table(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
diff --git a/sqlmesh/core/lineage.py b/sqlmesh/core/lineage.py
index 777a2a7d9a..8363979034 100644
--- a/sqlmesh/core/lineage.py
+++ b/sqlmesh/core/lineage.py
@@ -16,7 +16,7 @@
from sqlmesh.core.model import Model
-CACHE: t.Dict[str, t.Tuple[int, exp.Expression, Scope]] = {}
+CACHE: t.Dict[str, t.Tuple[int, exp.Expr, Scope]] = {}
def lineage(
@@ -25,8 +25,8 @@ def lineage(
trim_selects: bool = True,
**kwargs: t.Any,
) -> Node:
- query = None
- scope = None
+ query: t.Optional[exp.Expr] = None
+ scope: t.Optional[Scope] = None
if model.name in CACHE:
obj_id, query, scope = CACHE[model.name]
diff --git a/sqlmesh/core/linter/rules/builtin.py b/sqlmesh/core/linter/rules/builtin.py
index c28822a154..4547ac0528 100644
--- a/sqlmesh/core/linter/rules/builtin.py
+++ b/sqlmesh/core/linter/rules/builtin.py
@@ -318,4 +318,4 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]:
return None
-BUILTIN_RULES = RuleSet(subclasses(__name__, Rule, (Rule,)))
+BUILTIN_RULES = RuleSet(subclasses(__name__, Rule, exclude={Rule}))
diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py
index a43f5f28ff..4b7b1bac02 100644
--- a/sqlmesh/core/loader.py
+++ b/sqlmesh/core/loader.py
@@ -840,7 +840,7 @@ def _load_linting_rules(self) -> RuleSet:
if os.path.getsize(path):
self._track_file(path)
module = import_python_file(path, self.config_path)
- module_rules = subclasses(module.__name__, Rule, (Rule,))
+ module_rules = subclasses(module.__name__, Rule, exclude={Rule})
for user_rule in module_rules:
user_rules[user_rule.name] = user_rule
diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py
index af7c344081..888acbb8eb 100644
--- a/sqlmesh/core/macros.py
+++ b/sqlmesh/core/macros.py
@@ -110,7 +110,7 @@ def _macro_sql(sql: str, into: t.Optional[str] = None) -> str:
return f"self.parse_one({', '.join(args)})"
-def _macro_func_sql(self: Generator, e: exp.Expression) -> str:
+def _macro_func_sql(self: Generator, e: exp.Expr) -> str:
func = e.this
if isinstance(func, exp.Anonymous):
@@ -178,7 +178,7 @@ def __init__(
schema: t.Optional[MappingSchema] = None,
runtime_stage: RuntimeStage = RuntimeStage.LOADING,
resolve_table: t.Optional[t.Callable[[str | exp.Table], str]] = None,
- resolve_tables: t.Optional[t.Callable[[exp.Expression], exp.Expression]] = None,
+ resolve_tables: t.Optional[t.Callable[[exp.Expr], exp.Expr]] = None,
snapshots: t.Optional[t.Dict[str, Snapshot]] = None,
default_catalog: t.Optional[str] = None,
path: t.Optional[Path] = None,
@@ -237,7 +237,7 @@ def __init__(
def send(
self, name: str, *args: t.Any, **kwargs: t.Any
- ) -> t.Union[None, exp.Expression, t.List[exp.Expression]]:
+ ) -> t.Union[None, exp.Expr, t.List[exp.Expr]]:
func = self.macros.get(normalize_macro_name(name))
if not callable(func):
@@ -253,14 +253,12 @@ def send(
+ format_evaluated_code_exception(e, self.python_env)
)
- def transform(
- self, expression: exp.Expression
- ) -> exp.Expression | t.List[exp.Expression] | None:
+ def transform(self, expression: exp.Expr) -> exp.Expr | t.List[exp.Expr] | None:
changed = False
def evaluate_macros(
- node: exp.Expression,
- ) -> exp.Expression | t.List[exp.Expression] | None:
+ node: exp.Expr,
+ ) -> exp.Expr | t.List[exp.Expr] | None:
nonlocal changed
if isinstance(node, MacroVar):
@@ -281,14 +279,10 @@ def evaluate_macros(
value = self.locals.get(var_name, variables.get(var_name))
if isinstance(value, list):
return exp.convert(
- tuple(
- self.transform(v) if isinstance(v, exp.Expression) else v for v in value
- )
+ tuple(self.transform(v) if isinstance(v, exp.Expr) else v for v in value)
)
- return exp.convert(
- self.transform(value) if isinstance(value, exp.Expression) else value
- )
+ return exp.convert(self.transform(value) if isinstance(value, exp.Expr) else value)
if isinstance(node, exp.Identifier) and "@" in node.this:
text = self.template(node.this, {})
if node.this != text:
@@ -311,7 +305,7 @@ def evaluate_macros(
self.parse_one(node.sql(dialect=self.dialect, copy=False))
for node in transformed
]
- if isinstance(transformed, exp.Expression):
+ if isinstance(transformed, exp.Expr):
return self.parse_one(transformed.sql(dialect=self.dialect, copy=False))
return transformed
@@ -339,7 +333,7 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str:
}
return MacroStrTemplate(str(text)).safe_substitute(CaseInsensitiveMapping(base_mapping))
- def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | None:
+ def evaluate(self, node: MacroFunc) -> exp.Expr | t.List[exp.Expr] | None:
if isinstance(node, MacroDef):
if isinstance(node.expression, exp.Lambda):
_, fn = _norm_var_arg_lambda(self, node.expression)
@@ -353,7 +347,7 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] |
return node
if isinstance(node, (MacroSQL, MacroStrReplace)):
- result: t.Optional[exp.Expression | t.List[exp.Expression]] = exp.convert(
+ result: t.Optional[exp.Expr | t.List[exp.Expr]] = exp.convert(
self.eval_expression(node)
)
else:
@@ -421,7 +415,7 @@ def eval_expression(self, node: t.Any) -> t.Any:
Returns:
The return value of the evaled Python Code.
"""
- if not isinstance(node, exp.Expression):
+ if not isinstance(node, exp.Expr):
return node
code = node.sql()
try:
@@ -434,8 +428,8 @@ def eval_expression(self, node: t.Any) -> t.Any:
)
def parse_one(
- self, sql: str | exp.Expression, into: t.Optional[exp.IntoType] = None, **opts: t.Any
- ) -> exp.Expression:
+ self, sql: str | exp.Expr, into: t.Optional[exp.IntoType] = None, **opts: t.Any
+ ) -> exp.Expr:
"""Parses the given SQL string and returns a syntax tree for the first
parsed SQL statement.
@@ -497,7 +491,7 @@ def resolve_table(self, table: str | exp.Table) -> str:
)
return self._resolve_table(table)
- def resolve_tables(self, query: exp.Expression) -> exp.Expression:
+ def resolve_tables(self, query: exp.Expr) -> exp.Expr:
"""Resolves queries with references to SQLMesh model names to their physical tables."""
if not self._resolve_tables:
raise SQLMeshError(
@@ -588,7 +582,7 @@ def variables(self) -> t.Dict[str, t.Any]:
**self.locals.get(c.SQLMESH_BLUEPRINT_VARS_METADATA, {}),
}
- def _coerce(self, expr: exp.Expression, typ: t.Any, strict: bool = False) -> t.Any:
+ def _coerce(self, expr: exp.Expr, typ: t.Any, strict: bool = False) -> t.Any:
"""Coerces the given expression to the specified type on a best-effort basis."""
return _coerce(expr, typ, self.dialect, self._path, strict)
@@ -648,8 +642,8 @@ def _norm_var_arg_lambda(
"""
def substitute(
- node: exp.Expression, args: t.Dict[str, exp.Expression]
- ) -> exp.Expression | t.List[exp.Expression] | None:
+ node: exp.Expr, args: t.Dict[str, exp.Expr]
+ ) -> exp.Expr | t.List[exp.Expr] | None:
if isinstance(node, (exp.Identifier, exp.Var)):
if not isinstance(node.parent, exp.Column):
name = node.name.lower()
@@ -798,8 +792,8 @@ def filter_(evaluator: MacroEvaluator, *args: t.Any) -> t.List[t.Any]:
def _optional_expression(
evaluator: MacroEvaluator,
condition: exp.Condition,
- expression: exp.Expression,
-) -> t.Optional[exp.Expression]:
+ expression: exp.Expr,
+) -> t.Optional[exp.Expr]:
"""Inserts expression when the condition is True
The following examples express the usage of this function in the context of the macros which wrap it.
@@ -864,7 +858,7 @@ def star(
suffix: exp.Literal = exp.Literal.string(""),
quote_identifiers: exp.Boolean = exp.true(),
except_: t.Union[exp.Array, exp.Tuple] = exp.Tuple(expressions=[]),
-) -> t.List[exp.Alias]:
+) -> t.List[exp.Expr]:
"""Returns a list of projections for the given relation.
Args:
@@ -939,7 +933,7 @@ def star(
@macro()
def generate_surrogate_key(
evaluator: MacroEvaluator,
- *fields: exp.Expression,
+ *fields: exp.Expr,
hash_function: exp.Literal = exp.Literal.string("MD5"),
) -> exp.Func:
"""Generates a surrogate key (string) for the given fields.
@@ -956,7 +950,7 @@ def generate_surrogate_key(
>>> MacroEvaluator(dialect="bigquery").transform(parse_one(sql, dialect="bigquery")).sql("bigquery")
"SELECT SHA256(CONCAT(COALESCE(CAST(a AS STRING), '_sqlmesh_surrogate_key_null_'), '|', COALESCE(CAST(b AS STRING), '_sqlmesh_surrogate_key_null_'), '|', COALESCE(CAST(c AS STRING), '_sqlmesh_surrogate_key_null_'))) FROM foo"
"""
- string_fields: t.List[exp.Expression] = []
+ string_fields: t.List[exp.Expr] = []
for i, field in enumerate(fields):
if i > 0:
string_fields.append(exp.Literal.string("|"))
@@ -980,7 +974,7 @@ def generate_surrogate_key(
@macro()
-def safe_add(_: MacroEvaluator, *fields: exp.Expression) -> exp.Case:
+def safe_add(_: MacroEvaluator, *fields: exp.Expr) -> exp.Case:
"""Adds numbers together, substitutes nulls for 0s and only returns null if all fields are null.
Example:
@@ -998,7 +992,7 @@ def safe_add(_: MacroEvaluator, *fields: exp.Expression) -> exp.Case:
@macro()
-def safe_sub(_: MacroEvaluator, *fields: exp.Expression) -> exp.Case:
+def safe_sub(_: MacroEvaluator, *fields: exp.Expr) -> exp.Case:
"""Subtract numbers, substitutes nulls for 0s and only returns null if all fields are null.
Example:
@@ -1016,7 +1010,7 @@ def safe_sub(_: MacroEvaluator, *fields: exp.Expression) -> exp.Case:
@macro()
-def safe_div(_: MacroEvaluator, numerator: exp.Expression, denominator: exp.Expression) -> exp.Div:
+def safe_div(_: MacroEvaluator, numerator: exp.Expr, denominator: exp.Expr) -> exp.Div:
"""Divides numbers, returns null if the denominator is 0.
Example:
@@ -1032,7 +1026,7 @@ def safe_div(_: MacroEvaluator, numerator: exp.Expression, denominator: exp.Expr
@macro()
def union(
evaluator: MacroEvaluator,
- *args: exp.Expression,
+ *args: exp.Expr,
) -> exp.Query:
"""Returns a UNION of the given tables. Only choosing columns that have the same name and type.
@@ -1107,10 +1101,10 @@ def union(
@macro()
def haversine_distance(
_: MacroEvaluator,
- lat1: exp.Expression,
- lon1: exp.Expression,
- lat2: exp.Expression,
- lon2: exp.Expression,
+ lat1: exp.Expr,
+ lon1: exp.Expr,
+ lat2: exp.Expr,
+ lon2: exp.Expr,
unit: exp.Literal = exp.Literal.string("mi"),
) -> exp.Mul:
"""Returns the haversine distance between two points.
@@ -1150,17 +1144,17 @@ def haversine_distance(
def pivot(
evaluator: MacroEvaluator,
column: SQL,
- values: t.List[exp.Expression],
+ values: t.List[exp.Expr],
alias: bool = True,
- agg: exp.Expression = exp.Literal.string("SUM"),
- cmp: exp.Expression = exp.Literal.string("="),
- prefix: exp.Expression = exp.Literal.string(""),
- suffix: exp.Expression = exp.Literal.string(""),
+ agg: exp.Expr = exp.Literal.string("SUM"),
+ cmp: exp.Expr = exp.Literal.string("="),
+ prefix: exp.Expr = exp.Literal.string(""),
+ suffix: exp.Expr = exp.Literal.string(""),
then_value: SQL = SQL("1"),
else_value: SQL = SQL("0"),
quote: bool = True,
distinct: bool = False,
-) -> t.List[exp.Expression]:
+) -> t.List[exp.Expr]:
"""Returns a list of projections as a result of pivoting the given column on the given values.
Example:
@@ -1173,14 +1167,14 @@ def pivot(
>>> MacroEvaluator(dialect="bigquery").transform(parse_one(sql)).sql("bigquery")
"SELECT SUM(CASE WHEN a = 'v' THEN tv ELSE 0 END) AS v_sfx"
"""
- aggregates: t.List[exp.Expression] = []
+ aggregates: t.List[exp.Expr] = []
for value in values:
proj = f"{agg.name}("
if distinct:
proj += "DISTINCT "
proj += f"CASE WHEN {column} {cmp.name} {value.sql(evaluator.dialect)} THEN {then_value} ELSE {else_value} END) "
- node = evaluator.parse_one(proj)
+ node: exp.Expr = evaluator.parse_one(proj)
if alias:
node = node.as_(
@@ -1196,7 +1190,7 @@ def pivot(
@macro("AND")
-def and_(evaluator: MacroEvaluator, *expressions: t.Optional[exp.Expression]) -> exp.Condition:
+def and_(evaluator: MacroEvaluator, *expressions: t.Optional[exp.Expr]) -> exp.Condition:
"""Returns an AND statement filtering out any NULL expressions."""
conditions = [e for e in expressions if not isinstance(e, exp.Null)]
@@ -1207,7 +1201,7 @@ def and_(evaluator: MacroEvaluator, *expressions: t.Optional[exp.Expression]) ->
@macro("OR")
-def or_(evaluator: MacroEvaluator, *expressions: t.Optional[exp.Expression]) -> exp.Condition:
+def or_(evaluator: MacroEvaluator, *expressions: t.Optional[exp.Expr]) -> exp.Condition:
"""Returns an OR statement filtering out any NULL expressions."""
conditions = [e for e in expressions if not isinstance(e, exp.Null)]
@@ -1219,8 +1213,8 @@ def or_(evaluator: MacroEvaluator, *expressions: t.Optional[exp.Expression]) ->
@macro("VAR")
def var(
- evaluator: MacroEvaluator, var_name: exp.Expression, default: t.Optional[exp.Expression] = None
-) -> exp.Expression:
+ evaluator: MacroEvaluator, var_name: exp.Expr, default: t.Optional[exp.Expr] = None
+) -> exp.Expr:
"""Returns the value of a variable or the default value if the variable is not set."""
if not var_name.is_string:
raise SQLMeshError(f"Invalid variable name '{var_name.sql()}'. Expected a string literal.")
@@ -1230,8 +1224,8 @@ def var(
@macro("BLUEPRINT_VAR")
def blueprint_var(
- evaluator: MacroEvaluator, var_name: exp.Expression, default: t.Optional[exp.Expression] = None
-) -> exp.Expression:
+ evaluator: MacroEvaluator, var_name: exp.Expr, default: t.Optional[exp.Expr] = None
+) -> exp.Expr:
"""Returns the value of a blueprint variable or the default value if the variable is not set."""
if not var_name.is_string:
raise SQLMeshError(
@@ -1244,8 +1238,8 @@ def blueprint_var(
@macro()
def deduplicate(
evaluator: MacroEvaluator,
- relation: exp.Expression,
- partition_by: t.List[exp.Expression],
+ relation: exp.Expr,
+ partition_by: t.List[exp.Expr],
order_by: t.List[str],
) -> exp.Query:
"""Returns a QUERY to deduplicate rows within a table
@@ -1301,9 +1295,9 @@ def deduplicate(
@macro()
def date_spine(
evaluator: MacroEvaluator,
- datepart: exp.Expression,
- start_date: exp.Expression,
- end_date: exp.Expression,
+ datepart: exp.Expr,
+ start_date: exp.Expr,
+ end_date: exp.Expr,
) -> exp.Select:
"""Returns a query that produces a date spine with the given datepart, and range of start_date and end_date. Useful for joining as a date lookup table.
@@ -1491,7 +1485,7 @@ def _coerce(
"""Coerces the given expression to the specified type on a best-effort basis."""
base_err_msg = f"Failed to coerce expression '{expr}' to type '{typ}'."
try:
- if typ is None or typ is t.Any or not isinstance(expr, exp.Expression):
+ if typ is None or typ is t.Any or not isinstance(expr, exp.Expr):
return expr
base = t.get_origin(typ) or typ
@@ -1503,7 +1497,7 @@ def _coerce(
except Exception:
pass
raise SQLMeshError(base_err_msg)
- if base is SQL and isinstance(expr, exp.Expression):
+ if base is SQL and isinstance(expr, exp.Expr):
return expr.sql(dialect)
if base is t.Literal:
@@ -1528,7 +1522,7 @@ def _coerce(
if isinstance(expr, base):
return expr
- if issubclass(base, exp.Expression):
+ if issubclass(base, exp.Expr):
d = Dialect.get_or_raise(dialect)
into = base if base in d.parser_class.EXPRESSION_PARSERS else None
if into is None:
@@ -1603,7 +1597,7 @@ def _convert_sql(v: t.Any, dialect: DialectType) -> t.Any:
except Exception:
pass
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
if (isinstance(v, exp.Column) and not v.table) or (
isinstance(v, exp.Identifier) or v.is_string
):
diff --git a/sqlmesh/core/metric/definition.py b/sqlmesh/core/metric/definition.py
index dd11cfd38d..70f10b2347 100644
--- a/sqlmesh/core/metric/definition.py
+++ b/sqlmesh/core/metric/definition.py
@@ -16,7 +16,7 @@
def load_metric_ddl(
- expression: exp.Expression, dialect: t.Optional[str], path: Path = Path(), **kwargs: t.Any
+ expression: exp.Expr, dialect: t.Optional[str], path: Path = Path(), **kwargs: t.Any
) -> MetricMeta:
"""Returns a MetricMeta from raw Metric DDL."""
if not isinstance(expression, d.Metric):
@@ -70,7 +70,7 @@ class MetricMeta(PydanticModel, frozen=True):
name: str
dialect: str
- expression: exp.Expression
+ expression: exp.Expr
description: t.Optional[str] = None
owner: t.Optional[str] = None
@@ -87,11 +87,11 @@ def _string_validator(cls, v: t.Any) -> t.Optional[str]:
return str_or_exp_to_str(v)
@field_validator("expression", mode="before")
- def _validate_expression(cls, v: t.Any, info: ValidationInfo) -> exp.Expression:
+ def _validate_expression(cls, v: t.Any, info: ValidationInfo) -> exp.Expr:
if isinstance(v, str):
dialect = info.data.get("dialect")
return d.parse_one(v, dialect=dialect)
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
return v
return v
@@ -139,7 +139,7 @@ def to_metric(
class Metric(MetricMeta, frozen=True):
- expanded: exp.Expression
+ expanded: exp.Expr
@property
def aggs(self) -> t.Dict[exp.AggFunc, MeasureAndDimTables]:
@@ -150,7 +150,7 @@ def aggs(self) -> t.Dict[exp.AggFunc, MeasureAndDimTables]:
return {
t.cast(
exp.AggFunc,
- t.cast(exp.Expression, agg.parent).transform(
+ t.cast(exp.Expr, agg.parent).transform(
lambda node: (
exp.column(node.this, table=remove_namespace(node))
if isinstance(node, exp.Column) and node.table
@@ -162,7 +162,7 @@ def aggs(self) -> t.Dict[exp.AggFunc, MeasureAndDimTables]:
}
@property
- def formula(self) -> exp.Expression:
+ def formula(self) -> exp.Expr:
"""Returns the post aggregation formula of a metric.
For simple metrics it is just the metric name. For derived metrics,
@@ -181,7 +181,7 @@ def _raise_metric_config_error(msg: str, path: Path) -> None:
raise ConfigError(f"{msg}. '{path}'")
-def _get_measure_and_dim_tables(expression: exp.Expression) -> MeasureAndDimTables:
+def _get_measure_and_dim_tables(expression: exp.Expr) -> MeasureAndDimTables:
"""Finds all the table references in a metric definition.
Additionally ensure than the first table returned is the 'measure' or numeric value being aggregated.
@@ -190,7 +190,7 @@ def _get_measure_and_dim_tables(expression: exp.Expression) -> MeasureAndDimTabl
tables = {}
measure_table = None
- def is_measure(node: exp.Expression) -> bool:
+ def is_measure(node: exp.Expr) -> bool:
parent = node.parent
if isinstance(parent, exp.AggFunc) and node.arg_key == "this":
diff --git a/sqlmesh/core/metric/rewriter.py b/sqlmesh/core/metric/rewriter.py
index 3519a77e68..6c9ec429a8 100644
--- a/sqlmesh/core/metric/rewriter.py
+++ b/sqlmesh/core/metric/rewriter.py
@@ -34,13 +34,13 @@ def __init__(
self.join_type = join_type
self.semantic_name = f"{semantic_schema}.{semantic_table}"
- def rewrite(self, expression: exp.Expression) -> exp.Expression:
+ def rewrite(self, expression: exp.Expr) -> exp.Expr:
for select in list(expression.find_all(exp.Select)):
self._expand(select)
return expression
- def _build_sources(self, projections: t.List[exp.Expression]) -> SourceAggsAndJoins:
+ def _build_sources(self, projections: t.List[exp.Expr]) -> SourceAggsAndJoins:
sources: SourceAggsAndJoins = {}
for projection in projections:
@@ -57,7 +57,7 @@ def _build_sources(self, projections: t.List[exp.Expression]) -> SourceAggsAndJo
return sources
def _expand(self, select: exp.Select) -> None:
- base = select.args["from"].this.find(exp.Table)
+ base = select.args["from_"].this.find(exp.Table)
base_alias = base.alias_or_name
base_name = exp.table_name(base)
@@ -78,7 +78,7 @@ def _expand(self, select: exp.Select) -> None:
explicit_joins = {exp.table_name(join.this): join for join in select.args.pop("joins", [])}
for i, (name, (aggs, joins)) in enumerate(sources.items()):
- source: exp.Expression = exp.to_table(name)
+ source: exp.Expr = exp.to_table(name)
table_name = remove_namespace(name)
if not isinstance(source, exp.Select):
@@ -110,7 +110,7 @@ def _expand(self, select: exp.Select) -> None:
copy=False,
)
- for node in find_all_in_scope(query, (exp.Column, exp.TableAlias)):
+ for node in find_all_in_scope(query, exp.Column, exp.TableAlias): # type: ignore[arg-type,var-annotated]
if isinstance(node, exp.Column):
if node.table in mapping:
node.set("table", exp.to_identifier(mapping[node.table]))
@@ -123,7 +123,7 @@ def _add_joins(
source: exp.Select,
name: str,
joins: t.Dict[str, t.Optional[exp.Join]],
- group_by: t.List[exp.Expression],
+ group_by: t.List[exp.Expr],
mapping: t.Dict[str, str],
) -> exp.Select:
grain = [e.copy() for e in group_by]
@@ -177,7 +177,7 @@ def _add_joins(
return source.select(*grain, copy=False).group_by(*grain, copy=False)
-def _replace_table(node: exp.Expression, table: str, base_alias: str) -> exp.Expression:
+def _replace_table(node: exp.Expr, table: str, base_alias: str) -> exp.Expr:
for column in find_all_in_scope(node, exp.Column):
if column.table == base_alias:
column.args["table"] = exp.to_identifier(table)
@@ -185,11 +185,11 @@ def _replace_table(node: exp.Expression, table: str, base_alias: str) -> exp.Exp
def rewrite(
- sql: str | exp.Expression,
+ sql: str | exp.Expr,
graph: ReferenceGraph,
metrics: t.Dict[str, Metric],
dialect: t.Optional[str] = "",
-) -> exp.Expression:
+) -> exp.Expr:
rewriter = Rewriter(graph=graph, metrics=metrics, dialect=dialect)
return optimize(
diff --git a/sqlmesh/core/model/cache.py b/sqlmesh/core/model/cache.py
index 774bfa402b..1f038c5d79 100644
--- a/sqlmesh/core/model/cache.py
+++ b/sqlmesh/core/model/cache.py
@@ -81,7 +81,7 @@ def get(self, name: str, entry_id: str = "") -> t.List[Model]:
@dataclass
class OptimizedQueryCacheEntry:
- optimized_rendered_query: t.Optional[exp.Expression]
+ optimized_rendered_query: t.Optional[exp.Query]
renderer_violations: t.Optional[t.Dict[type[Rule], t.Any]]
diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py
index 9e117b56fb..ccde7624bd 100644
--- a/sqlmesh/core/model/common.py
+++ b/sqlmesh/core/model/common.py
@@ -33,8 +33,8 @@
def make_python_env(
expressions: t.Union[
- exp.Expression,
- t.List[t.Union[exp.Expression, t.Tuple[exp.Expression, bool]]],
+ exp.Expr,
+ t.List[t.Union[exp.Expr, t.Tuple[exp.Expr, bool]]],
],
jinja_macro_references: t.Optional[t.Set[MacroReference]],
module_path: Path,
@@ -71,7 +71,7 @@ def make_python_env(
visited_macro_funcs: t.Set[int] = set()
def _is_metadata_var(
- name: str, expression: exp.Expression, appears_in_metadata_expression: bool
+ name: str, expression: exp.Expr, appears_in_metadata_expression: bool
) -> t.Optional[bool]:
is_metadata_so_far = used_variables.get(name, True)
if is_metadata_so_far is False:
@@ -202,7 +202,7 @@ def _is_metadata_macro(name: str, appears_in_metadata_expression: bool) -> bool:
def _extract_macro_func_variable_references(
- macro_func: exp.Expression,
+ macro_func: exp.Expr,
is_metadata: bool,
) -> t.Tuple[t.Set[str], t.Dict[int, bool], t.Set[int]]:
var_references = set()
@@ -255,7 +255,7 @@ def _add_variables_to_python_env(
# - appear in metadata-only expressions, such as `audits (...)`, virtual statements, etc
# - appear in the ASTs or definitions of metadata-only macros
#
- # See also: https://github.com/TobikoData/sqlmesh/pull/4936#issuecomment-3136339936,
+ # See also: https://github.com/SQLMesh/sqlmesh/pull/4936#issuecomment-3136339936,
# specifically the "Terminology" and "Observations" section.
metadata_used_variables = {
var_name for var_name, is_metadata in used_variables.items() if is_metadata
@@ -275,7 +275,7 @@ def _add_variables_to_python_env(
if overlapping_variables := (non_metadata_used_variables & metadata_used_variables):
raise ConfigError(
f"Variables {', '.join(overlapping_variables)} are both metadata and non-metadata, "
- "which is unexpected. Please file an issue at https://github.com/TobikoData/sqlmesh/issues/new."
+ "which is unexpected. Please file an issue at https://github.com/SQLMesh/sqlmesh/issues/new."
)
metadata_variables = {
@@ -292,12 +292,12 @@ def _add_variables_to_python_env(
if blueprint_variables:
metadata_blueprint_variables = {
- k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v
+ k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expr) else v
for k, v in blueprint_variables.items()
if k in metadata_used_variables
}
blueprint_variables = {
- k.lower(): SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v
+ k.lower(): SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expr) else v
for k, v in blueprint_variables.items()
if k in non_metadata_used_variables
}
@@ -469,9 +469,9 @@ def single_value_or_tuple(values: t.Sequence) -> exp.Identifier | exp.Tuple:
def parse_expression(
cls: t.Type,
- v: t.Union[t.List[str], t.List[exp.Expression], str, exp.Expression, t.Callable, None],
+ v: t.Union[t.List[str], t.List[exp.Expr], str, exp.Expr, t.Callable, None],
info: t.Optional[ValidationInfo],
-) -> t.List[exp.Expression] | exp.Expression | t.Callable | None:
+) -> t.List[exp.Expr] | exp.Expr | t.Callable | None:
"""Helper method to deserialize SQLGlot expressions in Pydantic Models."""
if v is None:
return None
@@ -483,7 +483,7 @@ def parse_expression(
if isinstance(v, list):
return [
- e if isinstance(e, exp.Expression) else d.parse_one(e, dialect=dialect)
+ e if isinstance(e, exp.Expr) else d.parse_one(e, dialect=dialect) # type: ignore[misc]
for e in v
if not isinstance(e, exp.Semicolon)
]
@@ -498,7 +498,7 @@ def parse_expression(
def parse_bool(v: t.Any) -> bool:
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
if not isinstance(v, exp.Boolean):
from sqlglot.optimizer.simplify import simplify
@@ -524,7 +524,7 @@ def parse_properties(
if isinstance(v, str):
v = d.parse_one(v, dialect=dialect)
if isinstance(v, (exp.Array, exp.Paren, exp.Tuple)):
- eq_expressions: t.List[exp.Expression] = (
+ eq_expressions: t.List[exp.Expr] = (
[v.unnest()] if isinstance(v, exp.Paren) else v.expressions
)
@@ -665,18 +665,18 @@ class ParsableSql(PydanticModel):
sql: str
transaction: t.Optional[bool] = None
- _parsed: t.Optional[exp.Expression] = None
+ _parsed: t.Optional[exp.Expr] = None
_parsed_dialect: t.Optional[str] = None
- def parse(self, dialect: str) -> exp.Expression:
+ def parse(self, dialect: str) -> exp.Expr:
if self._parsed is None or self._parsed_dialect != dialect:
self._parsed = d.parse_one(self.sql, dialect=dialect)
self._parsed_dialect = dialect
- return self._parsed
+ return self._parsed # type: ignore[return-value]
@classmethod
def from_parsed_expression(
- cls, parsed_expression: exp.Expression, dialect: str, use_meta_sql: bool = False
+ cls, parsed_expression: exp.Expr, dialect: str, use_meta_sql: bool = False
) -> ParsableSql:
sql = (
parsed_expression.meta.get("sql") or parsed_expression.sql(dialect=dialect)
@@ -697,7 +697,7 @@ def _validate_parsable_sql(
return v
if isinstance(v, str):
return ParsableSql(sql=v)
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
return ParsableSql.from_parsed_expression(
v, get_dialect(info.data), use_meta_sql=False
)
@@ -707,7 +707,7 @@ def _validate_parsable_sql(
ParsableSql(sql=s)
if isinstance(s, str)
else ParsableSql.from_parsed_expression(s, dialect, use_meta_sql=False)
- if isinstance(s, exp.Expression)
+ if isinstance(s, exp.Expr)
else ParsableSql.parse_obj(s)
for s in v
]
diff --git a/sqlmesh/core/model/decorator.py b/sqlmesh/core/model/decorator.py
index 73452cc165..328b763f9f 100644
--- a/sqlmesh/core/model/decorator.py
+++ b/sqlmesh/core/model/decorator.py
@@ -193,7 +193,7 @@ def model(
)
rendered_name = rendered_fields["name"]
- if isinstance(rendered_name, exp.Expression):
+ if isinstance(rendered_name, exp.Expr):
rendered_fields["name"] = rendered_name.sql(dialect=dialect)
rendered_defaults = (
diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py
index b6ea6d23e1..d4f23b4fc0 100644
--- a/sqlmesh/core/model/definition.py
+++ b/sqlmesh/core/model/definition.py
@@ -34,6 +34,7 @@
)
from sqlmesh.core.model.meta import ModelMeta
from sqlmesh.core.model.kind import (
+ ExternalKind,
ModelKindName,
SeedKind,
ModelKind,
@@ -214,7 +215,7 @@ def render_definition(
include_python: bool = True,
include_defaults: bool = False,
render_query: bool = False,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
"""Returns the original list of sql expressions comprising the model definition.
Args:
@@ -365,7 +366,7 @@ def render_pre_statements(
engine_adapter: t.Optional[EngineAdapter] = None,
inside_transaction: t.Optional[bool] = True,
**kwargs: t.Any,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
"""Renders pre-statements for a model.
Pre-statements are statements that preceded the model's SELECT query.
@@ -412,7 +413,7 @@ def render_post_statements(
engine_adapter: t.Optional[EngineAdapter] = None,
inside_transaction: t.Optional[bool] = True,
**kwargs: t.Any,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
"""Renders post-statements for a model.
Post-statements are statements that follow after the model's SELECT query.
@@ -459,7 +460,7 @@ def render_on_virtual_update(
deployability_index: t.Optional[DeployabilityIndex] = None,
engine_adapter: t.Optional[EngineAdapter] = None,
**kwargs: t.Any,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
return self._render_statements(
self.on_virtual_update,
start=start,
@@ -551,15 +552,15 @@ def render_audit_query(
return rendered_query
@property
- def pre_statements(self) -> t.List[exp.Expression]:
+ def pre_statements(self) -> t.List[exp.Expr]:
return self._get_parsed_statements("pre_statements_")
@property
- def post_statements(self) -> t.List[exp.Expression]:
+ def post_statements(self) -> t.List[exp.Expr]:
return self._get_parsed_statements("post_statements_")
@property
- def on_virtual_update(self) -> t.List[exp.Expression]:
+ def on_virtual_update(self) -> t.List[exp.Expr]:
return self._get_parsed_statements("on_virtual_update_")
@property
@@ -571,7 +572,7 @@ def macro_definitions(self) -> t.List[d.MacroDef]:
if isinstance(s, d.MacroDef)
]
- def _get_parsed_statements(self, attr_name: str) -> t.List[exp.Expression]:
+ def _get_parsed_statements(self, attr_name: str) -> t.List[exp.Expr]:
value = getattr(self, attr_name)
if not value:
return []
@@ -586,9 +587,9 @@ def _get_parsed_statements(self, attr_name: str) -> t.List[exp.Expression]:
def _render_statements(
self,
- statements: t.Iterable[exp.Expression],
+ statements: t.Iterable[exp.Expr],
**kwargs: t.Any,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
rendered = (
self._statement_renderer(statement).render(**kwargs)
for statement in statements
@@ -596,7 +597,7 @@ def _render_statements(
)
return [r for expressions in rendered if expressions for r in expressions]
- def _statement_renderer(self, expression: exp.Expression) -> ExpressionRenderer:
+ def _statement_renderer(self, expression: exp.Expr) -> ExpressionRenderer:
expression_key = id(expression)
if expression_key not in self._statement_renderer_cache:
self._statement_renderer_cache[expression_key] = ExpressionRenderer(
@@ -630,7 +631,7 @@ def render_signals(
The list of rendered expressions.
"""
- def _render(e: exp.Expression) -> str | int | float | bool:
+ def _render(e: exp.Expr) -> str | int | float | bool:
rendered_exprs = (
self._create_renderer(e).render(start=start, end=end, execution_time=execution_time)
or []
@@ -675,7 +676,7 @@ def render_merge_filter(
start: t.Optional[TimeLike] = None,
end: t.Optional[TimeLike] = None,
execution_time: t.Optional[TimeLike] = None,
- ) -> t.Optional[exp.Expression]:
+ ) -> t.Optional[exp.Expr]:
if self.merge_filter is None:
return None
rendered_exprs = (
@@ -689,9 +690,9 @@ def render_merge_filter(
return rendered_exprs[0].transform(d.replace_merge_table_aliases, dialect=self.dialect)
def _render_properties(
- self, properties: t.Dict[str, exp.Expression] | SessionProperties, **render_kwargs: t.Any
+ self, properties: t.Dict[str, exp.Expr] | SessionProperties, **render_kwargs: t.Any
) -> t.Dict[str, t.Any]:
- def _render(expression: exp.Expression) -> exp.Expression | None:
+ def _render(expression: exp.Expr) -> exp.Expr | None:
# note: we use the _statement_renderer instead of _create_renderer because it sets model_fqn which
# in turn makes @this_model available in the evaluation context
rendered_exprs = self._statement_renderer(expression).render(**render_kwargs)
@@ -713,7 +714,7 @@ def _render(expression: exp.Expression) -> exp.Expression | None:
return {
k: rendered
for k, v in properties.items()
- if (rendered := (_render(v) if isinstance(v, exp.Expression) else v))
+ if (rendered := (_render(v) if isinstance(v, exp.Expr) else v))
}
def render_physical_properties(self, **render_kwargs: t.Any) -> t.Dict[str, t.Any]:
@@ -725,7 +726,7 @@ def render_virtual_properties(self, **render_kwargs: t.Any) -> t.Dict[str, t.Any
def render_session_properties(self, **render_kwargs: t.Any) -> t.Dict[str, t.Any]:
return self._render_properties(properties=self.session_properties, **render_kwargs)
- def _create_renderer(self, expression: exp.Expression) -> ExpressionRenderer:
+ def _create_renderer(self, expression: exp.Expr) -> ExpressionRenderer:
return ExpressionRenderer(
expression,
self.dialect,
@@ -752,7 +753,7 @@ def ctas_query(self, **render_kwarg: t.Any) -> exp.Query:
query = self.render_query_or_raise(**render_kwarg).limit(0)
for select_or_set_op in query.find_all(exp.Select, exp.SetOperation):
- if isinstance(select_or_set_op, exp.Select) and select_or_set_op.args.get("from"):
+ if isinstance(select_or_set_op, exp.Select) and select_or_set_op.args.get("from_"):
select_or_set_op.where(exp.false(), copy=False)
if self.managed_columns:
@@ -821,7 +822,7 @@ def set_time_format(self, default_time_format: str = c.DEFAULT_TIME_COLUMN_FORMA
def convert_to_time_column(
self, time: TimeLike, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None
- ) -> exp.Expression:
+ ) -> exp.Expr:
"""Convert a TimeLike object to the same time format and type as the model's time column."""
if self.time_column:
if columns_to_types is None:
@@ -969,7 +970,7 @@ def validate_definition(self) -> None:
col.name
for expr in values
for col in t.cast(
- exp.Expression, exp.maybe_parse(expr, dialect=self.dialect)
+ exp.Expr, exp.maybe_parse(expr, dialect=self.dialect)
).find_all(exp.Column)
]
@@ -1265,7 +1266,7 @@ def _additional_metadata(self) -> t.List[str]:
return additional_metadata
- def _is_metadata_statement(self, statement: exp.Expression) -> bool:
+ def _is_metadata_statement(self, statement: exp.Expr) -> bool:
if isinstance(statement, d.MacroDef):
return True
if isinstance(statement, d.MacroFunc):
@@ -1294,7 +1295,7 @@ def full_depends_on(self) -> t.Set[str]:
return self._full_depends_on
@property
- def partitioned_by(self) -> t.List[exp.Expression]:
+ def partitioned_by(self) -> t.List[exp.Expr]:
"""Columns to partition the model by, including the time column if it is not already included."""
if self.time_column and not self._is_time_column_in_partitioned_by:
# This allows the user to opt out of automatic time_column injection
@@ -1322,7 +1323,7 @@ def partition_interval_unit(self) -> t.Optional[IntervalUnit]:
return None
@property
- def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expression]]]:
+ def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expr]]]:
from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS
audits_by_name = {**BUILT_IN_AUDITS, **self.audit_definitions}
@@ -1421,8 +1422,8 @@ def render_definition(
include_python: bool = True,
include_defaults: bool = False,
render_query: bool = False,
- ) -> t.List[exp.Expression]:
- result = super().render_definition(
+ ) -> t.List[exp.Expr]:
+ result: t.List[exp.Expr] = super().render_definition(
include_python=include_python, include_defaults=include_defaults
)
@@ -1945,7 +1946,7 @@ def render_definition(
include_python: bool = True,
include_defaults: bool = False,
render_query: bool = False,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
# Ignore the provided value for the include_python flag, since the Pyhon model's
# definition without Python code is meaningless.
return super().render_definition(
@@ -1969,6 +1970,7 @@ def _data_hash_values_no_sql(self) -> t.List[str]:
class ExternalModel(_Model):
"""The model definition which represents an external source/table."""
+ kind: ModelKind = ExternalKind()
source_type: t.Literal["external"] = "external"
def is_breaking_change(self, previous: Model) -> t.Optional[bool]:
@@ -1999,7 +2001,7 @@ class AuditResult(PydanticModel):
"""The model this audit is for."""
count: t.Optional[int] = None
"""The number of records returned by the audit query. This could be None if the audit was skipped."""
- query: t.Optional[exp.Expression] = None
+ query: t.Optional[exp.Expr] = None
"""The rendered query used by the audit. This could be None if the audit was skipped."""
skipped: bool = False
"""Whether or not the audit was blocking. This can be overriden by the user."""
@@ -2007,7 +2009,7 @@ class AuditResult(PydanticModel):
class EvaluatableSignals(PydanticModel):
- signals_to_kwargs: t.Dict[str, t.Dict[str, t.Optional[exp.Expression]]]
+ signals_to_kwargs: t.Dict[str, t.Dict[str, t.Optional[exp.Expr]]]
"""A mapping of signal names to the kwargs passed to the signal."""
python_env: t.Dict[str, Executable]
"""The Python environment that should be used to evaluated the rendered signal calls."""
@@ -2052,7 +2054,7 @@ def _extract_blueprint_variables(blueprint: t.Any, path: Path) -> t.Dict[str, t.
def create_models_from_blueprints(
- gateway: t.Optional[str | exp.Expression],
+ gateway: t.Optional[str | exp.Expr],
blueprints: t.Any,
get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]],
loader: t.Callable[..., Model],
@@ -2063,7 +2065,9 @@ def create_models_from_blueprints(
**loader_kwargs: t.Any,
) -> t.List[Model]:
model_blueprints: t.List[Model] = []
+ original_default_catalog = loader_kwargs.get("default_catalog")
for blueprint in _extract_blueprints(blueprints, path):
+ loader_kwargs["default_catalog"] = original_default_catalog
blueprint_variables = _extract_blueprint_variables(blueprint, path)
if gateway:
@@ -2081,12 +2085,15 @@ def create_models_from_blueprints(
else:
gateway_name = None
- if (
- default_catalog_per_gateway
- and gateway_name
- and (catalog := default_catalog_per_gateway.get(gateway_name)) is not None
- ):
- loader_kwargs["default_catalog"] = catalog
+ if default_catalog_per_gateway and gateway_name:
+ catalog = default_catalog_per_gateway.get(gateway_name)
+ if catalog is not None:
+ loader_kwargs["default_catalog"] = catalog
+ else:
+ # Gateway exists but has no entry in the dict (e.g., catalog-unsupported
+ # engines like ClickHouse). Clear the default catalog so the global
+ # default from the primary gateway doesn't leak into this model's name.
+ loader_kwargs["default_catalog"] = None
model_blueprints.append(
loader(
@@ -2103,7 +2110,7 @@ def create_models_from_blueprints(
def load_sql_based_models(
- expressions: t.List[exp.Expression],
+ expressions: t.List[exp.Expr],
get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]],
path: Path = Path(),
module_path: Path = Path(),
@@ -2111,8 +2118,8 @@ def load_sql_based_models(
default_catalog_per_gateway: t.Optional[t.Dict[str, str]] = None,
**loader_kwargs: t.Any,
) -> t.List[Model]:
- gateway: t.Optional[exp.Expression] = None
- blueprints: t.Optional[exp.Expression] = None
+ gateway: t.Optional[exp.Expr] = None
+ blueprints: t.Optional[exp.Expr] = None
model_meta = seq_get(expressions, 0)
for prop in (isinstance(model_meta, d.Model) and model_meta.expressions) or []:
@@ -2158,7 +2165,7 @@ def load_sql_based_models(
def load_sql_based_model(
- expressions: t.List[exp.Expression],
+ expressions: t.List[exp.Expr],
*,
defaults: t.Optional[t.Dict[str, t.Any]] = None,
path: t.Optional[Path] = None,
@@ -2304,7 +2311,7 @@ def load_sql_based_model(
if kind_prop.name.lower() == "merge_filter":
meta_fields["kind"].expressions[idx] = unrendered_merge_filter
- if isinstance(meta_fields.get("dialect"), exp.Expression):
+ if isinstance(meta_fields.get("dialect"), exp.Expr):
meta_fields["dialect"] = meta_fields["dialect"].name
# The name of the model will be inferred from its path relative to `models/`, if it's not explicitly specified
@@ -2365,7 +2372,7 @@ def load_sql_based_model(
def create_sql_model(
name: TableName,
- query: t.Optional[exp.Expression],
+ query: t.Optional[exp.Expr],
**kwargs: t.Any,
) -> Model:
"""Creates a SQL model.
@@ -2490,7 +2497,7 @@ def create_python_model(
)
depends_on = {
dep.sql(dialect=dialect)
- for dep in t.cast(t.List[exp.Expression], depends_on_rendered)[0].expressions
+ for dep in t.cast(t.List[exp.Expr], depends_on_rendered)[0].expressions
}
used_variables = {k: v for k, v in (variables or {}).items() if k in referenced_variables}
@@ -2595,7 +2602,7 @@ def _create_model(
if not issubclass(klass, SqlModel):
defaults.pop("optimize_query", None)
- statements: t.List[t.Union[exp.Expression, t.Tuple[exp.Expression, bool]]] = []
+ statements: t.List[t.Union[exp.Expr, t.Tuple[exp.Expr, bool]]] = []
if "query" in kwargs:
statements.append(kwargs["query"])
@@ -2634,11 +2641,11 @@ def _create_model(
if isinstance(property_values, exp.Tuple):
statements.extend(property_values.expressions)
- if isinstance(getattr(kwargs.get("kind"), "merge_filter", None), exp.Expression):
+ if isinstance(getattr(kwargs.get("kind"), "merge_filter", None), exp.Expr):
statements.append(kwargs["kind"].merge_filter)
jinja_macro_references, referenced_variables = extract_macro_references_and_variables(
- *(gen(e if isinstance(e, exp.Expression) else e[0]) for e in statements)
+ *(gen(e if isinstance(e, exp.Expr) else e[0]) for e in statements)
)
if jinja_macros:
@@ -2685,7 +2692,7 @@ def _create_model(
model.audit_definitions.update(audit_definitions)
# Any macro referenced in audits or signals needs to be treated as metadata-only
- statements.extend((audit.query, True) for audit in audit_definitions.values())
+ statements.extend((audit.query, True) for audit in audit_definitions.values()) # type: ignore[misc]
# Ensure that all audits referenced in the model are defined
from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS
@@ -2741,14 +2748,14 @@ def _create_model(
def _split_sql_model_statements(
- expressions: t.List[exp.Expression],
+ expressions: t.List[exp.Expr],
path: t.Optional[Path],
dialect: t.Optional[str] = None,
) -> t.Tuple[
- t.Optional[exp.Expression],
- t.List[exp.Expression],
- t.List[exp.Expression],
- t.List[exp.Expression],
+ t.Optional[exp.Expr],
+ t.List[exp.Expr],
+ t.List[exp.Expr],
+ t.List[exp.Expr],
UniqueKeyDict[str, ModelAudit],
]:
"""Extracts the SELECT query from a sequence of expressions.
@@ -2809,8 +2816,8 @@ def _split_sql_model_statements(
def _resolve_properties(
default: t.Optional[t.Dict[str, t.Any]],
- provided: t.Optional[exp.Expression | t.Dict[str, t.Any]],
-) -> t.Optional[exp.Expression]:
+ provided: t.Optional[exp.Expr | t.Dict[str, t.Any]],
+) -> t.Optional[exp.Expr]:
if isinstance(provided, dict):
properties = {k: exp.Literal.string(k).eq(v) for k, v in provided.items()}
elif provided:
@@ -2832,7 +2839,7 @@ def _resolve_properties(
return None
-def _list_of_calls_to_exp(value: t.List[t.Tuple[str, t.Dict[str, t.Any]]]) -> exp.Expression:
+def _list_of_calls_to_exp(value: t.List[t.Tuple[str, t.Dict[str, t.Any]]]) -> exp.Expr:
return exp.Tuple(
expressions=[
exp.Anonymous(
@@ -2847,16 +2854,16 @@ def _list_of_calls_to_exp(value: t.List[t.Tuple[str, t.Dict[str, t.Any]]]) -> ex
)
-def _is_projection(expr: exp.Expression) -> bool:
+def _is_projection(expr: exp.Expr) -> bool:
parent = expr.parent
return isinstance(parent, exp.Select) and expr.arg_key == "expressions"
-def _single_expr_or_tuple(values: t.Sequence[exp.Expression]) -> exp.Expression | exp.Tuple:
+def _single_expr_or_tuple(values: t.Sequence[exp.Expr]) -> exp.Expr | exp.Tuple:
return values[0] if len(values) == 1 else exp.Tuple(expressions=values)
-def _refs_to_sql(values: t.Any) -> exp.Expression:
+def _refs_to_sql(values: t.Any) -> exp.Expr:
return exp.Tuple(expressions=values)
@@ -2872,7 +2879,7 @@ def render_meta_fields(
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
) -> t.Dict[str, t.Any]:
def render_field_value(value: t.Any) -> t.Any:
- if isinstance(value, exp.Expression) or (isinstance(value, str) and "@" in value):
+ if isinstance(value, exp.Expr) or (isinstance(value, str) and "@" in value):
expression = exp.maybe_parse(value, dialect=dialect)
rendered_expr = render_expression(
expression=expression,
@@ -3009,7 +3016,7 @@ def parse_defaults_properties(
def render_expression(
- expression: exp.Expression,
+ expression: exp.Expr,
module_path: Path,
path: t.Optional[Path],
jinja_macros: t.Optional[JinjaMacroRegistry] = None,
@@ -3018,7 +3025,7 @@ def render_expression(
variables: t.Optional[t.Dict[str, t.Any]] = None,
default_catalog: t.Optional[str] = None,
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
-) -> t.Optional[t.List[exp.Expression]]:
+) -> t.Optional[t.List[exp.Expr]]:
meta_python_env = make_python_env(
expressions=expression,
jinja_macro_references=None,
@@ -3090,8 +3097,8 @@ def get_model_name(path: Path) -> str:
# function applied to time column when automatically used for partitioning in INCREMENTAL_BY_TIME_RANGE models
def clickhouse_partition_func(
- column: exp.Expression, columns_to_types: t.Optional[t.Dict[str, exp.DataType]]
-) -> exp.Expression:
+ column: exp.Expr, columns_to_types: t.Optional[t.Dict[str, exp.DataType]]
+) -> exp.Expr:
# `toMonday()` function accepts a Date or DateTime type column
col_type = (columns_to_types and columns_to_types.get(column.name)) or exp.DataType.build(
diff --git a/sqlmesh/core/model/kind.py b/sqlmesh/core/model/kind.py
index 9abaa9c650..d7a7bb9579 100644
--- a/sqlmesh/core/model/kind.py
+++ b/sqlmesh/core/model/kind.py
@@ -279,7 +279,7 @@ def model_kind_name(self) -> t.Optional[ModelKindName]:
return self.name
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
kwargs["expressions"] = expressions
return d.ModelKind(this=self.name.value.upper(), **kwargs)
@@ -294,7 +294,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]:
class TimeColumn(PydanticModel):
- column: exp.Expression
+ column: exp.Expr
format: t.Optional[str] = None
@classmethod
@@ -306,7 +306,7 @@ def _time_column_validator(v: t.Any, info: ValidationInfo) -> TimeColumn:
@field_validator("column", mode="before")
@classmethod
- def _column_validator(cls, v: t.Union[str, exp.Expression]) -> exp.Expression:
+ def _column_validator(cls, v: t.Union[str, exp.Expr]) -> exp.Expr:
if not v:
raise ConfigError("Time Column cannot be empty.")
if isinstance(v, str):
@@ -314,14 +314,14 @@ def _column_validator(cls, v: t.Union[str, exp.Expression]) -> exp.Expression:
return v
@property
- def expression(self) -> exp.Expression:
+ def expression(self) -> exp.Expr:
"""Convert this pydantic model into a time_column SQLGlot expression."""
if not self.format:
return self.column
return exp.Tuple(expressions=[self.column, exp.Literal.string(self.format)])
- def to_expression(self, dialect: str) -> exp.Expression:
+ def to_expression(self, dialect: str) -> exp.Expr:
"""Convert this pydantic model into a time_column SQLGlot expression."""
if not self.format:
return self.column
@@ -346,7 +346,7 @@ def create(cls, v: t.Any, dialect: str) -> Self:
exp.column(column_expr) if isinstance(column_expr, exp.Identifier) else column_expr
)
format = v.expressions[1].name if len(v.expressions) > 1 else None
- elif isinstance(v, exp.Expression):
+ elif isinstance(v, exp.Expr):
column = exp.column(v) if isinstance(v, exp.Identifier) else v
format = None
elif isinstance(v, str):
@@ -400,7 +400,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -444,7 +444,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -473,7 +473,7 @@ class IncrementalByTimeRangeKind(_IncrementalBy):
_time_column_validator = TimeColumn.validator()
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -513,7 +513,7 @@ class IncrementalByUniqueKeyKind(_IncrementalBy):
)
unique_key: SQLGlotListOfFields
when_matched: t.Optional[exp.Whens] = None
- merge_filter: t.Optional[exp.Expression] = None
+ merge_filter: t.Optional[exp.Expr] = None
batch_concurrency: t.Literal[1] = 1
@field_validator("when_matched", mode="before")
@@ -543,9 +543,9 @@ def _when_matched_validator(
@field_validator("merge_filter", mode="before")
def _merge_filter_validator(
cls,
- v: t.Optional[exp.Expression],
+ v: t.Optional[exp.Expr],
info: ValidationInfo,
- ) -> t.Optional[exp.Expression]:
+ ) -> t.Optional[exp.Expr]:
if v is None:
return v
@@ -568,7 +568,7 @@ def data_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -590,7 +590,7 @@ class IncrementalByPartitionKind(_Incremental):
disable_restatement: SQLGlotBool = False
@field_validator("forward_only", mode="before")
- def _forward_only_validator(cls, v: t.Union[bool, exp.Expression]) -> t.Literal[True]:
+ def _forward_only_validator(cls, v: t.Union[bool, exp.Expr]) -> t.Literal[True]:
if v is not True:
raise ConfigError(
"Do not specify the `forward_only` configuration key - INCREMENTAL_BY_PARTITION models are always forward_only."
@@ -606,7 +606,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -640,7 +640,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -669,7 +669,7 @@ def supports_python_models(self) -> bool:
return False
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -690,7 +690,7 @@ class SeedKind(_ModelKind):
def _parse_csv_settings(cls, v: t.Any) -> t.Optional[CsvSettings]:
if v is None or isinstance(v, CsvSettings):
return v
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
tuple_exp = parse_properties(cls, v, None)
if not tuple_exp:
return None
@@ -700,7 +700,7 @@ def _parse_csv_settings(cls, v: t.Any) -> t.Optional[CsvSettings]:
return v
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
"""Convert the seed kind into a SQLGlot expression."""
return super().to_expression(
@@ -756,13 +756,16 @@ class _SCDType2Kind(_Incremental):
@field_validator("time_data_type", mode="before")
@classmethod
- def _time_data_type_validator(
- cls, v: t.Union[str, exp.Expression], values: t.Any
- ) -> exp.Expression:
- if isinstance(v, exp.Expression) and not isinstance(v, exp.DataType):
+ def _time_data_type_validator(cls, v: t.Union[str, exp.Expr], values: t.Any) -> exp.Expr:
+ if isinstance(v, exp.Expr) and not isinstance(v, exp.DataType):
v = v.name
dialect = get_dialect(values)
data_type = exp.DataType.build(v, dialect=dialect)
+ # Clear meta["sql"] (set by our parser extension) so the pydantic encoder
+ # uses dialect-aware rendering: e.sql(dialect=meta["dialect"]). Without this,
+ # the raw SQL text takes priority, which can be wrong for dialect-normalized
+ # types (e.g., default "TIMESTAMP" should render as "DATETIME" in BigQuery).
+ data_type.meta.pop("sql", None)
data_type.meta["dialect"] = dialect
return data_type
@@ -795,7 +798,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -835,7 +838,7 @@ def data_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -871,7 +874,7 @@ def data_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -922,7 +925,7 @@ def data_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -1005,7 +1008,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -1142,7 +1145,7 @@ def create_model_kind(v: t.Any, dialect: str, defaults: t.Dict[str, t.Any]) -> M
)
return kind_type(**props)
- name = (v.name if isinstance(v, exp.Expression) else str(v)).upper()
+ name = (v.name if isinstance(v, exp.Expr) else str(v)).upper()
return model_kind_type_from_name(name)(name=name) # type: ignore
diff --git a/sqlmesh/core/model/meta.py b/sqlmesh/core/model/meta.py
index c48b7d1524..a73d6d871a 100644
--- a/sqlmesh/core/model/meta.py
+++ b/sqlmesh/core/model/meta.py
@@ -50,7 +50,7 @@
from sqlmesh.core._typing import CustomMaterializationProperties, SessionProperties
from sqlmesh.core.engine_adapter._typing import GrantsConfig
-FunctionCall = t.Tuple[str, t.Dict[str, exp.Expression]]
+FunctionCall = t.Tuple[str, t.Dict[str, exp.Expr]]
class GrantsTargetLayer(str, Enum):
@@ -92,8 +92,8 @@ class ModelMeta(_Node):
retention: t.Optional[int] = None # not implemented yet
table_format: t.Optional[str] = None
storage_format: t.Optional[str] = None
- partitioned_by_: t.List[exp.Expression] = Field(default=[], alias="partitioned_by")
- clustered_by: t.List[exp.Expression] = []
+ partitioned_by_: t.List[exp.Expr] = Field(default=[], alias="partitioned_by")
+ clustered_by: t.List[exp.Expr] = []
default_catalog: t.Optional[str] = None
depends_on_: t.Optional[t.Set[str]] = Field(default=None, alias="depends_on")
columns_to_types_: t.Optional[t.Dict[str, exp.DataType]] = Field(default=None, alias="columns")
@@ -101,8 +101,8 @@ class ModelMeta(_Node):
default=None, alias="column_descriptions"
)
audits: t.List[FunctionCall] = []
- grains: t.List[exp.Expression] = []
- references: t.List[exp.Expression] = []
+ grains: t.List[exp.Expr] = []
+ references: t.List[exp.Expr] = []
physical_schema_override: t.Optional[str] = None
physical_properties_: t.Optional[exp.Tuple] = Field(default=None, alias="physical_properties")
virtual_properties_: t.Optional[exp.Tuple] = Field(default=None, alias="virtual_properties")
@@ -151,11 +151,11 @@ def _normalize(value: t.Any) -> t.Any:
if isinstance(v, (exp.Tuple, exp.Array)):
return [_normalize(e).name for e in v.expressions]
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
return _normalize(v).name
if isinstance(v, str):
value = _normalize(v)
- return value.name if isinstance(value, exp.Expression) else value
+ return value.name if isinstance(value, exp.Expr) else value
if isinstance(v, (list, tuple)):
return [cls._validate_value_or_tuple(elm, data, normalize=normalize) for elm in v]
@@ -163,7 +163,7 @@ def _normalize(value: t.Any) -> t.Any:
@field_validator("table_format", "storage_format", mode="before")
def _format_validator(cls, v: t.Any, info: ValidationInfo) -> t.Optional[str]:
- if isinstance(v, exp.Expression) and not (isinstance(v, (exp.Literal, exp.Identifier))):
+ if isinstance(v, exp.Expr) and not (isinstance(v, (exp.Literal, exp.Identifier))):
return v.sql(info.data.get("dialect"))
return str_or_exp_to_str(v)
@@ -188,9 +188,7 @@ def _gateway_validator(cls, v: t.Any) -> t.Optional[str]:
return gateway and gateway.lower()
@field_validator("partitioned_by_", "clustered_by", mode="before")
- def _partition_and_cluster_validator(
- cls, v: t.Any, info: ValidationInfo
- ) -> t.List[exp.Expression]:
+ def _partition_and_cluster_validator(cls, v: t.Any, info: ValidationInfo) -> t.List[exp.Expr]:
if (
isinstance(v, list)
and all(isinstance(i, str) for i in v)
@@ -244,9 +242,33 @@ def _columns_validator(
return columns_to_types
if isinstance(v, dict):
- udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
+ dialect_obj = Dialect.get_or_raise(dialect)
+ udt = dialect_obj.SUPPORTS_USER_DEFINED_TYPES
for k, data_type in v.items():
+ is_string_type = isinstance(data_type, str)
expr = exp.DataType.build(data_type, dialect=dialect, udt=udt)
+ # When deserializing from a string (e.g. JSON roundtrip), normalize the type
+ # through the dialect's type system so that aliases (e.g. INT in BigQuery,
+ # which is an alias for INT64/BIGINT) are resolved to their canonical form.
+ # This ensures stable data hash computation across serialization/deserialization
+ # roundtrips. We skip this for DataType objects passed directly (Python API)
+ # since those should be used as-is.
+ if (
+ is_string_type
+ and dialect
+ and expr.this
+ not in (
+ exp.DataType.Type.USERDEFINED,
+ exp.DataType.Type.UNKNOWN,
+ )
+ ):
+ sql_repr = expr.sql(dialect=dialect)
+ try:
+ normalized = parse_one(sql_repr, read=dialect, into=exp.DataType)
+ if normalized is not None:
+ expr = normalized
+ except Exception:
+ pass
expr.meta["dialect"] = dialect
columns_to_types[normalize_identifiers(k, dialect=dialect).name] = expr
@@ -295,7 +317,7 @@ def _column_descriptions_validator(
return col_descriptions
@field_validator("grains", "references", mode="before")
- def _refs_validator(cls, vs: t.Any, info: ValidationInfo) -> t.List[exp.Expression]:
+ def _refs_validator(cls, vs: t.Any, info: ValidationInfo) -> t.List[exp.Expr]:
dialect = info.data.get("dialect")
if isinstance(vs, exp.Paren):
@@ -349,7 +371,7 @@ def session_properties_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any:
"Invalid value for `session_properties.query_label`. Must be an array or tuple."
)
- label_tuples: t.List[exp.Expression] = (
+ label_tuples: t.List[exp.Expr] = (
[query_label.unnest()]
if isinstance(query_label, exp.Paren)
else query_label.expressions
@@ -449,7 +471,7 @@ def time_column(self) -> t.Optional[TimeColumn]:
return getattr(self.kind, "time_column", None)
@property
- def unique_key(self) -> t.List[exp.Expression]:
+ def unique_key(self) -> t.List[exp.Expr]:
if isinstance(
self.kind, (SCDType2ByTimeKind, SCDType2ByColumnKind, IncrementalByUniqueKeyKind)
):
@@ -485,14 +507,14 @@ def batch_concurrency(self) -> t.Optional[int]:
return getattr(self.kind, "batch_concurrency", None)
@cached_property
- def physical_properties(self) -> t.Dict[str, exp.Expression]:
+ def physical_properties(self) -> t.Dict[str, exp.Expr]:
"""A dictionary of properties that will be applied to the physical layer. It replaces table_properties which is deprecated."""
if self.physical_properties_:
return {e.this.name: e.expression for e in self.physical_properties_.expressions}
return {}
@cached_property
- def virtual_properties(self) -> t.Dict[str, exp.Expression]:
+ def virtual_properties(self) -> t.Dict[str, exp.Expr]:
"""A dictionary of properties that will be applied to the virtual layer."""
if self.virtual_properties_:
return {e.this.name: e.expression for e in self.virtual_properties_.expressions}
@@ -568,7 +590,7 @@ def when_matched(self) -> t.Optional[exp.Whens]:
return None
@property
- def merge_filter(self) -> t.Optional[exp.Expression]:
+ def merge_filter(self) -> t.Optional[exp.Expr]:
if isinstance(self.kind, IncrementalByUniqueKeyKind):
return self.kind.merge_filter
return None
@@ -601,7 +623,7 @@ def on_additive_change(self) -> OnAdditiveChange:
def ignored_rules(self) -> t.Set[str]:
return self.ignored_rules_ or set()
- def _validate_config_expression(self, expr: exp.Expression) -> str:
+ def _validate_config_expression(self, expr: exp.Expr) -> str:
if isinstance(expr, (d.MacroFunc, d.MacroVar)):
raise ConfigError(f"Unresolved macro: {expr.sql(dialect=self.dialect)}")
@@ -614,10 +636,10 @@ def _validate_config_expression(self, expr: exp.Expression) -> str:
return expr.name
return expr.sql(dialect=self.dialect).strip()
- def _validate_nested_config_values(self, value_expr: exp.Expression) -> t.List[str]:
+ def _validate_nested_config_values(self, value_expr: exp.Expr) -> t.List[str]:
result = []
- def flatten_expr(expr: exp.Expression) -> None:
+ def flatten_expr(expr: exp.Expr) -> None:
if isinstance(expr, exp.Array):
for elem in expr.expressions:
flatten_expr(elem)
diff --git a/sqlmesh/core/model/seed.py b/sqlmesh/core/model/seed.py
index fe1aa85204..9fd57fe6d3 100644
--- a/sqlmesh/core/model/seed.py
+++ b/sqlmesh/core/model/seed.py
@@ -49,7 +49,7 @@ def _bool_validator(cls, v: t.Any) -> t.Optional[bool]:
)
@classmethod
def _str_validator(cls, v: t.Any) -> t.Optional[str]:
- if v is None or not isinstance(v, exp.Expression):
+ if v is None or not isinstance(v, exp.Expr):
return v
# SQLGlot parses escape sequences like \t as \\t for dialects that don't treat \ as
@@ -60,7 +60,7 @@ def _str_validator(cls, v: t.Any) -> t.Optional[str]:
@field_validator("na_values", mode="before")
@classmethod
def _na_values_validator(cls, v: t.Any) -> t.Optional[NaValues]:
- if v is None or not isinstance(v, exp.Expression):
+ if v is None or not isinstance(v, exp.Expr):
return v
try:
diff --git a/sqlmesh/core/node.py b/sqlmesh/core/node.py
index 4a3bf2564b..d3b63312f1 100644
--- a/sqlmesh/core/node.py
+++ b/sqlmesh/core/node.py
@@ -215,7 +215,7 @@ def post_init(self) -> Self:
self.alias = None
return self
- def to_expression(self) -> exp.Expression:
+ def to_expression(self) -> exp.Expr:
"""Produce a SQLGlot expression representing this object, for use in things like the model/audit definition renderers"""
return exp.tuple_(
*(
@@ -324,7 +324,7 @@ def copy(self, **kwargs: t.Any) -> Self:
def _name_validator(cls, v: t.Any) -> t.Optional[str]:
if v is None:
return None
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
return v.meta["sql"]
return str(v)
@@ -352,7 +352,7 @@ def _cron_tz_validator(cls, v: t.Any) -> t.Optional[zoneinfo.ZoneInfo]:
@field_validator("start", "end", mode="before")
@classmethod
def _date_validator(cls, v: t.Any) -> t.Optional[TimeLike]:
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
v = v.name
if v and not to_datetime(v):
raise ConfigError(f"'{v}' needs to be time-like: https://pypi.org/project/dateparser")
@@ -555,6 +555,6 @@ def __str__(self) -> str:
def str_or_exp_to_str(v: t.Any) -> t.Optional[str]:
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
return v.name
return str(v) if v is not None else None
diff --git a/sqlmesh/core/plan/builder.py b/sqlmesh/core/plan/builder.py
index 7d753cc330..01834594cd 100644
--- a/sqlmesh/core/plan/builder.py
+++ b/sqlmesh/core/plan/builder.py
@@ -165,7 +165,7 @@ def __init__(
# There may be an significant delay between the PlanBuilder producing a Plan and the Plan actually being run
# so if execution_time=None is passed to the PlanBuilder, then the resulting Plan should also have execution_time=None
# in order to prevent the Plan that was intended to run "as at now" from having "now" fixed to some time in the past
- # ref: https://github.com/TobikoData/sqlmesh/pull/4702#discussion_r2140696156
+ # ref: https://github.com/SQLMesh/sqlmesh/pull/4702#discussion_r2140696156
self._execution_time = execution_time
self._backfill_models = backfill_models
diff --git a/sqlmesh/core/reference.py b/sqlmesh/core/reference.py
index 2bf2c04e98..9e93ce7b38 100644
--- a/sqlmesh/core/reference.py
+++ b/sqlmesh/core/reference.py
@@ -14,7 +14,7 @@
class Reference(PydanticModel, frozen=True):
model_name: str
- expression: exp.Expression
+ expression: exp.Expr
unique: bool = False
_name: str = ""
diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py
index 0cbf9b6e94..7683956064 100644
--- a/sqlmesh/core/renderer.py
+++ b/sqlmesh/core/renderer.py
@@ -48,7 +48,7 @@
class BaseExpressionRenderer:
def __init__(
self,
- expression: exp.Expression,
+ expression: exp.Expr,
dialect: DialectType,
macro_definitions: t.List[d.MacroDef],
path: t.Optional[Path] = None,
@@ -73,7 +73,7 @@ def __init__(
self._normalize_identifiers = normalize_identifiers
self._quote_identifiers = quote_identifiers
self.update_schema({} if schema is None else schema)
- self._cache: t.List[t.Optional[exp.Expression]] = []
+ self._cache: t.List[t.Optional[exp.Expr]] = []
self._model_fqn = model.fqn if model else None
self._optimize_query_flag = optimize_query is not False
self._model = model
@@ -91,7 +91,7 @@ def _render(
deployability_index: t.Optional[DeployabilityIndex] = None,
runtime_stage: RuntimeStage = RuntimeStage.LOADING,
**kwargs: t.Any,
- ) -> t.List[t.Optional[exp.Expression]]:
+ ) -> t.List[t.Optional[exp.Expr]]:
"""Renders a expression, expanding macros with provided kwargs
Args:
@@ -205,7 +205,7 @@ def _resolve_table(table: str | exp.Table) -> str:
if variables:
macro_evaluator.locals.setdefault(c.SQLMESH_VARS, {}).update(variables)
- expressions = [self._expression]
+ expressions: t.List[exp.Expr] = [self._expression]
if isinstance(self._expression, d.Jinja):
try:
jinja_env_kwargs = {
@@ -283,7 +283,7 @@ def _resolve_table(table: str | exp.Table) -> str:
f"Failed to evaluate macro '{definition}'.\n\n{ex}\n", self._path
)
- resolved_expressions: t.List[t.Optional[exp.Expression]] = []
+ resolved_expressions: t.List[t.Optional[exp.Expr]] = []
for expression in expressions:
try:
@@ -294,7 +294,7 @@ def _resolve_table(table: str | exp.Table) -> str:
self._path,
)
- for expression in t.cast(t.List[exp.Expression], transformed_expressions):
+ for expression in t.cast(t.List[exp.Expr], transformed_expressions):
with self._normalize_and_quote(expression) as expression:
if hasattr(expression, "selects"):
for select in expression.selects:
@@ -320,12 +320,12 @@ def _resolve_table(table: str | exp.Table) -> str:
self._cache = resolved_expressions
return resolved_expressions
- def update_cache(self, expression: t.Optional[exp.Expression]) -> None:
+ def update_cache(self, expression: t.Optional[exp.Expr]) -> None:
self._cache = [expression]
def _resolve_table(
self,
- table_name: str | exp.Expression,
+ table_name: str | exp.Expr,
snapshots: t.Optional[t.Dict[str, Snapshot]] = None,
table_mapping: t.Optional[t.Dict[str, str]] = None,
deployability_index: t.Optional[DeployabilityIndex] = None,
@@ -380,7 +380,7 @@ def _resolve_tables(
if snapshot.is_model
}
- def _expand(node: exp.Expression) -> exp.Expression:
+ def _expand(node: exp.Expr) -> exp.Expr:
if isinstance(node, exp.Table) and snapshots:
name = exp.table_name(node, identify=True)
model = model_mapping.get(name)
@@ -449,7 +449,7 @@ def render(
deployability_index: t.Optional[DeployabilityIndex] = None,
expand: t.Iterable[str] = tuple(),
**kwargs: t.Any,
- ) -> t.Optional[t.List[exp.Expression]]:
+ ) -> t.Optional[t.List[exp.Expr]]:
try:
expressions = super()._render(
start=start,
@@ -631,7 +631,7 @@ def render(
def update_cache(
self,
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
violated_rules: t.Optional[t.Dict[type[Rule], t.Any]] = None,
optimized: bool = False,
) -> None:
@@ -690,7 +690,7 @@ def _optimize_query(self, query: exp.Query, all_deps: t.Set[str]) -> exp.Query:
except Exception as ex:
raise_config_error(
- f"Failed to optimize query, please file an issue at https://github.com/TobikoData/sqlmesh/issues/new. {ex}",
+ f"Failed to optimize query, please file an issue at https://github.com/SQLMesh/sqlmesh/issues/new. {ex}",
self._path,
)
diff --git a/sqlmesh/core/schema_diff.py b/sqlmesh/core/schema_diff.py
index e1f9d72a6c..ecf38b18a8 100644
--- a/sqlmesh/core/schema_diff.py
+++ b/sqlmesh/core/schema_diff.py
@@ -37,7 +37,7 @@ def is_additive(self) -> bool:
@property
@abc.abstractmethod
- def _alter_actions(self) -> t.List[exp.Expression]:
+ def _alter_actions(self) -> t.List[exp.Expr]:
pass
@property
@@ -104,7 +104,7 @@ def is_destructive(self) -> bool:
return self.is_part_of_destructive_change
@property
- def _alter_actions(self) -> t.List[exp.Expression]:
+ def _alter_actions(self) -> t.List[exp.Expr]:
column_def = exp.ColumnDef(
this=self.column,
kind=self.column_type,
@@ -127,7 +127,7 @@ def is_destructive(self) -> bool:
return True
@property
- def _alter_actions(self) -> t.List[exp.Expression]:
+ def _alter_actions(self) -> t.List[exp.Expr]:
return [exp.Drop(this=self.column, kind="COLUMN", cascade=self.cascade)]
@@ -145,7 +145,7 @@ def is_destructive(self) -> bool:
return self.is_part_of_destructive_change
@property
- def _alter_actions(self) -> t.List[exp.Expression]:
+ def _alter_actions(self) -> t.List[exp.Expr]:
return [
exp.AlterColumn(
this=self.column,
@@ -363,14 +363,12 @@ class SchemaDiffer(PydanticModel):
coerceable_types_: t.Dict[exp.DataType, t.Set[exp.DataType]] = Field(
default_factory=dict, alias="coerceable_types"
)
- precision_increase_allowed_types: t.Optional[t.Set[exp.DataType.Type]] = None
+ precision_increase_allowed_types: t.Optional[t.Set[exp.DType]] = None
support_coercing_compatible_types: bool = False
drop_cascade: bool = False
- parameterized_type_defaults: t.Dict[
- exp.DataType.Type, t.List[t.Tuple[t.Union[int, float], ...]]
- ] = {}
- max_parameter_length: t.Dict[exp.DataType.Type, t.Union[int, float]] = {}
- types_with_unlimited_length: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
+ parameterized_type_defaults: t.Dict[exp.DType, t.List[t.Tuple[t.Union[int, float], ...]]] = {}
+ max_parameter_length: t.Dict[exp.DType, t.Union[int, float]] = {}
+ types_with_unlimited_length: t.Dict[exp.DType, t.Set[exp.DType]] = {}
treat_alter_data_type_as_destructive: bool = False
_coerceable_types: t.Dict[exp.DataType, t.Set[exp.DataType]] = {}
diff --git a/sqlmesh/core/selector.py b/sqlmesh/core/selector.py
index 3865327acd..9eaf4995c8 100644
--- a/sqlmesh/core/selector.py
+++ b/sqlmesh/core/selector.py
@@ -191,7 +191,7 @@ def expand_model_selections(
models_by_tags.setdefault(tag, set())
models_by_tags[tag].add(model.fqn)
- def evaluate(node: exp.Expression) -> t.Set[str]:
+ def evaluate(node: exp.Expr) -> t.Set[str]:
if isinstance(node, exp.Var):
pattern = node.this
if "*" in pattern:
@@ -400,7 +400,7 @@ class Direction(exp.Expression):
pass
-def parse(selector: str, dialect: DialectType = None) -> exp.Expression:
+def parse(selector: str, dialect: DialectType = None) -> exp.Expr:
tokens = SelectorDialect().tokenize(selector)
i = 0
@@ -444,7 +444,7 @@ def _parse_kind(kind: str) -> bool:
return True
return False
- def _parse_var() -> exp.Expression:
+ def _parse_var() -> exp.Expr:
upstream = _match(TokenType.PLUS)
downstream = None
tag = _parse_kind("tag")
@@ -457,7 +457,7 @@ def _parse_var() -> exp.Expression:
name = _prev().text
rstar = "*" if _match(TokenType.STAR) else ""
downstream = _match(TokenType.PLUS)
- this: exp.Expression = exp.Var(this=f"{lstar}{name}{rstar}")
+ this: exp.Expr = exp.Var(this=f"{lstar}{name}{rstar}")
elif _match(TokenType.L_PAREN):
this = exp.Paren(this=_parse_conjunction())
@@ -483,12 +483,12 @@ def _parse_var() -> exp.Expression:
this = Direction(this=this, **directions)
return this
- def _parse_unary() -> exp.Expression:
+ def _parse_unary() -> exp.Expr:
if _match(TokenType.CARET):
return exp.Not(this=_parse_unary())
return _parse_var()
- def _parse_conjunction() -> exp.Expression:
+ def _parse_conjunction() -> exp.Expr:
this = _parse_unary()
if _match(TokenType.AMP):
diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py
index 1808011854..b1ffd4dc26 100644
--- a/sqlmesh/core/snapshot/evaluator.py
+++ b/sqlmesh/core/snapshot/evaluator.py
@@ -249,7 +249,7 @@ def evaluate_and_fetch(
query_or_df = next(queries_or_dfs)
if isinstance(query_or_df, pd.DataFrame):
return query_or_df.head(limit)
- if not isinstance(query_or_df, exp.Expression):
+ if not isinstance(query_or_df, exp.Expr):
# We assume that if this branch is reached, `query_or_df` is a pyspark / snowpark / bigframe dataframe,
# so we use `limit` instead of `head` to get back a dataframe instead of List[Row]
# https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.DataFrame.head.html#pyspark.sql.DataFrame.head
@@ -714,7 +714,7 @@ def _evaluate_snapshot(
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
is_snapshot_deployable = deployability_index.is_deployable(snapshot)
target_table_name = snapshot.table_name(is_deployable=is_snapshot_deployable)
- # https://github.com/TobikoData/sqlmesh/issues/2609
+ # https://github.com/SQLMesh/sqlmesh/issues/2609
# If there are no existing intervals yet; only consider this a first insert for the first snapshot in the batch
if target_table_exists is None:
target_table_exists = adapter.table_exists(target_table_name)
@@ -940,7 +940,7 @@ def _render_and_insert_snapshot(
snapshots: t.Dict[str, Snapshot],
render_kwargs: t.Dict[str, t.Any],
create_render_kwargs: t.Dict[str, t.Any],
- rendered_physical_properties: t.Dict[str, exp.Expression],
+ rendered_physical_properties: t.Dict[str, exp.Expr],
deployability_index: DeployabilityIndex,
target_table_name: str,
is_first_insert: bool,
@@ -1069,7 +1069,7 @@ def _clone_snapshot_in_dev(
snapshots: t.Dict[str, Snapshot],
deployability_index: DeployabilityIndex,
render_kwargs: t.Dict[str, t.Any],
- rendered_physical_properties: t.Dict[str, exp.Expression],
+ rendered_physical_properties: t.Dict[str, exp.Expr],
allow_destructive_snapshots: t.Set[str],
allow_additive_snapshots: t.Set[str],
run_pre_post_statements: bool = False,
@@ -1186,7 +1186,7 @@ def _migrate_target_table(
snapshots: t.Dict[str, Snapshot],
deployability_index: DeployabilityIndex,
render_kwargs: t.Dict[str, t.Any],
- rendered_physical_properties: t.Dict[str, exp.Expression],
+ rendered_physical_properties: t.Dict[str, exp.Expr],
allow_destructive_snapshots: t.Set[str],
allow_additive_snapshots: t.Set[str],
run_pre_post_statements: bool = False,
@@ -1472,7 +1472,7 @@ def _execute_create(
is_table_deployable: bool,
deployability_index: DeployabilityIndex,
create_render_kwargs: t.Dict[str, t.Any],
- rendered_physical_properties: t.Dict[str, exp.Expression],
+ rendered_physical_properties: t.Dict[str, exp.Expr],
dry_run: bool,
run_pre_post_statements: bool = True,
skip_grants: bool = False,
@@ -3106,7 +3106,7 @@ def create(
query=model.render_query_or_raise(**render_kwargs),
target_columns_to_types=model.columns_to_types,
partitioned_by=model.partitioned_by,
- clustered_by=model.clustered_by,
+ clustered_by=model.clustered_by, # type: ignore[arg-type]
table_properties=kwargs.get("physical_properties", model.physical_properties),
table_description=model.description,
column_descriptions=model.column_descriptions,
@@ -3151,7 +3151,7 @@ def insert(
query=query_or_df, # type: ignore
target_columns_to_types=model.columns_to_types,
partitioned_by=model.partitioned_by,
- clustered_by=model.clustered_by,
+ clustered_by=model.clustered_by, # type: ignore[arg-type]
table_properties=kwargs.get("physical_properties", model.physical_properties),
table_description=model.description,
column_descriptions=model.column_descriptions,
diff --git a/sqlmesh/core/state_sync/common.py b/sqlmesh/core/state_sync/common.py
index 056565b060..d1208c5213 100644
--- a/sqlmesh/core/state_sync/common.py
+++ b/sqlmesh/core/state_sync/common.py
@@ -140,9 +140,9 @@ def all_batch_range(cls) -> ExpiredBatchRange:
def _expanded_tuple_comparison(
cls,
columns: t.List[exp.Column],
- values: t.List[exp.Literal],
- operator: t.Type[exp.Expression],
- ) -> exp.Expression:
+ values: t.List[t.Union[exp.Literal, exp.Neg]],
+ operator: t.Type[exp.Expr],
+ ) -> exp.Condition:
"""Generate expanded tuple comparison that works across all SQL engines.
Converts tuple comparisons like (a, b, c) OP (x, y, z) into an expanded form
@@ -177,8 +177,8 @@ def _expanded_tuple_comparison(
# e.g., (a, b) <= (x, y) becomes: a < x OR (a = x AND b <= y)
# For < and >, we use the strict operator throughout
# e.g., (a, b) > (x, y) becomes: a > x OR (a = x AND b > x)
- strict_operator: t.Type[exp.Expression]
- final_operator: t.Type[exp.Expression]
+ strict_operator: t.Type[exp.Expr]
+ final_operator: t.Type[exp.Expr]
if operator in (exp.LTE, exp.GTE):
# For inclusive operators (<=, >=), use strict form for intermediate columns
@@ -190,7 +190,7 @@ def _expanded_tuple_comparison(
strict_operator = operator
final_operator = operator
- conditions: t.List[exp.Expression] = []
+ conditions: t.List[exp.Expr] = []
for i in range(len(columns)):
# Build equality conditions for all columns before current
equality_conditions = [exp.EQ(this=columns[j], expression=values[j]) for j in range(i)]
@@ -204,10 +204,10 @@ def _expanded_tuple_comparison(
else:
conditions.append(comparison_condition)
- return exp.or_(*conditions) if len(conditions) > 1 else conditions[0]
+ return exp.or_(*conditions) if len(conditions) > 1 else t.cast(exp.Condition, conditions[0])
@property
- def where_filter(self) -> exp.Expression:
+ def where_filter(self) -> exp.Condition:
# Use expanded tuple comparisons for cross-engine compatibility
# Native tuple comparisons like (a, b) > (x, y) don't work reliably across all SQL engines
columns = [
@@ -223,7 +223,7 @@ def where_filter(self) -> exp.Expression:
start_condition = self._expanded_tuple_comparison(columns, start_values, exp.GT)
- range_filter: exp.Expression
+ range_filter: exp.Condition
if isinstance(self.end, RowBoundary):
end_values = [
exp.Literal.number(self.end.updated_ts),
diff --git a/sqlmesh/core/state_sync/db/environment.py b/sqlmesh/core/state_sync/db/environment.py
index e3f1d1ec9e..713ce0193e 100644
--- a/sqlmesh/core/state_sync/db/environment.py
+++ b/sqlmesh/core/state_sync/db/environment.py
@@ -296,7 +296,7 @@ def _environment_summmary_from_row(self, row: t.Tuple[str, ...]) -> EnvironmentS
def _environments_query(
self,
- where: t.Optional[str | exp.Expression] = None,
+ where: t.Optional[str | exp.Expr] = None,
lock_for_update: bool = False,
required_fields: t.Optional[t.List[str]] = None,
) -> exp.Select:
@@ -310,7 +310,7 @@ def _environments_query(
return query.lock(copy=False)
return query
- def _create_expiration_filter_expr(self, current_ts: int) -> exp.Expression:
+ def _create_expiration_filter_expr(self, current_ts: int) -> exp.Expr:
"""Creates a SQLGlot filter expression to find expired environments.
Args:
@@ -322,7 +322,7 @@ def _create_expiration_filter_expr(self, current_ts: int) -> exp.Expression:
)
def _fetch_environment_summaries(
- self, where: t.Optional[str | exp.Expression] = None
+ self, where: t.Optional[str | exp.Expr] = None
) -> t.List[EnvironmentSummary]:
return [
self._environment_summmary_from_row(row)
diff --git a/sqlmesh/core/state_sync/db/migrator.py b/sqlmesh/core/state_sync/db/migrator.py
index ad60c57570..8d73e1d395 100644
--- a/sqlmesh/core/state_sync/db/migrator.py
+++ b/sqlmesh/core/state_sync/db/migrator.py
@@ -195,7 +195,7 @@ def _apply_migrations(
raise SQLMeshError(
f"Number of snapshots before ({snapshot_count_before}) and after "
f"({snapshot_count_after}) applying migration scripts {scripts} does not match. "
- "Please file an issue issue at https://github.com/TobikoData/sqlmesh/issues/new."
+ "Please file an issue issue at https://github.com/SQLMesh/sqlmesh/issues/new."
)
migrate_snapshots_and_environments = (
diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py
index d584c69d65..8ca98f2d48 100644
--- a/sqlmesh/core/state_sync/db/snapshot.py
+++ b/sqlmesh/core/state_sync/db/snapshot.py
@@ -623,7 +623,7 @@ def _get_snapshots_expressions(
self,
snapshot_ids: t.Iterable[SnapshotIdLike],
lock_for_update: bool = False,
- ) -> t.Iterator[exp.Expression]:
+ ) -> t.Iterator[exp.Expr]:
for where in snapshot_id_filter(
self.engine_adapter,
snapshot_ids,
diff --git a/sqlmesh/core/state_sync/db/utils.py b/sqlmesh/core/state_sync/db/utils.py
index 87c259f5d6..b0f321e21f 100644
--- a/sqlmesh/core/state_sync/db/utils.py
+++ b/sqlmesh/core/state_sync/db/utils.py
@@ -123,11 +123,9 @@ def create_batches(l: t.List[T], batch_size: int) -> t.List[t.List[T]]:
return [l[i : i + batch_size] for i in range(0, len(l), batch_size)]
-def fetchone(
- engine_adapter: EngineAdapter, query: t.Union[exp.Expression, str]
-) -> t.Optional[t.Tuple]:
+def fetchone(engine_adapter: EngineAdapter, query: t.Union[exp.Expr, str]) -> t.Optional[t.Tuple]:
return engine_adapter.fetchone(query, ignore_unsupported_errors=True, quote_identifiers=True)
-def fetchall(engine_adapter: EngineAdapter, query: t.Union[exp.Expression, str]) -> t.List[t.Tuple]:
+def fetchall(engine_adapter: EngineAdapter, query: t.Union[exp.Expr, str]) -> t.List[t.Tuple]:
return engine_adapter.fetchall(query, ignore_unsupported_errors=True, quote_identifiers=True)
diff --git a/sqlmesh/core/state_sync/export_import.py b/sqlmesh/core/state_sync/export_import.py
index 3a63351ddb..2461ee50fa 100644
--- a/sqlmesh/core/state_sync/export_import.py
+++ b/sqlmesh/core/state_sync/export_import.py
@@ -29,7 +29,7 @@
class SQLMeshJSONStreamEncoder(JSONStreamEncoder):
def default(self, obj: t.Any) -> t.Any:
- if isinstance(obj, exp.Expression):
+ if isinstance(obj, exp.Expr):
return _expression_encoder(obj)
return super().default(obj)
diff --git a/sqlmesh/core/table_diff.py b/sqlmesh/core/table_diff.py
index bd32cc170f..df99227f89 100644
--- a/sqlmesh/core/table_diff.py
+++ b/sqlmesh/core/table_diff.py
@@ -224,9 +224,9 @@ def __init__(
adapter: EngineAdapter,
source: TableName,
target: TableName,
- on: t.List[str] | exp.Condition,
+ on: t.List[str] | exp.Expr,
skip_columns: t.List[str] | None = None,
- where: t.Optional[str | exp.Condition] = None,
+ where: t.Optional[str | exp.Expr] = None,
limit: int = 20,
source_alias: t.Optional[str] = None,
target_alias: t.Optional[str] = None,
@@ -305,18 +305,18 @@ def key_columns(self) -> t.Tuple[t.List[exp.Column], t.List[exp.Column], t.List[
return s_index, t_index, index_cols
@property
- def source_key_expression(self) -> exp.Expression:
+ def source_key_expression(self) -> exp.Expr:
s_index, _, _ = self.key_columns
return self._key_expression(s_index, self.source_schema)
@property
- def target_key_expression(self) -> exp.Expression:
+ def target_key_expression(self) -> exp.Expr:
_, t_index, _ = self.key_columns
return self._key_expression(t_index, self.target_schema)
def _key_expression(
self, cols: t.List[exp.Column], schema: t.Dict[str, exp.DataType]
- ) -> exp.Expression:
+ ) -> exp.Expr:
# if there is a single column, dont do anything fancy to it in order to allow existing indexes to be hit
if len(cols) == 1:
return exp.to_column(cols[0].name)
@@ -363,7 +363,7 @@ def row_diff(
s_index_names = [c.name for c in s_index]
t_index_names = [t.name for t in t_index]
- def _column_expr(name: str, table: str) -> exp.Expression:
+ def _column_expr(name: str, table: str) -> exp.Expr:
column_type = matched_columns[name]
qualified_column = exp.column(name, table)
@@ -678,9 +678,9 @@ def _column_expr(name: str, table: str) -> exp.Expression:
def _fetch_sample(
self,
sample_table: exp.Table,
- s_selects: t.Dict[str, exp.Alias],
+ s_selects: t.Dict[str, exp.Expr],
s_index: t.List[exp.Column],
- t_selects: t.Dict[str, exp.Alias],
+ t_selects: t.Dict[str, exp.Expr],
t_index: t.List[exp.Column],
limit: int,
) -> pd.DataFrame:
@@ -742,5 +742,5 @@ def _fetch_sample(
return self.adapter.fetchdf(query, quote_identifiers=True)
-def name(e: exp.Expression) -> str:
+def name(e: exp.Expr) -> str:
return e.args["alias"].sql(identify=True)
diff --git a/sqlmesh/core/test/definition.py b/sqlmesh/core/test/definition.py
index 1c9807cfa1..629e8f8d5b 100644
--- a/sqlmesh/core/test/definition.py
+++ b/sqlmesh/core/test/definition.py
@@ -355,11 +355,12 @@ def _to_hashable(x: t.Any) -> t.Any:
for df in _split_df_by_column_pairs(diff)
)
else:
- from pandas import MultiIndex
+ from pandas import DataFrame, MultiIndex
levels = t.cast(MultiIndex, diff.columns).levels[0]
for col in levels:
- col_diff = diff[col]
+ # diff[col] returns a DataFrame when columns is a MultiIndex
+ col_diff = t.cast(DataFrame, diff[col])
if not col_diff.empty:
table = df_to_table(
f"[bold red]Column '{col}' mismatch{failed_subtest}[/bold red]",
@@ -673,7 +674,7 @@ def _add_missing_columns(
class SqlModelTest(ModelTest):
- def test_ctes(self, ctes: t.Dict[str, exp.Expression], recursive: bool = False) -> None:
+ def test_ctes(self, ctes: t.Dict[str, exp.Expr], recursive: bool = False) -> None:
"""Run CTE queries and compare output to expected output"""
for cte_name, values in self.body["outputs"].get("ctes", {}).items():
with self.subTest(cte=cte_name):
@@ -710,7 +711,7 @@ def runTest(self) -> None:
query = self._render_model_query()
sql = query.sql(self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql)
- with_clause = query.args.get("with")
+ with_clause = query.args.get("with_")
if with_clause:
self.test_ctes(
@@ -818,7 +819,7 @@ def _execute_model(self) -> pd.DataFrame:
time_kwargs = {key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables}
df = next(self.model.render(context=self.context, variables=variables, **time_kwargs))
- assert not isinstance(df, exp.Expression)
+ assert not isinstance(df, exp.Expr)
return df if isinstance(df, pd.DataFrame) else df.toPandas()
@@ -904,7 +905,7 @@ def generate_test(
if isinstance(model, SqlModel):
assert isinstance(test, SqlModelTest)
model_query = test._render_model_query()
- with_clause = model_query.args.get("with")
+ with_clause = model_query.args.get("with_")
if with_clause and include_ctes:
ctes = {}
diff --git a/sqlmesh/dbt/column.py b/sqlmesh/dbt/column.py
index 755f574388..80a6ad9325 100644
--- a/sqlmesh/dbt/column.py
+++ b/sqlmesh/dbt/column.py
@@ -42,7 +42,7 @@ def column_types_to_sqlmesh(
)
if column_def.args.get("constraints"):
logger.warning(
- f"Ignoring unsupported constraints for column '{name}' with definition '{column.data_type}'. Please refer to github.com/TobikoData/sqlmesh/issues/4717 for more information."
+ f"Ignoring unsupported constraints for column '{name}' with definition '{column.data_type}'. Please refer to github.com/SQLMesh/sqlmesh/issues/4717 for more information."
)
kind = column_def.kind
if kind:
diff --git a/sqlmesh/dbt/model.py b/sqlmesh/dbt/model.py
index 41cea9b9ae..55994abf85 100644
--- a/sqlmesh/dbt/model.py
+++ b/sqlmesh/dbt/model.py
@@ -485,7 +485,7 @@ def model_kind(self, context: DbtContext) -> ModelKind:
raise ConfigError(f"{materialization.value} materialization not supported.")
- def _big_query_partition_by_expr(self, context: DbtContext) -> exp.Expression:
+ def _big_query_partition_by_expr(self, context: DbtContext) -> exp.Expr:
assert isinstance(self.partition_by, dict)
data_type = self.partition_by["data_type"].lower()
raw_field = self.partition_by["field"]
diff --git a/sqlmesh/integrations/github/cicd/command.py b/sqlmesh/integrations/github/cicd/command.py
index f1b611150a..5506d4917b 100644
--- a/sqlmesh/integrations/github/cicd/command.py
+++ b/sqlmesh/integrations/github/cicd/command.py
@@ -25,12 +25,21 @@
envvar="GITHUB_TOKEN",
help="The Github Token to be used. Pass in `${{ secrets.GITHUB_TOKEN }}` if you want to use the one created by Github actions",
)
+@click.option(
+ "--full-logs",
+ is_flag=True,
+ help="Whether to print all logs in the Github Actions output or only in their relevant GA check",
+)
@click.pass_context
-def github(ctx: click.Context, token: str) -> None:
+def github(ctx: click.Context, token: str, full_logs: bool = False) -> None:
"""Github Action CI/CD Bot. See https://sqlmesh.readthedocs.io/en/stable/integrations/github/ for details"""
# set a larger width because if none is specified, it auto-detects 80 characters when running in GitHub Actions
# which can result in surprise newlines when outputting dates to backfill
- set_console(MarkdownConsole(width=1000, warning_capture_only=True, error_capture_only=True))
+ set_console(
+ MarkdownConsole(
+ width=1000, warning_capture_only=not full_logs, error_capture_only=not full_logs
+ )
+ )
ctx.obj["github"] = GithubController(
paths=ctx.obj["paths"],
token=token,
diff --git a/sqlmesh/integrations/github/cicd/config.py b/sqlmesh/integrations/github/cicd/config.py
index a287bf1af5..7fb3a0f5b6 100644
--- a/sqlmesh/integrations/github/cicd/config.py
+++ b/sqlmesh/integrations/github/cicd/config.py
@@ -36,6 +36,7 @@ class GithubCICDBotConfig(BaseConfig):
forward_only_branch_suffix_: t.Optional[str] = Field(
default=None, alias="forward_only_branch_suffix"
)
+ check_if_blocked_on_deploy_to_prod: bool = True
@model_validator(mode="before")
@classmethod
diff --git a/sqlmesh/integrations/github/cicd/controller.py b/sqlmesh/integrations/github/cicd/controller.py
index d7a9ef8eb8..40102b97e8 100644
--- a/sqlmesh/integrations/github/cicd/controller.py
+++ b/sqlmesh/integrations/github/cicd/controller.py
@@ -448,10 +448,9 @@ def prod_plan_with_gaps(self) -> Plan:
c.PROD,
# this is required to highlight any data gaps between this PR environment and prod (since PR environments may only contain a subset of data)
no_gaps=False,
- # this works because the snapshots were already categorized when applying self.pr_plan so there are no uncategorized local snapshots to trigger a plan error
- no_auto_categorization=True,
skip_tests=True,
skip_linter=True,
+ categorizer_config=self.bot_config.auto_categorize_changes,
run=self.bot_config.run_on_deploy_to_prod,
forward_only=self.forward_only_plan,
)
@@ -773,10 +772,10 @@ def deploy_to_prod(self) -> None:
"PR is already merged and this event was triggered prior to the merge."
)
merge_status = self._get_merge_state_status()
- if merge_status.is_blocked:
+ if self.bot_config.check_if_blocked_on_deploy_to_prod and merge_status.is_blocked:
raise CICDBotError(
"Branch protection or ruleset requirement is likely not satisfied, e.g. missing CODEOWNERS approval. "
- "Please check PR and resolve any issues."
+ "Please check PR and resolve any issues. To disable this check, set `check_if_blocked_on_deploy_to_prod` to false in the bot configuration."
)
if merge_status.is_dirty:
raise CICDBotError(
diff --git a/sqlmesh/lsp/hints.py b/sqlmesh/lsp/hints.py
index a8d56e2f31..611ce8608d 100644
--- a/sqlmesh/lsp/hints.py
+++ b/sqlmesh/lsp/hints.py
@@ -5,7 +5,6 @@
from lsprotocol import types
from sqlglot import exp
-from sqlglot.expressions import Expression
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlmesh.core.model.definition import SqlModel
from sqlmesh.lsp.context import LSPContext, ModelTarget
@@ -60,7 +59,7 @@ def get_hints(
def _get_type_hints_for_select(
- expression: exp.Expression,
+ expression: exp.Expr,
dialect: str,
columns_to_types: t.Dict[str, exp.DataType],
start_line: int,
@@ -113,7 +112,7 @@ def _get_type_hints_for_select(
def _get_type_hints_for_model_from_query(
- query: Expression,
+ query: exp.Expr,
dialect: str,
columns_to_types: t.Dict[str, exp.DataType],
start_line: int,
diff --git a/sqlmesh/lsp/reference.py b/sqlmesh/lsp/reference.py
index 80d401f79c..73c4e5681b 100644
--- a/sqlmesh/lsp/reference.py
+++ b/sqlmesh/lsp/reference.py
@@ -209,7 +209,7 @@ def get_macro_reference(
target: t.Union[Model, StandaloneAudit],
read_file: t.List[str],
config_path: t.Optional[Path],
- node: exp.Expression,
+ node: exp.Expr,
macro_name: str,
) -> t.Optional[Reference]:
# Get the file path where the macro is defined
diff --git a/sqlmesh/migrations/v0092_warn_about_dbt_data_type_diff.py b/sqlmesh/migrations/v0092_warn_about_dbt_data_type_diff.py
index 02e2a5f4c1..5407e5a99a 100644
--- a/sqlmesh/migrations/v0092_warn_about_dbt_data_type_diff.py
+++ b/sqlmesh/migrations/v0092_warn_about_dbt_data_type_diff.py
@@ -5,7 +5,7 @@
doesn't match dbt's behavior. dbt only uses data_type for contracts/validation, not DDL.
This fix may cause diffs if tables were created with incorrect types.
-More context: https://github.com/TobikoData/sqlmesh/pull/5231
+More context: https://github.com/SQLMesh/sqlmesh/pull/5231
"""
import json
@@ -33,7 +33,7 @@ def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore
"tables may have been created with incorrect column types. After this migration, run "
"'sqlmesh diff prod' to check for column type differences, and if any are found, "
"apply a plan to correct the table schemas. For more details, see: "
- "https://github.com/TobikoData/sqlmesh/pull/5231."
+ "https://github.com/SQLMesh/sqlmesh/pull/5231."
)
for (snapshot,) in engine_adapter.fetchall(
diff --git a/sqlmesh/utils/date.py b/sqlmesh/utils/date.py
index c9bb19c835..bdc15125d4 100644
--- a/sqlmesh/utils/date.py
+++ b/sqlmesh/utils/date.py
@@ -168,7 +168,7 @@ def to_datetime(
dt: t.Optional[datetime] = value
elif isinstance(value, date):
dt = datetime(value.year, value.month, value.day)
- elif isinstance(value, exp.Expression):
+ elif isinstance(value, exp.Expr):
return to_datetime(value.name)
else:
try:
@@ -401,7 +401,7 @@ def to_time_column(
dialect: str,
time_column_format: t.Optional[str] = None,
nullable: bool = False,
-) -> exp.Expression:
+) -> exp.Expr:
"""Convert a TimeLike object to the same time format and type as the model's time column."""
if dialect == "clickhouse" and time_column_type.is_type(
*(exp.DataType.TEMPORAL_TYPES - {exp.DataType.Type.DATE, exp.DataType.Type.DATE32})
diff --git a/sqlmesh/utils/git.py b/sqlmesh/utils/git.py
index 00410e776c..cdb9d4e2d5 100644
--- a/sqlmesh/utils/git.py
+++ b/sqlmesh/utils/git.py
@@ -16,7 +16,9 @@ def list_untracked_files(self) -> t.List[Path]:
)
def list_uncommitted_changed_files(self) -> t.List[Path]:
- return self._execute_list_output(["diff", "--name-only", "--diff-filter=d"], self._git_root)
+ return self._execute_list_output(
+ ["diff", "--name-only", "--diff-filter=d", "HEAD"], self._git_root
+ )
def list_committed_changed_files(self, target_branch: str = "main") -> t.List[Path]:
return self._execute_list_output(
diff --git a/sqlmesh/utils/jinja.py b/sqlmesh/utils/jinja.py
index 240b183391..bd82cf225c 100644
--- a/sqlmesh/utils/jinja.py
+++ b/sqlmesh/utils/jinja.py
@@ -12,7 +12,8 @@
from jinja2 import Environment, Template, nodes, UndefinedError
from jinja2.runtime import Macro
-from sqlglot import Dialect, Expression, Parser, TokenType
+from sqlglot import Dialect, Parser, TokenType
+from sqlglot.expressions import Expression
from sqlmesh.core import constants as c
from sqlmesh.core import dialect as d
@@ -78,6 +79,11 @@ def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]:
self.reset()
self.sql = jinja
self._tokens = Dialect.get_or_raise(dialect).tokenize(jinja)
+
+ # guard for older sqlglot versions (before 30.0.3)
+ if hasattr(self, "_tokens_size"):
+ # keep the cached length in sync
+ self._tokens_size = len(self._tokens)
self._index = -1
self._advance()
diff --git a/sqlmesh/utils/lineage.py b/sqlmesh/utils/lineage.py
index f5b4506c68..f63395708d 100644
--- a/sqlmesh/utils/lineage.py
+++ b/sqlmesh/utils/lineage.py
@@ -70,7 +70,7 @@ class MacroReference(PydanticModel):
def extract_references_from_query(
- query: exp.Expression,
+ query: exp.Expr,
context: t.Union["Context", "GenericContext[t.Any]"],
document_path: Path,
read_file: t.List[str],
@@ -95,7 +95,11 @@ def extract_references_from_query(
# Check if this table reference is a CTE in the current scope
if cte_scope := scope.cte_sources.get(table_name):
+ if cte_scope.expression is None:
+ continue
cte = cte_scope.expression.parent
+ if cte is None:
+ continue
alias = cte.args["alias"]
if isinstance(alias, exp.TableAlias):
identifier = alias.this
diff --git a/sqlmesh/utils/metaprogramming.py b/sqlmesh/utils/metaprogramming.py
index 858e8a50da..cd77c36353 100644
--- a/sqlmesh/utils/metaprogramming.py
+++ b/sqlmesh/utils/metaprogramming.py
@@ -352,7 +352,8 @@ def walk(obj: t.Any, name: str, is_metadata: bool = False) -> None:
walk(base, base.__qualname__, is_metadata)
for k, v in obj.__dict__.items():
- if k.startswith("__"):
+ # skip dunder methods bar __init__ as it might contain user defined logic with cross class references
+ if k.startswith("__") and k != "__init__":
continue
# Traverse methods in a class to find global references
@@ -362,10 +363,14 @@ def walk(obj: t.Any, name: str, is_metadata: bool = False) -> None:
if callable(v):
# Walk the method if it's part of the object, else it's a global function and we just store it
if v.__qualname__.startswith(obj.__qualname__):
- for k, v in func_globals(v).items():
- walk(v, k, is_metadata)
- else:
- walk(v, v.__name__, is_metadata)
+ try:
+ for k, v in func_globals(v).items():
+ walk(v, k, is_metadata)
+ except (OSError, TypeError):
+ # __init__ may come from built-ins or wrapped callables
+ pass
+ else:
+ walk(v, k, is_metadata)
elif callable(obj):
for k, v in func_globals(obj).items():
walk(v, k, is_metadata)
@@ -439,6 +444,41 @@ def value(
)
+def _resolve_import_module(obj: t.Any, name: str) -> str:
+ """Resolve the most appropriate module path for importing an object.
+
+ When a callable's ``__module__`` points to a submodule of a known public
+ module (e.g. ``sqlglot.expressions.builders`` is a submodule of
+ ``sqlglot.expressions``), and the object is re-exported from that public
+ parent module, prefer the public parent so that generated import statements
+ remain stable across internal restructurings of third-party packages.
+
+ Args:
+ obj: The callable to resolve.
+ name: The name under which the object will be imported.
+
+ Returns:
+ The module path to use in the ``from import `` statement.
+ """
+ module_name = getattr(obj, "__module__", None) or ""
+ parts = module_name.split(".")
+
+ # Walk from the shallowest ancestor (excluding the top-level package) up to
+ # the immediate parent, returning the shallowest one that re-exports the object.
+ # We skip the top-level package to avoid over-normalizing (e.g. ``sqlglot``
+ # re-exports everything, but callers expect ``sqlglot.expressions``).
+ for i in range(2, len(parts)):
+ parent = ".".join(parts[:i])
+ try:
+ parent_module = sys.modules.get(parent) or importlib.import_module(parent)
+ if getattr(parent_module, name, None) is obj:
+ return parent
+ except Exception:
+ continue
+
+ return module_name
+
+
def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable]:
"""Serializes a python function into a self contained dictionary.
@@ -507,7 +547,7 @@ def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable
)
else:
serialized[k] = Executable(
- payload=f"from {v.__module__} import {name}",
+ payload=f"from {_resolve_import_module(v, name)} import {name}",
kind=ExecutableKind.IMPORT,
is_metadata=is_metadata,
)
diff --git a/sqlmesh/utils/pydantic.py b/sqlmesh/utils/pydantic.py
index 2c9c570e5b..8bc81e2774 100644
--- a/sqlmesh/utils/pydantic.py
+++ b/sqlmesh/utils/pydantic.py
@@ -56,7 +56,7 @@ def get_dialect(values: t.Any) -> str:
return model._dialect if dialect is None else dialect # type: ignore
-def _expression_encoder(e: exp.Expression) -> str:
+def _expression_encoder(e: exp.Expr) -> str:
return e.meta.get("sql") or e.sql(dialect=e.meta.get("dialect"))
@@ -70,7 +70,7 @@ class PydanticModel(pydantic.BaseModel):
# crippled badly. Here we need to enumerate all different ways of how sqlglot expressions
# show up in pydantic models.
json_encoders={
- exp.Expression: _expression_encoder,
+ exp.Expr: _expression_encoder,
exp.DataType: _expression_encoder,
exp.Tuple: _expression_encoder,
AuditQueryTypes: _expression_encoder, # type: ignore
@@ -190,7 +190,7 @@ def validate_list_of_strings(v: t.Any) -> t.List[str]:
def validate_string(v: t.Any) -> str:
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
return v.name
return str(v)
@@ -204,13 +204,13 @@ def validate_expression(expression: E, dialect: str) -> E:
def bool_validator(v: t.Any) -> bool:
if isinstance(v, exp.Boolean):
return v.this
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
return str_to_bool(v.name)
return str_to_bool(str(v or ""))
def positive_int_validator(v: t.Any) -> int:
- if isinstance(v, exp.Expression) and v.is_int:
+ if isinstance(v, exp.Expr) and v.is_int:
v = int(v.name)
if not isinstance(v, int):
raise ValueError(f"Invalid num {v}. Value must be an integer value")
@@ -237,10 +237,10 @@ def _formatted_validation_errors(error: pydantic.ValidationError) -> t.List[str]
def _get_field(
v: t.Any,
values: t.Any,
-) -> exp.Expression:
+) -> exp.Expr:
dialect = get_dialect(values)
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
expression = v
else:
expression = parse_one(v, dialect=dialect)
@@ -257,16 +257,16 @@ def _get_field(
def _get_fields(
v: t.Any,
values: t.Any,
-) -> t.List[exp.Expression]:
+) -> t.List[exp.Expr]:
dialect = get_dialect(values)
if isinstance(v, (exp.Tuple, exp.Array)):
- expressions: t.List[exp.Expression] = v.expressions
- elif isinstance(v, exp.Expression):
+ expressions: t.List[exp.Expr] = v.expressions
+ elif isinstance(v, exp.Expr):
expressions = [v]
else:
expressions = [
- parse_one(entry, dialect=dialect) if isinstance(entry, str) else entry
+ parse_one(entry, dialect=dialect) if isinstance(entry, str) else entry # type: ignore[misc]
for entry in ensure_list(v)
]
@@ -278,7 +278,7 @@ def _get_fields(
return results
-def list_of_fields_validator(v: t.Any, values: t.Any) -> t.List[exp.Expression]:
+def list_of_fields_validator(v: t.Any, values: t.Any) -> t.List[exp.Expr]:
return _get_fields(v, values)
@@ -291,15 +291,15 @@ def column_validator(v: t.Any, values: t.Any) -> exp.Column:
def list_of_fields_or_star_validator(
v: t.Any, values: t.Any
-) -> t.Union[exp.Star, t.List[exp.Expression]]:
+) -> t.Union[exp.Star, t.List[exp.Expr]]:
expressions = _get_fields(v, values)
if len(expressions) == 1 and isinstance(expressions[0], exp.Star):
return t.cast(exp.Star, expressions[0])
- return t.cast(t.List[exp.Expression], expressions)
+ return t.cast(t.List[exp.Expr], expressions)
def cron_validator(v: t.Any) -> str:
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
v = v.name
from croniter import CroniterBadCronError, croniter
@@ -338,7 +338,7 @@ def get_concrete_types_from_typehint(typehint: type[t.Any]) -> set[type[t.Any]]:
SQLGlotBool = bool
SQLGlotPositiveInt = int
SQLGlotColumn = exp.Column
- SQLGlotListOfFields = t.List[exp.Expression]
+ SQLGlotListOfFields = t.List[exp.Expr]
SQLGlotListOfFieldsOrStar = t.Union[SQLGlotListOfFields, exp.Star]
SQLGlotCron = str
else:
@@ -348,10 +348,8 @@ def get_concrete_types_from_typehint(typehint: type[t.Any]) -> set[type[t.Any]]:
SQLGlotString = t.Annotated[str, BeforeValidator(validate_string)]
SQLGlotBool = t.Annotated[bool, BeforeValidator(bool_validator)]
SQLGlotPositiveInt = t.Annotated[int, BeforeValidator(positive_int_validator)]
- SQLGlotColumn = t.Annotated[exp.Expression, BeforeValidator(column_validator)]
- SQLGlotListOfFields = t.Annotated[
- t.List[exp.Expression], BeforeValidator(list_of_fields_validator)
- ]
+ SQLGlotColumn = t.Annotated[exp.Expr, BeforeValidator(column_validator)]
+ SQLGlotListOfFields = t.Annotated[t.List[exp.Expr], BeforeValidator(list_of_fields_validator)]
SQLGlotListOfFieldsOrStar = t.Annotated[
t.Union[SQLGlotListOfFields, exp.Star], BeforeValidator(list_of_fields_or_star_validator)
]
diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py
index 480d186fa1..5e0737e1b6 100644
--- a/tests/cli/test_cli.py
+++ b/tests/cli/test_cli.py
@@ -878,7 +878,6 @@ def test_dlt_pipeline_errors(runner, tmp_path):
assert "Error: Could not attach to pipeline" in result.output
-@time_machine.travel(FREEZE_TIME)
def test_dlt_filesystem_pipeline(tmp_path):
import dlt
@@ -982,7 +981,6 @@ def test_dlt_filesystem_pipeline(tmp_path):
rmtree(storage_path)
-@time_machine.travel(FREEZE_TIME)
def test_dlt_pipeline(runner, tmp_path):
from dlt.common.pipeline import get_dlt_pipelines_dir
@@ -1951,7 +1949,7 @@ def test_init_dbt_template(runner: CliRunner, tmp_path: Path):
def test_init_project_engine_configs(tmp_path):
engine_type_to_config = {
"redshift": "# concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # user: \n # password: \n # database: \n # host: \n # port: \n # source_address: \n # unix_sock: \n # ssl: \n # sslmode: \n # timeout: \n # tcp_keepalive: \n # application_name: \n # preferred_role: \n # principal_arn: \n # credentials_provider: \n # region: \n # cluster_identifier: \n # iam: \n # is_serverless: \n # serverless_acct_id: \n # serverless_work_group: \n # enable_merge: ",
- "bigquery": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # method: oauth\n # project: \n # execution_project: \n # quota_project: \n # location: \n # keyfile: \n # keyfile_json: \n # token: \n # refresh_token: \n # client_id: \n # client_secret: \n # token_uri: \n # scopes: \n # impersonated_service_account: \n # job_creation_timeout_seconds: \n # job_execution_timeout_seconds: \n # job_retries: 1\n # job_retry_deadline_seconds: \n # priority: \n # maximum_bytes_billed: ",
+ "bigquery": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # method: oauth\n # project: \n # execution_project: \n # quota_project: \n # location: \n # keyfile: \n # keyfile_json: \n # token: \n # refresh_token: \n # client_id: \n # client_secret: \n # token_uri: \n # scopes: \n # impersonated_service_account: \n # job_creation_timeout_seconds: \n # job_execution_timeout_seconds: \n # job_retries: 1\n # job_retry_deadline_seconds: \n # priority: \n # maximum_bytes_billed: \n # reservation: ",
"snowflake": "account: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # user: \n # password: \n # warehouse: \n # database: \n # role: \n # authenticator: \n # token: \n # host: \n # port: \n # application: Tobiko_SQLMesh\n # private_key: \n # private_key_path: \n # private_key_passphrase: \n # session_parameters: ",
"databricks": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # server_hostname: \n # http_path: \n # access_token: \n # auth_type: \n # oauth_client_id: \n # oauth_client_secret: \n # catalog: \n # http_headers: \n # session_configuration: \n # databricks_connect_server_hostname: \n # databricks_connect_access_token: \n # databricks_connect_cluster_id: \n # databricks_connect_use_serverless: False\n # force_databricks_connect: False\n # disable_databricks_connect: False\n # disable_spark_session: False",
"postgres": "host: \n user: \n password: \n port: \n database: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: True\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # keepalives_idle: \n # connect_timeout: 10\n # role: \n # sslmode: \n # application_name: ",
diff --git a/tests/conftest.py b/tests/conftest.py
index b18271465d..46086444bd 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -381,7 +381,7 @@ def _make_function(
@pytest.fixture
def assert_exp_eq() -> t.Callable:
def _assert_exp_eq(
- source: exp.Expression | str, expected: exp.Expression | str, dialect: DialectType = None
+ source: exp.Expr | str, expected: exp.Expr | str, dialect: DialectType = None
) -> None:
source_exp = maybe_parse(source, dialect=dialect)
expected_exp = maybe_parse(expected, dialect=dialect)
diff --git a/tests/core/engine_adapter/__init__.py b/tests/core/engine_adapter/__init__.py
index 4761c4100b..a9370b8cc3 100644
--- a/tests/core/engine_adapter/__init__.py
+++ b/tests/core/engine_adapter/__init__.py
@@ -11,7 +11,7 @@ def to_sql_calls(adapter: EngineAdapter, identify: bool = True) -> t.List[str]:
value = call[0][0]
sql = (
value.sql(dialect=adapter.dialect, identify=identify)
- if isinstance(value, exp.Expression)
+ if isinstance(value, exp.Expr)
else str(value)
)
output.append(sql)
diff --git a/tests/core/engine_adapter/integration/__init__.py b/tests/core/engine_adapter/integration/__init__.py
index 49624154e4..47ccdc876a 100644
--- a/tests/core/engine_adapter/integration/__init__.py
+++ b/tests/core/engine_adapter/integration/__init__.py
@@ -276,7 +276,7 @@ def time_formatter(self) -> t.Callable:
return lambda x, _: exp.Literal.string(to_ds(x))
@property
- def partitioned_by(self) -> t.List[exp.Expression]:
+ def partitioned_by(self) -> t.List[exp.Expr]:
return [parse_one(self.time_column)]
@property
@@ -388,8 +388,8 @@ def table(self, table_name: TableName, schema: str = TEST_SCHEMA) -> exp.Table:
)
def physical_properties(
- self, properties_for_dialect: t.Dict[str, t.Dict[str, str | exp.Expression]]
- ) -> t.Dict[str, exp.Expression]:
+ self, properties_for_dialect: t.Dict[str, t.Dict[str, str | exp.Expr]]
+ ) -> t.Dict[str, exp.Expr]:
if props := properties_for_dialect.get(self.dialect):
return {k: exp.Literal.string(v) if isinstance(v, str) else v for k, v in props.items()}
return {}
@@ -756,7 +756,10 @@ def _get_create_user_or_role(
return username, f"CREATE ROLE {username}"
if self.dialect == "databricks":
# Creating an account-level group in Databricks requires making REST API calls so we are going to
- # use a pre-created group instead. We assume the suffix on the name is the unique id
+ # use a pre-created group instead. We assume the suffix on the name is the unique id.
+ # In the Databricks UI, Workspace Settings -> Identity and Access, create the following groups:
+ # - test_user, test_analyst, test_etl_user, test_reader, test_writer, test_admin
+ # (there do not need to be any users assigned to these groups)
return "_".join(username.split("_")[:-1]), None
if self.dialect == "bigquery":
# BigQuery uses IAM service accounts that need to be pre-created
diff --git a/tests/core/engine_adapter/integration/config.yaml b/tests/core/engine_adapter/integration/config.yaml
index 8e87b2c3c8..5635f4e1ba 100644
--- a/tests/core/engine_adapter/integration/config.yaml
+++ b/tests/core/engine_adapter/integration/config.yaml
@@ -128,7 +128,8 @@ gateways:
warehouse: {{ env_var('SNOWFLAKE_WAREHOUSE') }}
database: {{ env_var('SNOWFLAKE_DATABASE') }}
user: {{ env_var('SNOWFLAKE_USER') }}
- password: {{ env_var('SNOWFLAKE_PASSWORD') }}
+ authenticator: SNOWFLAKE_JWT
+ private_key_path: {{ env_var('SNOWFLAKE_PRIVATE_KEY_FILE', 'tests/fixtures/snowflake/rsa_key_no_pass.p8') }}
check_import: false
state_connection:
type: duckdb
@@ -139,7 +140,10 @@ gateways:
catalog: {{ env_var('DATABRICKS_CATALOG') }}
server_hostname: {{ env_var('DATABRICKS_SERVER_HOSTNAME') }}
http_path: {{ env_var('DATABRICKS_HTTP_PATH') }}
- access_token: {{ env_var('DATABRICKS_ACCESS_TOKEN') }}
+ auth_type: {{ env_var('DATABRICKS_AUTH_TYPE', 'databricks-oauth') }}
+ oauth_client_id: {{ env_var('DATABRICKS_CLIENT_ID') }}
+ oauth_client_secret: {{ env_var('DATABRICKS_CLIENT_SECRET') }}
+ databricks_connect_use_serverless: true
check_import: false
inttest_redshift:
diff --git a/tests/core/engine_adapter/integration/conftest.py b/tests/core/engine_adapter/integration/conftest.py
index 308819b671..3fb4bc15f1 100644
--- a/tests/core/engine_adapter/integration/conftest.py
+++ b/tests/core/engine_adapter/integration/conftest.py
@@ -7,7 +7,6 @@
import logging
from pytest import FixtureRequest
-
from sqlmesh import Config, EngineAdapter
from sqlmesh.core.constants import SQLMESH_PATH
from sqlmesh.core.config.connection import (
diff --git a/tests/core/engine_adapter/integration/test_freshness.py b/tests/core/engine_adapter/integration/test_freshness.py
index 5e4c4cf439..e5ee574e7e 100644
--- a/tests/core/engine_adapter/integration/test_freshness.py
+++ b/tests/core/engine_adapter/integration/test_freshness.py
@@ -25,6 +25,16 @@
EVALUATION_SPY = None
+@pytest.fixture(autouse=True)
+def _skip_snowflake(ctx: TestContext):
+ if ctx.dialect == "snowflake":
+ # these tests use callbacks that need to run db queries within a time_travel context that changes the system time to be in the future
+ # this causes invalid JWT's to be generated when the callbacks try to run a db query
+ pytest.skip(
+ "snowflake.connector generates an invalid JWT when time_travel changes the system time"
+ )
+
+
# Mock the snapshot evaluator's evaluate function to count the number of times it is called
@pytest.fixture(autouse=True, scope="function")
def _install_evaluation_spy(mocker: MockerFixture):
diff --git a/tests/core/engine_adapter/integration/test_integration_athena.py b/tests/core/engine_adapter/integration/test_integration_athena.py
index 1c0ece6d78..9d23af206e 100644
--- a/tests/core/engine_adapter/integration/test_integration_athena.py
+++ b/tests/core/engine_adapter/integration/test_integration_athena.py
@@ -378,7 +378,7 @@ def test_insert_overwrite_by_time_partition_date_type(
), # note: columns_to_types_from_df() would infer this as TEXT but we need a DATE type
}
- def time_formatter(time: TimeLike, _: t.Optional[t.Dict[str, exp.DataType]]) -> exp.Expression:
+ def time_formatter(time: TimeLike, _: t.Optional[t.Dict[str, exp.DataType]]) -> exp.Expr:
return exp.cast(exp.Literal.string(to_ds(time)), "date")
engine_adapter.create_table(
@@ -440,7 +440,7 @@ def test_insert_overwrite_by_time_partition_datetime_type(
), # note: columns_to_types_from_df() would infer this as TEXT but we need a DATETIME type
}
- def time_formatter(time: TimeLike, _: t.Optional[t.Dict[str, exp.DataType]]) -> exp.Expression:
+ def time_formatter(time: TimeLike, _: t.Optional[t.Dict[str, exp.DataType]]) -> exp.Expr:
return exp.cast(exp.Literal.string(to_ts(time)), "datetime")
engine_adapter.create_table(
diff --git a/tests/core/engine_adapter/integration/test_integration_clickhouse.py b/tests/core/engine_adapter/integration/test_integration_clickhouse.py
index f09360c673..4420acec71 100644
--- a/tests/core/engine_adapter/integration/test_integration_clickhouse.py
+++ b/tests/core/engine_adapter/integration/test_integration_clickhouse.py
@@ -64,9 +64,7 @@ def _create_table_and_insert_existing_data(
"ds": exp.DataType.build("Date", "clickhouse"),
},
table_name: str = "data_existing",
- partitioned_by: t.Optional[t.List[exp.Expression]] = [
- parse_one("toMonth(ds)", dialect="clickhouse")
- ],
+ partitioned_by: t.Optional[t.List[exp.Expr]] = [parse_one("toMonth(ds)", dialect="clickhouse")],
) -> exp.Table:
existing_data = existing_data
existing_table_name: exp.Table = ctx.table(table_name)
diff --git a/tests/core/engine_adapter/integration/test_integration_snowflake.py b/tests/core/engine_adapter/integration/test_integration_snowflake.py
index f9862c51cb..7f3c38be46 100644
--- a/tests/core/engine_adapter/integration/test_integration_snowflake.py
+++ b/tests/core/engine_adapter/integration/test_integration_snowflake.py
@@ -186,6 +186,7 @@ def _get_data_object(table: exp.Table) -> DataObject:
assert not metadata.is_clustered
+@pytest.mark.skip(reason="External volume LIST privileges not configured for CI test databases")
def test_create_iceberg_table(ctx: TestContext) -> None:
# Note: this test relies on a default Catalog and External Volume being configured in Snowflake
# ref: https://docs.snowflake.com/en/user-guide/tables-iceberg-configure-catalog-integration#set-a-default-catalog-at-the-account-database-or-schema-level
diff --git a/tests/core/engine_adapter/test_athena.py b/tests/core/engine_adapter/test_athena.py
index 66e84ae025..19c92f66ac 100644
--- a/tests/core/engine_adapter/test_athena.py
+++ b/tests/core/engine_adapter/test_athena.py
@@ -81,7 +81,7 @@ def table_diff(adapter: AthenaEngineAdapter) -> TableDiff:
def test_table_location(
adapter: AthenaEngineAdapter,
config_s3_warehouse_location: t.Optional[str],
- table_properties: t.Optional[t.Dict[str, exp.Expression]],
+ table_properties: t.Optional[t.Dict[str, exp.Expr]],
table: exp.Table,
expected_location: t.Optional[str],
) -> None:
diff --git a/tests/core/engine_adapter/test_bigquery.py b/tests/core/engine_adapter/test_bigquery.py
index 047613e47a..134f144df1 100644
--- a/tests/core/engine_adapter/test_bigquery.py
+++ b/tests/core/engine_adapter/test_bigquery.py
@@ -593,7 +593,7 @@ def _to_sql_calls(execute_mock: t.Any, identify: bool = True) -> t.List[str]:
for value in values:
sql = (
value.sql(dialect="bigquery", identify=identify)
- if isinstance(value, exp.Expression)
+ if isinstance(value, exp.Expr)
else str(value)
)
output.append(sql)
@@ -1245,7 +1245,7 @@ def test_sync_grants_config(make_mocked_engine_adapter: t.Callable, mocker: Mock
executed_sql = executed_query.sql(dialect="bigquery")
expected_sql = (
"SELECT privilege_type, grantee FROM `project`.`region-us-central1`.`INFORMATION_SCHEMA.OBJECT_PRIVILEGES` AS OBJECT_PRIVILEGES "
- "WHERE object_schema = 'dataset' AND object_name = 'test_table' AND SPLIT(grantee, ':')[OFFSET(1)] <> session_user()"
+ "WHERE object_schema = 'dataset' AND object_name = 'test_table' AND SPLIT(grantee, ':')[OFFSET(1)] <> SESSION_USER()"
)
assert executed_sql == expected_sql
@@ -1306,7 +1306,7 @@ def test_sync_grants_config_with_overlaps(
executed_sql = executed_query.sql(dialect="bigquery")
expected_sql = (
"SELECT privilege_type, grantee FROM `project`.`region-us-central1`.`INFORMATION_SCHEMA.OBJECT_PRIVILEGES` AS OBJECT_PRIVILEGES "
- "WHERE object_schema = 'dataset' AND object_name = 'test_table' AND SPLIT(grantee, ':')[OFFSET(1)] <> session_user()"
+ "WHERE object_schema = 'dataset' AND object_name = 'test_table' AND SPLIT(grantee, ':')[OFFSET(1)] <> SESSION_USER()"
)
assert executed_sql == expected_sql
diff --git a/tests/core/engine_adapter/test_clickhouse.py b/tests/core/engine_adapter/test_clickhouse.py
index 188ae7f394..7ff971b742 100644
--- a/tests/core/engine_adapter/test_clickhouse.py
+++ b/tests/core/engine_adapter/test_clickhouse.py
@@ -327,16 +327,16 @@ def build_properties_sql(storage_format="", order_by="", primary_key="", propert
assert (
build_properties_sql(
- order_by="ORDER_BY = 'timestamp with fill to toStartOfDay(toDateTime64(\\'2024-07-11\\', 3)) step toIntervalDay(1) interpolate(price as price)',"
+ order_by="ORDER_BY = 'timestamp with fill to dateTrunc(\\'DAY\\', toDateTime64(\\'2024-07-11\\', 3)) step toIntervalDay(1) interpolate(price as price)',"
)
- == "ENGINE=MergeTree ORDER BY (timestamp WITH FILL TO toStartOfDay(toDateTime64('2024-07-11', 3)) STEP toIntervalDay(1) INTERPOLATE (price AS price))"
+ == "ENGINE=MergeTree ORDER BY (timestamp WITH FILL TO dateTrunc('DAY', toDateTime64('2024-07-11', 3)) STEP toIntervalDay(1) INTERPOLATE (price AS price))"
)
assert (
build_properties_sql(
- order_by="ORDER_BY = (\"a\", 'timestamp with fill to toStartOfDay(toDateTime64(\\'2024-07-11\\', 3)) step toIntervalDay(1) interpolate(price as price)'),"
+ order_by="ORDER_BY = (\"a\", 'timestamp with fill to dateTrunc(\\'DAY\\', toDateTime64(\\'2024-07-11\\', 3)) step toIntervalDay(1) interpolate(price as price)'),"
)
- == "ENGINE=MergeTree ORDER BY (\"a\", timestamp WITH FILL TO toStartOfDay(toDateTime64('2024-07-11', 3)) STEP toIntervalDay(1) INTERPOLATE (price AS price))"
+ == "ENGINE=MergeTree ORDER BY (\"a\", timestamp WITH FILL TO dateTrunc('DAY', toDateTime64('2024-07-11', 3)) STEP toIntervalDay(1) INTERPOLATE (price AS price))"
)
assert (
@@ -368,7 +368,7 @@ def test_partitioned_by_expr(make_mocked_engine_adapter: t.Callable):
assert (
model.partitioned_by[0].sql("clickhouse")
- == """toMonday(CAST("ds" AS DateTime64(9, 'UTC')))"""
+ == """dateTrunc('WEEK', CAST("ds" AS DateTime64(9, 'UTC')))"""
)
# user specifies without time column, unknown time column type
@@ -393,7 +393,7 @@ def test_partitioned_by_expr(make_mocked_engine_adapter: t.Callable):
)
assert [p.sql("clickhouse") for p in model.partitioned_by] == [
- """toMonday(CAST("ds" AS DateTime64(9, 'UTC')))""",
+ """dateTrunc('WEEK', CAST("ds" AS DateTime64(9, 'UTC')))""",
'"x"',
]
@@ -417,7 +417,7 @@ def test_partitioned_by_expr(make_mocked_engine_adapter: t.Callable):
)
)
- assert model.partitioned_by[0].sql("clickhouse") == 'toMonday("ds")'
+ assert model.partitioned_by[0].sql("clickhouse") == """dateTrunc('WEEK', "ds")"""
# user doesn't specify, non-conformable time column type
model = load_sql_based_model(
@@ -441,7 +441,7 @@ def test_partitioned_by_expr(make_mocked_engine_adapter: t.Callable):
assert (
model.partitioned_by[0].sql("clickhouse")
- == """CAST(toMonday(CAST("ds" AS DateTime64(9, 'UTC'))) AS String)"""
+ == """CAST(dateTrunc('WEEK', CAST("ds" AS DateTime64(9, 'UTC'))) AS String)"""
)
# user specifies partitioned_by with time column
@@ -993,7 +993,7 @@ def test_insert_overwrite_by_condition_replace_partitioned(
temp_table_mock.return_value = make_temp_table_name(table_name, "abcd")
fetchone_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchone")
- fetchone_mock.return_value = "toMonday(ds)"
+ fetchone_mock.return_value = "dateTrunc('WEEK', ds)"
insert_table_name = make_temp_table_name("new_records", "abcd")
existing_table_name = make_temp_table_name("existing_records", "abcd")
@@ -1069,7 +1069,7 @@ def test_insert_overwrite_by_condition_where_partitioned(
temp_table_mock.return_value = make_temp_table_name(table_name, "abcd")
fetchone_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchone")
- fetchone_mock.return_value = "toMonday(ds)"
+ fetchone_mock.return_value = "dateTrunc('WEEK', ds)"
fetchall_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchall")
fetchall_mock.side_effect = [
@@ -1175,7 +1175,7 @@ def test_insert_overwrite_by_condition_by_key_partitioned(
temp_table_mock.return_value = make_temp_table_name(table_name, "abcd")
fetchone_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchone")
- fetchone_mock.side_effect = ["toMonday(ds)", "toMonday(ds)"]
+ fetchone_mock.side_effect = ["dateTrunc('WEEK', ds)", "dateTrunc('WEEK', ds)"]
fetchall_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchall")
fetchall_mock.side_effect = [
@@ -1240,7 +1240,7 @@ def test_insert_overwrite_by_condition_inc_by_partition(
temp_table_mock.return_value = make_temp_table_name(table_name, "abcd")
fetchone_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchone")
- fetchone_mock.return_value = "toMonday(ds)"
+ fetchone_mock.return_value = "dateTrunc('WEEK', ds)"
fetchall_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchall")
fetchall_mock.return_value = [("1",), ("2",), ("4",)]
@@ -1365,7 +1365,7 @@ def test_exchange_tables(
# The EXCHANGE TABLES call errored, so we RENAME TABLE instead
assert [
quote_identifiers(call.args[0]).sql("clickhouse")
- if isinstance(call.args[0], exp.Expression)
+ if isinstance(call.args[0], exp.Expr)
else call.args[0]
for call in execute_mock.call_args_list
] == [
diff --git a/tests/core/engine_adapter/test_mssql.py b/tests/core/engine_adapter/test_mssql.py
index bf28157d00..ec6a4ba3e8 100644
--- a/tests/core/engine_adapter/test_mssql.py
+++ b/tests/core/engine_adapter/test_mssql.py
@@ -833,7 +833,7 @@ def test_create_table_from_query(make_mocked_engine_adapter: t.Callable, mocker:
columns_mock.assert_called_once_with(exp.table_("__temp_ctas_test_random_id", quoted=True))
# We don't want to drop anything other than LIMIT 0
- # See https://github.com/TobikoData/sqlmesh/issues/4048
+ # See https://github.com/SQLMesh/sqlmesh/issues/4048
adapter.ctas(
table_name="test_schema.test_table",
query_or_df=parse_one(
@@ -848,7 +848,7 @@ def test_create_table_from_query(make_mocked_engine_adapter: t.Callable, mocker:
def test_replace_query_strategy(adapter: MSSQLEngineAdapter, mocker: MockerFixture):
- # ref issue 4472: https://github.com/TobikoData/sqlmesh/issues/4472
+ # ref issue 4472: https://github.com/SQLMesh/sqlmesh/issues/4472
# The FULL strategy calls EngineAdapter.replace_query() which calls _insert_overwrite_by_condition() should use DELETE+INSERT and not MERGE
expressions = d.parse(
f"""
diff --git a/tests/core/engine_adapter/test_snowflake.py b/tests/core/engine_adapter/test_snowflake.py
index 60f6d38e5f..dcb6820297 100644
--- a/tests/core/engine_adapter/test_snowflake.py
+++ b/tests/core/engine_adapter/test_snowflake.py
@@ -123,7 +123,7 @@ def test_get_data_objects_lowercases_columns(
def test_session(
mocker: MockerFixture,
make_mocked_engine_adapter: t.Callable,
- current_warehouse: t.Union[str, exp.Expression],
+ current_warehouse: t.Union[str, exp.Expr],
current_warehouse_exp: str,
configured_warehouse: t.Optional[str],
configured_warehouse_exp: t.Optional[str],
diff --git a/tests/core/engine_adapter/test_trino.py b/tests/core/engine_adapter/test_trino.py
index bf925c875a..1bfe82b858 100644
--- a/tests/core/engine_adapter/test_trino.py
+++ b/tests/core/engine_adapter/test_trino.py
@@ -404,6 +404,123 @@ def test_delta_timestamps(make_mocked_engine_adapter: t.Callable):
}
+def test_timestamp_mapping():
+ """Test that timestamp_mapping config property is properly defined and accessible."""
+ config = TrinoConnectionConfig(
+ user="user",
+ host="host",
+ catalog="catalog",
+ )
+
+ assert config._connection_factory_with_kwargs.keywords["source"] == "sqlmesh"
+
+ adapter = config.create_engine_adapter()
+ assert adapter.timestamp_mapping is None
+
+ config = TrinoConnectionConfig(
+ user="user",
+ host="host",
+ catalog="catalog",
+ source="my_source",
+ timestamp_mapping={
+ "TIMESTAMP": "TIMESTAMP(6)",
+ "TIMESTAMP(3)": "TIMESTAMP WITH TIME ZONE",
+ },
+ )
+ assert config._connection_factory_with_kwargs.keywords["source"] == "my_source"
+ adapter = config.create_engine_adapter()
+ assert adapter.timestamp_mapping is not None
+ assert adapter.timestamp_mapping[exp.DataType.build("TIMESTAMP")] == exp.DataType.build(
+ "TIMESTAMP(6)"
+ )
+
+
+def test_delta_timestamps_with_custom_mapping(make_mocked_engine_adapter: t.Callable):
+ """Test that _apply_timestamp_mapping + _to_delta_ts respects custom timestamp_mapping."""
+ # Create config with custom timestamp mapping
+ # Mapped columns are skipped by _to_delta_ts
+ config = TrinoConnectionConfig(
+ user="user",
+ host="host",
+ catalog="catalog",
+ timestamp_mapping={
+ "TIMESTAMP": "TIMESTAMP(3)",
+ "TIMESTAMP(1)": "TIMESTAMP(3)",
+ "TIMESTAMP WITH TIME ZONE": "TIMESTAMP(6) WITH TIME ZONE",
+ "TIMESTAMP(1) WITH TIME ZONE": "TIMESTAMP(6) WITH TIME ZONE",
+ },
+ )
+
+ adapter = make_mocked_engine_adapter(
+ TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping
+ )
+
+ ts3 = exp.DataType.build("timestamp(3)")
+ ts6_tz = exp.DataType.build("timestamp(6) with time zone")
+
+ columns_to_types = {
+ "ts": exp.DataType.build("TIMESTAMP"),
+ "ts_1": exp.DataType.build("TIMESTAMP(1)"),
+ "ts_tz": exp.DataType.build("TIMESTAMP WITH TIME ZONE"),
+ "ts_tz_1": exp.DataType.build("TIMESTAMP(1) WITH TIME ZONE"),
+ }
+
+ # Apply mapping first, then convert to delta types (skipping mapped columns)
+ mapped_columns_to_types, mapped_column_names = adapter._apply_timestamp_mapping(
+ columns_to_types
+ )
+ delta_columns_to_types = adapter._to_delta_ts(mapped_columns_to_types, mapped_column_names)
+
+ # All types were mapped, so _to_delta_ts skips them - they keep their mapped types
+ assert delta_columns_to_types == {
+ "ts": ts3,
+ "ts_1": ts3,
+ "ts_tz": ts6_tz,
+ "ts_tz_1": ts6_tz,
+ }
+
+
+def test_delta_timestamps_with_partial_mapping(make_mocked_engine_adapter: t.Callable):
+ """Test that _apply_timestamp_mapping + _to_delta_ts uses custom mapping for specified types."""
+ config = TrinoConnectionConfig(
+ user="user",
+ host="host",
+ catalog="catalog",
+ timestamp_mapping={
+ "TIMESTAMP": "TIMESTAMP(3)",
+ },
+ )
+
+ adapter = make_mocked_engine_adapter(
+ TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping
+ )
+
+ ts3 = exp.DataType.build("TIMESTAMP(3)")
+ ts6 = exp.DataType.build("timestamp(6)")
+ ts3_tz = exp.DataType.build("timestamp(3) with time zone")
+
+ columns_to_types = {
+ "ts": exp.DataType.build("TIMESTAMP"),
+ "ts_1": exp.DataType.build("TIMESTAMP(1)"),
+ "ts_tz": exp.DataType.build("TIMESTAMP WITH TIME ZONE"),
+ }
+
+ # Apply mapping first, then convert to delta types (skipping mapped columns)
+ mapped_columns_to_types, mapped_column_names = adapter._apply_timestamp_mapping(
+ columns_to_types
+ )
+ delta_columns_to_types = adapter._to_delta_ts(mapped_columns_to_types, mapped_column_names)
+
+ # TIMESTAMP is in mapping β TIMESTAMP(3), skipped by _to_delta_ts
+ # TIMESTAMP(1) is NOT in mapping, uses default TIMESTAMP β ts6
+ # TIMESTAMP WITH TIME ZONE is NOT in mapping, uses default TIMESTAMPTZ β ts3_tz
+ assert delta_columns_to_types == {
+ "ts": ts3, # Mapped to TIMESTAMP(3), skipped by _to_delta_ts
+ "ts_1": ts6, # Not in mapping, uses default
+ "ts_tz": ts3_tz, # Not in mapping, uses default
+ }
+
+
def test_table_format(trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture):
adapter = trino_mocked_engine_adapter
mocker.patch(
@@ -755,3 +872,77 @@ def test_insert_overwrite_time_partition_iceberg(
'DELETE FROM "my_catalog"."schema"."test_table" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'',
'INSERT INTO "my_catalog"."schema"."test_table" ("a", "b") SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'',
]
+
+
+def test_delta_timestamps_with_non_timestamp_columns(make_mocked_engine_adapter: t.Callable):
+ """Test that _apply_timestamp_mapping + _to_delta_ts handles non-timestamp columns."""
+ config = TrinoConnectionConfig(
+ user="user",
+ host="host",
+ catalog="catalog",
+ timestamp_mapping={
+ "TIMESTAMP": "TIMESTAMP(3)",
+ },
+ )
+
+ adapter = make_mocked_engine_adapter(
+ TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping
+ )
+
+ ts3 = exp.DataType.build("TIMESTAMP(3)")
+ ts6 = exp.DataType.build("timestamp(6)")
+
+ columns_to_types = {
+ "ts": exp.DataType.build("TIMESTAMP"),
+ "ts_1": exp.DataType.build("TIMESTAMP(1)"),
+ "int_col": exp.DataType.build("INT"),
+ "varchar_col": exp.DataType.build("VARCHAR(100)"),
+ "decimal_col": exp.DataType.build("DECIMAL(10,2)"),
+ }
+
+ # Apply mapping first, then convert to delta types (skipping mapped columns)
+ mapped_columns_to_types, mapped_column_names = adapter._apply_timestamp_mapping(
+ columns_to_types
+ )
+ delta_columns_to_types = adapter._to_delta_ts(mapped_columns_to_types, mapped_column_names)
+
+ # TIMESTAMP is in mapping β TIMESTAMP(3), skipped by _to_delta_ts
+ # TIMESTAMP(1) is NOT in mapping (exact match), uses default TIMESTAMP β ts6
+ # Non-timestamp columns should pass through unchanged
+ assert delta_columns_to_types == {
+ "ts": ts3, # Mapped to TIMESTAMP(3), skipped by _to_delta_ts
+ "ts_1": ts6, # Not in mapping, uses default
+ "int_col": exp.DataType.build("INT"),
+ "varchar_col": exp.DataType.build("VARCHAR(100)"),
+ "decimal_col": exp.DataType.build("DECIMAL(10,2)"),
+ }
+
+
+def test_delta_timestamps_with_empty_mapping(make_mocked_engine_adapter: t.Callable):
+ """Test that _to_delta_ts handles empty custom mapping dictionary."""
+ config = TrinoConnectionConfig(
+ user="user",
+ host="host",
+ catalog="catalog",
+ timestamp_mapping={},
+ )
+
+ adapter = make_mocked_engine_adapter(
+ TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping
+ )
+
+ ts6 = exp.DataType.build("timestamp(6)")
+ ts3_tz = exp.DataType.build("timestamp(3) with time zone")
+
+ columns_to_types = {
+ "ts": exp.DataType.build("TIMESTAMP"),
+ "ts_tz": exp.DataType.build("TIMESTAMP WITH TIME ZONE"),
+ }
+
+ delta_columns_to_types = adapter._to_delta_ts(columns_to_types)
+
+ # With empty custom mapping, should fall back to defaults
+ assert delta_columns_to_types == {
+ "ts": ts6,
+ "ts_tz": ts3_tz,
+ }
diff --git a/tests/core/integration/test_auto_restatement.py b/tests/core/integration/test_auto_restatement.py
index 70ca227fd3..1bda373a8f 100644
--- a/tests/core/integration/test_auto_restatement.py
+++ b/tests/core/integration/test_auto_restatement.py
@@ -27,7 +27,7 @@ def test_run_auto_restatement(init_and_plan_context: t.Callable):
@macro()
def record_intervals(
- evaluator, name: exp.Expression, start: exp.Expression, end: exp.Expression, **kwargs: t.Any
+ evaluator, name: exp.Expr, start: exp.Expr, end: exp.Expr, **kwargs: t.Any
) -> None:
if evaluator.runtime_stage == "evaluating":
evaluator.engine_adapter.insert_append(
@@ -178,7 +178,7 @@ def test_run_auto_restatement_failure(init_and_plan_context: t.Callable):
context, _ = init_and_plan_context("examples/sushi")
@macro()
- def fail_auto_restatement(evaluator, start: exp.Expression, **kwargs: t.Any) -> None:
+ def fail_auto_restatement(evaluator, start: exp.Expr, **kwargs: t.Any) -> None:
if evaluator.runtime_stage == "evaluating" and start.name != "2023-01-01":
raise Exception("Failed")
diff --git a/tests/core/integration/test_aux_commands.py b/tests/core/integration/test_aux_commands.py
index ecdd3e05fc..326e81e0c1 100644
--- a/tests/core/integration/test_aux_commands.py
+++ b/tests/core/integration/test_aux_commands.py
@@ -287,20 +287,20 @@ def test_destroy(copy_to_temp_path):
# Validate tables have been deleted as well
with pytest.raises(
- Exception, match=r"Catalog Error: Table with name model_two does not exist!"
+ Exception, match=r"Catalog Error: Table with name.*model_two.*does not exist"
):
context.fetchdf("SELECT * FROM db_1.first_schema.model_two")
with pytest.raises(
- Exception, match=r"Catalog Error: Table with name model_one does not exist!"
+ Exception, match=r"Catalog Error: Table with name.*model_one.*does not exist"
):
context.fetchdf("SELECT * FROM db_1.first_schema.model_one")
with pytest.raises(
- Exception, match=r"Catalog Error: Table with name model_two does not exist!"
+ Exception, match=r"Catalog Error: Table with name.*model_two.*does not exist"
):
context.engine_adapters["second"].fetchdf("SELECT * FROM db_2.second_schema.model_two")
with pytest.raises(
- Exception, match=r"Catalog Error: Table with name model_one does not exist!"
+ Exception, match=r"Catalog Error: Table with name.*model_one.*does not exist"
):
context.engine_adapters["second"].fetchdf("SELECT * FROM db_2.second_schema.model_one")
diff --git a/tests/core/integration/test_multi_repo.py b/tests/core/integration/test_multi_repo.py
index 6477b08741..4d72d137b3 100644
--- a/tests/core/integration/test_multi_repo.py
+++ b/tests/core/integration/test_multi_repo.py
@@ -421,6 +421,111 @@ def test_multi_hybrid(mocker):
validate_apply_basics(context, c.PROD, plan.snapshots.values())
+def test_multi_repo_no_project_to_project(copy_to_temp_path):
+ paths = copy_to_temp_path("examples/multi")
+ repo_1_path = f"{paths[0]}/repo_1"
+ repo_1_config_path = f"{repo_1_path}/config.yaml"
+ with open(repo_1_config_path, "r") as f:
+ config_content = f.read()
+ with open(repo_1_config_path, "w") as f:
+ f.write(config_content.replace("project: repo_1\n", ""))
+
+ context = Context(paths=[repo_1_path], gateway="memory")
+ context._new_state_sync().reset(default_catalog=context.default_catalog)
+ plan = context.plan_builder().build()
+ context.apply(plan)
+
+ # initially models in prod have no project
+ prod_snapshots = context.state_reader.get_snapshots(
+ context.state_reader.get_environment(c.PROD).snapshots
+ )
+ for snapshot in prod_snapshots.values():
+ assert snapshot.node.project == ""
+
+ # we now adopt multi project by adding a project name
+ with open(repo_1_config_path, "r") as f:
+ config_content = f.read()
+ with open(repo_1_config_path, "w") as f:
+ f.write("project: repo_1\n" + config_content)
+
+ context_with_project = Context(
+ paths=[repo_1_path],
+ state_sync=context.state_sync,
+ gateway="memory",
+ )
+ context_with_project._engine_adapter = context.engine_adapter
+ del context_with_project.engine_adapters
+
+ # local models should take precedence to pick up the new project name
+ local_model_a = context_with_project.get_model("bronze.a")
+ assert local_model_a.project == "repo_1"
+ local_model_b = context_with_project.get_model("bronze.b")
+ assert local_model_b.project == "repo_1"
+
+ # also verify the plan works
+ plan = context_with_project.plan_builder().build()
+ context_with_project.apply(plan)
+ validate_apply_basics(context_with_project, c.PROD, plan.snapshots.values())
+
+
+def test_multi_repo_local_model_overrides_prod_from_other_project(copy_to_temp_path):
+ paths = copy_to_temp_path("examples/multi")
+ repo_1_path = f"{paths[0]}/repo_1"
+ repo_2_path = f"{paths[0]}/repo_2"
+
+ context = Context(paths=[repo_1_path, repo_2_path], gateway="memory")
+ context._new_state_sync().reset(default_catalog=context.default_catalog)
+ plan = context.plan_builder().build()
+ assert len(plan.new_snapshots) == 5
+ context.apply(plan)
+
+ prod_model_c = context.get_model("silver.c")
+ assert prod_model_c.project == "repo_2"
+
+ with open(f"{repo_1_path}/models/c.sql", "w") as f:
+ f.write(
+ dedent("""\
+ MODEL (
+ name silver.c,
+ kind FULL
+ );
+
+ SELECT DISTINCT col_a, col_b
+ FROM bronze.a
+ """)
+ )
+
+ # silver.c exists locally in repo 1 now AND in prod under repo_2
+ context_repo1 = Context(
+ paths=[repo_1_path],
+ state_sync=context.state_sync,
+ gateway="memory",
+ )
+ context_repo1._engine_adapter = context.engine_adapter
+ del context_repo1.engine_adapters
+
+ # local model should take precedence and its project should reflect the new project name
+ local_model_c = context_repo1.get_model("silver.c")
+ assert local_model_c.project == "repo_1"
+
+ rendered = context_repo1.render("silver.c").sql()
+ assert "col_b" in rendered
+
+ # its downstream dependencies though should still be picked up
+ plan = context_repo1.plan_builder().build()
+ directly_modified_names = {snapshot.name for snapshot in plan.directly_modified}
+ assert '"memory"."silver"."c"' in directly_modified_names
+ assert '"memory"."silver"."d"' in directly_modified_names
+ missing_interval_names = {s.snapshot_id.name for s in plan.missing_intervals}
+ assert '"memory"."silver"."c"' in missing_interval_names
+ assert '"memory"."silver"."d"' in missing_interval_names
+
+ context_repo1.apply(plan)
+ validate_apply_basics(context_repo1, c.PROD, plan.snapshots.values())
+ result = context_repo1.fetchdf("SELECT * FROM memory.silver.c")
+ assert "col_b" in result.columns
+
+
def test_engine_adapters_multi_repo_all_gateways_gathered(copy_to_temp_path):
paths = copy_to_temp_path("examples/multi")
repo_1_path = paths[0] / "repo_1"
diff --git a/tests/core/integration/utils.py b/tests/core/integration/utils.py
index bc731e6cc8..ba233080b5 100644
--- a/tests/core/integration/utils.py
+++ b/tests/core/integration/utils.py
@@ -105,7 +105,10 @@ def apply_to_environment(
def change_data_type(
- context: Context, model_name: str, old_type: DataType.Type, new_type: DataType.Type
+ context: Context,
+ model_name: str,
+ old_type: exp.DType,
+ new_type: exp.DType,
) -> None:
model = context.get_model(model_name)
assert model is not None
diff --git a/tests/core/test_audit.py b/tests/core/test_audit.py
index 2ffcbbc4b2..90ac655cc6 100644
--- a/tests/core/test_audit.py
+++ b/tests/core/test_audit.py
@@ -329,7 +329,7 @@ def test_load_with_dictionary_defaults():
audit = load_audit(expressions, dialect="spark")
assert audit.defaults.keys() == {"field1", "field2"}
for value in audit.defaults.values():
- assert isinstance(value, exp.Expression)
+ assert isinstance(value, exp.Expr)
def test_load_with_single_defaults():
@@ -350,7 +350,7 @@ def test_load_with_single_defaults():
audit = load_audit(expressions, dialect="duckdb")
assert audit.defaults.keys() == {"field1"}
for value in audit.defaults.values():
- assert isinstance(value, exp.Expression)
+ assert isinstance(value, exp.Expr)
def test_no_audit_statement():
@@ -397,7 +397,7 @@ def test_no_query():
def test_macro(model: Model):
- expected_query = """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE "a" IS NULL"""
+ expected_query = """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE "a" IS NULL"""
audit = ModelAudit(
name="test_audit",
@@ -456,7 +456,7 @@ def test_not_null_audit(model: Model):
)
assert (
rendered_query_a.sql()
- == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE "a" IS NULL AND TRUE"""
+ == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE "a" IS NULL AND TRUE"""
)
rendered_query_a_and_b = model.render_audit_query(
@@ -465,7 +465,7 @@ def test_not_null_audit(model: Model):
)
assert (
rendered_query_a_and_b.sql()
- == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE ("a" IS NULL OR "b" IS NULL) AND TRUE"""
+ == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE ("a" IS NULL OR "b" IS NULL) AND TRUE"""
)
@@ -476,7 +476,7 @@ def test_not_null_audit_default_catalog(model_default_catalog: Model):
)
assert (
rendered_query_a.sql()
- == """SELECT * FROM (SELECT * FROM "test_catalog"."db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE "a" IS NULL AND TRUE"""
+ == """SELECT * FROM (SELECT * FROM "test_catalog"."db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE "a" IS NULL AND TRUE"""
)
rendered_query_a_and_b = model_default_catalog.render_audit_query(
@@ -485,7 +485,7 @@ def test_not_null_audit_default_catalog(model_default_catalog: Model):
)
assert (
rendered_query_a_and_b.sql()
- == """SELECT * FROM (SELECT * FROM "test_catalog"."db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE ("a" IS NULL OR "b" IS NULL) AND TRUE"""
+ == """SELECT * FROM (SELECT * FROM "test_catalog"."db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE ("a" IS NULL OR "b" IS NULL) AND TRUE"""
)
@@ -495,7 +495,7 @@ def test_unique_values_audit(model: Model):
)
assert (
rendered_query_a.sql()
- == 'SELECT * FROM (SELECT ROW_NUMBER() OVER (PARTITION BY "a" ORDER BY "a") AS "rank_a" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_q_0" WHERE "b" IS NULL) AS "_q_1" WHERE "rank_a" > 1'
+ == 'SELECT * FROM (SELECT ROW_NUMBER() OVER (PARTITION BY "a" ORDER BY "a") AS "rank_a" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE "b" IS NULL) AS "_1" WHERE "rank_a" > 1'
)
rendered_query_a_and_b = model.render_audit_query(
@@ -503,7 +503,7 @@ def test_unique_values_audit(model: Model):
)
assert (
rendered_query_a_and_b.sql()
- == 'SELECT * FROM (SELECT ROW_NUMBER() OVER (PARTITION BY "a" ORDER BY "a") AS "rank_a", ROW_NUMBER() OVER (PARTITION BY "b" ORDER BY "b") AS "rank_b" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_q_0" WHERE TRUE) AS "_q_1" WHERE "rank_a" > 1 OR "rank_b" > 1'
+ == 'SELECT * FROM (SELECT ROW_NUMBER() OVER (PARTITION BY "a" ORDER BY "a") AS "rank_a", ROW_NUMBER() OVER (PARTITION BY "b" ORDER BY "b") AS "rank_b" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE TRUE) AS "_1" WHERE "rank_a" > 1 OR "rank_b" > 1'
)
@@ -515,7 +515,7 @@ def test_accepted_values_audit(model: Model):
)
assert (
rendered_query.sql()
- == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE NOT "a" IN ('value_a', 'value_b') AND TRUE"""
+ == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE NOT "a" IN ('value_a', 'value_b') AND TRUE"""
)
@@ -526,7 +526,7 @@ def test_number_of_rows_audit(model: Model):
)
assert (
rendered_query.sql()
- == """SELECT COUNT(*) FROM (SELECT 1 FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE TRUE LIMIT 0 + 1) AS "_q_1" HAVING COUNT(*) <= 0"""
+ == """SELECT COUNT(*) FROM (SELECT 1 FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE TRUE LIMIT 0 + 1) AS "_1" HAVING COUNT(*) <= 0"""
)
@@ -537,7 +537,7 @@ def test_forall_audit(model: Model):
)
assert (
rendered_query_a.sql()
- == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE NOT ("a" >= "b") AND TRUE"""
+ == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE NOT ("a" >= "b") AND TRUE"""
)
rendered_query_a = model.render_audit_query(
@@ -546,7 +546,7 @@ def test_forall_audit(model: Model):
)
assert (
rendered_query_a.sql()
- == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE (NOT ("a" >= "b") OR NOT ("c" + "d" - "e" < 1.0)) AND TRUE"""
+ == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE (NOT ("a" >= "b") OR NOT ("c" + "d" - "e" < 1.0)) AND TRUE"""
)
rendered_query_a = model.render_audit_query(
@@ -556,7 +556,7 @@ def test_forall_audit(model: Model):
)
assert (
rendered_query_a.sql()
- == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE (NOT ("a" >= "b") OR NOT ("c" + "d" - "e" < 1.0)) AND "f" = 42"""
+ == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE (NOT ("a" >= "b") OR NOT ("c" + "d" - "e" < 1.0)) AND "f" = 42"""
)
@@ -566,21 +566,21 @@ def test_accepted_range_audit(model: Model):
)
assert (
rendered_query.sql()
- == 'SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_q_0" WHERE "a" < 0 AND TRUE'
+ == 'SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE "a" < 0 AND TRUE'
)
rendered_query = model.render_audit_query(
builtin.accepted_range_audit, column=exp.to_column("a"), max_v=100, inclusive=exp.false()
)
assert (
rendered_query.sql()
- == 'SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_q_0" WHERE "a" >= 100 AND TRUE'
+ == 'SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE "a" >= 100 AND TRUE'
)
rendered_query = model.render_audit_query(
builtin.accepted_range_audit, column=exp.to_column("a"), min_v=100, max_v=100
)
assert (
rendered_query.sql()
- == 'SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_q_0" WHERE ("a" < 100 OR "a" > 100) AND TRUE'
+ == 'SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE ("a" < 100 OR "a" > 100) AND TRUE'
)
@@ -591,7 +591,7 @@ def test_at_least_one_audit(model: Model):
)
assert (
rendered_query.sql()
- == 'SELECT 1 AS "1" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_q_0" WHERE TRUE GROUP BY 1 HAVING COUNT("a") = 0'
+ == 'SELECT 1 AS "1" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE TRUE GROUP BY 1 HAVING COUNT("a") = 0'
)
@@ -603,7 +603,7 @@ def test_mutually_exclusive_ranges_audit(model: Model):
)
assert (
rendered_query.sql()
- == '''WITH "window_functions" AS (SELECT "a" AS "lower_bound", "a" AS "upper_bound", LEAD("a") OVER (ORDER BY "a", "a") AS "next_lower_bound", ROW_NUMBER() OVER (ORDER BY "a" DESC, "a" DESC) = 1 AS "is_last_record" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE TRUE), "calc" AS (SELECT *, COALESCE("lower_bound" <= "upper_bound", FALSE) AS "lower_bound_lte_upper_bound", COALESCE("upper_bound" <= "next_lower_bound", "is_last_record", FALSE) AS "upper_bound_lte_next_lower_bound" FROM "window_functions" AS "window_functions"), "validation_errors" AS (SELECT * FROM "calc" AS "calc" WHERE NOT ("lower_bound_lte_upper_bound" AND "upper_bound_lte_next_lower_bound")) SELECT * FROM "validation_errors" AS "validation_errors"'''
+ == '''WITH "window_functions" AS (SELECT "a" AS "lower_bound", "a" AS "upper_bound", LEAD("a") OVER (ORDER BY "a", "a") AS "next_lower_bound", ROW_NUMBER() OVER (ORDER BY "a" DESC, "a" DESC) = 1 AS "is_last_record" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE TRUE), "calc" AS (SELECT *, COALESCE("lower_bound" <= "upper_bound", FALSE) AS "lower_bound_lte_upper_bound", COALESCE("upper_bound" <= "next_lower_bound", "is_last_record", FALSE) AS "upper_bound_lte_next_lower_bound" FROM "window_functions" AS "window_functions"), "validation_errors" AS (SELECT * FROM "calc" AS "calc" WHERE NOT ("lower_bound_lte_upper_bound" AND "upper_bound_lte_next_lower_bound")) SELECT * FROM "validation_errors" AS "validation_errors"'''
)
@@ -614,7 +614,7 @@ def test_sequential_values_audit(model: Model):
)
assert (
rendered_query.sql()
- == '''WITH "windowed" AS (SELECT "a", LAG("a") OVER (ORDER BY "a") AS "prv" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE TRUE), "validation_errors" AS (SELECT * FROM "windowed" AS "windowed" WHERE NOT ("a" = "prv" + 1)) SELECT * FROM "validation_errors" AS "validation_errors"'''
+ == '''WITH "windowed" AS (SELECT "a", LAG("a") OVER (ORDER BY "a") AS "prv" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE TRUE), "validation_errors" AS (SELECT * FROM "windowed" AS "windowed" WHERE NOT ("a" = "prv" + 1)) SELECT * FROM "validation_errors" AS "validation_errors"'''
)
@@ -627,7 +627,7 @@ def test_chi_square_audit(model: Model):
)
assert (
rendered_query.sql()
- == """WITH "samples" AS (SELECT "a" AS "x_a", "b" AS "x_b" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE (NOT "a" IS NULL AND NOT "b" IS NULL) AND TRUE), "contingency_table" AS (SELECT "x_a", "x_b", COUNT(*) AS "observed", (SELECT COUNT(*) FROM "samples" AS "t" WHERE "r"."x_a" = "t"."x_a") AS "tot_a", (SELECT COUNT(*) FROM "samples" AS "t" WHERE "r"."x_b" = "t"."x_b") AS "tot_b", (SELECT COUNT(*) FROM "samples" AS "samples") AS "g_t" /* g_t is the grand total */ FROM "samples" AS "r" GROUP BY "x_a", "x_b") SELECT ((SELECT COUNT(DISTINCT "x_a") FROM "contingency_table" AS "contingency_table") - 1) * ((SELECT COUNT(DISTINCT "x_b") FROM "contingency_table" AS "contingency_table") - 1) AS "degrees_of_freedom", SUM(("observed" - ("tot_a" * "tot_b" / "g_t")) * ("observed" - ("tot_a" * "tot_b" / "g_t")) / ("tot_a" * "tot_b" / "g_t")) AS "chi_square" FROM "contingency_table" AS "contingency_table" /* H0: the two variables are independent */ /* H1: the two variables are dependent */ /* if chi_square > critical_value, reject H0 */ /* if chi_square <= critical_value, fail to reject H0 */ HAVING NOT "chi_square" > 9.48773"""
+ == """WITH "samples" AS (SELECT "a" AS "x_a", "b" AS "x_b" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE (NOT "a" IS NULL AND NOT "b" IS NULL) AND TRUE), "contingency_table" AS (SELECT "x_a", "x_b", COUNT(*) AS "observed", (SELECT COUNT(*) FROM "samples" AS "t" WHERE "r"."x_a" = "t"."x_a") AS "tot_a", (SELECT COUNT(*) FROM "samples" AS "t" WHERE "r"."x_b" = "t"."x_b") AS "tot_b", (SELECT COUNT(*) FROM "samples" AS "samples") AS "g_t" /* g_t is the grand total */ FROM "samples" AS "r" GROUP BY "x_a", "x_b") SELECT ((SELECT COUNT(DISTINCT "x_a") FROM "contingency_table" AS "contingency_table") - 1) * ((SELECT COUNT(DISTINCT "x_b") FROM "contingency_table" AS "contingency_table") - 1) AS "degrees_of_freedom", SUM(("observed" - ("tot_a" * "tot_b" / "g_t")) * ("observed" - ("tot_a" * "tot_b" / "g_t")) / ("tot_a" * "tot_b" / "g_t")) AS "chi_square" FROM "contingency_table" AS "contingency_table" /* H0: the two variables are independent */ /* H1: the two variables are dependent */ /* if chi_square > critical_value, reject H0 */ /* if chi_square <= critical_value, fail to reject H0 */ HAVING NOT "chi_square" > 9.48773"""
)
@@ -639,7 +639,7 @@ def test_pattern_audits(model: Model):
)
assert (
rendered_query.sql()
- == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_q_0" WHERE (NOT REGEXP_LIKE("a", \'^\\d.*\') AND NOT REGEXP_LIKE("a", \'.*!$\')) AND TRUE"""
+ == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE (NOT REGEXP_LIKE("a", \'^\\d.*\') AND NOT REGEXP_LIKE("a", \'.*!$\')) AND TRUE"""
)
rendered_query = model.render_audit_query(
@@ -649,7 +649,7 @@ def test_pattern_audits(model: Model):
)
assert (
rendered_query.sql()
- == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_q_0" WHERE (REGEXP_LIKE("a", \'^\\d.*\') OR REGEXP_LIKE("a", \'.*!$\')) AND TRUE"""
+ == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE (REGEXP_LIKE("a", \'^\\d.*\') OR REGEXP_LIKE("a", \'.*!$\')) AND TRUE"""
)
rendered_query = model.render_audit_query(
@@ -659,7 +659,7 @@ def test_pattern_audits(model: Model):
)
assert (
rendered_query.sql()
- == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_q_0" WHERE (NOT "a" LIKE \'jim%\' AND NOT "a" LIKE \'pam%\') AND TRUE"""
+ == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE (NOT "a" LIKE \'jim%\' AND NOT "a" LIKE \'pam%\') AND TRUE"""
)
rendered_query = model.render_audit_query(
@@ -669,7 +669,7 @@ def test_pattern_audits(model: Model):
)
assert (
rendered_query.sql()
- == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_q_0" WHERE ("a" LIKE \'jim%\' OR "a" LIKE \'pam%\') AND TRUE"""
+ == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE ("a" LIKE \'jim%\' OR "a" LIKE \'pam%\') AND TRUE"""
)
@@ -814,7 +814,7 @@ def test_string_length_between_audit(model: Model):
)
assert (
rendered_query.sql()
- == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE (LENGTH("x") < 1 OR LENGTH("x") > 5) AND TRUE"""
+ == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE (LENGTH("x") < 1 OR LENGTH("x") > 5) AND TRUE"""
)
@@ -824,7 +824,7 @@ def test_not_constant_audit(model: Model):
)
assert (
rendered_query.sql()
- == """SELECT 1 AS "1" FROM (SELECT COUNT(DISTINCT "x") AS "t_cardinality" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE "x" > 1) AS "r" WHERE "r"."t_cardinality" <= 1"""
+ == """SELECT 1 AS "1" FROM (SELECT COUNT(DISTINCT "x") AS "t_cardinality" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE "x" > 1) AS "r" WHERE "r"."t_cardinality" <= 1"""
)
@@ -836,7 +836,7 @@ def test_condition_with_macro_var(model: Model):
)
assert (
rendered_query.sql(dialect="duckdb")
- == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE "x" IS NULL AND "dt" BETWEEN CAST('1970-01-01 00:00:00+00:00' AS TIMESTAMPTZ) AND CAST('1970-01-01 23:59:59.999999+00:00' AS TIMESTAMPTZ)"""
+ == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE "x" IS NULL AND "dt" BETWEEN CAST('1970-01-01 00:00:00+00:00' AS TIMESTAMPTZ) AND CAST('1970-01-01 23:59:59.999999+00:00' AS TIMESTAMPTZ)"""
)
@@ -907,7 +907,7 @@ def test_load_inline_audits(assert_exp_eq):
def test_model_inline_audits(sushi_context: Context):
model_name = "sushi.waiter_names"
- expected_query = 'SELECT * FROM (SELECT * FROM "memory"."sushi"."waiter_names" AS "waiter_names") AS "_q_0" WHERE "id" < 0'
+ expected_query = 'SELECT * FROM (SELECT * FROM "memory"."sushi"."waiter_names" AS "waiter_names") AS "_0" WHERE "id" < 0'
model = sushi_context.get_snapshot(model_name, raise_if_missing=True).node
assert isinstance(model, SeedModel)
diff --git a/tests/core/test_config.py b/tests/core/test_config.py
index d0fad16e76..8c81a90b8d 100644
--- a/tests/core/test_config.py
+++ b/tests/core/test_config.py
@@ -570,7 +570,8 @@ def test_variables():
assert config.get_gateway("local").variables == {"uppercase_var": 2}
with pytest.raises(
- ConfigError, match="Unsupported variable value type: "
+ ConfigError,
+ match=r"Unsupported variable value type: ",
):
Config(variables={"invalid_var": exp.column("sqlglot_expr")})
@@ -862,6 +863,39 @@ def test_trino_schema_location_mapping_syntax(tmp_path):
assert len(conn.schema_location_mapping) == 2
+def test_trino_source_option(tmp_path):
+ config_path = tmp_path / "config_trino_source.yaml"
+ with open(config_path, "w", encoding="utf-8") as fd:
+ fd.write(
+ """
+ gateways:
+ trino:
+ connection:
+ type: trino
+ user: trino
+ host: trino
+ catalog: trino
+ source: my_sqlmesh_source
+
+ default_gateway: trino
+
+ model_defaults:
+ dialect: trino
+ """
+ )
+
+ config = load_config_from_paths(
+ Config,
+ project_paths=[config_path],
+ )
+
+ from sqlmesh.core.config.connection import TrinoConnectionConfig
+
+ conn = config.gateways["trino"].connection
+ assert isinstance(conn, TrinoConnectionConfig)
+ assert conn.source == "my_sqlmesh_source"
+
+
def test_gcp_postgres_ip_and_scopes(tmp_path):
config_path = tmp_path / "config_gcp_postgres.yaml"
with open(config_path, "w", encoding="utf-8") as fd:
@@ -1017,7 +1051,7 @@ def test_environment_statements_config(tmp_path):
]
-# https://github.com/TobikoData/sqlmesh/pull/4049
+# https://github.com/SQLMesh/sqlmesh/pull/4049
def test_pydantic_import_error() -> None:
class TestConfig(DuckDBConnectionConfig):
pass
diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py
index a0d54e03dd..2ff95525f7 100644
--- a/tests/core/test_connection_config.py
+++ b/tests/core/test_connection_config.py
@@ -4,6 +4,7 @@
import pytest
from _pytest.fixtures import FixtureRequest
+from sqlglot import exp
from unittest.mock import patch, MagicMock
from sqlmesh.core.config.connection import (
@@ -444,6 +445,63 @@ def test_trino_catalog_type_override(make_config):
assert config.catalog_type_overrides == {"my_catalog": "iceberg"}
+def test_trino_timestamp_mapping(make_config):
+ required_kwargs = dict(
+ type="trino",
+ user="user",
+ host="host",
+ catalog="catalog",
+ )
+
+ # Test config without timestamp_mapping
+ config = make_config(**required_kwargs)
+ assert config.timestamp_mapping is None
+
+ # Test config with timestamp_mapping
+ config = make_config(
+ **required_kwargs,
+ timestamp_mapping={
+ "TIMESTAMP": "TIMESTAMP(6)",
+ "TIMESTAMP(3)": "TIMESTAMP WITH TIME ZONE",
+ },
+ )
+
+ assert config.timestamp_mapping is not None
+ assert config.timestamp_mapping[exp.DataType.build("TIMESTAMP")] == exp.DataType.build(
+ "TIMESTAMP(6)"
+ )
+
+ # Test with invalid source type
+ with pytest.raises(ConfigError) as exc_info:
+ make_config(
+ **required_kwargs,
+ timestamp_mapping={
+ "INVALID_TYPE": "TIMESTAMP",
+ },
+ )
+ assert "Invalid SQL type string" in str(exc_info.value)
+ assert "INVALID_TYPE" in str(exc_info.value)
+
+ # Test with invalid target type (not a valid SQL type)
+ with pytest.raises(ConfigError) as exc_info:
+ make_config(
+ **required_kwargs,
+ timestamp_mapping={
+ "TIMESTAMP": "INVALID_TARGET_TYPE",
+ },
+ )
+ assert "Invalid SQL type string" in str(exc_info.value)
+ assert "INVALID_TARGET_TYPE" in str(exc_info.value)
+
+ # Test with empty mapping
+ config = make_config(
+ **required_kwargs,
+ timestamp_mapping={},
+ )
+ assert config.timestamp_mapping is not None
+ assert config.timestamp_mapping == {}
+
+
def test_duckdb(make_config):
config = make_config(
type="duckdb",
@@ -1073,6 +1131,27 @@ def test_bigquery(make_config):
assert config.get_catalog() == "project"
assert config.is_recommended_for_state_sync is False
+ # Test reservation
+ config_with_reservation = make_config(
+ type="bigquery",
+ project="project",
+ reservation="projects/my-project/locations/us-central1/reservations/my-reservation",
+ check_import=False,
+ )
+ assert isinstance(config_with_reservation, BigQueryConnectionConfig)
+ assert (
+ config_with_reservation.reservation
+ == "projects/my-project/locations/us-central1/reservations/my-reservation"
+ )
+
+ # Test that reservation is included in _extra_engine_config
+ extra_config = config_with_reservation._extra_engine_config
+ assert "reservation" in extra_config
+ assert (
+ extra_config["reservation"]
+ == "projects/my-project/locations/us-central1/reservations/my-reservation"
+ )
+
with pytest.raises(ConfigError, match="you must also specify the `project` field"):
make_config(type="bigquery", execution_project="execution_project", check_import=False)
diff --git a/tests/core/test_context.py b/tests/core/test_context.py
index 54b8cd891a..c3d88e205e 100644
--- a/tests/core/test_context.py
+++ b/tests/core/test_context.py
@@ -1157,6 +1157,72 @@ def test_plan_start_ahead_of_end(copy_to_temp_path):
context.close()
+@pytest.mark.slow
+def test_plan_seed_model_excluded_from_default_end(copy_to_temp_path: t.Callable):
+ path = copy_to_temp_path("examples/sushi")
+ with time_machine.travel("2024-06-01 00:00:00 UTC"):
+ context = Context(paths=path, gateway="duckdb_persistent")
+ context.plan("prod", no_prompts=True, auto_apply=True)
+ max_ends = context.state_sync.max_interval_end_per_model("prod")
+ seed_fqns = [k for k in max_ends if "waiter_names" in k]
+ assert len(seed_fqns) == 1
+ assert max_ends[seed_fqns[0]] == to_timestamp("2024-06-01")
+ context.close()
+
+ with time_machine.travel("2026-03-01 00:00:00 UTC"):
+ context = Context(paths=path, gateway="duckdb_persistent")
+
+ # a model that depends on this seed but has no interval in prod yet so only the seed would contribute to max_interval_end_per_model
+ context.upsert_model(
+ load_sql_based_model(
+ parse(
+ """
+ MODEL(
+ name sushi.waiter_summary,
+ kind INCREMENTAL_BY_TIME_RANGE (
+ time_column ds
+ ),
+ start '2025-01-01',
+ cron '@daily'
+ );
+
+ SELECT
+ id,
+ name,
+ @start_ds AS ds
+ FROM
+ sushi.waiter_names
+ WHERE
+ @start_ds BETWEEN @start_ds AND @end_ds
+ """
+ ),
+ default_catalog=context.default_catalog,
+ )
+ )
+
+ # the seed's interval end would still be 2024-06-01
+ max_ends = context.state_sync.max_interval_end_per_model("prod")
+ seed_fqns = [k for k in max_ends if "waiter_names" in k]
+ assert len(seed_fqns) == 1
+ assert max_ends[seed_fqns[0]] == to_timestamp("2024-06-01")
+
+ # the plan start date 2025-01-01 is after the seeds end date but shouldnt cause the plan to fail
+ plan = context.plan(
+ "dev", start="2025-01-01", no_prompts=True, select_models=["*waiter_summary"]
+ )
+
+ # the end should fall back to execution_time rather than seeds end
+ assert plan.models_to_backfill == {
+ '"duckdb"."sushi"."waiter_names"',
+ '"duckdb"."sushi"."waiter_summary"',
+ }
+ assert plan.provided_end is None
+ assert plan.provided_start == "2025-01-01"
+ assert to_timestamp(plan.end) == to_timestamp("2026-03-01")
+ assert to_timestamp(plan.start) == to_timestamp("2025-01-01")
+ context.close()
+
+
@pytest.mark.slow
def test_schema_error_no_default(sushi_context_pre_scheduling) -> None:
context = sushi_context_pre_scheduling
@@ -1506,6 +1572,8 @@ def test_requirements(copy_to_temp_path: t.Callable):
"dev", no_prompts=True, skip_tests=True, skip_backfill=True, auto_apply=True
).environment
requirements = {"ipywidgets", "numpy", "pandas", "test_package"}
+ if IS_WINDOWS:
+ requirements.add("pendulum")
assert environment.requirements["pandas"] == "2.2.2"
assert set(environment.requirements) == requirements
@@ -1513,7 +1581,10 @@ def test_requirements(copy_to_temp_path: t.Callable):
context._excluded_requirements = {"ipywidgets", "ruamel.yaml", "ruamel.yaml.clib"}
diff = context.plan_builder("dev", skip_tests=True, skip_backfill=True).build().context_diff
assert set(diff.previous_requirements) == requirements
- assert set(diff.requirements) == {"numpy", "pandas"}
+ reqs = {"numpy", "pandas"}
+ if IS_WINDOWS:
+ reqs.add("pendulum")
+ assert set(diff.requirements) == reqs
def test_deactivate_automatic_requirement_inference(copy_to_temp_path: t.Callable):
@@ -1985,7 +2056,7 @@ def access_adapter(evaluator):
assert (
model.pre_statements[0].sql()
- == "@IF(@runtime_stage IN ('evaluating', 'creating'), SET VARIABLE stats_model_start = NOW())"
+ == "@IF(@runtime_stage IN ('evaluating', 'creating'), SET stats_model_start = NOW())"
)
assert (
model.post_statements[0].sql()
@@ -2337,13 +2408,13 @@ def test_plan_audit_intervals(tmp_path: pathlib.Path, caplog):
# Case 1: The timestamp audit should be in the inclusive range ['2025-02-01 00:00:00', '2025-02-01 23:59:59.999999']
assert (
- f"""SELECT COUNT(*) FROM (SELECT "timestamp_id" AS "timestamp_id" FROM (SELECT * FROM "sqlmesh__sqlmesh_audit"."sqlmesh_audit__timestamp_example__{timestamp_snapshot.version}" AS "sqlmesh_audit__timestamp_example__{timestamp_snapshot.version}" WHERE "timestamp_id" BETWEEN CAST('2025-02-01 00:00:00' AS TIMESTAMP) AND CAST('2025-02-01 23:59:59.999999' AS TIMESTAMP)) AS "_q_0" WHERE TRUE GROUP BY "timestamp_id" HAVING COUNT(*) > 1) AS "audit\""""
+ f"""SELECT COUNT(*) FROM (SELECT "timestamp_id" AS "timestamp_id" FROM (SELECT * FROM "sqlmesh__sqlmesh_audit"."sqlmesh_audit__timestamp_example__{timestamp_snapshot.version}" AS "sqlmesh_audit__timestamp_example__{timestamp_snapshot.version}" WHERE "timestamp_id" BETWEEN CAST('2025-02-01 00:00:00' AS TIMESTAMP) AND CAST('2025-02-01 23:59:59.999999' AS TIMESTAMP)) AS "_0" WHERE TRUE GROUP BY "timestamp_id" HAVING COUNT(*) > 1) AS "audit\""""
in caplog.text
)
# Case 2: The date audit should be in the inclusive range ['2025-02-01', '2025-02-01']
assert (
- f"""SELECT COUNT(*) FROM (SELECT "date_id" AS "date_id" FROM (SELECT * FROM "sqlmesh__sqlmesh_audit"."sqlmesh_audit__date_example__{date_snapshot.version}" AS "sqlmesh_audit__date_example__{date_snapshot.version}" WHERE "date_id" BETWEEN CAST('2025-02-01' AS DATE) AND CAST('2025-02-01' AS DATE)) AS "_q_0" WHERE TRUE GROUP BY "date_id" HAVING COUNT(*) > 1) AS "audit\""""
+ f"""SELECT COUNT(*) FROM (SELECT "date_id" AS "date_id" FROM (SELECT * FROM "sqlmesh__sqlmesh_audit"."sqlmesh_audit__date_example__{date_snapshot.version}" AS "sqlmesh_audit__date_example__{date_snapshot.version}" WHERE "date_id" BETWEEN CAST('2025-02-01' AS DATE) AND CAST('2025-02-01' AS DATE)) AS "_0" WHERE TRUE GROUP BY "date_id" HAVING COUNT(*) > 1) AS "audit\""""
in caplog.text
)
diff --git a/tests/core/test_macros.py b/tests/core/test_macros.py
index fb10f64b27..e37a7ec05b 100644
--- a/tests/core/test_macros.py
+++ b/tests/core/test_macros.py
@@ -98,7 +98,7 @@ def test_select_macro(evaluator):
@macro()
def test_literal_type(evaluator, a: t.Literal["test_literal_a", "test_literal_b", 1, True]):
- if isinstance(a, exp.Expression):
+ if isinstance(a, exp.Expr):
raise SQLMeshError("Coercion failed")
return f"'{a}'"
@@ -694,8 +694,8 @@ def test_macro_coercion(macro_evaluator: MacroEvaluator, assert_exp_eq):
) == (1, "2", (3.0,))
# Using exp.Expression will always return the input expression
- assert coerce(parse_one("order", into=exp.Column), exp.Expression) == exp.column("order")
- assert coerce(exp.Literal.string("OK"), exp.Expression) == exp.Literal.string("OK")
+ assert coerce(parse_one("order", into=exp.Column), exp.Expr) == exp.column("order")
+ assert coerce(exp.Literal.string("OK"), exp.Expr) == exp.Literal.string("OK")
# Strict flag allows raising errors and is used when recursively coercing expressions
# otherwise, in general, we want to be lenient and just warn the user when something is not possible
@@ -930,12 +930,10 @@ def test_date_spine(assert_exp_eq, dialect, date_part):
FLATTEN(
INPUT => ARRAY_GENERATE_RANGE(
0,
- (
- DATEDIFF(
- {date_part.upper()},
- CAST('2022-01-01' AS DATE),
- CAST('2024-12-31' AS DATE)
- ) + 1 - 1
+ DATEDIFF(
+ {date_part.upper()},
+ CAST('2022-01-01' AS DATE),
+ CAST('2024-12-31' AS DATE)
) + 1
)
)
diff --git a/tests/core/test_model.py b/tests/core/test_model.py
index 3e0f6d40b9..9bdc976b56 100644
--- a/tests/core/test_model.py
+++ b/tests/core/test_model.py
@@ -464,10 +464,10 @@ def test_model_qualification(tmp_path: Path):
ctx.upsert_model(load_sql_based_model(expressions))
ctx.plan_builder("dev")
- assert (
- """Column '"a"' could not be resolved for model '"db"."table"', the column may not exist or is ambiguous."""
- in mock_logger.call_args[0][0]
- )
+ warning_msg = mock_logger.call_args[0][0]
+ assert "ambiguousorinvalidcolumn:" in warning_msg
+ assert "could not be resolved" in warning_msg
+ assert "db.table" in warning_msg
@use_terminal_console
@@ -2727,6 +2727,156 @@ def test_parse(assert_exp_eq):
)
+def test_dialect_pattern():
+ def make_test_sql(text: str) -> str:
+ return f"""
+ MODEL (
+ name test_model,
+ kind INCREMENTAL_BY_TIME_RANGE(
+ time_column ds
+ ),
+ {text}
+ );
+
+ SELECT 1;
+ """
+
+ def assert_match(test_sql: str, expected_value: t.Optional[str] = "duckdb"):
+ match = d.DIALECT_PATTERN.search(test_sql)
+
+ dialect_str: t.Optional[str] = None
+ if expected_value is not None:
+ assert match
+ dialect_str = match.group("dialect")
+
+ assert dialect_str == expected_value
+
+ # single-quoted dialect
+ assert_match(
+ make_test_sql(
+ """
+ dialect 'duckdb',
+ description 'there's a dialect foo in here too!'
+ """
+ )
+ )
+
+ # bare dialect
+ assert_match(
+ make_test_sql(
+ """
+ dialect duckdb,
+ description 'there's a dialect foo in here too!'
+ """
+ )
+ )
+
+ # double-quoted dialect (allowed in BQ)
+ assert_match(
+ make_test_sql(
+ """
+ dialect "duckdb",
+ description 'there's a dialect foo in here too!'
+ """
+ )
+ )
+
+ # no dialect specified, "dialect" in description
+ test_sql = make_test_sql(
+ """
+ description 'there's a dialect foo in here too!'
+ """
+ )
+
+ matches = list(d.DIALECT_PATTERN.finditer(test_sql))
+ assert not matches
+
+ # line comment between properties
+ assert_match(
+ make_test_sql(
+ """
+ tag my_tag, -- comment
+ dialect duckdb
+ """
+ )
+ )
+
+ # block comment between properties
+ assert_match(
+ make_test_sql(
+ """
+ tag my_tag, /* comment */
+ dialect duckdb
+ """
+ )
+ )
+
+ # quoted empty dialect
+ assert_match(
+ make_test_sql(
+ """
+ dialect '',
+ tag my_tag
+ """
+ ),
+ None,
+ )
+
+ # double-quoted empty dialect
+ assert_match(
+ make_test_sql(
+ """
+ dialect "",
+ tag my_tag
+ """
+ ),
+ None,
+ )
+
+ # trailing comment after dialect value
+ assert_match(
+ make_test_sql(
+ """
+ dialect duckdb -- trailing comment
+ """
+ )
+ )
+
+ # dialect value isn't terminated by ',' or ')'
+ test_sql = make_test_sql(
+ """
+ dialect duckdb -- trailing comment
+ tag my_tag
+ """
+ )
+
+ matches = list(d.DIALECT_PATTERN.finditer(test_sql))
+ assert not matches
+
+ # dialect first
+ assert_match(
+ """
+ MODEL(
+ dialect duckdb,
+ name my_name
+ );
+ """
+ )
+
+ # full parse
+ sql = """
+ MODEL (
+ name test_model,
+ description 'this text mentions dialect foo but is not a property'
+ );
+
+ SELECT 1;
+ """
+ expressions = d.parse(sql, default_dialect="duckdb")
+ model = load_sql_based_model(expressions)
+ assert model.dialect == ""
+
+
CONST = "bar"
@@ -3500,7 +3650,7 @@ def test_model_ctas_query():
assert (
load_sql_based_model(expressions, dialect="bigquery").ctas_query().sql()
- == 'WITH RECURSIVE "a" AS (SELECT * FROM (SELECT * FROM (SELECT * FROM "x" AS "x" WHERE FALSE) AS "_q_0" WHERE FALSE) AS "_q_1" WHERE FALSE), "b" AS (SELECT * FROM "a" AS "a" WHERE FALSE UNION ALL SELECT * FROM "a" AS "a" WHERE FALSE) SELECT * FROM "b" AS "b" WHERE FALSE LIMIT 0'
+ == 'WITH RECURSIVE "a" AS (SELECT * FROM (SELECT * FROM (SELECT * FROM "x" AS "x" WHERE FALSE) AS "_0" WHERE FALSE) AS "_1" WHERE FALSE), "b" AS (SELECT * FROM "a" AS "a" WHERE FALSE UNION ALL SELECT * FROM "a" AS "a" WHERE FALSE) SELECT * FROM "b" AS "b" WHERE FALSE LIMIT 0'
)
expressions = d.parse(
@@ -3521,7 +3671,7 @@ def test_model_ctas_query():
assert (
load_sql_based_model(expressions, dialect="bigquery").ctas_query().sql()
- == 'WITH RECURSIVE "a" AS (WITH "nested_a" AS (SELECT * FROM (SELECT * FROM (SELECT * FROM "x" AS "x" WHERE FALSE) AS "_q_0" WHERE FALSE) AS "_q_1" WHERE FALSE) SELECT * FROM "nested_a" AS "nested_a" WHERE FALSE), "b" AS (SELECT * FROM "a" AS "a" WHERE FALSE UNION ALL SELECT * FROM "a" AS "a" WHERE FALSE) SELECT * FROM "b" AS "b" WHERE FALSE LIMIT 0'
+ == 'WITH RECURSIVE "a" AS (WITH "nested_a" AS (SELECT * FROM (SELECT * FROM (SELECT * FROM "x" AS "x" WHERE FALSE) AS "_0" WHERE FALSE) AS "_1" WHERE FALSE) SELECT * FROM "nested_a" AS "nested_a" WHERE FALSE), "b" AS (SELECT * FROM "a" AS "a" WHERE FALSE UNION ALL SELECT * FROM "a" AS "a" WHERE FALSE) SELECT * FROM "b" AS "b" WHERE FALSE LIMIT 0'
)
@@ -4845,7 +4995,7 @@ def test_model_session_properties(sushi_context):
)
)
assert model.session_properties == {
- "query_label": parse_one("[('key1', 'value1'), ('key2', 'value2')]")
+ "query_label": parse_one("[('key1', 'value1'), ('key2', 'value2')]", dialect="bigquery")
}
model = load_sql_based_model(
@@ -5861,7 +6011,7 @@ def test_when_matched_normalization() -> None:
assert isinstance(model.kind, IncrementalByUniqueKeyKind)
assert isinstance(model.kind.when_matched, exp.Whens)
first_expression = model.kind.when_matched.expressions[0]
- assert isinstance(first_expression, exp.Expression)
+ assert isinstance(first_expression, exp.Expr)
assert (
first_expression.sql(dialect="snowflake")
== 'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."KEY_A" = "__MERGE_SOURCE__"."KEY_A", "__MERGE_TARGET__"."KEY_B" = "__MERGE_SOURCE__"."KEY_B"'
@@ -5889,7 +6039,7 @@ def test_when_matched_normalization() -> None:
assert isinstance(model.kind, IncrementalByUniqueKeyKind)
assert isinstance(model.kind.when_matched, exp.Whens)
first_expression = model.kind.when_matched.expressions[0]
- assert isinstance(first_expression, exp.Expression)
+ assert isinstance(first_expression, exp.Expr)
assert (
first_expression.sql(dialect="snowflake")
== 'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."kEy_A" = "__MERGE_SOURCE__"."kEy_A", "__MERGE_TARGET__"."kEY_b" = "__MERGE_SOURCE__"."KEY_B"'
@@ -6297,7 +6447,7 @@ def test_end_no_start():
def test_variables():
@macro()
- def test_macro_var(evaluator) -> exp.Expression:
+ def test_macro_var(evaluator) -> exp.Expr:
return exp.convert(evaluator.var("TEST_VAR_D") + 10)
expressions = parse(
@@ -6796,7 +6946,7 @@ def test_unrendered_macros_sql_model(mocker: MockerFixture) -> None:
# merge_filter will stay unrendered as well
assert model.unique_key[0] == exp.column("a", quoted=True)
assert (
- t.cast(exp.Expression, model.merge_filter).sql()
+ t.cast(exp.Expr, model.merge_filter).sql()
== '"__MERGE_SOURCE__"."id" > 0 AND "__MERGE_TARGET__"."updated_at" < @end_ds AND "__MERGE_SOURCE__"."updated_at" > @start_ds AND @merge_filter_var'
)
@@ -6999,7 +7149,7 @@ def test_gateway_macro() -> None:
assert model.render_query_or_raise().sql() == "SELECT 'in_memory' AS \"gateway\""
@macro()
- def macro_uses_gateway(evaluator) -> exp.Expression:
+ def macro_uses_gateway(evaluator) -> exp.Expr:
return exp.convert(evaluator.gateway + "_from_macro")
model = load_sql_based_model(
@@ -8579,7 +8729,7 @@ def test_merge_filter_macro():
def predicate(
evaluator: MacroEvaluator,
cluster_column: exp.Column,
- ) -> exp.Expression:
+ ) -> exp.Expr:
return parse_one(f"source.{cluster_column} > dateadd(day, -7, target.{cluster_column})")
expressions = d.parse(
@@ -9754,7 +9904,7 @@ def entrypoint(evaluator):
{"customer": SqlValue(sql="customer1"), "customer_field": SqlValue(sql="'bar'")}
)
- assert t.cast(exp.Expression, customer1_model.render_query()).sql() == (
+ assert t.cast(exp.Expr, customer1_model.render_query()).sql() == (
"""SELECT 'bar' AS "foo", "bar" AS "foo2", 'bar' AS "foo3" FROM "db"."customer1"."my_source" AS "my_source\""""
)
@@ -9767,7 +9917,7 @@ def entrypoint(evaluator):
{"customer": SqlValue(sql="customer2"), "customer_field": SqlValue(sql="qux")}
)
- assert t.cast(exp.Expression, customer2_model.render_query()).sql() == (
+ assert t.cast(exp.Expr, customer2_model.render_query()).sql() == (
'''SELECT "qux" AS "foo", "qux" AS "foo2", "qux" AS "foo3" FROM "db"."customer2"."my_source" AS "my_source"'''
)
@@ -10553,12 +10703,12 @@ def m4_non_metadata_references_v6(evaluator):
query_with_vars = macro_evaluator.transform(
parse_one("SELECT " + ", ".join(f"@v{var}, @VAR('v{var}')" for var in [1, 2, 3, 6]))
)
- assert t.cast(exp.Expression, query_with_vars).sql() == "SELECT 1, 1, 2, 2, 3, 3, 6, 6"
+ assert t.cast(exp.Expr, query_with_vars).sql() == "SELECT 1, 1, 2, 2, 3, 3, 6, 6"
query_with_blueprint_vars = macro_evaluator.transform(
parse_one("SELECT " + ", ".join(f"@v{var}, @BLUEPRINT_VAR('v{var}')" for var in [4, 5]))
)
- assert t.cast(exp.Expression, query_with_blueprint_vars).sql() == "SELECT 4, 4, 5, 5"
+ assert t.cast(exp.Expr, query_with_blueprint_vars).sql() == "SELECT 4, 4, 5, 5"
def test_variable_mentioned_in_both_metadata_and_non_metadata_macro(tmp_path: Path) -> None:
@@ -12192,3 +12342,170 @@ def test_audits_in_embedded_model():
)
with pytest.raises(ConfigError, match="Audits are not supported for embedded models"):
load_sql_based_model(expression).validate_definition()
+
+
+def test_default_catalog_not_leaked_to_unsupported_gateway():
+ """
+ Regression test for https://github.com/SQLMesh/sqlmesh/issues/5748
+
+ When a model targets a gateway that is NOT in default_catalog_per_gateway,
+ the global default_catalog should be cleared (set to None) instead of
+ leaking through from the default gateway.
+ """
+ from sqlglot import parse
+
+ expressions = parse(
+ """
+ MODEL (
+ name my_schema.my_model,
+ kind FULL,
+ gateway clickhouse_gw,
+ dialect clickhouse,
+ );
+
+ SELECT 1 AS id
+ """,
+ read="clickhouse",
+ )
+
+ default_catalog_per_gateway = {
+ "default_gw": "example_catalog",
+ }
+
+ models = load_sql_based_models(
+ expressions,
+ get_variables=lambda gw: {},
+ dialect="clickhouse",
+ default_catalog_per_gateway=default_catalog_per_gateway,
+ default_catalog="example_catalog",
+ )
+
+ assert len(models) == 1
+ model = models[0]
+
+ assert not model.catalog, (
+ f"Default gateway catalog leaked into catalog-unsupported gateway model. "
+ f"Expected no catalog, got: {model.catalog}"
+ )
+ assert "example_catalog" not in model.fqn, (
+ f"Default gateway catalog found in model FQN: {model.fqn}"
+ )
+
+
+def test_default_catalog_still_applied_to_supported_gateway():
+ """
+ Control test: when a model targets a gateway that IS in default_catalog_per_gateway,
+ the catalog should still be correctly applied.
+ """
+ from sqlglot import parse
+
+ expressions = parse(
+ """
+ MODEL (
+ name my_schema.my_model,
+ kind FULL,
+ gateway other_duckdb,
+ );
+
+ SELECT 1 AS id
+ """,
+ read="duckdb",
+ )
+
+ default_catalog_per_gateway = {
+ "default_gw": "example_catalog",
+ "other_duckdb": "other_db",
+ }
+
+ models = load_sql_based_models(
+ expressions,
+ get_variables=lambda gw: {},
+ dialect="duckdb",
+ default_catalog_per_gateway=default_catalog_per_gateway,
+ default_catalog="example_catalog",
+ )
+
+ assert len(models) == 1
+ model = models[0]
+
+ assert model.catalog == "other_db", f"Expected catalog 'other_db', got: {model.catalog}"
+
+
+def test_no_gateway_uses_global_default_catalog():
+ """
+ Control test: when a model does NOT specify a gateway, the global
+ default_catalog should still be applied as before.
+ """
+ from sqlglot import parse
+
+ expressions = parse(
+ """
+ MODEL (
+ name my_schema.my_model,
+ kind FULL,
+ );
+
+ SELECT 1 AS id
+ """,
+ read="duckdb",
+ )
+
+ model = load_sql_based_model(
+ expressions,
+ default_catalog="example_catalog",
+ dialect="duckdb",
+ )
+
+ assert model.catalog == "example_catalog"
+
+
+def test_blueprint_catalog_not_cross_contaminated():
+ """
+ When blueprints iterate over different gateways, the catalog from one
+ blueprint iteration should not leak into the next. A ClickHouse blueprint
+ setting default_catalog to None should not prevent the following blueprint
+ from getting its correct catalog.
+ """
+ from sqlglot import parse
+
+ expressions = parse(
+ """
+ MODEL (
+ name @{blueprint}.my_model,
+ kind FULL,
+ gateway @{gw},
+ blueprints (
+ (blueprint := ch_schema, gw := clickhouse_gw),
+ (blueprint := db_schema, gw := default_gw),
+ ),
+ );
+
+ SELECT 1 AS id
+ """,
+ read="duckdb",
+ )
+
+ default_catalog_per_gateway = {
+ "default_gw": "example_catalog",
+ }
+
+ models = load_sql_based_models(
+ expressions,
+ get_variables=lambda gw: {},
+ dialect="duckdb",
+ default_catalog_per_gateway=default_catalog_per_gateway,
+ default_catalog="example_catalog",
+ )
+
+ assert len(models) == 2
+
+ ch_model = next(m for m in models if "ch_schema" in m.fqn)
+ db_model = next(m for m in models if "db_schema" in m.fqn)
+
+ assert not ch_model.catalog, (
+ f"Catalog leaked into ClickHouse blueprint. Got: {ch_model.catalog}"
+ )
+
+ assert db_model.catalog == "example_catalog", (
+ f"Catalog lost for DuckDB blueprint after ClickHouse iteration. Got: {db_model.catalog}"
+ )
diff --git a/tests/core/test_plan.py b/tests/core/test_plan.py
index 4b330c376f..590cda01ec 100644
--- a/tests/core/test_plan.py
+++ b/tests/core/test_plan.py
@@ -1795,7 +1795,7 @@ def test_forward_only_models_model_kind_changed(make_snapshot, mocker: MockerFix
)
def test_forward_only_models_model_kind_changed_to_incremental_by_time_range(
make_snapshot,
- partitioned_by: t.List[exp.Expression],
+ partitioned_by: t.List[exp.Expr],
expected_forward_only: bool,
):
snapshot = make_snapshot(
diff --git a/tests/core/test_selector_native.py b/tests/core/test_selector_native.py
index 46d666db64..5889efadda 100644
--- a/tests/core/test_selector_native.py
+++ b/tests/core/test_selector_native.py
@@ -6,6 +6,7 @@
import pytest
from pytest_mock.plugin import MockerFixture
+import subprocess
from sqlmesh.core import dialect as d
from sqlmesh.core.audit import StandaloneAudit
@@ -16,6 +17,7 @@
from sqlmesh.core.snapshot import SnapshotChangeCategory
from sqlmesh.utils import UniqueKeyDict
from sqlmesh.utils.date import now_timestamp
+from sqlmesh.utils.git import GitClient
@pytest.mark.parametrize(
@@ -634,6 +636,92 @@ def test_expand_git_selection(
git_client_mock.list_untracked_files.assert_called_once()
+def test_expand_git_selection_integration(tmp_path: Path, mocker: MockerFixture):
+ repo_path = tmp_path / "test_repo"
+ repo_path.mkdir()
+ subprocess.run(["git", "init", "-b", "main"], cwd=repo_path, check=True, capture_output=True)
+
+ models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
+ model_a_path = repo_path / "model_a.sql"
+ model_a_path.write_text("SELECT 1 AS a")
+ model_a = SqlModel(name="test_model_a", query=d.parse_one("SELECT 1 AS a"))
+ model_a._path = model_a_path
+ models[model_a.fqn] = model_a
+
+ model_b_path = repo_path / "model_b.sql"
+ model_b_path.write_text("SELECT 2 AS b")
+ model_b = SqlModel(name="test_model_b", query=d.parse_one("SELECT 2 AS b"))
+ model_b._path = model_b_path
+ models[model_b.fqn] = model_b
+
+ subprocess.run(["git", "add", "."], cwd=repo_path, check=True, capture_output=True)
+ subprocess.run(
+ [
+ "git",
+ "-c",
+ "user.name=Max",
+ "-c",
+ "user.email=max@rb.com",
+ "commit",
+ "-m",
+ "Initial commit",
+ ],
+ cwd=repo_path,
+ check=True,
+ capture_output=True,
+ )
+
+ # no changes should select nothing
+ git_client = GitClient(repo_path)
+ selector = NativeSelector(mocker.Mock(), models)
+ selector._git_client = git_client
+ assert selector.expand_model_selections([f"git:main"]) == set()
+
+ # modify A but dont stage it, should be only selected
+ model_a_path.write_text("SELECT 10 AS a")
+ assert selector.expand_model_selections([f"git:main"]) == {'"test_model_a"'}
+
+ # stage model A, should still select it
+ subprocess.run(["git", "add", "model_a.sql"], cwd=repo_path, check=True, capture_output=True)
+ assert selector.expand_model_selections([f"git:main"]) == {'"test_model_a"'}
+
+ # now add unstaged change to B and both should be selected
+ model_b_path.write_text("SELECT 20 AS b")
+ assert selector.expand_model_selections([f"git:main"]) == {
+ '"test_model_a"',
+ '"test_model_b"',
+ }
+
+ subprocess.run(
+ ["git", "checkout", "-b", "dev"],
+ cwd=repo_path,
+ check=True,
+ capture_output=True,
+ )
+
+ subprocess.run(
+ [
+ "git",
+ "-c",
+ "user.name=Max",
+ "-c",
+ "user.email=max@rb.com",
+ "commit",
+ "-m",
+ "Update model_a",
+ ],
+ cwd=repo_path,
+ check=True,
+ capture_output=True,
+ )
+
+ # now A is committed in the dev branch and B unstaged but should both be selected
+ assert selector.expand_model_selections([f"git:main"]) == {
+ '"test_model_a"',
+ '"test_model_b"',
+ }
+
+
def test_select_models_with_external_parent(mocker: MockerFixture):
default_catalog = "test_catalog"
added_model = SqlModel(
diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py
index 9dd645ac15..f3fae15e8a 100644
--- a/tests/core/test_snapshot_evaluator.py
+++ b/tests/core/test_snapshot_evaluator.py
@@ -2131,7 +2131,7 @@ def test_temp_table_includes_schema_for_ignore_changes(
model = SqlModel(
name="test_schema.test_model",
kind=IncrementalByTimeRangeKind(
- time_column="a", on_destructive_change=OnDestructiveChange.IGNORE
+ time_column="ds", on_destructive_change=OnDestructiveChange.IGNORE
),
query=parse_one("SELECT c, a FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"),
)
@@ -2148,6 +2148,7 @@ def columns(table_name):
return {
"c": exp.DataType.build("int"),
"a": exp.DataType.build("int"),
+ "ds": exp.DataType.build("timestamp"),
}
adapter.columns = columns # type: ignore
@@ -3682,7 +3683,7 @@ def test_custom_materialization_strategy_with_custom_properties(adapter_mock, ma
custom_insert_kind = None
class TestCustomKind(CustomKind):
- _primary_key: t.List[exp.Expression] # type: ignore[no-untyped-def]
+ _primary_key: t.List[exp.Expr] # type: ignore[no-untyped-def]
@model_validator(mode="after")
def _validate_model(self) -> Self:
@@ -3694,7 +3695,7 @@ def _validate_model(self) -> Self:
return self
@property
- def primary_key(self) -> t.List[exp.Expression]:
+ def primary_key(self) -> t.List[exp.Expr]:
return self._primary_key
class TestCustomMaterializationStrategy(CustomMaterialization[TestCustomKind]):
@@ -4321,13 +4322,14 @@ def test_multiple_engine_promotion(mocker: MockerFixture, adapter_mock, make_sna
def columns(table_name):
return {
"a": exp.DataType.build("int"),
+ "ds": exp.DataType.build("timestamp"),
}
adapter.columns = columns # type: ignore
model = SqlModel(
name="test_schema.test_model",
- kind=IncrementalByTimeRangeKind(time_column="a"),
+ kind=IncrementalByTimeRangeKind(time_column="ds"),
gateway="secondary",
query=parse_one("SELECT a FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"),
)
@@ -4350,10 +4352,10 @@ def columns(table_name):
cursor_mock.execute.assert_has_calls(
[
call(
- f'DELETE FROM "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}" WHERE "a" BETWEEN 2020-01-01 00:00:00+00:00 AND 2020-01-02 23:59:59.999999+00:00'
+ f'DELETE FROM "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}" WHERE "ds" BETWEEN CAST(\'2020-01-01 00:00:00\' AS TIMESTAMP) AND CAST(\'2020-01-02 23:59:59.999999\' AS TIMESTAMP)'
),
call(
- f'INSERT INTO "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}" ("a") SELECT "a" FROM (SELECT "a" AS "a" FROM "tbl" AS "tbl" WHERE "ds" BETWEEN \'2020-01-01\' AND \'2020-01-02\') AS "_subquery" WHERE "a" BETWEEN 2020-01-01 00:00:00+00:00 AND 2020-01-02 23:59:59.999999+00:00'
+ f'INSERT INTO "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}" ("a", "ds") SELECT "a", "ds" FROM (SELECT "a" AS "a" FROM "tbl" AS "tbl" WHERE "ds" BETWEEN \'2020-01-01\' AND \'2020-01-02\') AS "_subquery" WHERE "ds" BETWEEN CAST(\'2020-01-01 00:00:00\' AS TIMESTAMP) AND CAST(\'2020-01-02 23:59:59.999999\' AS TIMESTAMP)'
),
]
)
diff --git a/tests/core/test_test.py b/tests/core/test_test.py
index 43d0f333c3..d679f09393 100644
--- a/tests/core/test_test.py
+++ b/tests/core/test_test.py
@@ -1718,10 +1718,12 @@ def test_generate_input_data_using_sql(mocker: MockerFixture, tmp_path: Path) ->
)
+@pytest.mark.pyspark
def test_pyspark_python_model(tmp_path: Path) -> None:
spark_connection_config = SparkConnectionConfig(
config={
"spark.master": "local",
+ "spark.driver.memory": "512m",
"spark.sql.warehouse.dir": f"{tmp_path}/data_dir",
"spark.driver.extraJavaOptions": f"-Dderby.system.home={tmp_path}/derby_dir",
},
diff --git a/tests/dbt/cli/test_run.py b/tests/dbt/cli/test_run.py
index 4fdb7a0cdb..c640950a27 100644
--- a/tests/dbt/cli/test_run.py
+++ b/tests/dbt/cli/test_run.py
@@ -1,6 +1,7 @@
import typing as t
import pytest
from pathlib import Path
+import shutil
from click.testing import Result
import time_machine
from sqlmesh_dbt.operations import create
@@ -71,6 +72,10 @@ def test_run_with_changes_and_full_refresh(
if partial_parse_file.exists():
partial_parse_file.unlink()
+ cache_dir = project_path / ".cache"
+ if cache_dir.exists():
+ shutil.rmtree(cache_dir)
+
# run with --full-refresh. this should:
# - fully refresh model_a (pick up the new records from external_table)
# - deploy the local change to model_b (introducing the 'changed' column)
diff --git a/tests/dbt/test_model.py b/tests/dbt/test_model.py
index 6d100e6aa5..a954f98f41 100644
--- a/tests/dbt/test_model.py
+++ b/tests/dbt/test_model.py
@@ -626,11 +626,11 @@ def test_load_microbatch_with_ref(
context = Context(paths=project_dir)
assert (
context.render(microbatch_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql()
- == 'SELECT "cola" AS "cola", "ds_source" AS "ds" FROM (SELECT * FROM "local"."my_source"."my_table" AS "my_table" WHERE "ds_source" >= \'2025-01-01 00:00:00+00:00\' AND "ds_source" < \'2025-01-11 00:00:00+00:00\') AS "_q_0"'
+ == 'SELECT "cola" AS "cola", "ds_source" AS "ds" FROM (SELECT * FROM "local"."my_source"."my_table" AS "my_table" WHERE "ds_source" >= \'2025-01-01 00:00:00+00:00\' AND "ds_source" < \'2025-01-11 00:00:00+00:00\') AS "_0"'
)
assert (
context.render(microbatch_two_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql()
- == 'SELECT "_q_0"."cola" AS "cola", "_q_0"."ds" AS "ds" FROM (SELECT "microbatch"."cola" AS "cola", "microbatch"."ds" AS "ds" FROM "local"."main"."microbatch" AS "microbatch" WHERE "microbatch"."ds" < \'2025-01-11 00:00:00+00:00\' AND "microbatch"."ds" >= \'2025-01-01 00:00:00+00:00\') AS "_q_0"'
+ == 'SELECT "_0"."cola" AS "cola", "_0"."ds" AS "ds" FROM (SELECT "microbatch"."cola" AS "cola", "microbatch"."ds" AS "ds" FROM "local"."main"."microbatch" AS "microbatch" WHERE "microbatch"."ds" < \'2025-01-11 00:00:00+00:00\' AND "microbatch"."ds" >= \'2025-01-01 00:00:00+00:00\') AS "_0"'
)
diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py
index 3b4df916d3..fe6073dfad 100644
--- a/tests/dbt/test_transformation.py
+++ b/tests/dbt/test_transformation.py
@@ -2213,7 +2213,7 @@ def test_clickhouse_properties(mocker: MockerFixture):
]
assert [e.sql("clickhouse") for e in model_to_sqlmesh.partitioned_by] == [
- 'toMonday("ds")',
+ "dateTrunc('WEEK', \"ds\")",
'"partition_col"',
]
assert model_to_sqlmesh.storage_format == "MergeTree()"
diff --git a/tests/engines/spark/conftest.py b/tests/engines/spark/conftest.py
index 933bc7870f..ce6a99ea35 100644
--- a/tests/engines/spark/conftest.py
+++ b/tests/engines/spark/conftest.py
@@ -9,6 +9,7 @@ def spark_session() -> t.Generator[SparkSession, None, None]:
session = (
SparkSession.builder.master("local")
.appName("SQLMesh Test")
+ .config("spark.driver.memory", "512m")
.enableHiveSupport()
.getOrCreate()
)
diff --git a/tests/fixtures/dbt/empty_project/profiles.yml b/tests/fixtures/dbt/empty_project/profiles.yml
index adae09e9c6..712456bffe 100644
--- a/tests/fixtures/dbt/empty_project/profiles.yml
+++ b/tests/fixtures/dbt/empty_project/profiles.yml
@@ -7,7 +7,7 @@ empty_project:
type: duckdb
# database is required for dbt < 1.5 where our adapter deliberately doesnt infer the database from the path and
# defaults it to "main", which raises a "project catalog doesnt match context catalog" error
- # ref: https://github.com/TobikoData/sqlmesh/pull/1109
+ # ref: https://github.com/SQLMesh/sqlmesh/pull/1109
database: empty_project
path: 'empty_project.duckdb'
threads: 4
diff --git a/tests/integrations/github/cicd/test_github_controller.py b/tests/integrations/github/cicd/test_github_controller.py
index 1e114171a3..e4fe10e321 100644
--- a/tests/integrations/github/cicd/test_github_controller.py
+++ b/tests/integrations/github/cicd/test_github_controller.py
@@ -339,7 +339,8 @@ def test_prod_plan_with_gaps(github_client, make_controller):
assert controller.prod_plan_with_gaps.environment.name == c.PROD
assert not controller.prod_plan_with_gaps.skip_backfill
- assert not controller._prod_plan_with_gaps_builder._auto_categorization_enabled
+ # auto_categorization should now be enabled to prevent uncategorized snapshot errors
+ assert controller._prod_plan_with_gaps_builder._auto_categorization_enabled
assert not controller.prod_plan_with_gaps.no_gaps
assert not controller._context.apply.called
assert controller._context._run_plan_tests.call_args == call(skip_tests=True)
@@ -475,6 +476,18 @@ def test_deploy_to_prod_blocked_pr(github_client, make_controller):
controller.deploy_to_prod()
+def test_deploy_to_prod_not_blocked_pr_if_config_set(github_client, make_controller):
+ mock_pull_request = github_client.get_repo().get_pull()
+ mock_pull_request.merged = False
+ controller = make_controller(
+ "tests/fixtures/github/pull_request_synchronized.json",
+ github_client,
+ merge_state_status=MergeStateStatus.BLOCKED,
+ bot_config=GithubCICDBotConfig(check_if_blocked_on_deploy_to_prod=False),
+ )
+ controller.deploy_to_prod()
+
+
def test_deploy_to_prod_dirty_pr(github_client, make_controller):
mock_pull_request = github_client.get_repo().get_pull()
mock_pull_request.merged = False
diff --git a/tests/lsp/test_reference_model_column_prefix.py b/tests/lsp/test_reference_model_column_prefix.py
index 3cd25a080e..082ee9c8e6 100644
--- a/tests/lsp/test_reference_model_column_prefix.py
+++ b/tests/lsp/test_reference_model_column_prefix.py
@@ -41,7 +41,7 @@ def test_model_reference_with_column_prefix():
model_refs = get_all_references(lsp_context, URI.from_path(sushi_customers_path), position)
- assert len(model_refs) >= 7
+ assert len(model_refs) >= 6
# Verify that we have the FROM clause reference
assert any(ref.range.start.line == from_clause_range.start.line for ref in model_refs), (
@@ -65,8 +65,8 @@ def test_column_prefix_references_are_found():
# Find all occurrences of sushi.orders in the file
ranges = find_ranges_from_regex(read_file, r"sushi\.orders")
- # Should find exactly 2: FROM clause and WHERE clause with column prefix
- assert len(ranges) == 2, f"Expected 2 occurrences of 'sushi.orders', found {len(ranges)}"
+ # Should find exactly 1 in FROM clause with column prefix
+ assert len(ranges) == 1, f"Expected 1 occurrence of 'sushi.orders', found {len(ranges)}"
# Verify we have the expected lines
line_contents = [read_file[r.start.line].strip() for r in ranges]
@@ -76,11 +76,6 @@ def test_column_prefix_references_are_found():
"Should find FROM clause with sushi.orders"
)
- # Should find customer_id in WHERE clause with column prefix
- assert any("WHERE sushi.orders.customer_id" in content for content in line_contents), (
- "Should find WHERE clause with sushi.orders.customer_id"
- )
-
def test_quoted_uppercase_table_and_column_references(tmp_path: Path):
# Initialize example project in temporary directory with case sensitive normalization
diff --git a/tests/lsp/test_reference_model_find_all.py b/tests/lsp/test_reference_model_find_all.py
index 7c0077d6cd..cd9c0a3a1c 100644
--- a/tests/lsp/test_reference_model_find_all.py
+++ b/tests/lsp/test_reference_model_find_all.py
@@ -30,8 +30,8 @@ def test_find_references_for_model_usages():
# Click on the model reference
position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 6)
references = get_model_find_all_references(lsp_context, URI.from_path(customers_path), position)
- assert len(references) >= 7, (
- f"Expected at least 7 references to sushi.orders (including column prefix), found {len(references)}"
+ assert len(references) >= 6, (
+ f"Expected at least 6 references to sushi.orders (including column prefix), found {len(references)}"
)
# Verify expected files are present
@@ -53,7 +53,7 @@ def test_find_references_for_model_usages():
# Note: customers file has multiple references due to column prefix support
expected_ranges = {
"orders": [(0, 0, 0, 0)], # the start for the model itself
- "customers": [(30, 7, 30, 19), (44, 6, 44, 18)], # FROM clause and WHERE clause
+ "customers": [(30, 7, 30, 19)], # FROM clause
"waiter_revenue_by_day": [(19, 5, 19, 17)],
"customer_revenue_lifetime": [(38, 7, 38, 19)],
"customer_revenue_by_day": [(33, 5, 33, 17)],
diff --git a/tests/pyproject.toml b/tests/pyproject.toml
index 6f9cd2f9d9..73f143bfde 100644
--- a/tests/pyproject.toml
+++ b/tests/pyproject.toml
@@ -8,8 +8,8 @@ license = { text = "Apache License 2.0" }
[project.urls]
Homepage = "https://sqlmesh.com/"
Documentation = "https://sqlmesh.readthedocs.io/en/stable/"
-Repository = "https://github.com/TobikoData/sqlmesh"
-Issues = "https://github.com/TobikoData/sqlmesh/issues"
+Repository = "https://github.com/SQLMesh/sqlmesh"
+Issues = "https://github.com/SQLMesh/sqlmesh/issues"
[build-system]
requires = ["setuptools", "setuptools_scm", "toml"]
diff --git a/tests/setup.py b/tests/setup.py
index d072cb555b..ab48a3128f 100644
--- a/tests/setup.py
+++ b/tests/setup.py
@@ -7,6 +7,8 @@
sqlmesh_pyproject = Path(__file__).parent / "sqlmesh_pyproject.toml"
parsed = toml.load(sqlmesh_pyproject)["project"]
install_requires = parsed["dependencies"] + parsed["optional-dependencies"]["dev"]
+# remove dbt dependencies
+install_requires = [req for req in install_requires if not req.startswith("dbt")]
# this is just so we can have a dynamic install_requires, everything else is defined in pyproject.toml
setuptools.setup(install_requires=install_requires)
diff --git a/tests/utils/test_cache.py b/tests/utils/test_cache.py
index 0b6d335446..ed19765b8a 100644
--- a/tests/utils/test_cache.py
+++ b/tests/utils/test_cache.py
@@ -106,7 +106,7 @@ def test_optimized_query_cache_macro_def_change(tmp_path: Path, mocker: MockerFi
assert cache.with_optimized_query(model)
assert (
model.render_query_or_raise().sql()
- == 'SELECT "_q_0"."a" AS "a" FROM (SELECT 1 AS "a") AS "_q_0" WHERE "_q_0"."a" = 1'
+ == 'SELECT "_0"."a" AS "a" FROM (SELECT 1 AS "a") AS "_0" WHERE "_0"."a" = 1'
)
# Change the filter_ definition
@@ -129,5 +129,5 @@ def test_optimized_query_cache_macro_def_change(tmp_path: Path, mocker: MockerFi
assert cache.with_optimized_query(new_model)
assert (
new_model.render_query_or_raise().sql()
- == 'SELECT "_q_0"."a" AS "a" FROM (SELECT 1 AS "a") AS "_q_0" WHERE "_q_0"."a" = 2'
+ == 'SELECT "_0"."a" AS "a" FROM (SELECT 1 AS "a") AS "_0" WHERE "_0"."a" = 2'
)
diff --git a/tests/utils/test_git_client.py b/tests/utils/test_git_client.py
new file mode 100644
index 0000000000..13eecf294b
--- /dev/null
+++ b/tests/utils/test_git_client.py
@@ -0,0 +1,173 @@
+import subprocess
+from pathlib import Path
+import pytest
+from sqlmesh.utils.git import GitClient
+
+
+@pytest.fixture
+def git_repo(tmp_path: Path) -> Path:
+ repo_path = tmp_path / "test_repo"
+ repo_path.mkdir()
+ subprocess.run(["git", "init", "-b", "main"], cwd=repo_path, check=True, capture_output=True)
+ return repo_path
+
+
+def test_git_uncommitted_changes(git_repo: Path):
+ git_client = GitClient(git_repo)
+
+ test_file = git_repo / "model.sql"
+ test_file.write_text("SELECT 1 AS a")
+ subprocess.run(["git", "add", "model.sql"], cwd=git_repo, check=True, capture_output=True)
+ subprocess.run(
+ [
+ "git",
+ "-c",
+ "user.name=Max",
+ "-c",
+ "user.email=max@rb.com",
+ "commit",
+ "-m",
+ "Initial commit",
+ ],
+ cwd=git_repo,
+ check=True,
+ capture_output=True,
+ )
+ assert git_client.list_uncommitted_changed_files() == []
+
+ # make an unstaged change and see that it is listed
+ test_file.write_text("SELECT 2 AS a")
+ uncommitted = git_client.list_uncommitted_changed_files()
+ assert len(uncommitted) == 1
+ assert uncommitted[0].name == "model.sql"
+
+ # stage the change and test that it is still detected
+ subprocess.run(["git", "add", "model.sql"], cwd=git_repo, check=True, capture_output=True)
+ uncommitted = git_client.list_uncommitted_changed_files()
+ assert len(uncommitted) == 1
+ assert uncommitted[0].name == "model.sql"
+
+
+def test_git_both_staged_and_unstaged_changes(git_repo: Path):
+ git_client = GitClient(git_repo)
+
+ file1 = git_repo / "model1.sql"
+ file2 = git_repo / "model2.sql"
+ file1.write_text("SELECT 1")
+ file2.write_text("SELECT 2")
+ subprocess.run(["git", "add", "."], cwd=git_repo, check=True, capture_output=True)
+ subprocess.run(
+ [
+ "git",
+ "-c",
+ "user.name=Max",
+ "-c",
+ "user.email=max@rb.com",
+ "commit",
+ "-m",
+ "Initial commit",
+ ],
+ cwd=git_repo,
+ check=True,
+ capture_output=True,
+ )
+
+ # stage file1
+ file1.write_text("SELECT 10")
+ subprocess.run(["git", "add", "model1.sql"], cwd=git_repo, check=True, capture_output=True)
+
+ # modify file2 but don't stage it!
+ file2.write_text("SELECT 20")
+
+ # both should be detected
+ uncommitted = git_client.list_uncommitted_changed_files()
+ assert len(uncommitted) == 2
+ file_names = {f.name for f in uncommitted}
+ assert file_names == {"model1.sql", "model2.sql"}
+
+
+def test_git_untracked_files(git_repo: Path):
+ git_client = GitClient(git_repo)
+ initial_file = git_repo / "initial.sql"
+ initial_file.write_text("SELECT 0")
+ subprocess.run(["git", "add", "initial.sql"], cwd=git_repo, check=True, capture_output=True)
+ subprocess.run(
+ [
+ "git",
+ "-c",
+ "user.name=Max",
+ "-c",
+ "user.email=max@rb.com",
+ "commit",
+ "-m",
+ "Initial commit",
+ ],
+ cwd=git_repo,
+ check=True,
+ capture_output=True,
+ )
+
+ new_file = git_repo / "new_model.sql"
+ new_file.write_text("SELECT 1")
+
+ # untracked file should not appear in uncommitted changes
+ assert git_client.list_uncommitted_changed_files() == []
+
+ # but in untracked
+ untracked = git_client.list_untracked_files()
+ assert len(untracked) == 1
+ assert untracked[0].name == "new_model.sql"
+
+
+def test_git_committed_changes(git_repo: Path):
+ git_client = GitClient(git_repo)
+
+ test_file = git_repo / "model.sql"
+ test_file.write_text("SELECT 1")
+ subprocess.run(["git", "add", "model.sql"], cwd=git_repo, check=True, capture_output=True)
+ subprocess.run(
+ [
+ "git",
+ "-c",
+ "user.name=Max",
+ "-c",
+ "user.email=max@rb.com",
+ "commit",
+ "-m",
+ "Initial commit",
+ ],
+ cwd=git_repo,
+ check=True,
+ capture_output=True,
+ )
+
+ subprocess.run(
+ ["git", "checkout", "-b", "feature"],
+ cwd=git_repo,
+ check=True,
+ capture_output=True,
+ )
+
+ test_file.write_text("SELECT 2")
+ subprocess.run(["git", "add", "model.sql"], cwd=git_repo, check=True, capture_output=True)
+ subprocess.run(
+ [
+ "git",
+ "-c",
+ "user.name=Max",
+ "-c",
+ "user.email=max@rb.com",
+ "commit",
+ "-m",
+ "Update on feature branch",
+ ],
+ cwd=git_repo,
+ check=True,
+ capture_output=True,
+ )
+
+ committed = git_client.list_committed_changed_files(target_branch="main")
+ assert len(committed) == 1
+ assert committed[0].name == "model.sql"
+
+ assert git_client.list_uncommitted_changed_files() == []
diff --git a/tests/utils/test_metaprogramming.py b/tests/utils/test_metaprogramming.py
index 19413f68ef..9a6f0c95cd 100644
--- a/tests/utils/test_metaprogramming.py
+++ b/tests/utils/test_metaprogramming.py
@@ -23,6 +23,7 @@
Executable,
ExecutableKind,
_dict_sort,
+ _resolve_import_module,
build_env,
func_globals,
normalize_source,
@@ -49,7 +50,7 @@ def test_print_exception(mocker: MockerFixture):
except Exception as ex:
print_exception(ex, test_env, out_mock)
- expected_message = r""" File ".*?.tests.utils.test_metaprogramming\.py", line 48, in test_print_exception
+ expected_message = r""" File ".*?.tests.utils.test_metaprogramming\.py", line 49, in test_print_exception
eval\("test_fun\(\)", env\).*
File '/test/path.py' \(or imported file\), line 2, in test_fun
@@ -83,7 +84,18 @@ class DataClass:
x: int
+class ReferencedClass:
+ def __init__(self, value: int):
+ self.value = value
+
+ def get_value(self) -> int:
+ return self.value
+
+
class MyClass:
+ def __init__(self, x: int):
+ self.helper = ReferencedClass(x * 2)
+
@staticmethod
def foo():
return KLASS_X
@@ -95,6 +107,13 @@ def bar(cls):
def baz(self):
return KLASS_Z
+ def use_referenced(self, value: int) -> int:
+ ref = ReferencedClass(value)
+ return ref.get_value()
+
+ def compute_with_reference(self) -> int:
+ return self.helper.get_value() + 10
+
def other_func(a: int) -> int:
import sqlglot
@@ -103,7 +122,8 @@ def other_func(a: int) -> int:
pd.DataFrame([{"x": 1}])
to_table("y")
my_lambda() # type: ignore
- return X + a + W
+ obj = MyClass(a)
+ return X + a + W + obj.compute_with_reference()
@contextmanager
@@ -131,7 +151,7 @@ def function_with_custom_decorator():
def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2) -> int:
"""DOC STRING"""
sqlglot.parse_one("1")
- MyClass()
+ MyClass(47)
DataClass(x=y)
normalize_model_name("test" + SQLGLOT_META)
fetch_data()
@@ -177,6 +197,7 @@ def test_func_globals() -> None:
assert func_globals(other_func) == {
"X": 1,
"W": 0,
+ "MyClass": MyClass,
"my_lambda": my_lambda,
"pd": pd,
"to_table": to_table,
@@ -202,7 +223,7 @@ def test_normalize_source() -> None:
== """def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2
):
sqlglot.parse_one('1')
- MyClass()
+ MyClass(47)
DataClass(x=y)
normalize_model_name('test' + SQLGLOT_META)
fetch_data()
@@ -223,7 +244,8 @@ def closure(z: int):
pd.DataFrame([{'x': 1}])
to_table('y')
my_lambda()
- return X + a + W"""
+ obj = MyClass(a)
+ return X + a + W + obj.compute_with_reference()"""
)
@@ -252,7 +274,7 @@ def test_serialize_env() -> None:
payload="""def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2
):
sqlglot.parse_one('1')
- MyClass()
+ MyClass(47)
DataClass(x=y)
normalize_model_name('test' + SQLGLOT_META)
fetch_data()
@@ -295,6 +317,9 @@ class DataClass:
path="test_metaprogramming.py",
payload="""class MyClass:
+ def __init__(self, x: int):
+ self.helper = ReferencedClass(x * 2)
+
@staticmethod
def foo():
return KLASS_X
@@ -304,7 +329,26 @@ def bar(cls):
return KLASS_Y
def baz(self):
- return KLASS_Z""",
+ return KLASS_Z
+
+ def use_referenced(self, value: int):
+ ref = ReferencedClass(value)
+ return ref.get_value()
+
+ def compute_with_reference(self):
+ return self.helper.get_value() + 10""",
+ ),
+ "ReferencedClass": Executable(
+ kind=ExecutableKind.DEFINITION,
+ name="ReferencedClass",
+ path="test_metaprogramming.py",
+ payload="""class ReferencedClass:
+
+ def __init__(self, value: int):
+ self.value = value
+
+ def get_value(self):
+ return self.value""",
),
"dataclass": Executable(
payload="from dataclasses import dataclass", kind=ExecutableKind.IMPORT
@@ -341,7 +385,8 @@ def sample_context_manager():
pd.DataFrame([{'x': 1}])
to_table('y')
my_lambda()
- return X + a + W""",
+ obj = MyClass(a)
+ return X + a + W + obj.compute_with_reference()""",
),
"sample_context_manager": Executable(
payload="""@contextmanager
@@ -424,6 +469,21 @@ def function_with_custom_decorator():
assert all(is_metadata for (_, is_metadata) in env.values())
assert serialized_env == expected_env
+ # Check that class references inside init are captured
+ init_globals = func_globals(MyClass.__init__)
+ assert "ReferencedClass" in init_globals
+
+ env = {}
+ build_env(other_func, env=env, name="other_func_test", path=path)
+ serialized_env = serialize_env(env, path=path)
+
+ assert "MyClass" in serialized_env
+ assert "ReferencedClass" in serialized_env
+
+ prepared_env = prepare_env(serialized_env)
+ result = eval("other_func_test(2)", prepared_env)
+ assert result == 17
+
def test_serialize_env_with_enum_import_appearing_in_two_functions() -> None:
path = Path("tests/utils")
@@ -579,3 +639,18 @@ def test_dict_sort_executable_integration():
# non-deterministic repr should not change the payload
exec3 = Executable.value(variables1)
assert exec3.payload == "{'env': 'dev', 'debug': True, 'timeout': 30}"
+
+
+def test_resolve_import_module():
+ """Test that _resolve_import_module finds the shallowest public re-exporting module."""
+ # to_table lives in sqlglot.expressions.builders but is re-exported from sqlglot.expressions
+ assert _resolve_import_module(to_table, "to_table") == "sqlglot.expressions"
+
+ # Objects whose __module__ is already the public module should be returned as-is
+ assert _resolve_import_module(exp.Column, "Column") == "sqlglot.expressions"
+
+ # Objects not re-exported by any parent should return the original module
+ class _Local:
+ __module__ = "some.deep.internal.module"
+
+ assert _resolve_import_module(_Local, "_Local") == "some.deep.internal.module"
diff --git a/vscode/extension/README.md b/vscode/extension/README.md
index 64f6c3e130..dac6d9cae6 100644
--- a/vscode/extension/README.md
+++ b/vscode/extension/README.md
@@ -77,8 +77,8 @@ If you encounter issues, please refer to the [VSCode Extension Guide](https://sq
We welcome contributions! Please:
-1. [Report bugs](https://github.com/tobikodata/sqlmesh/issues) you encounter
-2. [Request features](https://github.com/tobikodata/sqlmesh/issues) you'd like to see
+1. [Report bugs](https://github.com/SQLMesh/sqlmesh/issues) you encounter
+2. [Request features](https://github.com/SQLMesh/sqlmesh/issues) you'd like to see
3. Share feedback on your experience
## π License
@@ -87,7 +87,7 @@ This extension is licensed under the Apache License 2.0. See [LICENSE](LICENSE)
## π Links
-- [SQLMesh GitHub Repository](https://github.com/tobikodata/sqlmesh)
+- [SQLMesh GitHub Repository](https://github.com/SQLMesh/sqlmesh)
- [Tobiko Data Website](https://tobikodata.com)
- [Extension Marketplace Page](https://marketplace.visualstudio.com/items?itemName=tobikodata.sqlmesh)
diff --git a/vscode/extension/package.json b/vscode/extension/package.json
index 35499ad68f..342096731f 100644
--- a/vscode/extension/package.json
+++ b/vscode/extension/package.json
@@ -6,7 +6,7 @@
"version": "0.0.7",
"repository": {
"type": "git",
- "url": "https://github.com/tobikodata/sqlmesh"
+ "url": "https://github.com/SQLMesh/sqlmesh"
},
"main": "./dist/extension.js",
"icon": "assets/logo.png",
diff --git a/web/client/playwright.config.ts b/web/client/playwright.config.ts
index afaa00c716..c574869b87 100644
--- a/web/client/playwright.config.ts
+++ b/web/client/playwright.config.ts
@@ -50,7 +50,10 @@ export default defineConfig({
/* Run your local dev server before starting the tests */
webServer: {
- command: 'npm run build && npm run preview',
+ command:
+ process.env.PLAYWRIGHT_SKIP_BUILD != null
+ ? 'npm run preview'
+ : 'npm run build && npm run preview',
url: URL,
reuseExistingServer: process.env.CI == null,
timeout: 120000, // Two minutes
diff --git a/web/client/vite.config.ts b/web/client/vite.config.ts
index 206504cf4b..4b98b21c68 100644
--- a/web/client/vite.config.ts
+++ b/web/client/vite.config.ts
@@ -68,5 +68,6 @@ export default defineConfig({
},
preview: {
port: 8005,
+ host: '127.0.0.1',
},
})
diff --git a/web/common/package.json b/web/common/package.json
index 6a0965f19e..924bbaa883 100644
--- a/web/common/package.json
+++ b/web/common/package.json
@@ -101,7 +101,7 @@
"tailwindcss": "3.4.17"
},
"private": false,
- "repository": "TobikoData/sqlmesh",
+ "repository": "SQLMesh/sqlmesh",
"scripts": {
"build": "tsc -p tsconfig.build.json && vite build --base './' && pnpm run build:css",
"build-storybook": "storybook build",
diff --git a/web/server/api/endpoints/table_diff.py b/web/server/api/endpoints/table_diff.py
index d441d49e5a..b0167ed032 100644
--- a/web/server/api/endpoints/table_diff.py
+++ b/web/server/api/endpoints/table_diff.py
@@ -126,7 +126,7 @@ def get_table_diff(
table_diffs = context.table_diff(
source=source,
target=target,
- on=exp.condition(on) if on else None,
+ on=t.cast(exp.Condition, exp.condition(on)) if on else None,
select_models={model_or_snapshot} if model_or_snapshot else None,
where=where,
limit=limit,