From 1d7241fb645281b270fb73dd49c1a4ff86f0dba5 Mon Sep 17 00:00:00 2001 From: jochen Date: Sat, 10 Jan 2026 15:47:38 +0100 Subject: [PATCH] Add MCP --- MCP.md | 165 ++++++++++++++++++++++ datacontract/cli.py | 43 ++++++ datacontract/mcp/__init__.py | 30 ++++ datacontract/mcp/connections.py | 235 ++++++++++++++++++++++++++++++++ datacontract/mcp/server.py | 174 +++++++++++++++++++++++ datacontract/mcp/sql_utils.py | 53 +++++++ pyproject.toml | 11 +- tests/test_mcp_connections.py | 185 +++++++++++++++++++++++++ tests/test_mcp_sql_utils.py | 137 +++++++++++++++++++ 9 files changed, 1032 insertions(+), 1 deletion(-) create mode 100644 MCP.md create mode 100644 datacontract/mcp/__init__.py create mode 100644 datacontract/mcp/connections.py create mode 100644 datacontract/mcp/server.py create mode 100644 datacontract/mcp/sql_utils.py create mode 100644 tests/test_mcp_connections.py create mode 100644 tests/test_mcp_sql_utils.py diff --git a/MCP.md b/MCP.md new file mode 100644 index 000000000..20e225d47 --- /dev/null +++ b/MCP.md @@ -0,0 +1,165 @@ +# Data Contract CLI MCP Server + +Data Contract CLI can be started as an MCP (Model Context Protocol) server +to provide AI assistants with data contract-aware SQL execution capabilities. + +## Starting the MCP Server + +Start the MCP server: + +``` +datacontract mcp +``` + +You can specify the `--port` to change the port (default is `8000`) and +`--host` to change the host binding (default is `127.0.0.1`). +If you run the MCP server in a Docker container, you should bind to `--host 0.0.0.0`. + +## Claude Desktop Configuration + +### Local (stdio) + +For local usage with Claude Desktop, add to your `claude_desktop_config.json`: + +```json +{ + "mcpServers": { + "datacontract": { + "command": "datacontract", + "args": ["mcp"] + } + } +} +``` + +### Remote (Streamable HTTP) + +For remote server deployments (MCP spec 2025-03-26): + +```json +{ + "mcpServers": { + "datacontract": { + "url": "https://your-server.com/mcp", + "headers": { + "Authorization": "Bearer YOUR_TOKEN" + } + } + } +} +``` + +## The execute_sql Tool + +The MCP server exposes a single tool `execute_sql` with the following parameters: + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `sql` | string | Yes | SQL query to execute. Must be a SELECT statement (read-only). | +| `data_contract_path` | string | No | Path to data contract YAML file. | +| `data_contract_yaml` | string | No | Data contract YAML content as string. | +| `server` | string | No | Server name from data contract. Uses first supported server if not specified. | +| `output_format` | string | No | Output format: `markdown` (default), `json`, or `csv`. | + +Either `data_contract_path` or `data_contract_yaml` must be provided. + +### Example Usage + +``` +execute_sql( + sql="SELECT * FROM orders LIMIT 10", + data_contract_path="/path/to/datacontract.yaml", + server="production" +) +``` + +### Supported Server Types + +- `postgres` - PostgreSQL databases +- `snowflake` - Snowflake data warehouse +- `databricks` - Databricks SQL +- `bigquery` - Google BigQuery + +## Configure Server Credentials + +To connect to a data source (server), define the required credentials as environment variables +before starting the MCP server. + +### PostgreSQL + +``` +export DATACONTRACT_POSTGRES_USERNAME=myuser +export DATACONTRACT_POSTGRES_PASSWORD=mypassword +``` + +### Snowflake + +``` +export DATACONTRACT_SNOWFLAKE_USERNAME=myuser +export DATACONTRACT_SNOWFLAKE_PASSWORD=mypassword +export DATACONTRACT_SNOWFLAKE_WAREHOUSE=MY_WAREHOUSE # optional +``` + +### Databricks + +``` +export DATACONTRACT_DATABRICKS_TOKEN=dapi... +export DATACONTRACT_DATABRICKS_HTTP_PATH=/sql/1.0/endpoints/... +``` + +### BigQuery + +``` +export DATACONTRACT_BIGQUERY_ACCOUNT_INFO_JSON_PATH=/path/to/service-account.json +# Or use the standard Google environment variable: +export GOOGLE_APPLICATION_CREDENTIALS=/path/to/service-account.json +``` + +## Secure the MCP Server + +To protect the MCP server, set the environment variable `DATACONTRACT_MCP_TOKEN` to a secret token. + +To authenticate, clients must include the header `Authorization: Bearer ` with the +correct token. + +This is highly recommended for production deployments. + +``` +export DATACONTRACT_MCP_TOKEN= +``` + +## Run as Docker Container + +You can use the pre-built Docker image to start the MCP server in a container. +You can run it in any container environment, such as Docker Compose, Kubernetes, Azure Container +Apps, Google Cloud Run, etc. + +Example for Docker Compose: + +```yaml +services: + datacontract-mcp: + image: datacontract/cli:latest + ports: + - "8000:8000" + environment: + - DATACONTRACT_MCP_TOKEN=your-secret-token + - DATACONTRACT_POSTGRES_USERNAME=myuser + - DATACONTRACT_POSTGRES_PASSWORD=mypassword + command: ["mcp", "--host", "0.0.0.0"] +``` + +_docker-compose.yml_ + +Start with: + +``` +docker compose up -d +``` + +## Security Notes + +- Only read-only (SELECT) queries are allowed. INSERT, UPDATE, DELETE, and other modifying statements are rejected. +- SQL queries are validated using sqlglot before execution. +- Always set `DATACONTRACT_MCP_TOKEN` in production to prevent unauthorized access. +- Database credentials are never exposed to MCP clients; they are only used server-side. diff --git a/datacontract/cli.py b/datacontract/cli.py index 25032bd9b..4facebbfb 100644 --- a/datacontract/cli.py +++ b/datacontract/cli.py @@ -507,6 +507,49 @@ def api( uvicorn.run(**uvicorn_args) +@app.command() +def mcp( + port: Annotated[int, typer.Option(help="Port for HTTP server.")] = 8000, + host: Annotated[ + str, typer.Option(help="Host to bind. Hint: For running in docker, set it to 0.0.0.0") + ] = "127.0.0.1", + debug: debug_option = None, +): + """ + Start an MCP (Model Context Protocol) server. + + The MCP server exposes a tool 'execute_sql' that allows AI assistants + to execute read-only SQL queries against data sources defined in data contracts. + + Supported server types: snowflake, databricks, postgres, bigquery. + + To authenticate, set the environment variable DATACONTRACT_MCP_TOKEN to a secret token. + Clients must include the 'Authorization: Bearer ' header. + + To connect to servers, set credentials as environment variables: + - PostgreSQL: DATACONTRACT_POSTGRES_USERNAME, DATACONTRACT_POSTGRES_PASSWORD + - Snowflake: DATACONTRACT_SNOWFLAKE_USERNAME, DATACONTRACT_SNOWFLAKE_PASSWORD + - Databricks: DATACONTRACT_DATABRICKS_TOKEN, DATACONTRACT_DATABRICKS_HTTP_PATH + - BigQuery: DATACONTRACT_BIGQUERY_ACCOUNT_INFO_JSON_PATH or GOOGLE_APPLICATION_CREDENTIALS + + The server uses Streamable HTTP transport (MCP spec 2025-03-26). + """ + enable_debug_logging(debug) + + try: + from datacontract.mcp import run_mcp_server + except ImportError as e: + console.print( + "[red]Error:[/red] MCP dependencies not installed. " + "Install with: pip install datacontract-cli[mcp]" + ) + console.print(f"[dim]Details: {e}[/dim]") + raise typer.Exit(code=1) + + console.print(f"Starting MCP server on {host}:{port}...") + run_mcp_server(host=host, port=port) + + def _print_logs(run): console.print("\nLogs:") for log in run.logs: diff --git a/datacontract/mcp/__init__.py b/datacontract/mcp/__init__.py new file mode 100644 index 000000000..3a9618a2b --- /dev/null +++ b/datacontract/mcp/__init__.py @@ -0,0 +1,30 @@ +"""MCP server module for datacontract-cli.""" + + +def create_mcp_server(name: str = "datacontract"): + """Create and configure the MCP server. + + Args: + name: Server name for MCP identification + + Returns: + Configured FastMCP server instance + """ + from datacontract.mcp.server import create_mcp_server as _create_mcp_server + + return _create_mcp_server(name) + + +def run_mcp_server(host: str = "127.0.0.1", port: int = 8000): + """Run the MCP server. + + Args: + host: Host to bind to + port: Port to bind to + """ + from datacontract.mcp.server import run_mcp_server as _run_mcp_server + + _run_mcp_server(host=host, port=port) + + +__all__ = ["create_mcp_server", "run_mcp_server"] diff --git a/datacontract/mcp/connections.py b/datacontract/mcp/connections.py new file mode 100644 index 000000000..09e13de9b --- /dev/null +++ b/datacontract/mcp/connections.py @@ -0,0 +1,235 @@ +"""Database connection utilities for MCP server using SQLAlchemy.""" + +import csv +import io +import json +import os +from typing import Any +from urllib.parse import quote_plus + +from open_data_contract_standard.model import Server +from sqlalchemy import create_engine, text +from sqlalchemy.engine import Engine + +# Supported server types for SQL execution +SUPPORTED_TYPES = {"snowflake", "databricks", "postgres", "bigquery"} + + +def build_connection_url(server: Server) -> str: + """Build SQLAlchemy connection URL for the given server. + + Args: + server: Server configuration from data contract + + Returns: + SQLAlchemy connection URL string + + Raises: + ValueError: If server type is unsupported or required credentials are missing + """ + if server.type == "postgres": + return _build_postgres_url(server) + elif server.type == "snowflake": + return _build_snowflake_url(server) + elif server.type == "databricks": + return _build_databricks_url(server) + elif server.type == "bigquery": + return _build_bigquery_url(server) + else: + raise ValueError(f"Unsupported server type: {server.type}. Supported: {SUPPORTED_TYPES}") + + +def _build_postgres_url(server: Server) -> str: + """Build PostgreSQL connection URL.""" + username = os.getenv("DATACONTRACT_POSTGRES_USERNAME") + password = os.getenv("DATACONTRACT_POSTGRES_PASSWORD") + + if not username: + raise ValueError("Missing environment variable: DATACONTRACT_POSTGRES_USERNAME") + + password_encoded = quote_plus(password) if password else "" + port = server.port or 5432 + + url = f"postgresql+psycopg2://{username}:{password_encoded}@{server.host}:{port}/{server.database}" + + if server.schema_: + url += f"?options=-csearch_path%3D{server.schema_}" + + return url + + +def _build_snowflake_url(server: Server) -> str: + """Build Snowflake connection URL.""" + username = os.getenv("DATACONTRACT_SNOWFLAKE_USERNAME") + password = os.getenv("DATACONTRACT_SNOWFLAKE_PASSWORD") + warehouse = os.getenv("DATACONTRACT_SNOWFLAKE_WAREHOUSE") + + if not username: + raise ValueError("Missing environment variable: DATACONTRACT_SNOWFLAKE_USERNAME") + if not password: + raise ValueError("Missing environment variable: DATACONTRACT_SNOWFLAKE_PASSWORD") + + password_encoded = quote_plus(password) + + url = f"snowflake://{username}:{password_encoded}@{server.account}/{server.database}" + + if server.schema_: + url += f"/{server.schema_}" + + params = [] + if warehouse: + params.append(f"warehouse={warehouse}") + + if params: + url += "?" + "&".join(params) + + return url + + +def _build_databricks_url(server: Server) -> str: + """Build Databricks connection URL.""" + token = os.getenv("DATACONTRACT_DATABRICKS_TOKEN") + http_path = os.getenv("DATACONTRACT_DATABRICKS_HTTP_PATH") + + if not token: + raise ValueError("Missing environment variable: DATACONTRACT_DATABRICKS_TOKEN") + if not http_path: + raise ValueError("Missing environment variable: DATACONTRACT_DATABRICKS_HTTP_PATH") + + host = server.host or os.getenv("DATACONTRACT_DATABRICKS_SERVER_HOSTNAME") + if not host: + raise ValueError("Server host not configured and DATACONTRACT_DATABRICKS_SERVER_HOSTNAME not set") + + # Remove https:// prefix if present + host = host.replace("https://", "").replace("http://", "") + + token_encoded = quote_plus(token) + http_path_encoded = quote_plus(http_path) + + url = f"databricks://token:{token_encoded}@{host}?http_path={http_path_encoded}" + + if server.catalog: + url += f"&catalog={server.catalog}" + if server.schema_: + url += f"&schema={server.schema_}" + + return url + + +def _build_bigquery_url(server: Server) -> str: + """Build BigQuery connection URL.""" + # BigQuery uses Application Default Credentials or service account file + credentials_path = os.getenv("DATACONTRACT_BIGQUERY_ACCOUNT_INFO_JSON_PATH") or os.getenv( + "GOOGLE_APPLICATION_CREDENTIALS" + ) + + url = f"bigquery://{server.project}" + + if server.dataset: + url += f"/{server.dataset}" + + if credentials_path: + url += f"?credentials_path={quote_plus(credentials_path)}" + + return url + + +def create_engine_for_server(server: Server) -> Engine: + """Create SQLAlchemy engine for the given server. + + Args: + server: Server configuration from data contract + + Returns: + SQLAlchemy Engine instance + """ + url = build_connection_url(server) + return create_engine(url) + + +def execute_query(engine: Engine, sql: str) -> tuple[list[str], list[dict[str, Any]]]: + """Execute SQL query and return results. + + Args: + engine: SQLAlchemy engine + sql: SQL query to execute + + Returns: + Tuple of (column_names, rows_as_dicts) + """ + with engine.connect() as conn: + result = conn.execute(text(sql)) + columns = list(result.keys()) + rows = [dict(row._mapping) for row in result] + return columns, rows + + +def format_results_markdown(columns: list[str], rows: list[dict], server_name: str, server_type: str) -> str: + """Format query results as a markdown table. + + Args: + columns: Column names + rows: List of row dicts + server_name: Name of the server + server_type: Type of the server (postgres, snowflake, etc.) + + Returns: + Markdown formatted string + """ + if not columns: + return f"(0 rows from server '{server_name}' [{server_type}])" + + # Build header + header = "| " + " | ".join(columns) + " |" + separator = "| " + " | ".join(["---"] * len(columns)) + " |" + + # Build rows + row_lines = [] + for row in rows: + values = [str(row.get(col, "")) for col in columns] + row_lines.append("| " + " | ".join(values) + " |") + + lines = [header, separator] + row_lines + lines.append("") + lines.append(f"({len(rows)} rows from server '{server_name}' [{server_type}])") + + return "\n".join(lines) + + +def format_results_json(columns: list[str], rows: list[dict], server_name: str, server_type: str) -> str: + """Format query results as JSON. + + Args: + columns: Column names + rows: List of row dicts + server_name: Name of the server + server_type: Type of the server (postgres, snowflake, etc.) + + Returns: + JSON formatted string + """ + result = { + "columns": columns, + "rows": rows, + "row_count": len(rows), + "server": server_name, + "server_type": server_type, + } + return json.dumps(result, default=str) + + +def format_results_csv(columns: list[str], rows: list[dict]) -> str: + """Format query results as CSV. + + Args: + columns: Column names + rows: List of row dicts + + Returns: + CSV formatted string + """ + output = io.StringIO() + writer = csv.DictWriter(output, fieldnames=columns) + writer.writeheader() + writer.writerows(rows) + return output.getvalue() diff --git a/datacontract/mcp/server.py b/datacontract/mcp/server.py new file mode 100644 index 000000000..aa9721a72 --- /dev/null +++ b/datacontract/mcp/server.py @@ -0,0 +1,174 @@ +"""MCP server implementation for data contract SQL execution.""" + +import logging +import os +from contextlib import asynccontextmanager +from typing import Optional + +from mcp.server.fastmcp import FastMCP + +from datacontract.export.odcs_export_helper import get_server_by_name +from datacontract.lint.resolve import resolve_data_contract +from datacontract.mcp.connections import ( + SUPPORTED_TYPES, + create_engine_for_server, + execute_query, + format_results_csv, + format_results_json, + format_results_markdown, +) +from datacontract.mcp.sql_utils import validate_read_only + +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def server_lifespan(mcp: FastMCP): + """Lifespan context manager for the MCP server.""" + logger.info("Starting datacontract MCP server") + try: + yield {} + finally: + logger.info("Shutting down datacontract MCP server") + + +def create_mcp_server( + name: str = "datacontract", + host: str = "127.0.0.1", + port: int = 8000, +) -> FastMCP: + """Create and configure the MCP server. + + Args: + name: Server name for MCP identification + host: Host to bind to + port: Port to bind to + + Returns: + Configured FastMCP server instance + """ + mcp = FastMCP(name, host=host, port=port, lifespan=server_lifespan) + + @mcp.tool() + def execute_sql( + sql: str, + data_contract_path: Optional[str] = None, + data_contract_yaml: Optional[str] = None, + server: Optional[str] = None, + output_format: Optional[str] = "markdown", + ) -> str: + """Execute a read-only SQL query against a data source defined in a data contract. + + Args: + sql: SQL query to execute. Must be a SELECT statement (read-only). + data_contract_path: Path to data contract YAML file. + data_contract_yaml: Data contract YAML content as string (alternative to path). + server: Server name from data contract. If not specified, uses first supported server. + output_format: Output format - "markdown" (default), "json", or "csv". + + Returns: + Query results in the specified format. + + Note: + Database credentials must be configured via environment variables: + - PostgreSQL: DATACONTRACT_POSTGRES_USERNAME, DATACONTRACT_POSTGRES_PASSWORD + - Snowflake: DATACONTRACT_SNOWFLAKE_USERNAME, DATACONTRACT_SNOWFLAKE_PASSWORD + - Databricks: DATACONTRACT_DATABRICKS_TOKEN, DATACONTRACT_DATABRICKS_HTTP_PATH + - BigQuery: DATACONTRACT_BIGQUERY_ACCOUNT_INFO_JSON_PATH or GOOGLE_APPLICATION_CREDENTIALS + """ + # Validate inputs + if not data_contract_path and not data_contract_yaml: + raise ValueError("Either data_contract_path or data_contract_yaml must be provided") + + if not sql or not sql.strip(): + raise ValueError("SQL query cannot be empty") + + # Validate SQL is read-only + validate_read_only(sql) + + # Resolve data contract + data_contract = resolve_data_contract( + data_contract_location=data_contract_path, + data_contract_str=data_contract_yaml, + ) + + # Find appropriate server + selected_server = _get_supported_server(data_contract, server) + + # Execute query + engine = create_engine_for_server(selected_server) + try: + columns, rows = execute_query(engine, sql) + finally: + engine.dispose() + + # Format output + server_name = selected_server.server or "default" + server_type = selected_server.type + + if output_format == "json": + return format_results_json(columns, rows, server_name, server_type) + elif output_format == "csv": + return format_results_csv(columns, rows) + else: # markdown (default) + return format_results_markdown(columns, rows, server_name, server_type) + + return mcp + + +def _get_supported_server(data_contract, server_name: Optional[str] = None): + """Get a supported server from the data contract. + + Args: + data_contract: Parsed data contract + server_name: Optional specific server name + + Returns: + Selected Server object + + Raises: + ValueError: If no servers or no supported server found + """ + if not data_contract.servers or len(data_contract.servers) == 0: + raise ValueError("Data contract has no servers defined") + + if server_name: + # Find specific server by name + server = get_server_by_name(data_contract, server_name) + if server is None: + available = [s.server for s in data_contract.servers] + raise ValueError(f"Server '{server_name}' not found. Available: {available}") + if server.type not in SUPPORTED_TYPES: + raise ValueError( + f"Server type '{server.type}' not supported for SQL execution. Supported types: {SUPPORTED_TYPES}" + ) + return server + + # Find first server with supported type + for srv in data_contract.servers: + if srv.type in SUPPORTED_TYPES: + logger.info(f"Using server '{srv.server}' (type: {srv.type})") + return srv + + # No supported server found + available_types = {s.type for s in data_contract.servers} + raise ValueError(f"No supported server type found. Available: {available_types}. Supported: {SUPPORTED_TYPES}") + + +def run_mcp_server(host: str = "127.0.0.1", port: int = 8000): + """Run the MCP server. + + Args: + host: Host to bind to + port: Port to bind to + """ + # Check for auth token + token = os.getenv("DATACONTRACT_MCP_TOKEN") + if not token: + logger.warning( + "WARNING: DATACONTRACT_MCP_TOKEN not set. Running without authentication. " + "This is not recommended for production deployments." + ) + + mcp = create_mcp_server(host=host, port=port) + mcp.run(transport="streamable-http") diff --git a/datacontract/mcp/sql_utils.py b/datacontract/mcp/sql_utils.py new file mode 100644 index 000000000..f0048d176 --- /dev/null +++ b/datacontract/mcp/sql_utils.py @@ -0,0 +1,53 @@ +"""SQL utility functions for MCP server.""" + +import sqlglot +from sqlglot import exp + + +def validate_read_only(sql: str) -> None: + """Validate that SQL contains only read-only SELECT statements. + + Args: + sql: SQL query to validate + + Raises: + ValueError: If SQL is empty, invalid, or contains non-SELECT statements + """ + if not sql or not sql.strip(): + raise ValueError("SQL query cannot be empty") + + try: + statements = sqlglot.parse(sql) + except sqlglot.errors.ParseError as e: + raise ValueError(f"Invalid SQL syntax: {e}") + + if not statements or all(stmt is None for stmt in statements): + raise ValueError("No valid SQL statements found") + + # Statement types that modify data or schema + FORBIDDEN_TYPES = { + exp.Insert, + exp.Update, + exp.Delete, + exp.Drop, + exp.Create, + exp.Alter, + exp.TruncateTable, + exp.Merge, + exp.Grant, + exp.Revoke, + exp.Command, # Catches CALL, EXECUTE, etc. + } + + for stmt in statements: + if stmt is None: + continue + + # Check if statement is a forbidden type + for forbidden in FORBIDDEN_TYPES: + if isinstance(stmt, forbidden): + raise ValueError(f"Only SELECT queries allowed, got: {type(stmt).__name__}") + + # Must be a SELECT statement + if not isinstance(stmt, exp.Select): + raise ValueError(f"Only SELECT queries allowed, got: {type(stmt).__name__}") diff --git a/pyproject.toml b/pyproject.toml index d87a329e6..a8e959ad0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,12 +125,21 @@ api = [ "uvicorn==0.38.0", ] +mcp = [ + "mcp>=1.0.0,<2.0.0", + "sqlalchemy>=2.0.0,<3.0.0", + "snowflake-sqlalchemy>=1.6.0,<2.0.0", + "databricks-sqlalchemy>=2.0.0,<3.0.0", + "sqlalchemy-bigquery>=1.9.0,<2.0.0", + "psycopg2-binary>=2.9.0,<3.0.0", +] + protobuf = [ "grpcio-tools>=1.53", ] all = [ - "datacontract-cli[kafka,bigquery,csv,excel,snowflake,postgres,databricks,sqlserver,s3,athena,trino,dbt,dbml,iceberg,parquet,rdf,api,protobuf,oracle]" + "datacontract-cli[kafka,bigquery,csv,excel,snowflake,postgres,databricks,sqlserver,s3,athena,trino,dbt,dbml,iceberg,parquet,rdf,api,protobuf,oracle,mcp]" ] # for development, we pin all libraries to an exact version diff --git a/tests/test_mcp_connections.py b/tests/test_mcp_connections.py new file mode 100644 index 000000000..a71f812bc --- /dev/null +++ b/tests/test_mcp_connections.py @@ -0,0 +1,185 @@ +"""Unit tests for MCP database connections.""" + + +import pytest +from open_data_contract_standard.model import Server + +from datacontract.mcp.connections import ( + SUPPORTED_TYPES, + build_connection_url, + format_results_csv, + format_results_json, + format_results_markdown, +) + + +class TestBuildConnectionUrl: + """Tests for building SQLAlchemy connection URLs.""" + + def test_postgres_url(self, monkeypatch): + """Test PostgreSQL URL building.""" + monkeypatch.setenv("DATACONTRACT_POSTGRES_USERNAME", "testuser") + monkeypatch.setenv("DATACONTRACT_POSTGRES_PASSWORD", "testpass") + + server = Server(type="postgres", host="localhost", port=5432, database="testdb", schema_="public") + + url = build_connection_url(server) + + assert url.startswith("postgresql+psycopg2://") + assert "testuser" in url + assert "localhost:5432" in url + assert "testdb" in url + + def test_postgres_missing_credentials(self, monkeypatch): + """Test PostgreSQL fails without credentials.""" + monkeypatch.delenv("DATACONTRACT_POSTGRES_USERNAME", raising=False) + monkeypatch.delenv("DATACONTRACT_POSTGRES_PASSWORD", raising=False) + + server = Server(type="postgres", host="localhost", port=5432, database="testdb") + + with pytest.raises(ValueError, match="DATACONTRACT_POSTGRES_USERNAME"): + build_connection_url(server) + + def test_snowflake_url(self, monkeypatch): + """Test Snowflake URL building.""" + monkeypatch.setenv("DATACONTRACT_SNOWFLAKE_USERNAME", "testuser") + monkeypatch.setenv("DATACONTRACT_SNOWFLAKE_PASSWORD", "testpass") + + server = Server(type="snowflake", account="myaccount", database="testdb", schema_="public") + + url = build_connection_url(server) + + assert url.startswith("snowflake://") + assert "testuser" in url + assert "myaccount" in url + assert "testdb" in url + + def test_snowflake_missing_credentials(self, monkeypatch): + """Test Snowflake fails without credentials.""" + monkeypatch.delenv("DATACONTRACT_SNOWFLAKE_USERNAME", raising=False) + monkeypatch.delenv("DATACONTRACT_SNOWFLAKE_PASSWORD", raising=False) + + server = Server(type="snowflake", account="myaccount", database="testdb", schema_="public") + + with pytest.raises(ValueError, match="DATACONTRACT_SNOWFLAKE_USERNAME"): + build_connection_url(server) + + def test_databricks_url(self, monkeypatch): + """Test Databricks URL building.""" + monkeypatch.setenv("DATACONTRACT_DATABRICKS_TOKEN", "testtoken") + monkeypatch.setenv("DATACONTRACT_DATABRICKS_HTTP_PATH", "/sql/1.0/endpoints/abc123") + + server = Server(type="databricks", host="myhost.databricks.com", catalog="mycatalog", schema_="myschema") + + url = build_connection_url(server) + + assert url.startswith("databricks://") + assert "myhost.databricks.com" in url + assert "testtoken" in url + + def test_databricks_missing_token(self, monkeypatch): + """Test Databricks fails without token.""" + monkeypatch.delenv("DATACONTRACT_DATABRICKS_TOKEN", raising=False) + + server = Server(type="databricks", host="myhost.databricks.com", catalog="mycatalog", schema_="myschema") + + with pytest.raises(ValueError, match="DATACONTRACT_DATABRICKS_TOKEN"): + build_connection_url(server) + + def test_bigquery_url(self, monkeypatch): + """Test BigQuery URL building.""" + server = Server(type="bigquery", project="myproject", dataset="mydataset") + + url = build_connection_url(server) + + assert url.startswith("bigquery://") + assert "myproject" in url + + def test_unsupported_type(self): + """Test unsupported server type fails.""" + server = Server(type="mysql", host="localhost", database="testdb") + + with pytest.raises(ValueError, match="Unsupported server type"): + build_connection_url(server) + + +class TestSupportedTypes: + """Tests for supported database types.""" + + def test_supported_types_contains_postgres(self): + assert "postgres" in SUPPORTED_TYPES + + def test_supported_types_contains_snowflake(self): + assert "snowflake" in SUPPORTED_TYPES + + def test_supported_types_contains_databricks(self): + assert "databricks" in SUPPORTED_TYPES + + def test_supported_types_contains_bigquery(self): + assert "bigquery" in SUPPORTED_TYPES + + +class TestFormatResults: + """Tests for result formatting functions.""" + + def test_format_markdown(self): + """Test markdown table formatting.""" + columns = ["id", "name", "amount"] + rows = [ + {"id": 1, "name": "Alice", "amount": 100.50}, + {"id": 2, "name": "Bob", "amount": 200.00}, + ] + + result = format_results_markdown(columns, rows, "production", "postgres") + + assert "| id | name | amount |" in result + assert "| 1 | Alice | 100.5 |" in result + assert "| 2 | Bob | 200.0 |" in result + assert "(2 rows from server 'production' [postgres])" in result + + def test_format_markdown_empty(self): + """Test markdown with empty results.""" + columns = ["id", "name"] + rows = [] + + result = format_results_markdown(columns, rows, "production", "postgres") + + assert "| id | name |" in result + assert "(0 rows" in result + + def test_format_json(self): + """Test JSON formatting.""" + columns = ["id", "name"] + rows = [{"id": 1, "name": "Alice"}] + + result = format_results_json(columns, rows, "production", "postgres") + + assert '"columns": ["id", "name"]' in result + assert '"row_count": 1' in result + assert '"server": "production"' in result + + def test_format_csv(self): + """Test CSV formatting.""" + columns = ["id", "name", "amount"] + rows = [ + {"id": 1, "name": "Alice", "amount": 100.50}, + {"id": 2, "name": "Bob", "amount": 200.00}, + ] + + result = format_results_csv(columns, rows) + + # Replace CRLF with LF for cross-platform compatibility + lines = result.replace("\r\n", "\n").strip().split("\n") + assert lines[0] == "id,name,amount" + assert lines[1] == "1,Alice,100.5" + assert lines[2] == "2,Bob,200.0" + + def test_format_csv_with_comma_in_value(self): + """Test CSV escaping for values containing commas.""" + columns = ["id", "description"] + rows = [{"id": 1, "description": "Hello, World"}] + + result = format_results_csv(columns, rows) + + # Value with comma should be quoted + assert '"Hello, World"' in result diff --git a/tests/test_mcp_sql_utils.py b/tests/test_mcp_sql_utils.py new file mode 100644 index 000000000..9b18b2692 --- /dev/null +++ b/tests/test_mcp_sql_utils.py @@ -0,0 +1,137 @@ +"""Unit tests for MCP SQL utilities.""" + +import pytest + +from datacontract.mcp.sql_utils import validate_read_only + + +class TestValidateReadOnly: + """Tests for validate_read_only function.""" + + def test_simple_select(self): + """Simple SELECT should pass.""" + validate_read_only("SELECT * FROM users") + + def test_select_with_where(self): + """SELECT with WHERE clause should pass.""" + validate_read_only("SELECT id, name FROM users WHERE active = true") + + def test_select_with_join(self): + """SELECT with JOIN should pass.""" + validate_read_only(""" + SELECT u.name, o.total + FROM users u + JOIN orders o ON u.id = o.user_id + """) + + def test_select_with_cte(self): + """SELECT with CTE should pass.""" + validate_read_only(""" + WITH active_users AS ( + SELECT * FROM users WHERE active = true + ) + SELECT * FROM active_users + """) + + def test_select_with_subquery(self): + """SELECT with subquery should pass.""" + validate_read_only(""" + SELECT * FROM users + WHERE id IN (SELECT user_id FROM orders WHERE total > 100) + """) + + def test_select_with_aggregation(self): + """SELECT with aggregation should pass.""" + validate_read_only(""" + SELECT user_id, COUNT(*), SUM(total) + FROM orders + GROUP BY user_id + HAVING COUNT(*) > 5 + """) + + def test_insert_rejected(self): + """INSERT should be rejected.""" + with pytest.raises(ValueError, match="Only SELECT"): + validate_read_only("INSERT INTO users (name) VALUES ('test')") + + def test_update_rejected(self): + """UPDATE should be rejected.""" + with pytest.raises(ValueError, match="Only SELECT"): + validate_read_only("UPDATE users SET name = 'test' WHERE id = 1") + + def test_delete_rejected(self): + """DELETE should be rejected.""" + with pytest.raises(ValueError, match="Only SELECT"): + validate_read_only("DELETE FROM users WHERE id = 1") + + def test_drop_table_rejected(self): + """DROP TABLE should be rejected.""" + with pytest.raises(ValueError, match="Only SELECT"): + validate_read_only("DROP TABLE users") + + def test_create_table_rejected(self): + """CREATE TABLE should be rejected.""" + with pytest.raises(ValueError, match="Only SELECT"): + validate_read_only("CREATE TABLE users (id INT, name VARCHAR(255))") + + def test_alter_table_rejected(self): + """ALTER TABLE should be rejected.""" + with pytest.raises(ValueError, match="Only SELECT"): + validate_read_only("ALTER TABLE users ADD COLUMN email VARCHAR(255)") + + def test_truncate_rejected(self): + """TRUNCATE should be rejected.""" + with pytest.raises(ValueError, match="Only SELECT"): + validate_read_only("TRUNCATE TABLE users") + + def test_merge_rejected(self): + """MERGE should be rejected.""" + with pytest.raises(ValueError, match="Only SELECT"): + validate_read_only(""" + MERGE INTO target t + USING source s ON t.id = s.id + WHEN MATCHED THEN UPDATE SET t.name = s.name + """) + + def test_grant_rejected(self): + """GRANT should be rejected.""" + with pytest.raises(ValueError, match="Only SELECT"): + validate_read_only("GRANT SELECT ON users TO public") + + def test_revoke_rejected(self): + """REVOKE should be rejected.""" + with pytest.raises(ValueError, match="Only SELECT"): + validate_read_only("REVOKE SELECT ON users FROM public") + + def test_multiple_selects_allowed(self): + """Multiple SELECT statements should pass.""" + validate_read_only("SELECT 1; SELECT 2") + + def test_mixed_select_and_insert_rejected(self): + """Mix of SELECT and INSERT should be rejected.""" + with pytest.raises(ValueError, match="Only SELECT"): + validate_read_only("SELECT * FROM users; INSERT INTO users (name) VALUES ('test')") + + def test_empty_sql_rejected(self): + """Empty SQL should be rejected.""" + with pytest.raises(ValueError): + validate_read_only("") + + def test_whitespace_only_rejected(self): + """Whitespace-only SQL should be rejected.""" + with pytest.raises(ValueError): + validate_read_only(" \n\t ") + + def test_invalid_sql_rejected(self): + """Invalid SQL syntax should be rejected.""" + # sqlglot is lenient with parsing, but result won't be a SELECT + with pytest.raises(ValueError, match="Only SELECT"): + validate_read_only("SELEKT * FORM users") + + def test_cte_with_insert_rejected(self): + """CTE followed by INSERT should be rejected.""" + with pytest.raises(ValueError, match="Only SELECT"): + validate_read_only(""" + WITH active AS (SELECT * FROM users WHERE active = true) + INSERT INTO archive SELECT * FROM active + """)