diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..b21412b --- /dev/null +++ b/.coveragerc @@ -0,0 +1,8 @@ +[run] +branch = true +omit = + */__init__.py + +[report] +show_missing = true +fail_under = 90 diff --git a/.github/blunderbuss.yml b/.github/blunderbuss.yml new file mode 100644 index 0000000..d922933 --- /dev/null +++ b/.github/blunderbuss.yml @@ -0,0 +1,4 @@ +assign_issues: + - googleapis/llama-index-alloydb +assign_prs: + - googleapis/llama-index-alloydb diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 28d3312..fbd4535 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -34,7 +34,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 - name: Setup Python - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 with: python-version: "3.11" diff --git a/.github/workflows/schedule_reporter.yml b/.github/workflows/schedule_reporter.yml new file mode 100644 index 0000000..ab846ef --- /dev/null +++ b/.github/workflows/schedule_reporter.yml @@ -0,0 +1,25 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Schedule Reporter + +on: + schedule: + - cron: '0 6 * * *' # Runs at 6 AM every morning + +jobs: + run_reporter: + uses: googleapis/langchain-google-alloydb-pg-python/.github/workflows/cloud_build_failure_reporter.yml@main + with: + trigger_names: "integration-test-nightly,continuous-test-on-merge" diff --git a/.kokoro/docker/docs/requirements.txt b/.kokoro/docker/docs/requirements.txt index 56381b8..43b2594 100644 --- a/.kokoro/docker/docs/requirements.txt +++ b/.kokoro/docker/docs/requirements.txt @@ -32,7 +32,7 @@ platformdirs==4.2.1 \ --hash=sha256:031cd18d4ec63ec53e82dceaac0417d218a6863f7745dfcc9efe7793b7039bdf \ --hash=sha256:17d5a1161b3fd67b390023cb2d3b026bbd40abde6fdb052dfbd3a29c3ba22ee1 # via virtualenv -virtualenv==20.26.0 \ - --hash=sha256:0846377ea76e818daaa3e00a4365c018bc3ac9760cbb3544de542885aad61fb3 \ - --hash=sha256:ec25a9671a5102c8d2657f62792a27b48f016664c6873f6beed3800008577210 - # via nox \ No newline at end of file +virtualenv==20.26.6 \ + --hash=sha256:280aede09a2a5c317e409a00102e7077c6432c5a38f0ef938e643805a7ad2c48 \ + --hash=sha256:7345cc5b25405607a624d8418154577459c3e0277f5466dd79c49d5e492995f2 + # via nox diff --git a/.kokoro/requirements.txt b/.kokoro/requirements.txt index 88fb726..b5a1d9a 100644 --- a/.kokoro/requirements.txt +++ b/.kokoro/requirements.txt @@ -277,9 +277,9 @@ jeepney==0.8.0 \ # via # keyring # secretstorage -jinja2==3.1.4 \ - --hash=sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369 \ - --hash=sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d +jinja2==3.1.5 \ + --hash=sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb \ + --hash=sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb # via gcp-releasetool keyring==24.3.1 \ --hash=sha256:c3327b6ffafc0e8befbdb597cacdb4928ffe5c1212f7645f186e6d9957a898db \ @@ -509,9 +509,9 @@ urllib3==2.2.2 \ # via # requests # twine -virtualenv==20.25.1 \ - --hash=sha256:961c026ac520bac5f69acb8ea063e8a4f071bcc9457b9c1f28f6b085c511583a \ - --hash=sha256:e08e13ecdca7a0bd53798f356d5831434afa5b07b93f0abdf0797b7a06ffe197 +virtualenv==20.26.6 \ + --hash=sha256:280aede09a2a5c317e409a00102e7077c6432c5a38f0ef938e643805a7ad2c48 \ + --hash=sha256:7345cc5b25405607a624d8418154577459c3e0277f5466dd79c49d5e492995f2 # via nox wheel==0.43.0 \ --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ @@ -525,4 +525,4 @@ zipp==3.19.1 \ # WARNING: The following packages were not pinned, but pip requires them to be # pinned when the requirements file includes hashes and the requirement is not # satisfied by a package already installed. Consider using the --allow-unsafe flag. -# setuptools \ No newline at end of file +# setuptools diff --git a/CHANGELOG.md b/CHANGELOG.md index a224203..657d2b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,24 @@ # Changelog +## [0.2.0](https://github.com/googleapis/llama-index-alloydb-pg-python/compare/v0.1.0...v0.2.0) (2025-01-30) + + +### Features + +* Adding AlloyDB Chat Store ([#37](https://github.com/googleapis/llama-index-alloydb-pg-python/issues/37)) ([320b448](https://github.com/googleapis/llama-index-alloydb-pg-python/commit/320b448fc60b2a41c4b3e1b90084d799319260eb)) +* Adding AlloyDB Reader ([#57](https://github.com/googleapis/llama-index-alloydb-pg-python/issues/57)) ([7314d83](https://github.com/googleapis/llama-index-alloydb-pg-python/commit/7314d835e62ccd7e8fe59b35f37dccaaee6aed36)) +* Adding Async AlloyDB Reader ([#55](https://github.com/googleapis/llama-index-alloydb-pg-python/issues/55)) ([56e6479](https://github.com/googleapis/llama-index-alloydb-pg-python/commit/56e64790c8eb85979d60b87366adb46596232e24)) +* Adding Async Chat Store ([#35](https://github.com/googleapis/llama-index-alloydb-pg-python/issues/35)) ([dd98771](https://github.com/googleapis/llama-index-alloydb-pg-python/commit/dd987718f0482177d03c84eee6334703613461d0)) +* Adding chat store init methods. ([#29](https://github.com/googleapis/llama-index-alloydb-pg-python/issues/29)) ([de53006](https://github.com/googleapis/llama-index-alloydb-pg-python/commit/de53006d00fe1edd5b3e5c1349613e82f0c94794)) + + +### Bug Fixes + +* Change default metadata_json_column default value ([#66](https://github.com/googleapis/llama-index-alloydb-pg-python/issues/66)) ([ecb53c8](https://github.com/googleapis/llama-index-alloydb-pg-python/commit/ecb53c80d311deb9232f0f8844761a816fc01bc0)) +* Programming error while setting multiple query option ([#47](https://github.com/googleapis/llama-index-alloydb-pg-python/issues/47)) ([5f1405e](https://github.com/googleapis/llama-index-alloydb-pg-python/commit/5f1405ed7ba7941c9c9a4370a428c720d857e6af)) +* Query and return only selected metadata columns ([#52](https://github.com/googleapis/llama-index-alloydb-pg-python/issues/52)) ([dff623b](https://github.com/googleapis/llama-index-alloydb-pg-python/commit/dff623bf8d340811ed88271e59b11d0f996cc811)) +* Update lazy_load_data return type to Iterable. ([#61](https://github.com/googleapis/llama-index-alloydb-pg-python/issues/61)) ([98f6c65](https://github.com/googleapis/llama-index-alloydb-pg-python/commit/98f6c65fc77cbb7f25b22d7118bbb89f3c674b2f)) + ## 0.1.0 (2024-12-03) diff --git a/integration.cloudbuild.yaml b/integration.cloudbuild.yaml index 383aa79..3749777 100644 --- a/integration.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -39,7 +39,7 @@ steps: - "-c" - | /workspace/alloydb-auth-proxy --port ${_DATABASE_PORT} ${_INSTANCE_CONNECTION_NAME} & sleep 2; - python -m pytest tests/ + python -m pytest --cov=llama_index_alloydb_pg --cov-config=.coveragerc tests/ env: - "PROJECT_ID=$PROJECT_ID" - "INSTANCE_ID=$_INSTANCE_ID" diff --git a/pyproject.toml b/pyproject.toml index 7cb81ae..04d7c1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,11 +37,11 @@ Changelog = "https://github.com/googleapis/llama-index-alloydb-pg-python/blob/ma [project.optional-dependencies] test = [ - "black[jupyter]==24.8.0", - "isort==5.13.2", - "mypy==1.11.2", - "pytest-asyncio==0.24.0", - "pytest==8.3.3", + "black[jupyter]==25.1.0", + "isort==6.0.0", + "mypy==1.14.1", + "pytest-asyncio==0.25.3", + "pytest==8.3.4", "pytest-cov==6.0.0", "pytest-depends==1.0.1", ] @@ -50,6 +50,9 @@ test = [ requires = ["setuptools"] build-backend = "setuptools.build_meta" +[tool.pytest.ini_options] +asyncio_default_fixture_loop_scope = "class" + [tool.black] target-version = ['py39'] diff --git a/requirements.txt b/requirements.txt index 5ff2916..ee5ab04 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -google-cloud-alloydb-connector[asyncpg]==1.4.0 -llama-index-core==0.12.0 +google-cloud-alloydb-connector[asyncpg]==1.7.0 +llama-index-core==0.12.14 pgvector==0.3.6 -SQLAlchemy[asyncio]==2.0.36 +SQLAlchemy[asyncio]==2.0.37 diff --git a/samples/llama_index_chat_store.ipynb b/samples/llama_index_chat_store.ipynb new file mode 100644 index 0000000..157ed60 --- /dev/null +++ b/samples/llama_index_chat_store.ipynb @@ -0,0 +1,419 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Google AlloyDB for PostgreSQL - `AlloyDBChatStore`\n", + "\n", + "> [AlloyDB](https://cloud.google.com/alloydb) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. AlloyDB is 100% compatible with PostgreSQL. Extend your database application to build AI-powered experiences leveraging AlloyDB's LlamaIndex integrations.\n", + "\n", + "This notebook goes over how to use `AlloyDB for PostgreSQL` to store chat history with `AlloyDBChatStore` class.\n", + "\n", + "Learn more about the package on [GitHub](https://github.com/googleapis/llama-index-alloydb-pg-python/).\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/llama-index-alloydb-pg-python/blob/main/samples/llama_index_chat_store.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Before you begin\n", + "\n", + "To run this notebook, you will need to do the following:\n", + "\n", + " * [Create a Google Cloud Project](https://developers.google.com/workspace/guides/create-project)\n", + " * [Enable the AlloyDB API](https://console.cloud.google.com/flows/enableapi?apiid=alloydb.googleapis.com)\n", + " * [Create a AlloyDB cluster and instance.](https://cloud.google.com/alloydb/docs/cluster-create)\n", + " * [Create a AlloyDB database.](https://cloud.google.com/alloydb/docs/quickstart/create-and-connect)\n", + " * [Add a User to the database.](https://cloud.google.com/alloydb/docs/database-users/about)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### πŸ¦™ Library Installation\n", + "Install the integration library, `llama-index-alloydb-pg`, and the library for the embedding service, `llama-index-embeddings-vertex`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet llama-index-alloydb-pg llama-index-llms-vertex llama-index" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Colab only:** Uncomment the following cell to restart the kernel or use the button to restart the kernel. For Vertex AI Workbench you can restart the terminal using the button on top." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# # Automatically restart kernel after installs so that your environment can access the new packages\n", + "# import IPython\n", + "\n", + "# app = IPython.Application.instance()\n", + "# app.kernel.do_shutdown(True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### πŸ” Authentication\n", + "Authenticate to Google Cloud as the IAM user logged into this notebook in order to access your Google Cloud Project.\n", + "\n", + "* If you are using Colab to run this notebook, use the cell below and continue.\n", + "* If you are using Vertex AI Workbench, check out the setup instructions [here](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/setup-env)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from google.colab import auth\n", + "\n", + "auth.authenticate_user()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ☁ Set Your Google Cloud Project\n", + "Set your Google Cloud project so that you can leverage Google Cloud resources within this notebook.\n", + "\n", + "If you don't know your project ID, try the following:\n", + "\n", + "* Run `gcloud config list`.\n", + "* Run `gcloud projects list`.\n", + "* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @markdown Please fill in the value below with your Google Cloud project ID and then run the cell.\n", + "\n", + "PROJECT_ID = \"my-project-id\" # @param {type:\"string\"}\n", + "\n", + "# Set the project id\n", + "!gcloud config set project {PROJECT_ID}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Usage" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Set AlloyDB database values\n", + "Find your database values, in the [AlloyDB Instances page](https://console.cloud.google.com/alloydb/clusters)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Set Your Values Here { display-mode: \"form\" }\n", + "REGION = \"us-central1\" # @param {type: \"string\"}\n", + "CLUSTER = \"my-cluster\" # @param {type: \"string\"}\n", + "INSTANCE = \"my-primary\" # @param {type: \"string\"}\n", + "DATABASE = \"my-database\" # @param {type: \"string\"}\n", + "TABLE_NAME = \"chat_store\" # @param {type: \"string\"}\n", + "USER = \"postgres\" # @param {type: \"string\"}\n", + "PASSWORD = \"my-password\" # @param {type: \"string\"}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### AlloyDBEngine Connection Pool\n", + "\n", + "One of the requirements and arguments to establish AlloyDB as a chat store is a `AlloyDBEngine` object. The `AlloyDBEngine` configures a connection pool to your AlloyDB database, enabling successful connections from your application and following industry best practices.\n", + "\n", + "To create a `AlloyDBEngine` using `AlloyDBEngine.from_instance()` you need to provide only 5 things:\n", + "\n", + "1. `project_id` : Project ID of the Google Cloud Project where the AlloyDB instance is located.\n", + "1. `region` : Region where the AlloyDB instance is located.\n", + "1. `cluster`: The name of the AlloyDB cluster.\n", + "1. `instance` : The name of the AlloyDB instance.\n", + "1. `database` : The name of the database to connect to on the AlloyDB instance.\n", + "\n", + "By default, [IAM database authentication](https://cloud.google.com/alloydb/docs/connect-iam) will be used as the method of database authentication. This library uses the IAM principal belonging to the [Application Default Credentials (ADC)](https://cloud.google.com/docs/authentication/application-default-credentials) sourced from the environment.\n", + "\n", + "Optionally, [built-in database authentication](https://cloud.google.com/alloydb/docs/database-users/about) using a username and password to access the AlloyDB database can also be used. Just provide the optional `user` and `password` arguments to `AlloyDBEngine.from_instance()`:\n", + "\n", + "* `user` : Database user to use for built-in database authentication and login\n", + "* `password` : Database password to use for built-in database authentication and login.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note:** This tutorial demonstrates the async interface. All async methods have corresponding sync methods." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index_alloydb_pg import AlloyDBEngine\n", + "\n", + "engine = await AlloyDBEngine.afrom_instance(\n", + " project_id=PROJECT_ID,\n", + " region=REGION,\n", + " cluster=CLUSTER,\n", + " instance=INSTANCE,\n", + " database=DATABASE,\n", + " user=USER,\n", + " password=PASSWORD,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### AlloyDBEngine for AlloyDB Omni\n", + "To create an `AlloyDBEngine` for AlloyDB Omni, you will need a connection url. `AlloyDBEngine.from_connection_string` first creates an async engine and then turns it into an `AlloyDBEngine`. Here is an example connection with the `asyncpg` driver:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Replace with your own AlloyDB Omni info\n", + "OMNI_USER = \"my-omni-user\"\n", + "OMNI_PASSWORD = \"\"\n", + "OMNI_HOST = \"127.0.0.1\"\n", + "OMNI_PORT = \"5432\"\n", + "OMNI_DATABASE = \"my-omni-db\"\n", + "\n", + "connstring = f\"postgresql+asyncpg://{OMNI_USER}:{OMNI_PASSWORD}@{OMNI_HOST}:{OMNI_PORT}/{OMNI_DATABASE}\"\n", + "engine = AlloyDBEngine.from_connection_string(connstring)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize a table\n", + "The `AlloyDBChatStore` class requires a database table. The `AlloyDBEngine` engine has a helper method `ainit_chat_store_table()` that can be used to create a table with the proper schema for you." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "await engine.ainit_chat_store_table(table_name=TABLE_NAME)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Optional Tip: πŸ’‘\n", + "You can also specify a schema name by passing `schema_name` wherever you pass `table_name`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "SCHEMA_NAME = \"my_schema\"\n", + "\n", + "await engine.ainit_chat_store_table(\n", + " table_name=TABLE_NAME,\n", + " schema_name=SCHEMA_NAME,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize a default AlloyDBChatStore" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index_alloydb_pg import AlloyDBChatStore\n", + "\n", + "chat_store = await AlloyDBChatStore.create(\n", + " engine=engine,\n", + " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create a ChatMemoryBuffer\n", + "The `ChatMemoryBuffer` stores a history of recent chat messages, enabling the LLM to access relevant context from prior interactions.\n", + "\n", + "By passing our chat store into the `ChatMemoryBuffer`, it can automatically retrieve and update messages associated with a specific session ID or `chat_store_key`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.memory import ChatMemoryBuffer\n", + "\n", + "memory = ChatMemoryBuffer.from_defaults(\n", + " token_limit=3000,\n", + " chat_store=chat_store,\n", + " chat_store_key=\"user1\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create an LLM class instance\n", + "\n", + "You can use any of the [LLMs compatible with LlamaIndex](https://docs.llamaindex.ai/en/stable/module_guides/models/llms/modules/).\n", + "You may need to enable Vertex AI API to use `Vertex`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.llms.vertex import Vertex\n", + "\n", + "llm = Vertex(model=\"gemini-1.5-flash-002\", project=PROJECT_ID)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Use the AlloyDBChatStore without a storage context" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Create a Simple Chat Engine" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.chat_engine import SimpleChatEngine\n", + "\n", + "chat_engine = SimpleChatEngine(memory=memory, llm=llm, prefix_messages=[])\n", + "\n", + "response = chat_engine.chat(\"Hello\")\n", + "\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Use the AlloyDBChatStore with a storage context" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Create a LlamaIndex `Index`\n", + "\n", + "An `Index` is allows us to quickly retrieve relevant context for a user query.\n", + "They are used to build `QueryEngines` and `ChatEngines`.\n", + "For a list of indexes that can be built in LlamaIndex, see [Index Guide](https://docs.llamaindex.ai/en/stable/module_guides/indexing/index_guide/).\n", + "\n", + "A `VectorStoreIndex`, can be built using the `AlloyDBVectorStore`. See the detailed guide on how to use the `AlloyDBVectorStore` to build an index [here](https://github.com/googleapis/llama-index-alloydb-pg-python/blob/main/samples/llama_index_vector_store.ipynb).\n", + "\n", + "You can also use the `AlloyDBDocumentStore` and `AlloyDBIndexStore` to persist documents and index metadata.\n", + "These modules can be used to build other `Indexes`.\n", + "For a detailed python notebook on this, see [LlamaIndex Doc Store Guide](https://github.com/googleapis/llama-index-alloydb-pg-python/blob/main/samples/llama_index_doc_store.ipynb)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Create and use the Chat Engine" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create an `index` here\n", + "\n", + "chat_engine = index.as_chat_engine(llm=llm, chat_mode=\"context\", memory=memory) # type: ignore\n", + "response = chat_engine.chat(\"What did the author do?\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "senseAIenv", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/samples/llama_index_doc_store.ipynb b/samples/llama_index_doc_store.ipynb index bc14192..149dd2f 100644 --- a/samples/llama_index_doc_store.ipynb +++ b/samples/llama_index_doc_store.ipynb @@ -40,7 +40,7 @@ "id": "IR54BmgvdHT_" }, "source": [ - "### πŸ¦œπŸ”— Library Installation\n", + "### πŸ¦™ Library Installation\n", "Install the integration library, `llama-index-alloydb-pg`, and the library for the embedding service, `llama-index-embeddings-vertex`." ] }, diff --git a/samples/llama_index_reader.ipynb b/samples/llama_index_reader.ipynb new file mode 100644 index 0000000..7a4f1e8 --- /dev/null +++ b/samples/llama_index_reader.ipynb @@ -0,0 +1,378 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Google AlloyDB for PostgreSQL - `AlloyDBReader`\n", + "\n", + "> [AlloyDB](https://cloud.google.com/alloydb) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. AlloyDB is 100% compatible with PostgreSQL. Extend your database application to build AI-powered experiences leveraging AlloyDB's LlamaIndex integrations.\n", + "\n", + "This notebook goes over how to use `AlloyDB for PostgreSQL` to retrieve data as documents with the `AlloyDBReader` class.\n", + "\n", + "Learn more about the package on [GitHub](https://github.com/googleapis/llama-index-alloydb-pg-python/).\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/llama-index-alloydb-pg-python/blob/main/samples/llama_index_reader.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Before you begin\n", + "\n", + "To run this notebook, you will need to do the following:\n", + "\n", + " * [Create a Google Cloud Project](https://developers.google.com/workspace/guides/create-project)\n", + " * [Enable the AlloyDB API](https://console.cloud.google.com/flows/enableapi?apiid=alloydb.googleapis.com)\n", + " * [Create a AlloyDB cluster and instance.](https://cloud.google.com/alloydb/docs/cluster-create)\n", + " * [Create a AlloyDB database.](https://cloud.google.com/alloydb/docs/quickstart/create-and-connect)\n", + " * [Add a User to the database.](https://cloud.google.com/alloydb/docs/database-users/about)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### πŸ¦™ Library Installation\n", + "Install the integration library, `llama-index-alloydb-pg`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Colab only:** Uncomment the following cell to restart the kernel or use the button to restart the kernel. For Vertex AI Workbench you can restart the terminal using the button on top." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# # Automatically restart kernel after installs so that your environment can access the new packages\n", + "# import IPython\n", + "\n", + "# app = IPython.Application.instance()\n", + "# app.kernel.do_shutdown(True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### πŸ” Authentication\n", + "Authenticate to Google Cloud as the IAM user logged into this notebook in order to access your Google Cloud Project.\n", + "\n", + "* If you are using Colab to run this notebook, use the cell below and continue.\n", + "* If you are using Vertex AI Workbench, check out the setup instructions [here](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/setup-env)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from google.colab import auth\n", + "\n", + "auth.authenticate_user()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ☁ Set Your Google Cloud Project\n", + "Set your Google Cloud project so that you can leverage Google Cloud resources within this notebook.\n", + "\n", + "If you don't know your project ID, try the following:\n", + "\n", + "* Run `gcloud config list`.\n", + "* Run `gcloud projects list`.\n", + "* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @markdown Please fill in the value below with your Google Cloud project ID and then run the cell.\n", + "\n", + "PROJECT_ID = \"my-project-id\" # @param {type:\"string\"}\n", + "\n", + "# Set the project id\n", + "!gcloud config set project {PROJECT_ID}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Usage" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Set AlloyDB database values\n", + "Find your database values, in the [AlloyDB Instances page](https://console.cloud.google.com/alloydb/clusters)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Set Your Values Here { display-mode: \"form\" }\n", + "REGION = \"us-central1\" # @param {type: \"string\"}\n", + "CLUSTER = \"my-cluster\" # @param {type: \"string\"}\n", + "INSTANCE = \"my-primary\" # @param {type: \"string\"}\n", + "DATABASE = \"my-database\" # @param {type: \"string\"}\n", + "TABLE_NAME = \"document_store\" # @param {type: \"string\"}\n", + "USER = \"postgres\" # @param {type: \"string\"}\n", + "PASSWORD = \"my-password\" # @param {type: \"string\"}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### AlloyDBEngine Connection Pool\n", + "\n", + "One of the requirements and arguments to establish AlloyDB Reader is a `AlloyDBEngine` object. The `AlloyDBEngine` configures a connection pool to your AlloyDB database, enabling successful connections from your application and following industry best practices.\n", + "\n", + "To create a `AlloyDBEngine` using `AlloyDBEngine.from_instance()` you need to provide only 5 things:\n", + "\n", + "1. `project_id` : Project ID of the Google Cloud Project where the AlloyDB instance is located.\n", + "1. `region` : Region where the AlloyDB instance is located.\n", + "1. `cluster`: The name of the AlloyDB cluster.\n", + "1. `instance` : The name of the AlloyDB instance.\n", + "1. `database` : The name of the database to connect to on the AlloyDB instance.\n", + "\n", + "By default, [IAM database authentication](https://cloud.google.com/alloydb/docs/connect-iam) will be used as the method of database authentication. This library uses the IAM principal belonging to the [Application Default Credentials (ADC)](https://cloud.google.com/docs/authentication/application-default-credentials) sourced from the environment.\n", + "\n", + "Optionally, [built-in database authentication](https://cloud.google.com/alloydb/docs/database-users/about) using a username and password to access the AlloyDB database can also be used. Just provide the optional `user` and `password` arguments to `AlloyDBEngine.from_instance()`:\n", + "\n", + "* `user` : Database user to use for built-in database authentication and login\n", + "* `password` : Database password to use for built-in database authentication and login.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note:** This tutorial demonstrates the async interface. All async methods have corresponding sync methods." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index_alloydb_pg import AlloyDBEngine\n", + "\n", + "engine = await AlloyDBEngine.afrom_instance(\n", + " project_id=PROJECT_ID,\n", + " region=REGION,\n", + " cluster=CLUSTER,\n", + " instance=INSTANCE,\n", + " database=DATABASE,\n", + " user=USER,\n", + " password=PASSWORD,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create AlloyDBReader" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When creating an `AlloyDBReader` for fetching data from AlloyDB, you have two main options to specify the data you want to load:\n", + "* using the table_name argument - When you specify the table_name argument, you're telling the reader to fetch all the data from the given table.\n", + "* using the query argument - When you specify the query argument, you can provide a custom SQL query to fetch the data. This allows you to have full control over the SQL query, including selecting specific columns, applying filters, sorting, joining tables, etc.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load Documents using the `table_name` argument" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Load Documents via default table\n", + "The reader returns a list of Documents from the table using the first column as text and all other columns as metadata. The default table will have the first column as\n", + "text and the second column as metadata (JSON). Each row becomes a document." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index_alloydb_pg import AlloyDBReader\n", + "\n", + "# Creating a basic AlloyDBReader object\n", + "reader = await AlloyDBReader.create(\n", + " engine,\n", + " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Load documents via custom table/metadata or custom page content columns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reader = await AlloyDBReader.create(\n", + " engine,\n", + " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", + " content_columns=[\"product_name\"], # Optional\n", + " metadata_columns=[\"id\"], # Optional\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load Documents using a SQL query\n", + "The query parameter allows users to specify a custom SQL query which can include filters to load specific documents from a database." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "table_name = \"products\"\n", + "content_columns = [\"product_name\", \"description\"]\n", + "metadata_columns = [\"id\", \"content\"]\n", + "\n", + "reader = AlloyDBReader.create(\n", + " engine=engine,\n", + " query=f\"SELECT * FROM {table_name};\",\n", + " content_columns=content_columns,\n", + " metadata_columns=metadata_columns,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note**: If the `content_columns` and `metadata_columns` are not specified, the reader will automatically treat the first returned column as the document’s `text` and all subsequent columns as `metadata`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Set page content format\n", + "The reader returns a list of Documents, with one document per row, with page content in specified string format, i.e. text (space separated concatenation), JSON, YAML, CSV, etc. JSON and YAML formats include headers, while text and CSV do not include field headers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reader = await AlloyDBReader.create(\n", + " engine,\n", + " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", + " content_columns=[\"product_name\", \"description\"],\n", + " format=\"YAML\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load the documents\n", + "\n", + "You can choose to load the documents in two ways:\n", + "* Load all the data at once\n", + "* Lazy load data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Load data all at once" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs = await reader.aload_data()\n", + "\n", + "print(docs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Lazy Load the data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs_iterable = reader.alazy_load_data()\n", + "\n", + "docs = []\n", + "async for doc in docs_iterable:\n", + " docs.append(doc)\n", + "\n", + "print(docs)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/samples/llama_index_vector_store.ipynb b/samples/llama_index_vector_store.ipynb index 40a67f3..873665a 100644 --- a/samples/llama_index_vector_store.ipynb +++ b/samples/llama_index_vector_store.ipynb @@ -14,7 +14,7 @@ "\n", "Learn more about the package on [GitHub](https://github.com/googleapis/llama-index-alloydb-pg-python/).\n", "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/llama-index-alloydb-pg-python/blob/main/docs/vector_store.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/llama-index-alloydb-pg-python/blob/main/samples/llama_index_vector_store.ipynb)" ] }, { @@ -40,7 +40,7 @@ "id": "IR54BmgvdHT_" }, "source": [ - "### πŸ¦œπŸ”— Library Installation\n", + "### πŸ¦™ Library Installation\n", "Install the integration library, `llama-index-alloydb-pg`, and the library for the embedding service, `llama-index-embeddings-vertex`." ] }, diff --git a/src/llama_index_alloydb_pg/__init__.py b/src/llama_index_alloydb_pg/__init__.py index b695ab3..91c28b8 100644 --- a/src/llama_index_alloydb_pg/__init__.py +++ b/src/llama_index_alloydb_pg/__init__.py @@ -12,16 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .chat_store import AlloyDBChatStore from .document_store import AlloyDBDocumentStore from .engine import AlloyDBEngine, Column from .index_store import AlloyDBIndexStore +from .reader import AlloyDBReader from .vector_store import AlloyDBVectorStore from .version import __version__ _all = [ + "AlloyDBChatStore", "AlloyDBDocumentStore", "AlloyDBEngine", "AlloyDBIndexStore", + "AlloyDBReader", "AlloyDBVectorStore", "Column", "__version__", diff --git a/src/llama_index_alloydb_pg/async_chat_store.py b/src/llama_index_alloydb_pg/async_chat_store.py new file mode 100644 index 0000000..da24402 --- /dev/null +++ b/src/llama_index_alloydb_pg/async_chat_store.py @@ -0,0 +1,295 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import List, Optional + +from llama_index.core.llms import ChatMessage +from llama_index.core.storage.chat_store.base import BaseChatStore +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import AlloyDBEngine + + +class AsyncAlloyDBChatStore(BaseChatStore): + """Chat Store Table stored in an AlloyDB for PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, + key: object, + engine: AsyncEngine, + table_name: str, + schema_name: str = "public", + ): + """AsyncAlloyDBChatStore constructor. + + Args: + key (object): Key to prevent direct constructor usage. + engine (AlloyDBEngine): Database connection pool. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. + Defaults to "public" + + Raises: + Exception: If constructor is directly called by the user. + """ + if key != AsyncAlloyDBChatStore.__create_key: + raise Exception("Only create class through 'create' method!") + + # Delegate to Pydantic's __init__ + super().__init__() + self._engine = engine + self._table_name = table_name + self._schema_name = schema_name + + @classmethod + async def create( + cls, + engine: AlloyDBEngine, + table_name: str, + schema_name: str = "public", + ) -> AsyncAlloyDBChatStore: + """Create a new AsyncAlloyDBChatStore instance. + + Args: + engine (AlloyDBEngine): AlloyDB engine to use. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. + Defaults to "public" + + Raises: + ValueError: If the table provided does not contain required schema. + + Returns: + AsyncAlloyDBChatStore: A newly created instance of AsyncAlloyDBChatStore. + """ + table_schema = await engine._aload_table_schema(table_name, schema_name) + column_names = table_schema.columns.keys() + + required_columns = ["id", "key", "message"] + + if not (all(x in column_names for x in required_columns)): + raise ValueError( + f"Table '{schema_name}'.'{table_name}' has an incorrect schema.\n" + f"Expected column names: {required_columns}\n" + f"Provided column names: {column_names}\n" + "Please create the table with the following schema:\n" + f"CREATE TABLE {schema_name}.{table_name} (\n" + " id SERIAL PRIMARY KEY,\n" + " key VARCHAR NOT NULL,\n" + " message JSON NOT NULL\n" + ");" + ) + + return cls(cls.__create_key, engine._pool, table_name, schema_name) + + async def __aexecute_query(self, query, params=None): + async with self._engine.connect() as conn: + await conn.execute(text(query), params) + await conn.commit() + + async def __afetch_query(self, query): + async with self._engine.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + results = result_map.fetchall() + await conn.commit() + return results + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "AsyncAlloyDBChatStore" + + async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None: + """Asynchronously sets the chat messages for a specific key. + + Args: + key (str): A unique identifier for the chat. + messages (List[ChatMessage]): A list of `ChatMessage` objects to upsert. + + Returns: + None + + """ + query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE key = '{key}'; """ + await self.__aexecute_query(query) + insert_query = f""" + INSERT INTO "{self._schema_name}"."{self._table_name}" (key, message) + VALUES (:key, :message);""" + + params = [ + { + "key": key, + "message": json.dumps(message.model_dump()), + } + for message in messages + ] + + await self.__aexecute_query(insert_query, params) + + async def aget_messages(self, key: str) -> List[ChatMessage]: + """Asynchronously retrieves the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for which the messages are to be retrieved. + + Returns: + List[ChatMessage]: A list of `ChatMessage` objects associated with the provided key. + If no messages are found, an empty list is returned. + """ + query = f"""SELECT message from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id;""" + results = await self.__afetch_query(query) + if results: + return [ + ChatMessage.model_validate(result.get("message")) for result in results + ] + return [] + + async def async_add_message(self, key: str, message: ChatMessage) -> None: + """Asynchronously adds a new chat message to the specified key. + + Args: + key (str): A unique identifierfor the chat to which the message is added. + message (ChatMessage): The `ChatMessage` object that is to be added. + + Returns: + None + """ + insert_query = f""" + INSERT INTO "{self._schema_name}"."{self._table_name}" (key, message) + VALUES (:key, :message);""" + params = {"key": key, "message": json.dumps(message.model_dump())} + + await self.__aexecute_query(insert_query, params) + + async def adelete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """Asynchronously deletes the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + + Returns: + Optional[List[ChatMessage]]: A list of `ChatMessage` objects that were deleted, or `None` if no messages + were associated with the key or could be deleted. + """ + query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' RETURNING *; """ + results = await self.__afetch_query(query) + if results: + return [ + ChatMessage.model_validate(result.get("message")) for result in results + ] + return None + + async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """Asynchronously deletes a specific chat message by index from the messages associated with a given key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + idx (int): The index of the `ChatMessage` to be deleted from the list of messages. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + query = f"""SELECT * from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id;""" + results = await self.__afetch_query(query) + if results: + if idx >= len(results): + return None + id_to_be_deleted = results[idx].get("id") + delete_query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE id = '{id_to_be_deleted}' RETURNING *;""" + result = await self.__afetch_query(delete_query) + result = result[0] + if result: + return ChatMessage.model_validate(result.get("message")) + return None + return None + + async def adelete_last_message(self, key: str) -> Optional[ChatMessage]: + """Asynchronously deletes the last chat message associated with a given key. + + Args: + key (str): A unique identifier for the chat whose message is to be deleted. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + query = f"""SELECT * from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id DESC LIMIT 1;""" + results = await self.__afetch_query(query) + if results: + id_to_be_deleted = results[0].get("id") + delete_query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE id = '{id_to_be_deleted}' RETURNING *;""" + result = await self.__afetch_query(delete_query) + result = result[0] + if result: + return ChatMessage.model_validate(result.get("message")) + return None + return None + + async def aget_keys(self) -> List[str]: + """Asynchronously retrieves a list of all keys. + + Returns: + Optional[str]: A list of strings representing the keys. If no keys are found, an empty list is returned. + """ + query = ( + f"""SELECT distinct key from "{self._schema_name}"."{self._table_name}";""" + ) + results = await self.__afetch_query(query) + keys = [] + if results: + keys = [row.get("key") for row in results] + return keys + + def set_messages(self, key: str, messages: List[ChatMessage]) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBChatStore . Use AlloyDBChatStore interface instead." + ) + + def get_messages(self, key: str) -> List[ChatMessage]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBChatStore . Use AlloyDBChatStore interface instead." + ) + + def add_message(self, key: str, message: ChatMessage) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBChatStore . Use AlloyDBChatStore interface instead." + ) + + def delete_messages(self, key: str) -> Optional[List[ChatMessage]]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBChatStore . Use AlloyDBChatStore interface instead." + ) + + def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBChatStore . Use AlloyDBChatStore interface instead." + ) + + def delete_last_message(self, key: str) -> Optional[ChatMessage]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBChatStore . Use AlloyDBChatStore interface instead." + ) + + def get_keys(self) -> List[str]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBChatStore . Use AlloyDBChatStore interface instead." + ) diff --git a/src/llama_index_alloydb_pg/async_document_store.py b/src/llama_index_alloydb_pg/async_document_store.py index b45b9dc..f72a9bd 100644 --- a/src/llama_index_alloydb_pg/async_document_store.py +++ b/src/llama_index_alloydb_pg/async_document_store.py @@ -16,7 +16,7 @@ import json import warnings -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Optional, Sequence from llama_index.core.constants import DATA_KEY from llama_index.core.schema import BaseNode @@ -119,13 +119,13 @@ async def __afetch_query(self, query): return results async def _put_all_doc_hashes_to_table( - self, rows: List[Tuple[str, str]], batch_size: int = int(DEFAULT_BATCH_SIZE) + self, rows: list[tuple[str, str]], batch_size: int = int(DEFAULT_BATCH_SIZE) ) -> None: """Puts a multiple rows of node ids with their doc_hash into the document table. Incase a row with the id already exists, it updates the row with the new doc_hash. Args: - rows (List[Tuple[str, str]]): List of tuples of id and doc_hash + rows (list[tuple[str, str]]): List of tuples of id and doc_hash batch_size (int): batch_size to insert the rows. Defaults to 1. Returns: @@ -173,7 +173,7 @@ async def async_add_documents( """Adds a document to the store. Args: - docs (List[BaseDocument]): documents + docs (list[BaseDocument]): documents allow_update (bool): allow update of docstore from document batch_size (int): batch_size to insert the rows. Defaults to 1. store_text (bool): allow the text content of the node to stored. @@ -225,7 +225,7 @@ async def async_add_documents( await self.__aexecute_query(query, batch) @property - async def adocs(self) -> Dict[str, BaseNode]: + async def adocs(self) -> dict[str, BaseNode]: """Get all documents. Returns: @@ -300,12 +300,12 @@ async def aget_ref_doc_info(self, ref_doc_id: str) -> Optional[RefDocInfo]: return RefDocInfo(node_ids=node_ids, metadata=merged_metadata) - async def aget_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + async def aget_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: """Get a mapping of ref_doc_id -> RefDocInfo for all ingested documents. Returns: Optional[ - Dict[ + dict[ str, #Ref_doc_id RefDocInfo, #Ref_doc_info of the id ] @@ -356,14 +356,14 @@ async def adocument_exists(self, doc_id: str) -> bool: async def _get_ref_doc_child_node_ids( self, ref_doc_id: str - ) -> Optional[Dict[str, List[str]]]: + ) -> Optional[dict[str, list[str]]]: """Helper function to find the child node mappings of a ref_doc_id. Returns: Optional[ - Dict[ + dict[ str, # Ref_doc_id - List # List of all nodes that refer to ref_doc_id + list # List of all nodes that refer to ref_doc_id ] ]""" query = f"""select id from "{self._schema_name}"."{self._table_name}" where ref_doc_id = '{ref_doc_id}';""" @@ -442,11 +442,11 @@ async def aset_document_hash(self, doc_id: str, doc_hash: str) -> None: await self._put_all_doc_hashes_to_table(rows=[(doc_id, doc_hash)]) - async def aset_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + async def aset_document_hashes(self, doc_hashes: dict[str, str]) -> None: """Set the hash for a given doc_id. Args: - doc_hashes (Dict[str, str]): Dictionary with doc_id as key and doc_hash as value. + doc_hashes (dict[str, str]): Dictionary with doc_id as key and doc_hash as value. Returns: None @@ -473,11 +473,11 @@ async def aget_document_hash(self, doc_id: str) -> Optional[str]: else: return None - async def aget_all_document_hashes(self) -> Dict[str, str]: + async def aget_all_document_hashes(self) -> dict[str, str]: """Get the stored hash for all documents. Returns: - Dict[ + dict[ str, # doc_hash str # doc_id ] @@ -498,11 +498,11 @@ async def aget_all_document_hashes(self) -> Dict[str, str]: return hashes @property - def docs(self) -> Dict[str, BaseNode]: + def docs(self) -> dict[str, BaseNode]: """Get all documents. Returns: - Dict[str, BaseDocument]: documents + dict[str, BaseDocument]: documents """ raise NotImplementedError( @@ -547,7 +547,7 @@ def set_document_hash(self, doc_id: str, doc_hash: str) -> None: "Sync methods are not implemented for AsyncAlloyDBDocumentStore. Use AlloyDBDocumentStore interface instead." ) - def set_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + def set_document_hashes(self, doc_hashes: dict[str, str]) -> None: raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBDocumentStore. Use AlloyDBDocumentStore interface instead." ) @@ -557,12 +557,12 @@ def get_document_hash(self, doc_id: str) -> Optional[str]: "Sync methods are not implemented for AsyncAlloyDBDocumentStore. Use AlloyDBDocumentStore interface instead." ) - def get_all_document_hashes(self) -> Dict[str, str]: + def get_all_document_hashes(self) -> dict[str, str]: raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBDocumentStore. Use AlloyDBDocumentStore interface instead." ) - def get_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + def get_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBDocumentStore. Use AlloyDBDocumentStore interface instead." ) diff --git a/src/llama_index_alloydb_pg/async_index_store.py b/src/llama_index_alloydb_pg/async_index_store.py index a93255e..09999f3 100644 --- a/src/llama_index_alloydb_pg/async_index_store.py +++ b/src/llama_index_alloydb_pg/async_index_store.py @@ -16,9 +16,8 @@ import json import warnings -from typing import List, Optional +from typing import Optional -from llama_index.core.constants import DATA_KEY from llama_index.core.data_structs.data_structs import IndexStruct from llama_index.core.storage.index_store.types import BaseIndexStore from llama_index.core.storage.index_store.utils import ( @@ -113,11 +112,11 @@ async def __afetch_query(self, query): await conn.commit() return results - async def aindex_structs(self) -> List[IndexStruct]: + async def aindex_structs(self) -> list[IndexStruct]: """Get all index structs. Returns: - List[IndexStruct]: index structs + list[IndexStruct]: index structs """ query = f"""SELECT * from "{self._schema_name}"."{self._table_name}";""" @@ -190,7 +189,7 @@ async def aget_index_struct( return json_to_index_struct(index_data) return None - def index_structs(self) -> List[IndexStruct]: + def index_structs(self) -> list[IndexStruct]: raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBIndexStore . Use AlloyDBIndexStore interface instead." ) diff --git a/src/llama_index_alloydb_pg/async_reader.py b/src/llama_index_alloydb_pg/async_reader.py new file mode 100644 index 0000000..b233ef2 --- /dev/null +++ b/src/llama_index_alloydb_pg/async_reader.py @@ -0,0 +1,270 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import Any, AsyncIterable, Callable, Iterable, Iterator, List, Optional + +from llama_index.core.bridge.pydantic import ConfigDict +from llama_index.core.readers.base import BasePydanticReader +from llama_index.core.schema import Document +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import AlloyDBEngine + +DEFAULT_METADATA_COL = "li_metadata" + + +def text_formatter(row: dict, content_columns: list[str]) -> str: + """txt document formatter.""" + return " ".join(str(row[column]) for column in content_columns if column in row) + + +def csv_formatter(row: dict, content_columns: list[str]) -> str: + """CSV document formatter.""" + return ", ".join(str(row[column]) for column in content_columns if column in row) + + +def yaml_formatter(row: dict, content_columns: list[str]) -> str: + """YAML document formatter.""" + return "\n".join( + f"{column}: {str(row[column])}" for column in content_columns if column in row + ) + + +def json_formatter(row: dict, content_columns: list[str]) -> str: + """JSON document formatter.""" + dictionary = {} + for column in content_columns: + if column in row: + dictionary[column] = row[column] + return json.dumps(dictionary) + + +def _parse_doc_from_row( + content_columns: Iterable[str], + metadata_columns: Iterable[str], + row: dict, + formatter: Callable = text_formatter, + metadata_json_column: Optional[str] = DEFAULT_METADATA_COL, +) -> Document: + """Parse row into document.""" + text = formatter(row, content_columns) + metadata: dict[str, Any] = {} + # unnest metadata from li_metadata column + if metadata_json_column and row.get(metadata_json_column): + for k, v in row[metadata_json_column].items(): + metadata[k] = v + # load metadata from other columns + for column in metadata_columns: + if column in row and column != metadata_json_column: + metadata[column] = row[column] + + return Document(text=text, extra_info=metadata) + + +class AsyncAlloyDBReader(BasePydanticReader): + """Load documents from AlloyDB. + + Each document represents one row of the result. The `content_columns` are + written into the `text` of the document. The `metadata_columns` are written + into the `metadata` of the document. By default, first columns is written into + the `text` and everything else into the `metadata`. + """ + + __create_key = object() + is_remote: bool = True + + def __init__( + self, + key: object, + pool: AsyncEngine, + query: str, + content_columns: list[str], + metadata_columns: list[str], + formatter: Callable, + metadata_json_column: Optional[str] = None, + is_remote: bool = True, + ) -> None: + """AsyncAlloyDBReader constructor. + + Args: + key (object): Prevent direct constructor usage. + engine (AlloyDBEngine): AsyncEngine with pool connection to the alloydb database + query (Optional[str], optional): SQL query. Defaults to None. + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". + is_remote (bool): Whether the data is loaded from a remote API or a local file. + + Raises: + Exception: If called directly by user. + """ + if key != AsyncAlloyDBReader.__create_key: + raise Exception("Only create class through 'create' method!") + + super().__init__(is_remote=is_remote) + + self._pool = pool + self._query = query + self._content_columns = content_columns + self._metadata_columns = metadata_columns + self._formatter = formatter + self._metadata_json_column = metadata_json_column + + @classmethod + async def create( + cls: type[AsyncAlloyDBReader], + engine: AlloyDBEngine, + query: Optional[str] = None, + table_name: Optional[str] = None, + schema_name: str = "public", + content_columns: Optional[list[str]] = None, + metadata_columns: Optional[list[str]] = None, + metadata_json_column: Optional[str] = None, + format: Optional[str] = None, + formatter: Optional[Callable] = None, + is_remote: bool = True, + ) -> AsyncAlloyDBReader: + """Create an AsyncAlloyDBReader instance. + + Args: + engine (AlloyDBEngine):AsyncEngine with pool connection to the alloydb database + query (Optional[str], optional): SQL query. Defaults to None. + table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Name of the schema where table is located. Defaults to "public". + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". + format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + is_remote (bool): Whether the data is loaded from a remote API or a local file. + + + Returns: + AsyncAlloyDBReader: A newly created instance of AsyncAlloyDBReader. + """ + if table_name and query: + raise ValueError("Only one of 'table_name' or 'query' should be specified.") + if not table_name and not query: + raise ValueError( + "At least one of the parameters 'table_name' or 'query' needs to be provided" + ) + if format and formatter: + raise ValueError("Only one of 'format' or 'formatter' should be specified.") + + if format and format not in ["csv", "text", "JSON", "YAML"]: + raise ValueError("format must be type: 'csv', 'text', 'JSON', 'YAML'") + if formatter: + formatter = formatter + elif format == "csv": + formatter = csv_formatter + elif format == "YAML": + formatter = yaml_formatter + elif format == "JSON": + formatter = json_formatter + else: + formatter = text_formatter + + if not query: + query = f'SELECT * FROM "{schema_name}"."{table_name}"' + + async with engine._pool.connect() as connection: + result_proxy = await connection.execute(text(query)) + column_names = list(result_proxy.keys()) + # Select content or default to first column + content_columns = content_columns or [column_names[0]] + # Select metadata columns + metadata_columns = metadata_columns or [ + col for col in column_names if col not in content_columns + ] + + # Check validity of metadata json column + if metadata_json_column and metadata_json_column not in column_names: + raise ValueError( + f"Column {metadata_json_column} not found in query result {column_names}." + ) + + if metadata_json_column and metadata_json_column in column_names: + metadata_json_column = metadata_json_column + elif DEFAULT_METADATA_COL in column_names: + metadata_json_column = DEFAULT_METADATA_COL + else: + metadata_json_column = None + + # check validity of other column + all_names = content_columns + metadata_columns + for name in all_names: + if name not in column_names: + raise ValueError( + f"Column {name} not found in query result {column_names}." + ) + return cls( + key=cls.__create_key, + pool=engine._pool, + query=query, + content_columns=content_columns, + metadata_columns=metadata_columns, + formatter=formatter, + metadata_json_column=metadata_json_column, + is_remote=is_remote, + ) + + @classmethod + def class_name(cls) -> str: + return "AsyncAlloyDBReader" + + async def aload_data(self) -> list[Document]: + """Asynchronously load AlloyDB data into Document objects.""" + return [doc async for doc in self.alazy_load_data()] + + async def alazy_load_data(self) -> AsyncIterable[Document]: # type: ignore + """Asynchronously load AlloyDB data into Document objects lazily.""" + async with self._pool.connect() as connection: + result_proxy = await connection.execute(text(self._query)) + # load document one by one + while True: + row = result_proxy.fetchone() + if not row: + break + + row_data = {} + column_names = self._content_columns + self._metadata_columns + column_names += ( + [self._metadata_json_column] if self._metadata_json_column else [] + ) + for column in column_names: + value = getattr(row, column) + row_data[column] = value + + yield _parse_doc_from_row( + self._content_columns, + self._metadata_columns, + row_data, + self._formatter, + self._metadata_json_column, + ) + + def lazy_load_data(self) -> Iterable[Document]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBReader. Use AlloyDBReader interface instead." + ) + + def load_data(self) -> List[Document]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBReader. Use AlloyDBReader interface instead." + ) diff --git a/src/llama_index_alloydb_pg/async_vector_store.py b/src/llama_index_alloydb_pg/async_vector_store.py index 8be4854..7d2c16d 100644 --- a/src/llama_index_alloydb_pg/async_vector_store.py +++ b/src/llama_index_alloydb_pg/async_vector_store.py @@ -15,14 +15,10 @@ # TODO: Remove below import when minimum supported Python version is 3.10 from __future__ import annotations -import base64 import json -import re -import uuid import warnings -from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type +from typing import Any, Optional, Sequence -import numpy as np from llama_index.core.schema import BaseNode, MetadataMode, NodeRelationship, TextNode from llama_index.core.vector_stores.types import ( BasePydanticVectorStore, @@ -31,7 +27,6 @@ MetadataFilter, MetadataFilters, VectorStoreQuery, - VectorStoreQueryMode, VectorStoreQueryResult, ) from llama_index.core.vector_stores.utils import ( @@ -71,7 +66,7 @@ def __init__( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -89,7 +84,7 @@ def __init__( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -121,7 +116,7 @@ def __init__( @classmethod async def create( - cls: Type[AsyncAlloyDBVectorStore], + cls: type[AsyncAlloyDBVectorStore], engine: AlloyDBEngine, table_name: str, schema_name: str = "public", @@ -129,7 +124,7 @@ async def create( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -147,7 +142,7 @@ async def create( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -234,7 +229,7 @@ def client(self) -> Any: """Get client.""" return self._engine - async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> List[str]: + async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> list[str]: """Asynchronously add nodes to the table.""" ids = [] metadata_col_names = ( @@ -293,14 +288,14 @@ async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: async def adelete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: """Asynchronously delete a set of nodes from the table matching the provided nodes and filters.""" if not node_ids and not filters: return - all_filters: List[MetadataFilter | MetadataFilters] = [] + all_filters: list[MetadataFilter | MetadataFilters] = [] if node_ids: all_filters.append( MetadataFilter( @@ -332,9 +327,9 @@ async def aclear(self) -> None: async def aget_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: """Asynchronously get nodes from the table matching the provided nodes and filters.""" query = VectorStoreQuery( node_ids=node_ids, filters=filters, similarity_top_k=-1 @@ -366,7 +361,7 @@ async def aquery( similarities.append(row["distance"]) return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids) - def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> List[str]: + def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]: raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." ) @@ -378,7 +373,7 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: def delete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: @@ -393,9 +388,9 @@ def clear(self) -> None: def get_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." ) @@ -495,7 +490,7 @@ async def __query_columns( **kwargs: Any, ) -> Sequence[RowMapping]: """Perform search query on database.""" - filters: List[MetadataFilter | MetadataFilters] = [] + filters: list[MetadataFilter | MetadataFilters] = [] if query.doc_ids: filters.append( MetadataFilter( @@ -545,13 +540,25 @@ async def __query_columns( f" LIMIT {query.similarity_top_k} " if query.similarity_top_k >= 1 else "" ) - query_stmt = f'SELECT * {scoring_stmt} FROM "{self._schema_name}"."{self._table_name}" {filters_stmt} {order_stmt} {limit_stmt}' + columns = self._metadata_columns + [ + self._id_column, + self._text_column, + self._embedding_column, + self._ref_doc_id_column, + self._node_column, + ] + if self._metadata_json_column: + columns.append(self._metadata_json_column) + + column_names = ", ".join(f'"{col}"' for col in columns) + + query_stmt = f'SELECT {column_names} {scoring_stmt} FROM "{self._schema_name}"."{self._table_name}" {filters_stmt} {order_stmt} {limit_stmt}' async with self._engine.connect() as conn: if self._index_query_options: - query_options_stmt = ( - f"SET LOCAL {self._index_query_options.to_string()};" - ) - await conn.execute(text(query_options_stmt)) + # Set each query option individually + for query_option in self._index_query_options.to_parameter(): + query_options_stmt = f"SET LOCAL {query_option};" + await conn.execute(text(query_options_stmt)) result = await conn.execute(text(query_stmt)) result_map = result.mappings() results = result_map.fetchall() diff --git a/src/llama_index_alloydb_pg/chat_store.py b/src/llama_index_alloydb_pg/chat_store.py new file mode 100644 index 0000000..1615ca4 --- /dev/null +++ b/src/llama_index_alloydb_pg/chat_store.py @@ -0,0 +1,289 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import List, Optional + +from llama_index.core.llms import ChatMessage +from llama_index.core.storage.chat_store.base import BaseChatStore + +from .async_chat_store import AsyncAlloyDBChatStore +from .engine import AlloyDBEngine + + +class AlloyDBChatStore(BaseChatStore): + """Chat Store Table stored in an AlloyDB for PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, key: object, engine: AlloyDBEngine, chat_store: AsyncAlloyDBChatStore + ): + """AlloyDBChatStore constructor. + + Args: + key (object): Key to prevent direct constructor usage. + engine (AlloyDBEngine): Database connection pool. + chat_store (AsyncAlloyDBChatStore): The async only ChatStore implementation + + Raises: + Exception: If constructor is directly called by the user. + """ + if key != AlloyDBChatStore.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + + # Delegate to Pydantic's __init__ + super().__init__() + self._engine = engine + self.__chat_store = chat_store + + @classmethod + async def create( + cls, + engine: AlloyDBEngine, + table_name: str, + schema_name: str = "public", + ) -> AlloyDBChatStore: + """Create a new AlloyDBChatStore instance. + + Args: + engine (AlloyDBEngine): AlloyDB engine to use. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. Defaults to "public" + + Raises: + ValueError: If the table provided does not contain required schema. + + Returns: + AlloyDBChatStore: A newly created instance of AlloyDBChatStore. + """ + coro = AsyncAlloyDBChatStore.create(engine, table_name, schema_name) + chat_store = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, chat_store) + + @classmethod + def create_sync( + cls, + engine: AlloyDBEngine, + table_name: str, + schema_name: str = "public", + ) -> AlloyDBChatStore: + """Create a new AlloyDBChatStore sync instance. + + Args: + engine (AlloyDBEngine): AlloyDB engine to use. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. Defaults to "public" + + Raises: + ValueError: If the table provided does not contain required schema. + + Returns: + AlloyDBChatStore: A newly created instance of AlloyDBChatStore. + """ + coro = AsyncAlloyDBChatStore.create(engine, table_name, schema_name) + chat_store = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, chat_store) + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "AlloyDBChatStore" + + async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None: + """Asynchronously sets the chat messages for a specific key. + + Args: + key (str): A unique identifier for the chat. + messages (List[ChatMessage]): A list of `ChatMessage` objects to upsert. + + Returns: + None + + """ + return await self._engine._run_as_async( + self.__chat_store.aset_messages(key=key, messages=messages) + ) + + async def aget_messages(self, key: str) -> List[ChatMessage]: + """Asynchronously retrieves the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for which the messages are to be retrieved. + + Returns: + List[ChatMessage]: A list of `ChatMessage` objects associated with the provided key. + If no messages are found, an empty list is returned. + """ + return await self._engine._run_as_async( + self.__chat_store.aget_messages(key=key) + ) + + async def async_add_message(self, key: str, message: ChatMessage) -> None: + """Asynchronously adds a new chat message to the specified key. + + Args: + key (str): A unique identifierfor the chat to which the message is added. + message (ChatMessage): The `ChatMessage` object that is to be added. + + Returns: + None + """ + return await self._engine._run_as_async( + self.__chat_store.async_add_message(key=key, message=message) + ) + + async def adelete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """Asynchronously deletes the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + + Returns: + Optional[List[ChatMessage]]: A list of `ChatMessage` objects that were deleted, or `None` if no messages + were associated with the key or could be deleted. + """ + return await self._engine._run_as_async( + self.__chat_store.adelete_messages(key=key) + ) + + async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """Asynchronously deletes a specific chat message by index from the messages associated with a given key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + idx (int): The index of the `ChatMessage` to be deleted from the list of messages. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return await self._engine._run_as_async( + self.__chat_store.adelete_message(key=key, idx=idx) + ) + + async def adelete_last_message(self, key: str) -> Optional[ChatMessage]: + """Asynchronously deletes the last chat message associated with a given key. + + Args: + key (str): A unique identifier for the chat whose message is to be deleted. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return await self._engine._run_as_async( + self.__chat_store.adelete_last_message(key=key) + ) + + async def aget_keys(self) -> List[str]: + """Asynchronously retrieves a list of all keys. + + Returns: + Optional[str]: A list of strings representing the keys. If no keys are found, an empty list is returned. + """ + return await self._engine._run_as_async(self.__chat_store.aget_keys()) + + def set_messages(self, key: str, messages: List[ChatMessage]) -> None: + """Synchronously sets the chat messages for a specific key. + + Args: + key (str): A unique identifier for the chat. + messages (List[ChatMessage]): A list of `ChatMessage` objects to upsert. + + Returns: + None + + """ + return self._engine._run_as_sync( + self.__chat_store.aset_messages(key=key, messages=messages) + ) + + def get_messages(self, key: str) -> List[ChatMessage]: + """Synchronously retrieves the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for which the messages are to be retrieved. + + Returns: + List[ChatMessage]: A list of `ChatMessage` objects associated with the provided key. + If no messages are found, an empty list is returned. + """ + return self._engine._run_as_sync(self.__chat_store.aget_messages(key=key)) + + def add_message(self, key: str, message: ChatMessage) -> None: + """Synchronously adds a new chat message to the specified key. + + Args: + key (str): A unique identifierfor the chat to which the message is added. + message (ChatMessage): The `ChatMessage` object that is to be added. + + Returns: + None + """ + return self._engine._run_as_sync( + self.__chat_store.async_add_message(key=key, message=message) + ) + + def delete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """Synchronously deletes the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + + Returns: + Optional[List[ChatMessage]]: A list of `ChatMessage` objects that were deleted, or `None` if no messages + were associated with the key or could be deleted. + """ + return self._engine._run_as_sync(self.__chat_store.adelete_messages(key=key)) + + def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """Synchronously deletes a specific chat message by index from the messages associated with a given key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + idx (int): The index of the `ChatMessage` to be deleted from the list of messages. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return self._engine._run_as_sync( + self.__chat_store.adelete_message(key=key, idx=idx) + ) + + def delete_last_message(self, key: str) -> Optional[ChatMessage]: + """Synchronously deletes the last chat message associated with a given key. + + Args: + key (str): A unique identifier for the chat whose message is to be deleted. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return self._engine._run_as_sync( + self.__chat_store.adelete_last_message(key=key) + ) + + def get_keys(self) -> List[str]: + """Synchronously retrieves a list of all keys. + + Returns: + Optional[str]: A list of strings representing the keys. If no keys are found, an empty list is returned. + """ + return self._engine._run_as_sync(self.__chat_store.aget_keys()) diff --git a/src/llama_index_alloydb_pg/document_store.py b/src/llama_index_alloydb_pg/document_store.py index cc86297..033c052 100644 --- a/src/llama_index_alloydb_pg/document_store.py +++ b/src/llama_index_alloydb_pg/document_store.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Dict, List, Optional, Sequence, Type +from typing import Optional, Sequence from llama_index.core.schema import BaseNode from llama_index.core.storage.docstore import BaseDocumentStore @@ -55,7 +55,7 @@ def __init__( @classmethod async def create( - cls: Type[AlloyDBDocumentStore], + cls: type[AlloyDBDocumentStore], engine: AlloyDBEngine, table_name: str, schema_name: str = "public", @@ -83,7 +83,7 @@ async def create( @classmethod def create_sync( - cls: Type[AlloyDBDocumentStore], + cls: type[AlloyDBDocumentStore], engine: AlloyDBEngine, table_name: str, schema_name: str = "public", @@ -110,11 +110,11 @@ def create_sync( return cls(cls.__create_key, engine, document_store) @property - def docs(self) -> Dict[str, BaseNode]: + def docs(self) -> dict[str, BaseNode]: """Get all documents. Returns: - Dict[str, BaseDocument]: documents + dict[str, BaseDocument]: documents """ return self._engine._run_as_sync(self.__document_store.adocs) @@ -291,11 +291,11 @@ def set_document_hash(self, doc_id: str, doc_hash: str) -> None: self.__document_store.aset_document_hash(doc_id, doc_hash) ) - async def aset_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + async def aset_document_hashes(self, doc_hashes: dict[str, str]) -> None: """Set the hash for a given doc_id. Args: - doc_hashes (Dict[str, str]): Dictionary with doc_id as key and doc_hash as value. + doc_hashes (dict[str, str]): Dictionary with doc_id as key and doc_hash as value. Returns: None @@ -304,11 +304,11 @@ async def aset_document_hashes(self, doc_hashes: Dict[str, str]) -> None: self.__document_store.aset_document_hashes(doc_hashes) ) - def set_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + def set_document_hashes(self, doc_hashes: dict[str, str]) -> None: """Set the hash for a given doc_id. Args: - doc_hashes (Dict[str, str]): Dictionary with doc_id as key and doc_hash as value. + doc_hashes (dict[str, str]): Dictionary with doc_id as key and doc_hash as value. Returns: None @@ -343,11 +343,11 @@ def get_document_hash(self, doc_id: str) -> Optional[str]: self.__document_store.aget_document_hash(doc_id) ) - async def aget_all_document_hashes(self) -> Dict[str, str]: + async def aget_all_document_hashes(self) -> dict[str, str]: """Get the stored hash for all documents. Returns: - Dict[ + dict[ str, # doc_hash str # doc_id ] @@ -356,11 +356,11 @@ async def aget_all_document_hashes(self) -> Dict[str, str]: self.__document_store.aget_all_document_hashes() ) - def get_all_document_hashes(self) -> Dict[str, str]: + def get_all_document_hashes(self) -> dict[str, str]: """Get the stored hash for all documents. Returns: - Dict[ + dict[ str, # doc_hash str # doc_id ] @@ -369,12 +369,12 @@ def get_all_document_hashes(self) -> Dict[str, str]: self.__document_store.aget_all_document_hashes() ) - async def aget_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + async def aget_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: """Get a mapping of ref_doc_id -> RefDocInfo for all ingested documents. Returns: Optional[ - Dict[ + dict[ str, #Ref_doc_id RefDocInfo, #Ref_doc_info of the id ] @@ -384,12 +384,12 @@ async def aget_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: self.__document_store.aget_all_ref_doc_info() ) - def get_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + def get_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: """Get a mapping of ref_doc_id -> RefDocInfo for all ingested documents. Returns: Optional[ - Dict[ + dict[ str, #Ref_doc_id RefDocInfo, #Ref_doc_info of the id ] diff --git a/src/llama_index_alloydb_pg/engine.py b/src/llama_index_alloydb_pg/engine.py index f9984d0..f004335 100644 --- a/src/llama_index_alloydb_pg/engine.py +++ b/src/llama_index_alloydb_pg/engine.py @@ -17,17 +17,7 @@ from concurrent.futures import Future from dataclasses import dataclass from threading import Thread -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Dict, - List, - Optional, - Type, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, Awaitable, Optional, TypeVar, Union import aiohttp import google.auth # type: ignore @@ -76,7 +66,7 @@ async def _get_iam_principal_email( url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}" async with aiohttp.ClientSession() as client: response = await client.get(url, raise_for_status=True) - response_json: Dict = await response.json() + response_json: dict = await response.json() email = response_json.get("email") if email is None: raise ValueError( @@ -179,7 +169,7 @@ def __start_background_loop( @classmethod def from_instance( - cls: Type[AlloyDBEngine], + cls: type[AlloyDBEngine], project_id: str, region: str, cluster: str, @@ -221,7 +211,7 @@ def from_instance( @classmethod async def _create( - cls: Type[AlloyDBEngine], + cls: type[AlloyDBEngine], project_id: str, region: str, cluster: str, @@ -305,7 +295,7 @@ async def getconn() -> asyncpg.Connection: @classmethod async def afrom_instance( - cls: Type[AlloyDBEngine], + cls: type[AlloyDBEngine], project_id: str, region: str, cluster: str, @@ -347,7 +337,7 @@ async def afrom_instance( @classmethod def from_engine( - cls: Type[AlloyDBEngine], + cls: type[AlloyDBEngine], engine: AsyncEngine, loop: Optional[asyncio.AbstractEventLoop] = None, ) -> AlloyDBEngine: @@ -512,7 +502,7 @@ async def _ainit_vector_store_table( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[Column] = [], + metadata_columns: list[Column] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -528,7 +518,7 @@ async def _ainit_vector_store_table( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -585,7 +575,7 @@ async def ainit_vector_store_table( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[Column] = [], + metadata_columns: list[Column] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -601,7 +591,7 @@ async def ainit_vector_store_table( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -636,7 +626,7 @@ def init_vector_store_table( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[Column] = [], + metadata_columns: list[Column] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -652,7 +642,7 @@ def init_vector_store_table( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -767,6 +757,97 @@ def init_index_store_table( ) ) + async def _ainit_chat_store_table( + self, + table_name: str, + schema_name: str = "public", + overwrite_existing: bool = False, + ) -> None: + """ + Create an AlloyDB table to save chat store. + + Args: + table_name (str): The table name to store chat history. + schema_name (str): The schema name to store the chat store table. + Default: "public". + overwrite_existing (bool): Whether to drop existing table. + Default: False. + + Returns: + None + """ + if overwrite_existing: + async with self._pool.connect() as conn: + await conn.execute( + text(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"') + ) + await conn.commit() + + create_table_query = f"""CREATE TABLE "{schema_name}"."{table_name}"( + id SERIAL PRIMARY KEY, + key VARCHAR NOT NULL, + message JSON NOT NULL + );""" + create_index_query = f"""CREATE INDEX "{table_name}_idx_key" ON "{schema_name}"."{table_name}" (key);""" + async with self._pool.connect() as conn: + await conn.execute(text(create_table_query)) + await conn.execute(text(create_index_query)) + await conn.commit() + + async def ainit_chat_store_table( + self, + table_name: str, + schema_name: str = "public", + overwrite_existing: bool = False, + ) -> None: + """ + Create an AlloyDB table to save chat store. + + Args: + table_name (str): The table name to store chat store. + schema_name (str): The schema name to store the chat store table. + Default: "public". + overwrite_existing (bool): Whether to drop existing table. + Default: False. + + Returns: + None + """ + await self._run_as_async( + self._ainit_chat_store_table( + table_name, + schema_name, + overwrite_existing, + ) + ) + + def init_chat_store_table( + self, + table_name: str, + schema_name: str = "public", + overwrite_existing: bool = False, + ) -> None: + """ + Create an AlloyDB table to save chat store. + + Args: + table_name (str): The table name to store chat store. + schema_name (str): The schema name to store the chat store table. + Default: "public". + overwrite_existing (bool): Whether to drop existing table. + Default: False. + + Returns: + None + """ + self._run_as_sync( + self._ainit_chat_store_table( + table_name, + schema_name, + overwrite_existing, + ) + ) + async def _aload_table_schema( self, table_name: str, schema_name: str = "public" ) -> Table: diff --git a/src/llama_index_alloydb_pg/index_store.py b/src/llama_index_alloydb_pg/index_store.py index f99e054..a90f0cc 100644 --- a/src/llama_index_alloydb_pg/index_store.py +++ b/src/llama_index_alloydb_pg/index_store.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import List, Optional +from typing import Optional from llama_index.core.data_structs.data_structs import IndexStruct from llama_index.core.storage.index_store.types import BaseIndexStore @@ -96,20 +96,20 @@ def create_sync( index_store = engine._run_as_sync(coro) return cls(cls.__create_key, engine, index_store) - async def aindex_structs(self) -> List[IndexStruct]: + async def aindex_structs(self) -> list[IndexStruct]: """Get all index structs. Returns: - List[IndexStruct]: index structs + list[IndexStruct]: index structs """ return await self._engine._run_as_async(self.__index_store.aindex_structs()) - def index_structs(self) -> List[IndexStruct]: + def index_structs(self) -> list[IndexStruct]: """Get all index structs. Returns: - List[IndexStruct]: index structs + list[IndexStruct]: index structs """ return self._engine._run_as_sync(self.__index_store.aindex_structs()) diff --git a/src/llama_index_alloydb_pg/indexes.py b/src/llama_index_alloydb_pg/indexes.py index 6c69ffa..20bdfec 100644 --- a/src/llama_index_alloydb_pg/indexes.py +++ b/src/llama_index_alloydb_pg/indexes.py @@ -13,9 +13,10 @@ # limitations under the License. import enum +import warnings from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import List, Optional +from typing import Optional @dataclass @@ -45,7 +46,7 @@ class BaseIndex(ABC): distance_strategy: DistanceStrategy = field( default_factory=lambda: DistanceStrategy.COSINE_DISTANCE ) - partial_indexes: Optional[List[str]] = None + partial_indexes: Optional[list[str]] = None @abstractmethod def index_options(self) -> str: @@ -62,6 +63,11 @@ class ExactNearestNeighbor(BaseIndex): @dataclass class QueryOptions(ABC): + @abstractmethod + def to_parameter(self) -> list[str]: + """Convert index attributes to list of configurations.""" + raise NotImplementedError("to_parameter method must be implemented by subclass") + @abstractmethod def to_string(self) -> str: """Convert index attributes to string.""" @@ -83,8 +89,16 @@ def index_options(self) -> str: class HNSWQueryOptions(QueryOptions): ef_search: int = 40 + def to_parameter(self) -> list[str]: + """Convert index attributes to list of configurations.""" + return [f"hnsw.ef_search = {self.ef_search}"] + def to_string(self) -> str: """Convert index attributes to string.""" + warnings.warn( + "to_string is deprecated, use to_parameter instead.", + DeprecationWarning, + ) return f"hnsw.ef_search = {self.ef_search}" @@ -102,8 +116,16 @@ def index_options(self) -> str: class IVFFlatQueryOptions(QueryOptions): probes: int = 1 + def to_parameter(self) -> list[str]: + """Convert index attributes to list of configurations.""" + return [f"ivfflat.probes = {self.probes}"] + def to_string(self) -> str: """Convert index attributes to string.""" + warnings.warn( + "to_string is deprecated, use to_parameter instead.", + DeprecationWarning, + ) return f"ivfflat.probes = {self.probes}" @@ -124,8 +146,16 @@ def index_options(self) -> str: class IVFQueryOptions(QueryOptions): probes: int = 1 + def to_parameter(self) -> list[str]: + """Convert index attributes to list of configurations.""" + return [f"ivf.probes = {self.probes}"] + def to_string(self) -> str: """Convert index attributes to string.""" + warnings.warn( + "to_string is deprecated, use to_parameter instead.", + DeprecationWarning, + ) return f"ivf.probes = {self.probes}" @@ -147,6 +177,17 @@ class ScaNNQueryOptions(QueryOptions): num_leaves_to_search: int = 1 pre_reordering_num_neighbors: int = -1 + def to_parameter(self) -> list[str]: + """Convert index attributes to list of configurations.""" + return [ + f"scann.num_leaves_to_search = {self.num_leaves_to_search}", + f"scann.pre_reordering_num_neighbors = {self.pre_reordering_num_neighbors}", + ] + def to_string(self) -> str: """Convert index attributes to string.""" + warnings.warn( + "to_string is deprecated, use to_parameter instead.", + DeprecationWarning, + ) return f"scann.num_leaves_to_search = {self.num_leaves_to_search}, scann.pre_reordering_num_neighbors = {self.pre_reordering_num_neighbors}" diff --git a/src/llama_index_alloydb_pg/reader.py b/src/llama_index_alloydb_pg/reader.py new file mode 100644 index 0000000..aac1582 --- /dev/null +++ b/src/llama_index_alloydb_pg/reader.py @@ -0,0 +1,187 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import AsyncIterable, Callable, Iterable, List, Optional + +from llama_index.core.bridge.pydantic import ConfigDict +from llama_index.core.readers.base import BasePydanticReader +from llama_index.core.schema import Document + +from .async_reader import AsyncAlloyDBReader +from .engine import AlloyDBEngine + +DEFAULT_METADATA_COL = "li_metadata" + + +class AlloyDBReader(BasePydanticReader): + """Chat Store Table stored in an AlloyDB for PostgreSQL database.""" + + __create_key = object() + is_remote: bool = True + + def __init__( + self, + key: object, + engine: AlloyDBEngine, + reader: AsyncAlloyDBReader, + is_remote: bool = True, + ) -> None: + """AlloyDBReader constructor. + + Args: + key (object): Prevent direct constructor usage. + engine (AlloyDBEngine): AlloyDB with pool connection to the alloydb database + reader (AsyncAlloyDBReader): The async only AlloyDBReader implementation + is_remote (Optional[bool]): Whether the data is loaded from a remote API or a local file. + + Raises: + Exception: If called directly by user. + """ + if key != AlloyDBReader.__create_key: + raise Exception("Only create class through 'create' method!") + + super().__init__(is_remote=is_remote) + + self._engine = engine + self.__reader = reader + + @classmethod + async def create( + cls: type[AlloyDBReader], + engine: AlloyDBEngine, + query: Optional[str] = None, + table_name: Optional[str] = None, + schema_name: str = "public", + content_columns: Optional[list[str]] = None, + metadata_columns: Optional[list[str]] = None, + metadata_json_column: Optional[str] = None, + format: Optional[str] = None, + formatter: Optional[Callable] = None, + is_remote: bool = True, + ) -> AlloyDBReader: + """Asynchronously create an AlloyDBReader instance. + + Args: + engine (AlloyDBEngine): AlloyDBEngine with pool connection to the alloydb database + query (Optional[str], optional): SQL query. Defaults to None. + table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Name of the schema where table is located. Defaults to "public". + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". + format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + is_remote (Optional[bool]): Whether the data is loaded from a remote API or a local file. + + + Returns: + AlloyDBReader: A newly created instance of AlloyDBReader. + """ + coro = AsyncAlloyDBReader.create( + engine=engine, + query=query, + table_name=table_name, + schema_name=schema_name, + content_columns=content_columns, + metadata_columns=metadata_columns, + metadata_json_column=metadata_json_column, + format=format, + formatter=formatter, + is_remote=is_remote, + ) + reader = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, reader, is_remote) + + @classmethod + def create_sync( + cls: type[AlloyDBReader], + engine: AlloyDBEngine, + query: Optional[str] = None, + table_name: Optional[str] = None, + schema_name: str = "public", + content_columns: Optional[list[str]] = None, + metadata_columns: Optional[list[str]] = None, + metadata_json_column: Optional[str] = None, + format: Optional[str] = None, + formatter: Optional[Callable] = None, + is_remote: bool = True, + ) -> AlloyDBReader: + """Synchronously create an AlloyDBReader instance. + + Args: + engine (AlloyDBEngine):AsyncEngine with pool connection to the alloydb database + query (Optional[str], optional): SQL query. Defaults to None. + table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Name of the schema where table is located. Defaults to "public". + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". + format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + is_remote (Optional[bool]): Whether the data is loaded from a remote API or a local file. + + + Returns: + AlloyDBReader: A newly created instance of AlloyDBReader. + """ + coro = AsyncAlloyDBReader.create( + engine=engine, + query=query, + table_name=table_name, + schema_name=schema_name, + content_columns=content_columns, + metadata_columns=metadata_columns, + metadata_json_column=metadata_json_column, + format=format, + formatter=formatter, + is_remote=is_remote, + ) + reader = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, reader, is_remote) + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "AlloyDBReader" + + async def aload_data(self) -> list[Document]: + """Asynchronously load AlloyDB data into Document objects.""" + return await self._engine._run_as_async(self.__reader.aload_data()) + + def load_data(self) -> list[Document]: + """Synchronously load AlloyDB data into Document objects.""" + return self._engine._run_as_sync(self.__reader.aload_data()) + + async def alazy_load_data(self) -> AsyncIterable[Document]: # type: ignore + """Asynchronously load AlloyDB data into Document objects lazily.""" + # The return type in the underlying base class is an Iterable which we are overriding to an AsyncIterable in this implementation. + iterator = self.__reader.alazy_load_data().__aiter__() + while True: + try: + result = await self._engine._run_as_async(iterator.__anext__()) + yield result + except StopAsyncIteration: + break + + def lazy_load_data(self) -> Iterable[Document]: # type: ignore + """Synchronously aoad AlloyDB data into Document objects lazily.""" + iterator = self.__reader.alazy_load_data().__aiter__() + while True: + try: + result = self._engine._run_as_sync(iterator.__anext__()) + yield result + except StopAsyncIteration: + break diff --git a/src/llama_index_alloydb_pg/vector_store.py b/src/llama_index_alloydb_pg/vector_store.py index 852b6b6..d4ef2f6 100644 --- a/src/llama_index_alloydb_pg/vector_store.py +++ b/src/llama_index_alloydb_pg/vector_store.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, List, Optional, Sequence, Type +from typing import Any, Optional, Sequence from llama_index.core.schema import BaseNode from llama_index.core.vector_stores.types import ( @@ -71,7 +71,7 @@ def __init__( @classmethod async def create( - cls: Type[AlloyDBVectorStore], + cls: type[AlloyDBVectorStore], engine: AlloyDBEngine, table_name: str, schema_name: str = "public", @@ -79,7 +79,7 @@ async def create( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -97,7 +97,7 @@ async def create( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -138,7 +138,7 @@ async def create( @classmethod def create_sync( - cls: Type[AlloyDBVectorStore], + cls: type[AlloyDBVectorStore], engine: AlloyDBEngine, table_name: str, schema_name: str = "public", @@ -146,7 +146,7 @@ def create_sync( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -164,7 +164,7 @@ def create_sync( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -212,11 +212,11 @@ def client(self) -> Any: """Get client.""" return self._engine - async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> List[str]: + async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> list[str]: """Asynchronously add nodes to the table.""" return await self._engine._run_as_async(self.__vs.async_add(nodes, **kwargs)) - def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> List[str]: + def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]: """Synchronously add nodes to the table.""" return self._engine._run_as_sync(self.__vs.async_add(nodes, **add_kwargs)) @@ -230,7 +230,7 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: async def adelete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: @@ -241,7 +241,7 @@ async def adelete_nodes( def delete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: @@ -260,17 +260,17 @@ def clear(self) -> None: async def aget_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: """Asynchronously get nodes from the table matching the provided nodes and filters.""" return await self._engine._run_as_async(self.__vs.aget_nodes(node_ids, filters)) def get_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: """Asynchronously get nodes from the table matching the provided nodes and filters.""" return self._engine._run_as_sync(self.__vs.aget_nodes(node_ids, filters)) diff --git a/src/llama_index_alloydb_pg/version.py b/src/llama_index_alloydb_pg/version.py index c1c8212..20c5861 100644 --- a/src/llama_index_alloydb_pg/version.py +++ b/src/llama_index_alloydb_pg/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.1.0" +__version__ = "0.2.0" diff --git a/tests/test_async_chat_store.py b/tests/test_async_chat_store.py new file mode 100644 index 0000000..f64f824 --- /dev/null +++ b/tests/test_async_chat_store.py @@ -0,0 +1,225 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.llms import ChatMessage +from sqlalchemy import RowMapping, text + +from llama_index_alloydb_pg import AlloyDBEngine +from llama_index_alloydb_pg.async_chat_store import AsyncAlloyDBChatStore + +default_table_name_async = "chat_store_" + str(uuid.uuid4()) + + +async def aexecute(engine: AlloyDBEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +async def afetch(engine: AlloyDBEngine, query: str) -> Sequence[RowMapping]: + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestAsyncAlloyDBChatStore: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on AlloyDB instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for AlloyDB") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for AlloyDB") + + @pytest_asyncio.fixture(scope="class") + async def async_engine( + self, + db_project, + db_region, + db_cluster, + db_instance, + db_name, + ): + async_engine = await AlloyDBEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + cluster=db_cluster, + region=db_region, + database=db_name, + ) + + yield async_engine + + await async_engine.close() + await async_engine._connector.close() + + @pytest_asyncio.fixture(scope="class") + async def chat_store(self, async_engine): + await async_engine._ainit_chat_store_table(table_name=default_table_name_async) + + chat_store = await AsyncAlloyDBChatStore.create( + engine=async_engine, table_name=default_table_name_async + ) + + yield chat_store + + query = f'DROP TABLE IF EXISTS "{default_table_name_async}"' + await aexecute(async_engine, query) + + async def test_init_with_constructor(self, async_engine): + with pytest.raises(Exception): + AsyncAlloyDBChatStore( + engine=async_engine, table_name=default_table_name_async + ) + + async def test_async_add_message(self, async_engine, chat_store): + key = "test_add_key" + + message = ChatMessage(content="add_message_test", role="user") + await chat_store.async_add_message(key, message=message) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + result = results[0] + assert result["message"] == message.model_dump() + + async def test_aset_and_aget_messages(self, chat_store): + message_1 = ChatMessage(content="First message", role="user") + message_2 = ChatMessage(content="Second message", role="user") + messages = [message_1, message_2] + key = "test_set_and_get_key" + await chat_store.aset_messages(key, messages) + + results = await chat_store.aget_messages(key) + + assert len(results) == 2 + assert results[0].content == message_1.content + assert results[1].content == message_2.content + + async def test_adelete_messages(self, async_engine, chat_store): + messages = [ChatMessage(content="Message to delete", role="user")] + key = "test_delete_key" + await chat_store.aset_messages(key, messages) + + await chat_store.adelete_messages(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 0 + + async def test_adelete_message(self, async_engine, chat_store): + message_1 = ChatMessage(content="Keep me", role="user") + message_2 = ChatMessage(content="Delete me", role="user") + messages = [message_1, message_2] + key = "test_delete_message_key" + await chat_store.aset_messages(key, messages) + + await chat_store.adelete_message(key, 1) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + assert results[0]["message"] == message_1.model_dump() + + async def test_adelete_last_message(self, async_engine, chat_store): + message_1 = ChatMessage(content="Message 1", role="user") + message_2 = ChatMessage(content="Message 2", role="user") + message_3 = ChatMessage(content="Message 3", role="user") + messages = [message_1, message_2, message_3] + key = "test_delete_last_message_key" + await chat_store.aset_messages(key, messages) + + await chat_store.adelete_last_message(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 2 + assert results[0]["message"] == message_1.model_dump() + assert results[1]["message"] == message_2.model_dump() + + async def test_aget_keys(self, async_engine, chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + message_2 = [ChatMessage(content="Second message", role="user")] + key_1 = "key1" + key_2 = "key2" + await chat_store.aset_messages(key_1, message_1) + await chat_store.aset_messages(key_2, message_2) + + keys = await chat_store.aget_keys() + + assert key_1 in keys + assert key_2 in keys + + async def test_set_exisiting_key(self, async_engine, chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + key = "test_set_exisiting_key" + await chat_store.aset_messages(key, message_1) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + result = results[0] + assert result["message"] == message_1[0].model_dump() + + message_2 = ChatMessage(content="Second message", role="user") + message_3 = ChatMessage(content="Third message", role="user") + messages = [message_2, message_3] + + await chat_store.aset_messages(key, messages) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + # Assert the previous messages are deleted and only the newest ones exist. + assert len(results) == 2 + + assert results[0]["message"] == message_2.model_dump() + assert results[1]["message"] == message_3.model_dump() diff --git a/tests/test_async_document_store.py b/tests/test_async_document_store.py index f045152..99ab5d1 100644 --- a/tests/test_async_document_store.py +++ b/tests/test_async_document_store.py @@ -28,6 +28,7 @@ default_table_name_async = "document_store_" + str(uuid.uuid4()) custom_table_name = "document_store_" + str(uuid.uuid4()) +sync_method_exception_str = "Sync methods are not implemented for AsyncAlloyDBDocumentStore. Use AlloyDBDocumentStore interface instead." async def aexecute(engine: AlloyDBEngine, query: str) -> None: @@ -96,6 +97,7 @@ async def async_engine( yield async_engine await async_engine.close() + await async_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def doc_store(self, async_engine): @@ -124,9 +126,16 @@ async def custom_doc_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): AsyncAlloyDBDocumentStore( - engine=async_engine, table_name=default_table_name_async + key, engine=async_engine, table_name=default_table_name_async + ) + + async def test_create_without_table(self, async_engine): + with pytest.raises(ValueError): + await AsyncAlloyDBDocumentStore.create( + engine=async_engine, table_name="non-existent-table" ) async def test_warning(self, custom_doc_store): @@ -187,7 +196,7 @@ async def test_add_hash_before_data(self, async_engine, doc_store): result = results[0] assert result["node_data"][DATA_KEY]["text"] == document_text - async def test_ref_doc_exists(self, doc_store): + async def test_aref_doc_exists(self, doc_store): # Create a ref_doc & a doc and add them to the store. ref_doc = Document( text="first doc", id_="doc_exists_doc_1", metadata={"doc": "info"} @@ -244,6 +253,8 @@ async def test_adelete_ref_doc(self, doc_store): assert ( await doc_store.aget_document(doc_id=doc.doc_id, raise_error=False) is None ) + # Confirm deleting an non-existent reference doc returns None. + assert await doc_store.adelete_ref_doc(ref_doc_id=ref_doc.doc_id) is None async def test_set_and_get_document_hash(self, doc_store): # Set a doc hash for a document @@ -254,6 +265,9 @@ async def test_set_and_get_document_hash(self, doc_store): # Assert with get that the hash is same as the one set. assert await doc_store.aget_document_hash(doc_id=doc_id) == doc_hash + async def test_aget_document_hash(self, doc_store): + assert await doc_store.aget_document_hash(doc_id="non-existent-doc") is None + async def test_set_and_get_document_hashes(self, doc_store): # Create a dictionary of doc_id -> doc_hash mappings and add it to the table. document_dict = { @@ -288,7 +302,7 @@ async def test_doc_store_basic(self, doc_store): retrieved_node = await doc_store.aget_document(doc_id=node.node_id) assert retrieved_node == node - async def test_delete_document(self, async_engine, doc_store): + async def test_adelete_document(self, async_engine, doc_store): # Create a doc and add it to the store. doc = Document(text="document_2", id_="doc_id_2", metadata={"doc": "info"}) await doc_store.async_add_documents([doc]) @@ -301,6 +315,11 @@ async def test_delete_document(self, async_engine, doc_store): result = await afetch(async_engine, query) assert len(result) == 0 + async def test_delete_non_existent_document(self, async_engine, doc_store): + await doc_store.adelete_document(doc_id="non-existent-doc", raise_error=False) + with pytest.raises(ValueError): + await doc_store.adelete_document(doc_id="non-existent-doc") + async def test_doc_store_ref_doc_not_added(self, async_engine, doc_store): # Create a ref_doc & doc. ref_doc = Document( @@ -376,3 +395,61 @@ async def test_doc_store_delete_all_ref_doc_nodes(self, async_engine, doc_store) query = f"""select * from "public"."{default_table_name_async}" where id = '{ref_doc.doc_id}';""" result = await afetch(async_engine, query) assert len(result) == 0 + + async def test_docs(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.docs() + + async def test_add_documents(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.add_documents([]) + + async def test_get_document(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_document("test_doc_id", raise_error=True) + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_document("test_doc_id", raise_error=False) + + async def test_delete_document(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_document("test_doc_id", raise_error=True) + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_document("test_doc_id", raise_error=False) + + async def test_document_exists(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.document_exists("test_doc_id") + + async def test_ref_doc_exists(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.ref_doc_exists(ref_doc_id="test_ref_doc_id") + + async def test_set_document_hash(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.set_document_hash("test_doc_id", "test_doc_hash") + + async def test_set_document_hashes(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.set_document_hashes({"test_doc_id": "test_doc_hash"}) + + async def test_get_document_hash(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_document_hash(doc_id="test_doc_id") + + async def test_get_all_document_hashes(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_all_document_hashes() + + async def test_get_all_ref_doc_info(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_all_ref_doc_info() + + async def test_get_ref_doc_info(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_ref_doc_info(ref_doc_id="test_doc_id") + + async def test_delete_ref_doc(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_ref_doc(ref_doc_id="test_doc_id", raise_error=False) + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_ref_doc(ref_doc_id="test_doc_id", raise_error=True) diff --git a/tests/test_async_index_store.py b/tests/test_async_index_store.py index 44a32cb..10297cb 100644 --- a/tests/test_async_index_store.py +++ b/tests/test_async_index_store.py @@ -19,13 +19,19 @@ import pytest import pytest_asyncio -from llama_index.core.data_structs.data_structs import IndexDict, IndexGraph, IndexList +from llama_index.core.data_structs.data_structs import ( + IndexDict, + IndexGraph, + IndexList, + IndexStruct, +) from sqlalchemy import RowMapping, text from llama_index_alloydb_pg import AlloyDBEngine from llama_index_alloydb_pg.async_index_store import AsyncAlloyDBIndexStore default_table_name_async = "index_store_" + str(uuid.uuid4()) +sync_method_exception_str = "Sync methods are not implemented for AsyncAlloyDBIndexStore . Use AlloyDBIndexStore interface instead." async def aexecute(engine: AlloyDBEngine, query: str) -> None: @@ -81,7 +87,7 @@ def password(self) -> str: @pytest_asyncio.fixture(scope="class") async def async_engine( - self, db_project, db_region, db_cluster, db_instance, db_name, user, password + self, db_project, db_region, db_cluster, db_instance, db_name ): async_engine = await AlloyDBEngine.afrom_instance( project_id=db_project, @@ -94,6 +100,7 @@ async def async_engine( yield async_engine await async_engine.close() + await async_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def index_store(self, async_engine): @@ -109,9 +116,16 @@ async def index_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): AsyncAlloyDBIndexStore( - engine=async_engine, table_name=default_table_name_async + key, engine=async_engine, table_name=default_table_name_async + ) + + async def test_create_without_table(self, async_engine): + with pytest.raises(ValueError): + await AsyncAlloyDBIndexStore.create( + engine=async_engine, table_name="non-existent-table" ) async def test_add_and_delete_index(self, index_store, async_engine): @@ -169,3 +183,20 @@ async def test_warning(self, index_store): assert "No struct_id specified and more than one struct exists." in str( w[-1].message ) + + async def test_index_structs(self, index_store): + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.index_structs() + + async def test_add_index_struct(self, index_store): + index_struct = IndexGraph() + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.add_index_struct(index_struct) + + async def test_delete_index_struct(self, index_store): + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.delete_index_struct("non_existent_key") + + async def test_get_index_struct(self, index_store): + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.get_index_struct(struct_id="non_existent_id") diff --git a/tests/test_async_reader.py b/tests/test_async_reader.py new file mode 100644 index 0000000..50a1be0 --- /dev/null +++ b/tests/test_async_reader.py @@ -0,0 +1,503 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import uuid +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.schema import Document +from sqlalchemy import RowMapping, text + +from llama_index_alloydb_pg import AlloyDBEngine +from llama_index_alloydb_pg.async_reader import AsyncAlloyDBReader + +default_table_name_async = "reader_test_" + str(uuid.uuid4()) +sync_method_exception_str = "Sync methods are not implemented for AsyncAlloyDBReader. Use AlloyDBReader interface instead." + + +async def aexecute(engine: AlloyDBEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +async def afetch(engine: AlloyDBEngine, query: str) -> Sequence[RowMapping]: + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestAsyncAlloyDBReader: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on AlloyDB instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for AlloyDB") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for AlloyDB") + + @pytest_asyncio.fixture(scope="class") + async def async_engine( + self, db_project, db_region, db_cluster, db_instance, db_name + ): + async_engine = await AlloyDBEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + cluster=db_cluster, + region=db_region, + database=db_name, + ) + await self._create_default_table(async_engine) + + yield async_engine + + await self._cleanup_table(async_engine) + + await async_engine.close() + await async_engine._connector.close() + + async def _cleanup_table(self, engine): + await aexecute(engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"') + + async def _create_default_table(self, engine): + create_query = f""" + CREATE TABLE IF NOT EXISTS "{default_table_name_async}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(engine, create_query) + + async def _collect_async_items(self, docs_generator): + """Collects items from an async generator.""" + docs = [] + async for doc in docs_generator: + docs.append(doc) + return docs + + async def test_create_reader_with_invalid_parameters(self, async_engine): + with pytest.raises(ValueError): + await AsyncAlloyDBReader.create( + engine=async_engine, + ) + with pytest.raises(ValueError): + + def fake_formatter(): + return None + + await AsyncAlloyDBReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="text", + formatter=fake_formatter, + ) + with pytest.raises(ValueError): + await AsyncAlloyDBReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="fake_format", + ) + + async def test_lazy_load_data(self, async_engine): + with pytest.raises(Exception, match=sync_method_exception_str): + reader = await AsyncAlloyDBReader.create( + engine=async_engine, + table_name=default_table_name_async, + ) + + reader.lazy_load_data() + + async def test_load_data(self, async_engine): + with pytest.raises(Exception, match=sync_method_exception_str): + reader = await AsyncAlloyDBReader.create( + engine=async_engine, + table_name=default_table_name_async, + ) + + reader.load_data() + + async def test_load_from_query_default(self, async_engine): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" ( + fruit_name, variety, quantity_in_stock, price_per_unit, organic + ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncAlloyDBReader.create( + engine=async_engine, + table_name=table_name, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_document = Document( + text="1", + metadata={ + "fruit_name": "Apple", + "variety": "Granny Smith", + "quantity_in_stock": 150, + "price_per_unit": 1, + "organic": 1, + }, + ) + + assert documents[0].text == expected_document.text + assert documents[0].metadata == expected_document.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_customized_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + expected_docs = [ + Document( + text="Apple Smith 150 1 1", + metadata={"fruit_id": 1}, + ), + Document( + text="Banana Cavendish 200 1 0", + metadata={"fruit_id": 2}, + ), + Document( + text="Orange Navel 80 1 1", + metadata={"fruit_id": 3}, + ), + ] + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Smith', 150, 0.99, 1), + ('Banana', 'Cavendish', 200, 0.59, 0), + ('Orange', 'Navel', 80, 1.29, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncAlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + # Compare the full list of documents to make sure all are in sync. + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncAlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_text_docs = [ + Document( + text="Granny Smith 150 1", + metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1}, + ) + ] + + for expected, actual in zip(expected_text_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + reader = await AsyncAlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="JSON", + ) + + actual_documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text='{"variety": "Granny Smith", "quantity_in_stock": 150, "price_per_unit": 1}', + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, actual_documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_with_json(self, async_engine): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}"( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety JSON NOT NULL, + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + li_metadata JSON NOT NULL + ) + """ + await aexecute(async_engine, query) + + metadata = json.dumps({"organic": 1}) + variety = json.dumps({"type": "Granny Smith"}) + insert_query = f""" + INSERT INTO "{table_name}" + (fruit_name, variety, quantity_in_stock, price_per_unit, li_metadata) + VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" + await aexecute(async_engine, insert_query) + + reader = await AsyncAlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "variety", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="1", + metadata={ + "variety": {"type": "Granny Smith"}, + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_formatter( + self, async_engine + ): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + def my_formatter(row, content_columns): + return "-".join( + str(row[column]) for column in content_columns if column in row + ) + + reader = await AsyncAlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + formatter=my_formatter, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_documents = [ + Document( + text="Granny Smith-150-1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_documents, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_page_content_format( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncAlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="YAML", + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="variety: Granny Smith\nquantity_in_stock: 150\nprice_per_unit: 1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') diff --git a/tests/test_async_vector_store.py b/tests/test_async_vector_store.py index f3200b5..36f05d1 100644 --- a/tests/test_async_vector_store.py +++ b/tests/test_async_vector_store.py @@ -14,7 +14,8 @@ import os import uuid -from typing import List, Sequence +import warnings +from typing import Sequence import pytest import pytest_asyncio @@ -31,6 +32,7 @@ from llama_index_alloydb_pg import AlloyDBEngine, Column from llama_index_alloydb_pg.async_vector_store import AsyncAlloyDBVectorStore +from llama_index_alloydb_pg.indexes import HNSWQueryOptions, ScaNNQueryOptions DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) DEFAULT_TABLE_CUSTOM_VS = "test_table" + str(uuid.uuid4()) @@ -114,9 +116,10 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): ) yield engine - await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE_CUSTOM_VS}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_CUSTOM_VS}"') await engine.close() + await engine._connector.close() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): @@ -154,12 +157,30 @@ async def custom_vs(self, engine): "nullable_int_field", "nullable_str_field", ], + index_query_options=HNSWQueryOptions(ef_search=1), + ) + yield vs + + @pytest_asyncio.fixture(scope="class") + async def custom_vs_scann(self, engine, custom_vs): + vs = await AsyncAlloyDBVectorStore.create( + engine, + table_name=DEFAULT_TABLE_CUSTOM_VS, + metadata_columns=[ + "len", + "nullable_int_field", + "nullable_str_field", + ], + index_query_options=ScaNNQueryOptions( + num_leaves_to_search=1, pre_reordering_num_neighbors=2 + ), ) yield vs async def test_init_with_constructor(self, engine): + key = object() with pytest.raises(Exception): - AsyncAlloyDBVectorStore(engine, table_name=DEFAULT_TABLE) + AsyncAlloyDBVectorStore(key, engine, table_name=DEFAULT_TABLE) async def test_validate_id_column_create(self, engine, vs): test_id_column = "test_id_column" @@ -318,6 +339,89 @@ async def test_aquery(self, engine, vs): assert len(results.nodes) == 3 assert results.nodes[0].get_content(metadata_mode=MetadataMode.NONE) == "foo" + async def test_aquery_scann(self, engine, custom_vs_scann): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE_CUSTOM_VS}"') + # setting extra metadata to be indexed in separate column + for node in nodes: + node.metadata["len"] = len(node.text) + + await custom_vs_scann.async_add(nodes) + query = VectorStoreQuery( + query_embedding=[1.0] * VECTOR_SIZE, similarity_top_k=3 + ) + results = await custom_vs_scann.aquery(query) + + assert results.nodes is not None + assert results.ids is not None + assert results.similarities is not None + assert len(results.nodes) == 3 + + async def test_aquery_filters(self, engine, custom_vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE_CUSTOM_VS}"') + # setting extra metadata to be indexed in separate column + for node in nodes: + node.metadata["len"] = len(node.text) + + await custom_vs.async_add(nodes) + + filters = MetadataFilters( + filters=[ + MetadataFilter( + key="some_test_column", + value=["value_should_be_ignored"], + operator=FilterOperator.CONTAINS, + ), + MetadataFilter( + key="len", + value=3, + operator=FilterOperator.LTE, + ), + MetadataFilter( + key="len", + value=3, + operator=FilterOperator.GTE, + ), + MetadataFilter( + key="len", + value=2, + operator=FilterOperator.GT, + ), + MetadataFilter( + key="len", + value=4, + operator=FilterOperator.LT, + ), + MetadataFilters( + filters=[ + MetadataFilter( + key="len", + value=6.0, + operator=FilterOperator.NE, + ), + ], + condition=FilterCondition.OR, + ), + ], + condition=FilterCondition.AND, + ) + query = VectorStoreQuery( + query_embedding=[1.0] * VECTOR_SIZE, filters=filters, similarity_top_k=-1 + ) + with warnings.catch_warnings(record=True) as w: + results = await custom_vs.aquery(query) + + assert len(w) == 1 + assert "Expecting a scalar in the filter value" in str(w[-1].message) + + assert results.nodes is not None + assert results.ids is not None + assert results.similarities is not None + assert len(results.nodes) == 3 + async def test_aclear(self, engine, vs): # Note: To be migrated to a pytest dependency on test_adelete # Blocked due to unexpected fixtures reloads while running integration test suite diff --git a/tests/test_async_vector_store_index.py b/tests/test_async_vector_store_index.py index 1b78e34..954edbb 100644 --- a/tests/test_async_vector_store_index.py +++ b/tests/test_async_vector_store_index.py @@ -15,13 +15,11 @@ import os import uuid -from typing import List, Sequence import pytest import pytest_asyncio -from llama_index.core.schema import MetadataMode, NodeRelationship, TextNode +from llama_index.core.schema import TextNode from sqlalchemy import text -from sqlalchemy.engine.row import RowMapping from llama_index_alloydb_pg import AlloyDBEngine from llama_index_alloydb_pg.async_vector_store import AsyncAlloyDBVectorStore @@ -31,7 +29,6 @@ HNSWIndex, IVFFlatIndex, ) -from llama_index_alloydb_pg.vector_store import AlloyDBVectorStore DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") DEFAULT_INDEX_NAME = DEFAULT_TABLE + DEFAULT_INDEX_NAME_SUFFIX @@ -107,6 +104,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") await engine.close() + await engine._connector.close() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): @@ -126,6 +124,7 @@ async def test_aapply_vector_index(self, vs): index = HNSWIndex() await vs.aapply_vector_index(index) assert await vs.is_valid_index(DEFAULT_INDEX_NAME) + await vs.adrop_vector_index(DEFAULT_INDEX_NAME) async def test_areindex(self, vs): if not await vs.is_valid_index(DEFAULT_INDEX_NAME): @@ -134,6 +133,7 @@ async def test_areindex(self, vs): await vs.areindex() await vs.areindex(DEFAULT_INDEX_NAME) assert await vs.is_valid_index(DEFAULT_INDEX_NAME) + await vs.adrop_vector_index(DEFAULT_INDEX_NAME) async def test_dropindex(self, vs): await vs.adrop_vector_index() @@ -150,6 +150,7 @@ async def test_aapply_vector_index_ivfflat(self, vs): ) await vs.aapply_vector_index(index) assert await vs.is_valid_index("secondindex") + await vs.adrop_vector_index() await vs.adrop_vector_index("secondindex") async def test_is_valid_index(self, vs): diff --git a/tests/test_chat_store.py b/tests/test_chat_store.py new file mode 100644 index 0000000..d10e0cb --- /dev/null +++ b/tests/test_chat_store.py @@ -0,0 +1,398 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +import warnings +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.llms import ChatMessage +from sqlalchemy import RowMapping, text + +from llama_index_alloydb_pg import AlloyDBChatStore, AlloyDBEngine + +default_table_name_async = "chat_store_" + str(uuid.uuid4()) +default_table_name_sync = "chat_store_" + str(uuid.uuid4()) + + +async def aexecute( + engine: AlloyDBEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +async def afetch(engine: AlloyDBEngine, query: str) -> Sequence[RowMapping]: + async def run(engine, query): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + return await engine._run_as_async(run(engine, query)) + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestAlloyDBChatStoreAsync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on AlloyDB instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for AlloyDB") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for AlloyDB") + + @pytest_asyncio.fixture(scope="class") + async def async_engine( + self, db_project, db_region, db_cluster, db_instance, db_name + ): + async_engine = await AlloyDBEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + cluster=db_cluster, + region=db_region, + database=db_name, + ) + + yield async_engine + + await async_engine.close() + await async_engine._connector.close() + + @pytest_asyncio.fixture(scope="class") + async def async_chat_store(self, async_engine): + await async_engine.ainit_chat_store_table(table_name=default_table_name_async) + + async_chat_store = await AlloyDBChatStore.create( + engine=async_engine, table_name=default_table_name_async + ) + + yield async_chat_store + + query = f'DROP TABLE IF EXISTS "{default_table_name_async}"' + await aexecute(async_engine, query) + + async def test_init_with_constructor(self, async_engine): + with pytest.raises(Exception): + AlloyDBChatStore(engine=async_engine, table_name=default_table_name_async) + + async def test_async_add_message(self, async_engine, async_chat_store): + key = "test_add_key" + + message = ChatMessage(content="add_message_test", role="user") + await async_chat_store.async_add_message(key, message=message) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + result = results[0] + assert result["message"] == message.model_dump() + + async def test_aset_and_aget_messages(self, async_chat_store): + message_1 = ChatMessage(content="First message", role="user") + message_2 = ChatMessage(content="Second message", role="user") + messages = [message_1, message_2] + key = "test_set_and_get_key" + await async_chat_store.aset_messages(key, messages) + + results = await async_chat_store.aget_messages(key) + + assert len(results) == 2 + assert results[0].content == message_1.content + assert results[1].content == message_2.content + + async def test_adelete_messages(self, async_engine, async_chat_store): + messages = [ChatMessage(content="Message to delete", role="user")] + key = "test_delete_key" + await async_chat_store.aset_messages(key, messages) + + await async_chat_store.adelete_messages(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 0 + + async def test_adelete_message(self, async_engine, async_chat_store): + message_1 = ChatMessage(content="Keep me", role="user") + message_2 = ChatMessage(content="Delete me", role="user") + messages = [message_1, message_2] + key = "test_delete_message_key" + await async_chat_store.aset_messages(key, messages) + + await async_chat_store.adelete_message(key, 1) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + assert results[0]["message"] == message_1.model_dump() + + async def test_adelete_last_message(self, async_engine, async_chat_store): + message_1 = ChatMessage(content="Message 1", role="user") + message_2 = ChatMessage(content="Message 2", role="user") + message_3 = ChatMessage(content="Message 3", role="user") + messages = [message_1, message_2, message_3] + key = "test_delete_last_message_key" + await async_chat_store.aset_messages(key, messages) + + await async_chat_store.adelete_last_message(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 2 + assert results[0]["message"] == message_1.model_dump() + assert results[1]["message"] == message_2.model_dump() + + async def test_aget_keys(self, async_engine, async_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + message_2 = [ChatMessage(content="Second message", role="user")] + key_1 = "key1" + key_2 = "key2" + await async_chat_store.aset_messages(key_1, message_1) + await async_chat_store.aset_messages(key_2, message_2) + + keys = await async_chat_store.aget_keys() + + assert key_1 in keys + assert key_2 in keys + + async def test_set_exisiting_key(self, async_engine, async_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + key = "test_set_exisiting_key" + await async_chat_store.aset_messages(key, message_1) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + result = results[0] + assert result["message"] == message_1[0].model_dump() + + message_2 = ChatMessage(content="Second message", role="user") + message_3 = ChatMessage(content="Third message", role="user") + messages = [message_2, message_3] + + await async_chat_store.aset_messages(key, messages) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + # Assert the previous messages are deleted and only the newest ones exist. + assert len(results) == 2 + + assert results[0]["message"] == message_2.model_dump() + assert results[1]["message"] == message_3.model_dump() + + +@pytest.mark.asyncio(loop_scope="class") +class TestAlloyDBChatStoreSync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on AlloyDB instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for AlloyDB") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for AlloyDB") + + @pytest_asyncio.fixture(scope="class") + async def sync_engine( + self, db_project, db_region, db_cluster, db_instance, db_name + ): + sync_engine = AlloyDBEngine.from_instance( + project_id=db_project, + instance=db_instance, + cluster=db_cluster, + region=db_region, + database=db_name, + ) + + yield sync_engine + + await sync_engine.close() + await sync_engine._connector.close() + + @pytest_asyncio.fixture(scope="class") + async def sync_chat_store(self, sync_engine): + sync_engine.init_chat_store_table(table_name=default_table_name_sync) + + sync_chat_store = AlloyDBChatStore.create_sync( + engine=sync_engine, table_name=default_table_name_sync + ) + + yield sync_chat_store + + query = f'DROP TABLE IF EXISTS "{default_table_name_sync}"' + await aexecute(sync_engine, query) + + async def test_init_with_constructor(self, sync_engine): + with pytest.raises(Exception): + AlloyDBChatStore(engine=sync_engine, table_name=default_table_name_sync) + + async def test_async_add_message(self, sync_engine, sync_chat_store): + key = "test_add_key" + + message = ChatMessage(content="add_message_test", role="user") + sync_chat_store.add_message(key, message=message) + + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}';""" + results = await afetch(sync_engine, query) + result = results[0] + assert result["message"] == message.model_dump() + + async def test_aset_and_aget_messages(self, sync_chat_store): + message_1 = ChatMessage(content="First message", role="user") + message_2 = ChatMessage(content="Second message", role="user") + messages = [message_1, message_2] + key = "test_set_and_get_key" + sync_chat_store.set_messages(key, messages) + + results = sync_chat_store.get_messages(key) + + assert len(results) == 2 + assert results[0].content == message_1.content + assert results[1].content == message_2.content + + async def test_adelete_messages(self, sync_engine, sync_chat_store): + messages = [ChatMessage(content="Message to delete", role="user")] + key = "test_delete_key" + sync_chat_store.set_messages(key, messages) + + sync_chat_store.delete_messages(key) + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}' ORDER BY id;""" + results = await afetch(sync_engine, query) + + assert len(results) == 0 + + async def test_adelete_message(self, sync_engine, sync_chat_store): + message_1 = ChatMessage(content="Keep me", role="user") + message_2 = ChatMessage(content="Delete me", role="user") + messages = [message_1, message_2] + key = "test_delete_message_key" + sync_chat_store.set_messages(key, messages) + + sync_chat_store.delete_message(key, 1) + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}' ORDER BY id;""" + results = await afetch(sync_engine, query) + + assert len(results) == 1 + assert results[0]["message"] == message_1.model_dump() + + async def test_adelete_last_message(self, sync_engine, sync_chat_store): + message_1 = ChatMessage(content="Message 1", role="user") + message_2 = ChatMessage(content="Message 2", role="user") + message_3 = ChatMessage(content="Message 3", role="user") + messages = [message_1, message_2, message_3] + key = "test_delete_last_message_key" + sync_chat_store.set_messages(key, messages) + + sync_chat_store.delete_last_message(key) + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}' ORDER BY id;""" + results = await afetch(sync_engine, query) + + assert len(results) == 2 + assert results[0]["message"] == message_1.model_dump() + assert results[1]["message"] == message_2.model_dump() + + async def test_aget_keys(self, sync_engine, sync_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + message_2 = [ChatMessage(content="Second message", role="user")] + key_1 = "key1" + key_2 = "key2" + sync_chat_store.set_messages(key_1, message_1) + sync_chat_store.set_messages(key_2, message_2) + + keys = sync_chat_store.get_keys() + + assert key_1 in keys + assert key_2 in keys + + async def test_set_exisiting_key(self, sync_engine, sync_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + key = "test_set_exisiting_key" + sync_chat_store.set_messages(key, message_1) + + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}';""" + results = await afetch(sync_engine, query) + + assert len(results) == 1 + result = results[0] + assert result["message"] == message_1[0].model_dump() + + message_2 = ChatMessage(content="Second message", role="user") + message_3 = ChatMessage(content="Third message", role="user") + messages = [message_2, message_3] + + sync_chat_store.set_messages(key, messages) + + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}';""" + results = await afetch(sync_engine, query) + + # Assert the previous messages are deleted and only the newest ones exist. + assert len(results) == 2 + + assert results[0]["message"] == message_2.model_dump() + assert results[1]["message"] == message_3.model_dump() diff --git a/tests/test_document_store.py b/tests/test_document_store.py index 7a9d391..6a7094e 100644 --- a/tests/test_document_store.py +++ b/tests/test_document_store.py @@ -108,6 +108,7 @@ async def async_engine( yield async_engine await async_engine.close() + await async_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def doc_store(self, async_engine): @@ -123,9 +124,10 @@ async def doc_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): AlloyDBDocumentStore( - engine=async_engine, table_name=default_table_name_async + key, engine=async_engine, table_name=default_table_name_async ) async def test_async_add_document(self, async_engine, doc_store): @@ -399,6 +401,7 @@ async def sync_engine( yield sync_engine await sync_engine.close() + await sync_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def sync_doc_store(self, sync_engine): @@ -414,8 +417,11 @@ async def sync_doc_store(self, sync_engine): await aexecute(sync_engine, query) async def test_init_with_constructor(self, sync_engine): + key = object() with pytest.raises(Exception): - AlloyDBDocumentStore(engine=sync_engine, table_name=default_table_name_sync) + AlloyDBDocumentStore( + key, engine=sync_engine, table_name=default_table_name_sync + ) async def test_docs(self, sync_doc_store): # Create and add document into the doc store. diff --git a/tests/test_engine.py b/tests/test_engine.py index 3d30bc0..c1cf94e 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -14,7 +14,7 @@ import os import uuid -from typing import Dict, Sequence +from typing import Sequence import asyncpg # type: ignore import pytest @@ -34,6 +34,8 @@ DEFAULT_IS_TABLE_SYNC = "index_store_" + str(uuid.uuid4()) DEFAULT_VS_TABLE = "vector_store_" + str(uuid.uuid4()) DEFAULT_VS_TABLE_SYNC = "vector_store_" + str(uuid.uuid4()) +DEFAULT_CS_TABLE = "chat_store_" + str(uuid.uuid4()) +DEFAULT_CS_TABLE_SYNC = "chat_store_" + str(uuid.uuid4()) VECTOR_SIZE = 768 @@ -118,7 +120,40 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE}"') + await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE}"') await engine.close() + await engine._connector.close() + + async def test_init_with_constructor( + self, + db_project, + db_region, + db_cluster, + db_instance, + db_name, + user, + password, + ): + async def getconn() -> asyncpg.Connection: + conn = await connector.connect( # type: ignore + f"projects/{db_project}/locations/{db_region}/clusters/{db_cluster}/instances/{db_instance}", + "asyncpg", + user=user, + password=password, + db=db_name, + enable_iam_auth=False, + ip_type=IPTypes.PUBLIC, + ) + return conn + + engine = create_async_engine( + "postgresql+asyncpg://", + async_creator=getconn, + ) + + key = object() + with pytest.raises(Exception): + AlloyDBEngine(key, engine) async def test_password( self, @@ -145,6 +180,35 @@ async def test_password( AlloyDBEngine._connector = None await engine.close() + async def test_missing_user_or_password( + self, + db_project, + db_region, + db_cluster, + db_instance, + db_name, + user, + password, + ): + with pytest.raises(ValueError): + await AlloyDBEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + cluster=db_cluster, + database=db_name, + user=user, + ) + with pytest.raises(ValueError): + await AlloyDBEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + cluster=db_cluster, + database=db_name, + password=password, + ) + async def test_from_engine( self, db_project, @@ -244,7 +308,9 @@ async def test_iam_account_override( async def test_init_document_store(self, engine): await engine.ainit_doc_store_table( - table_name=DEFAULT_DS_TABLE, schema_name="public", overwrite_existing=True + table_name=DEFAULT_DS_TABLE, + schema_name="public", + overwrite_existing=True, ) stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{DEFAULT_DS_TABLE}';" results = await afetch(engine, stmt) @@ -272,13 +338,21 @@ async def test_init_vector_store(self, engine): "data_type": "character varying", "is_nullable": "NO", }, - {"column_name": "li_metadata", "data_type": "jsonb", "is_nullable": "NO"}, + { + "column_name": "li_metadata", + "data_type": "jsonb", + "is_nullable": "NO", + }, { "column_name": "embedding", "data_type": "USER-DEFINED", "is_nullable": "YES", }, - {"column_name": "node_data", "data_type": "json", "is_nullable": "NO"}, + { + "column_name": "node_data", + "data_type": "json", + "is_nullable": "NO", + }, { "column_name": "ref_doc_id", "data_type": "character varying", @@ -307,6 +381,22 @@ async def test_init_index_store(self, engine): for row in results: assert row in expected + async def test_init_chat_store(self, engine): + await engine.ainit_chat_store_table( + table_name=DEFAULT_CS_TABLE, + schema_name="public", + overwrite_existing=True, + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{DEFAULT_CS_TABLE}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "id", "data_type": "integer"}, + {"column_name": "key", "data_type": "character varying"}, + {"column_name": "message", "data_type": "json"}, + ] + for row in results: + assert row in expected + @pytest.mark.asyncio class TestEngineSync: @@ -359,7 +449,9 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE_SYNC}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE_SYNC}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE_SYNC}"') await engine.close() + await engine._connector.close() async def test_password( self, @@ -414,6 +506,7 @@ async def test_iam_account_override( assert engine await aexecute(engine, "SELECT 1") await engine.close() + await engine._connector.close() async def test_init_document_store(self, engine): engine.init_doc_store_table( @@ -447,13 +540,21 @@ async def test_init_vector_store(self, engine): "data_type": "character varying", "is_nullable": "NO", }, - {"column_name": "li_metadata", "data_type": "jsonb", "is_nullable": "NO"}, + { + "column_name": "li_metadata", + "data_type": "jsonb", + "is_nullable": "NO", + }, { "column_name": "embedding", "data_type": "USER-DEFINED", "is_nullable": "YES", }, - {"column_name": "node_data", "data_type": "json", "is_nullable": "NO"}, + { + "column_name": "node_data", + "data_type": "json", + "is_nullable": "NO", + }, { "column_name": "ref_doc_id", "data_type": "character varying", @@ -481,3 +582,19 @@ async def test_init_index_store(self, engine): ] for row in results: assert row in expected + + async def test_init_chat_store(self, engine): + engine.init_chat_store_table( + table_name=DEFAULT_CS_TABLE_SYNC, + schema_name="public", + overwrite_existing=True, + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{DEFAULT_CS_TABLE_SYNC}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "id", "data_type": "integer"}, + {"column_name": "key", "data_type": "character varying"}, + {"column_name": "message", "data_type": "json"}, + ] + for row in results: + assert row in expected diff --git a/tests/test_index_store.py b/tests/test_index_store.py index 03a2eb9..c9a4bce 100644 --- a/tests/test_index_store.py +++ b/tests/test_index_store.py @@ -108,6 +108,7 @@ async def async_engine( yield async_engine await async_engine.close() + await async_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def index_store(self, async_engine): @@ -123,8 +124,11 @@ async def index_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): - AlloyDBIndexStore(engine=async_engine, table_name=default_table_name_async) + AlloyDBIndexStore( + key, engine=async_engine, table_name=default_table_name_async + ) async def test_add_and_delete_index(self, index_store, async_engine): index_struct = IndexGraph() @@ -233,6 +237,7 @@ async def async_engine( yield async_engine await async_engine.close() + await async_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def index_store(self, async_engine): @@ -248,8 +253,11 @@ async def index_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): - AlloyDBIndexStore(engine=async_engine, table_name=default_table_name_sync) + AlloyDBIndexStore( + key, engine=async_engine, table_name=default_table_name_sync + ) async def test_add_and_delete_index(self, index_store, async_engine): index_struct = IndexGraph() diff --git a/tests/test_indexes.py b/tests/test_indexes.py new file mode 100644 index 0000000..c2a781c --- /dev/null +++ b/tests/test_indexes.py @@ -0,0 +1,122 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +from llama_index_alloydb_pg.indexes import ( + DistanceStrategy, + HNSWIndex, + HNSWQueryOptions, + IVFFlatIndex, + IVFFlatQueryOptions, + IVFIndex, + IVFQueryOptions, + ScaNNIndex, + ScaNNQueryOptions, +) + + +class TestAlloyDBIndex: + def test_distance_strategy(self): + assert DistanceStrategy.EUCLIDEAN.operator == "<->" + assert DistanceStrategy.EUCLIDEAN.search_function == "l2_distance" + assert DistanceStrategy.EUCLIDEAN.index_function == "vector_l2_ops" + assert DistanceStrategy.EUCLIDEAN.scann_index_function == "l2" + + assert DistanceStrategy.COSINE_DISTANCE.operator == "<=>" + assert DistanceStrategy.COSINE_DISTANCE.search_function == "cosine_distance" + assert DistanceStrategy.COSINE_DISTANCE.index_function == "vector_cosine_ops" + assert DistanceStrategy.COSINE_DISTANCE.scann_index_function == "cosine" + + assert DistanceStrategy.INNER_PRODUCT.operator == "<#>" + assert DistanceStrategy.INNER_PRODUCT.search_function == "inner_product" + assert DistanceStrategy.INNER_PRODUCT.index_function == "vector_ip_ops" + assert DistanceStrategy.INNER_PRODUCT.scann_index_function == "dot_product" + + def test_hnsw_index(self): + index = HNSWIndex(name="test_index", m=32, ef_construction=128) + assert index.index_type == "hnsw" + assert index.m == 32 + assert index.ef_construction == 128 + assert index.index_options() == "(m = 32, ef_construction = 128)" + + def test_hnsw_query_options(self): + options = HNSWQueryOptions(ef_search=80) + assert options.to_parameter() == ["hnsw.ef_search = 80"] + + with warnings.catch_warnings(record=True) as w: + options.to_string() + + assert len(w) == 1 + assert "to_string is deprecated, use to_parameter instead." in str( + w[-1].message + ) + + def test_ivfflat_index(self): + index = IVFFlatIndex(name="test_index", lists=200) + assert index.index_type == "ivfflat" + assert index.lists == 200 + assert index.index_options() == "(lists = 200)" + + def test_ivfflat_query_options(self): + options = IVFFlatQueryOptions(probes=2) + assert options.to_parameter() == ["ivfflat.probes = 2"] + + with warnings.catch_warnings(record=True) as w: + options.to_string() + assert len(w) == 1 + assert "to_string is deprecated, use to_parameter instead." in str( + w[-1].message + ) + + def test_ivf_index(self): + index = IVFIndex(name="test_index", lists=200) + assert index.index_type == "ivf" + assert index.lists == 200 + assert index.quantizer == "sq8" # Check default value + assert index.index_options() == "(lists = 200, quantizer = sq8)" + + def test_ivf_query_options(self): + options = IVFQueryOptions(probes=2) + assert options.to_parameter() == ["ivf.probes = 2"] + + with warnings.catch_warnings(record=True) as w: + options.to_string() + assert len(w) == 1 + assert "to_string is deprecated, use to_parameter instead." in str( + w[-1].message + ) + + def test_scann_index(self): + index = ScaNNIndex(name="test_index", num_leaves=10) + assert index.index_type == "ScaNN" + assert index.num_leaves == 10 + assert index.quantizer == "sq8" # Check default value + assert index.index_options() == "(num_leaves = 10, quantizer = sq8)" + + def test_scann_query_options(self): + options = ScaNNQueryOptions( + num_leaves_to_search=2, pre_reordering_num_neighbors=10 + ) + assert options.to_parameter() == [ + "scann.num_leaves_to_search = 2", + "scann.pre_reordering_num_neighbors = 10", + ] + + with warnings.catch_warnings(record=True) as w: + options.to_string() + assert len(w) == 1 + assert "to_string is deprecated, use to_parameter instead." in str( + w[-1].message + ) diff --git a/tests/test_reader.py b/tests/test_reader.py new file mode 100644 index 0000000..2f9df85 --- /dev/null +++ b/tests/test_reader.py @@ -0,0 +1,912 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import uuid +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.schema import Document +from sqlalchemy import RowMapping, text + +from llama_index_alloydb_pg import AlloyDBEngine, AlloyDBReader + +default_table_name_async = "async_reader_test_" + str(uuid.uuid4()) +default_table_name_sync = "sync_reader_test_" + str(uuid.uuid4()) + + +async def aexecute( + engine: AlloyDBEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +async def afetch(engine: AlloyDBEngine, query: str) -> Sequence[RowMapping]: + async def run(engine, query): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + return await engine._run_as_async(run(engine, query)) + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestAlloyDBReaderAsync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on AlloyDB instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for AlloyDB") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for AlloyDB") + + @pytest_asyncio.fixture(scope="class") + async def async_engine( + self, + db_project, + db_region, + db_cluster, + db_instance, + db_name, + ): + async_engine = await AlloyDBEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + cluster=db_cluster, + region=db_region, + database=db_name, + ) + + yield async_engine + + await aexecute( + async_engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"' + ) + + await async_engine.close() + + async def _cleanup_table(self, engine): + await aexecute(engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"') + + async def _collect_async_items(self, docs_generator): + """Collects items from an async generator.""" + docs = [] + async for doc in docs_generator: + docs.append(doc) + return docs + + async def test_create_reader_with_invalid_parameters(self, async_engine): + with pytest.raises(ValueError): + await AlloyDBReader.create( + engine=async_engine, + ) + with pytest.raises(ValueError): + + def fake_formatter(): + return None + + await AlloyDBReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="text", + formatter=fake_formatter, + ) + with pytest.raises(ValueError): + await AlloyDBReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="fake_format", + ) + + async def test_load_from_query_default(self, async_engine): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" ( + fruit_name, variety, quantity_in_stock, price_per_unit, organic + ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AlloyDBReader.create( + engine=async_engine, + table_name=table_name, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_document = Document( + text="1", + metadata={ + "fruit_name": "Apple", + "variety": "Granny Smith", + "quantity_in_stock": 150, + "price_per_unit": 1, + "organic": 1, + }, + ) + + assert documents[0].text == expected_document.text + assert documents[0].metadata == expected_document.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_customized_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + expected_docs = [ + Document( + text="Apple Smith 150 1 1", + metadata={"fruit_id": 1}, + ), + Document( + text="Banana Cavendish 200 1 0", + metadata={"fruit_id": 2}, + ), + Document( + text="Orange Navel 80 1 1", + metadata={"fruit_id": 3}, + ), + ] + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Smith', 150, 0.99, 1), + ('Banana', 'Cavendish', 200, 0.59, 0), + ('Orange', 'Navel', 80, 1.29, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + # Compare the full list of documents to make sure all are in sync. + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_text_docs = [ + Document( + text="Granny Smith 150 1", + metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1}, + ) + ] + + for expected, actual in zip(expected_text_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + reader = await AlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="JSON", + ) + + actual_documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text='{"variety": "Granny Smith", "quantity_in_stock": 150, "price_per_unit": 1}', + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, actual_documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_with_json(self, async_engine): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}"( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety JSON NOT NULL, + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + li_metadata JSON NOT NULL + ) + """ + await aexecute(async_engine, query) + + metadata = json.dumps({"organic": 1}) + variety = json.dumps({"type": "Granny Smith"}) + insert_query = f""" + INSERT INTO "{table_name}" + (fruit_name, variety, quantity_in_stock, price_per_unit, li_metadata) + VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" + await aexecute(async_engine, insert_query) + + reader = await AlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "variety", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="1", + metadata={ + "variety": {"type": "Granny Smith"}, + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_formatter( + self, async_engine + ): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + def my_formatter(row, content_columns): + return "-".join( + str(row[column]) for column in content_columns if column in row + ) + + reader = await AlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + formatter=my_formatter, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_documents = [ + Document( + text="Granny Smith-150-1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_documents, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_page_content_format( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="YAML", + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="variety: Granny Smith\nquantity_in_stock: 150\nprice_per_unit: 1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + +@pytest.mark.asyncio(loop_scope="class") +class TestAlloyDBReaderSync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on AlloyDB instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for AlloyDB") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for AlloyDB") + + @pytest_asyncio.fixture(scope="class") + async def sync_engine( + self, + db_project, + db_region, + db_cluster, + db_instance, + db_name, + ): + sync_engine = await AlloyDBEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + cluster=db_cluster, + region=db_region, + database=db_name, + ) + + yield sync_engine + + await aexecute( + sync_engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"' + ) + + await sync_engine.close() + + async def _cleanup_table(self, engine): + await aexecute(engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"') + + def _collect_items(self, docs_generator): + """Collects items from a generator.""" + docs = [] + for doc in docs_generator: + docs.append(doc) + return docs + + async def test_create_reader_with_invalid_parameters(self, sync_engine): + with pytest.raises(ValueError): + AlloyDBReader.create_sync( + engine=sync_engine, + ) + with pytest.raises(ValueError): + + def fake_formatter(): + return None + + AlloyDBReader.create_sync( + engine=sync_engine, + table_name=default_table_name_async, + format="text", + formatter=fake_formatter, + ) + with pytest.raises(ValueError): + AlloyDBReader.create_sync( + engine=sync_engine, + table_name=default_table_name_async, + format="fake_format", + ) + + async def test_load_from_query_default(self, sync_engine): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" ( + fruit_name, variety, quantity_in_stock, price_per_unit, organic + ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = AlloyDBReader.create_sync( + engine=sync_engine, + table_name=table_name, + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_document = Document( + text="1", + metadata={ + "fruit_name": "Apple", + "variety": "Granny Smith", + "quantity_in_stock": 150, + "price_per_unit": 1, + "organic": 1, + }, + ) + + assert documents[0].text == expected_document.text + assert documents[0].metadata == expected_document.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_customized_metadata( + self, sync_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + expected_docs = [ + Document( + text="Apple Smith 150 1 1", + metadata={"fruit_id": 1}, + ), + Document( + text="Banana Cavendish 200 1 0", + metadata={"fruit_id": 2}, + ), + Document( + text="Orange Navel 80 1 1", + metadata={"fruit_id": 3}, + ), + ] + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Smith', 150, 0.99, 1), + ('Banana', 'Cavendish', 200, 0.59, 0), + ('Orange', 'Navel', 80, 1.29, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = AlloyDBReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ) + + documents = self._collect_items(reader.lazy_load_data()) + + # Compare the full list of documents to make sure all are in sync. + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata( + self, sync_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = AlloyDBReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_text_docs = [ + Document( + text="Granny Smith 150 1", + metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1}, + ) + ] + + for expected, actual in zip(expected_text_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + reader = AlloyDBReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="JSON", + ) + + actual_documents = self._collect_items(reader.lazy_load_data()) + + expected_docs = [ + Document( + text='{"variety": "Granny Smith", "quantity_in_stock": 150, "price_per_unit": 1}', + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, actual_documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_with_json(self, sync_engine): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}"( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety JSON NOT NULL, + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + li_metadata JSON NOT NULL + ) + """ + await aexecute(sync_engine, query) + + metadata = json.dumps({"organic": 1}) + variety = json.dumps({"type": "Granny Smith"}) + insert_query = f""" + INSERT INTO "{table_name}" + (fruit_name, variety, quantity_in_stock, price_per_unit, li_metadata) + VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" + await aexecute(sync_engine, insert_query) + + reader = AlloyDBReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "variety", + ], + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_docs = [ + Document( + text="1", + metadata={ + "variety": {"type": "Granny Smith"}, + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_formatter( + self, sync_engine + ): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + def my_formatter(row, content_columns): + return "-".join( + str(row[column]) for column in content_columns if column in row + ) + + reader = AlloyDBReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + formatter=my_formatter, + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_documents = [ + Document( + text="Granny Smith-150-1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_documents, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_page_content_format( + self, sync_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = AlloyDBReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="YAML", + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_docs = [ + Document( + text="variety: Granny Smith\nquantity_in_stock: 150\nprice_per_unit: 1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py index b55e582..2e7361a 100644 --- a/tests/test_vector_store.py +++ b/tests/test_vector_store.py @@ -14,7 +14,7 @@ import os import uuid -from typing import List, Sequence +from typing import Sequence import pytest import pytest_asyncio @@ -124,6 +124,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): await aexecute(sync_engine, f'DROP TABLE "{DEFAULT_TABLE}"') await sync_engine.close() + await sync_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): @@ -134,8 +135,9 @@ async def vs(self, engine): yield vs async def test_init_with_constructor(self, engine): + key = object() with pytest.raises(Exception): - AlloyDBVectorStore(engine, table_name=DEFAULT_TABLE) + AlloyDBVectorStore(key, engine, table_name=DEFAULT_TABLE) async def test_validate_id_column_create(self, engine, vs): test_id_column = "test_id_column" @@ -503,6 +505,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): await aexecute(sync_engine, f'DROP TABLE "{DEFAULT_TABLE}"') await sync_engine.close() + await sync_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): @@ -513,8 +516,9 @@ async def vs(self, engine): yield vs async def test_init_with_constructor(self, engine): + key = object() with pytest.raises(Exception): - AlloyDBVectorStore(engine, table_name=DEFAULT_TABLE) + AlloyDBVectorStore(key, engine, table_name=DEFAULT_TABLE) async def test_validate_id_column_create(self, engine, vs): test_id_column = "test_id_column" diff --git a/tests/test_vector_store_index.py b/tests/test_vector_store_index.py index 405f762..94bd002 100644 --- a/tests/test_vector_store_index.py +++ b/tests/test_vector_store_index.py @@ -14,15 +14,12 @@ import os -import sys import uuid -from typing import List, Sequence import pytest import pytest_asyncio -from llama_index.core.schema import MetadataMode, NodeRelationship, TextNode +from llama_index.core.schema import TextNode from sqlalchemy import text -from sqlalchemy.engine.row import RowMapping from sqlalchemy.ext.asyncio import create_async_engine from llama_index_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore @@ -38,11 +35,8 @@ DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE_ASYNC = "test_table" + str(uuid.uuid4()).replace("-", "_") -DEFAULT_TABLE_OMNI = "test_table" + str(uuid.uuid4()).replace("-", "_") -CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") DEFAULT_INDEX_NAME = DEFAULT_TABLE + DEFAULT_INDEX_NAME_SUFFIX DEFAULT_INDEX_NAME_ASYNC = DEFAULT_TABLE_ASYNC + DEFAULT_INDEX_NAME_SUFFIX -DEFAULT_INDEX_NAME_OMNI = DEFAULT_TABLE_OMNI + DEFAULT_INDEX_NAME_SUFFIX VECTOR_SIZE = 5 @@ -122,17 +116,58 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") await engine.close() + await engine._connector.close() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): - engine.init_vector_store_table(DEFAULT_TABLE, VECTOR_SIZE) + engine.init_vector_store_table( + DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True + ) vs = AlloyDBVectorStore.create_sync( engine, table_name=DEFAULT_TABLE, ) - await vs.async_add(nodes) + vs.add(nodes) + vs.drop_vector_index() + yield vs + + @pytest.fixture(scope="module") + def omni_host(self) -> str: + return get_env_var("OMNI_HOST", "AlloyDB Omni host address") + + @pytest.fixture(scope="module") + def omni_user(self) -> str: + return get_env_var("OMNI_USER", "AlloyDB Omni user name") + + @pytest.fixture(scope="module") + def omni_password(self) -> str: + return get_env_var("OMNI_PASSWORD", "AlloyDB Omni password") + + @pytest.fixture(scope="module") + def omni_database_name(self) -> str: + return get_env_var("OMNI_DATABASE_ID", "AlloyDB Omni database name") + @pytest_asyncio.fixture(scope="class") + async def omni_engine( + self, omni_host, omni_user, omni_password, omni_database_name + ): + connstring = f"postgresql+asyncpg://{omni_user}:{omni_password}@{omni_host}:5432/{omni_database_name}" + omni_engine = AlloyDBEngine.from_connection_string(connstring) + yield omni_engine + await aexecute(omni_engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") + await omni_engine.close() + + @pytest_asyncio.fixture(scope="class") + async def omni_vs(self, omni_engine): + omni_engine.init_vector_store_table( + DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True + ) + vs = AlloyDBVectorStore.create_sync( + omni_engine, + table_name=DEFAULT_TABLE, + ) + vs.add(nodes) vs.drop_vector_index() yield vs @@ -140,6 +175,7 @@ async def test_aapply_vector_index(self, vs): index = HNSWIndex() vs.apply_vector_index(index) assert vs.is_valid_index(DEFAULT_INDEX_NAME) + vs.drop_vector_index(DEFAULT_INDEX_NAME) async def test_areindex(self, vs): if not vs.is_valid_index(DEFAULT_INDEX_NAME): @@ -148,6 +184,7 @@ async def test_areindex(self, vs): vs.reindex() vs.reindex(DEFAULT_INDEX_NAME) assert vs.is_valid_index(DEFAULT_INDEX_NAME) + vs.drop_vector_index(DEFAULT_INDEX_NAME) async def test_dropindex(self, vs): vs.drop_vector_index() @@ -165,11 +202,54 @@ async def test_aapply_vector_index_ivfflat(self, vs): vs.apply_vector_index(index) assert vs.is_valid_index("secondindex") vs.drop_vector_index("secondindex") + vs.drop_vector_index(DEFAULT_INDEX_NAME) + + async def test_apply_vector_index_scann(self, vs): + index = ScaNNIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) + vs.set_maintenance_work_mem(index.num_leaves, VECTOR_SIZE) + vs.apply_vector_index(index, concurrently=True) + assert vs.is_valid_index(DEFAULT_INDEX_NAME) + index = ScaNNIndex( + name="secondindex", + distance_strategy=DistanceStrategy.COSINE_DISTANCE, + ) + vs.apply_vector_index(index) + assert vs.is_valid_index("secondindex") + vs.drop_vector_index("secondindex") + vs.drop_vector_index(DEFAULT_INDEX_NAME) async def test_is_valid_index(self, vs): is_valid = vs.is_valid_index("invalid_index") assert is_valid == False + async def test_aapply_vector_index_scann(self, vs): + index = ScaNNIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) + await vs.aset_maintenance_work_mem(index.num_leaves, VECTOR_SIZE) + await vs.aapply_vector_index(index, concurrently=True) + assert await vs.ais_valid_index(DEFAULT_INDEX_NAME) + index = ScaNNIndex( + name="secondindex", + distance_strategy=DistanceStrategy.COSINE_DISTANCE, + ) + await vs.aapply_vector_index(index) + assert await vs.ais_valid_index("secondindex") + await vs.adrop_vector_index("secondindex") + await vs.adrop_vector_index() + + async def test_apply_vector_index_scann_omni(self, omni_vs): + index = ScaNNIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) + omni_vs.set_maintenance_work_mem(index.num_leaves, VECTOR_SIZE) + omni_vs.apply_vector_index(index, concurrently=True) + assert omni_vs.is_valid_index(DEFAULT_INDEX_NAME) + index = ScaNNIndex( + name="secondindex", + distance_strategy=DistanceStrategy.COSINE_DISTANCE, + ) + omni_vs.apply_vector_index(index) + assert omni_vs.is_valid_index("secondindex") + omni_vs.drop_vector_index("secondindex") + omni_vs.drop_vector_index() + @pytest.mark.asyncio(loop_scope="class") class TestAsyncIndex: @@ -213,10 +293,13 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_ASYNC}") await engine.close() + await engine._connector.close() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): - await engine.ainit_vector_store_table(DEFAULT_TABLE_ASYNC, VECTOR_SIZE) + await engine.ainit_vector_store_table( + DEFAULT_TABLE_ASYNC, VECTOR_SIZE, overwrite_existing=True + ) vs = await AlloyDBVectorStore.create( engine, table_name=DEFAULT_TABLE_ASYNC, @@ -247,28 +330,30 @@ async def omni_engine( self, omni_host, omni_user, omni_password, omni_database_name ): connstring = f"postgresql+asyncpg://{omni_user}:{omni_password}@{omni_host}:5432/{omni_database_name}" - print(f"Connecting to AlloyDB Omni with {connstring}") - async_engine = create_async_engine(connstring) omni_engine = AlloyDBEngine.from_engine(async_engine) yield omni_engine - await aexecute(omni_engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_OMNI}") + await aexecute(omni_engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_ASYNC}") await omni_engine.close() @pytest_asyncio.fixture(scope="class") - async def omni_vs(self, engine): - await engine.ainit_vector_store_table(DEFAULT_TABLE_OMNI, VECTOR_SIZE) + async def omni_vs(self, omni_engine): + await omni_engine.ainit_vector_store_table( + DEFAULT_TABLE_ASYNC, VECTOR_SIZE, overwrite_existing=True + ) vs = await AlloyDBVectorStore.create( - engine, - table_name=DEFAULT_TABLE_OMNI, + omni_engine, + table_name=DEFAULT_TABLE_ASYNC, ) await vs.async_add(nodes) + await vs.adrop_vector_index() yield vs async def test_aapply_vector_index(self, vs): index = HNSWIndex() await vs.aapply_vector_index(index) assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) + await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC) async def test_areindex(self, vs): if not await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC): @@ -277,6 +362,7 @@ async def test_areindex(self, vs): await vs.areindex() await vs.areindex(DEFAULT_INDEX_NAME_ASYNC) assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) + await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC) async def test_dropindex(self, vs): await vs.adrop_vector_index() @@ -313,11 +399,11 @@ async def test_aapply_vector_index_ivf(self, vs): await vs.adrop_vector_index("secondindex") await vs.adrop_vector_index() - async def test_aapply_alloydb_scann_index_ScaNN(self, omni_vs): + async def test_aapply_vector_index_scann_omni(self, omni_vs): index = ScaNNIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) await omni_vs.aset_maintenance_work_mem(index.num_leaves, VECTOR_SIZE) await omni_vs.aapply_vector_index(index, concurrently=True) - assert await omni_vs.ais_valid_index(DEFAULT_INDEX_NAME_OMNI) + assert await omni_vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) index = ScaNNIndex( name="secondindex", distance_strategy=DistanceStrategy.COSINE_DISTANCE, @@ -326,3 +412,17 @@ async def test_aapply_alloydb_scann_index_ScaNN(self, omni_vs): assert await omni_vs.ais_valid_index("secondindex") await omni_vs.adrop_vector_index("secondindex") await omni_vs.adrop_vector_index() + + async def test_apply_vector_index_scann(self, vs): + index = ScaNNIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) + vs.set_maintenance_work_mem(index.num_leaves, VECTOR_SIZE) + vs.apply_vector_index(index, concurrently=True) + assert vs.is_valid_index(DEFAULT_INDEX_NAME_ASYNC) + index = ScaNNIndex( + name="secondindex", + distance_strategy=DistanceStrategy.COSINE_DISTANCE, + ) + vs.apply_vector_index(index) + assert vs.is_valid_index("secondindex") + vs.drop_vector_index("secondindex") + vs.drop_vector_index()