diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 3e5996ce..f3b29a86 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -31,10 +31,10 @@ jobs: steps: - name: Checkout Repository - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4 + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4 - name: Setup Python - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: python-version: "3.11" diff --git a/CHANGELOG.md b/CHANGELOG.md index b62c7d1d..a5bc737a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,33 @@ # Changelog +## [0.10.0](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/compare/v0.9.0...v0.10.0) (2024-09-17) + + +### ⚠ BREAKING CHANGES + +* support async and sync versions of indexing methods +* remove _aexecute(), _execute(), _afetch(), and _fetch() methods + +### Features + +* Add from_engine_args method ([de16842](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/de168427f9884f33332086b68308e1225ee9e952)) +* Add support for sync from_engine ([de16842](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/de168427f9884f33332086b68308e1225ee9e952)) +* Allow non-uuid data types for vectorstore primary key ([#209](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/issues/209)) ([ffaa87f](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/ffaa87fd864d1c3ffeb00a34370af9e986a37cf5)) +* Refactor to support both async and sync usage ([de16842](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/de168427f9884f33332086b68308e1225ee9e952)) + + +### Bug Fixes + +* Replacing cosine_similarity and maximal_marginal_relevance local methods with the ones in langchain core. ([#190](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/issues/190)) ([7f27092](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/7f2709225a1a5a71b33522dafd354dc7159c358f)) +* Support async and sync versions of indexing methods ([de16842](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/de168427f9884f33332086b68308e1225ee9e952)) +* Updating the minimum langchain core version to 0.2.36 ([#205](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/issues/205)) ([0651231](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/0651231b7d77e0451ae769f78fe6dce3e724dec4)) + + +### Documentation + +* Update sample python notebooks to reflect the support for custom schema. ([#204](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/issues/204)) ([7ef9335](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/7ef9335a45578273e9ffc0921f60a1c6cc3e89ed)) + + ## [0.9.0](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/compare/v0.8.0...v0.9.0) (2024-09-05) diff --git a/DEVELOPER.md b/DEVELOPER.md index 9e9ef215..a2089553 100644 --- a/DEVELOPER.md +++ b/DEVELOPER.md @@ -18,7 +18,7 @@ Learn more by reading [How should I write my commits?](https://github.com/google ### Run tests locally -1. Set environment variables for `INSTANCE_ID`, `DATABASE_ID`, `REGION`, `DB_USER`, `DB_PASSWORD` +1. Set environment variables for `INSTANCE_ID`, `DATABASE_ID`, `REGION`, `DB_USER`, `DB_PASSWORD`, `IAM_ACCOUNT`. 1. Run pytest to automatically run all tests: @@ -26,6 +26,14 @@ Learn more by reading [How should I write my commits?](https://github.com/google pytest ``` +Notes: + +* Tests use both IAM and built-in authentication. + * Learn how to set up a built-in databases user at [Cloud SQL built-in database authentication](https://cloud.google.com/sql/docs/postgres/built-in-authentication). + * Local tests will run against your `gcloud` credentials. Use `gcloud` to login with your personal account or a service account. This account will be used to run IAM tests. Learn how to set up access to the database at [Manage users with IAM database authentication](https://cloud.google.com/sql/docs/postgres/add-manage-iam-users). The "IAM_ACCOUNT" environment variable is also used to test authentication to override the local account. A personal account or a service account can be used for this test. + * You may need to grant access to the public schema for your new database user: `GRANT ALL ON SCHEMA public TO myaccount@example.com;` + + ### CI Platform Setup Cloud Build is used to run tests against Google Cloud resources in test project: langchain-cloud-sql-testing. @@ -121,4 +129,4 @@ The kokoro docs pipeline runs when a new release is created. See `.kokoro/` for [triggers]: https://console.cloud.google.com/cloud-build/triggers?e=13802955&project=langchain-cloud-sql-testing [vectorstore]: https://github.com/googleapis/langchain-google-cloud-sql-pg-python/tree/main/docs/vector_store.ipynb [loader]: https://github.com/googleapis/langchain-google-cloud-sql-pg-python/tree/main/docs/document_loader.ipynb -[history]: https://github.com/googleapis/langchain-google-cloud-sql-pg-python/tree/main/docs/chat_message_history.ipynb \ No newline at end of file +[history]: https://github.com/googleapis/langchain-google-cloud-sql-pg-python/tree/main/docs/chat_message_history.ipynb diff --git a/docs/chat_message_history.ipynb b/docs/chat_message_history.ipynb index e9493dee..02e6f04f 100644 --- a/docs/chat_message_history.ipynb +++ b/docs/chat_message_history.ipynb @@ -287,6 +287,24 @@ "engine.init_chat_history_table(table_name=TABLE_NAME)" ] }, + { + "cell_type": "markdown", + "id": "345b76b8", + "metadata": {}, + "source": [ + "#### Optional Tip: 💡\n", + "You can also specify a schema name by passing `schema_name` wherever you pass `table_name`. Eg:\n", + "\n", + "```python\n", + "SCHEMA_NAME=\"my_schema\"\n", + "\n", + "engine.init_chat_history_table(\n", + " table_name=TABLE_NAME,\n", + " schema_name=SCHEMA_NAME # Default: \"public\"\n", + ")\n", + "```" + ] + }, { "cell_type": "markdown", "id": "zSYQTYf3UfOi", @@ -300,7 +318,8 @@ "\n", "1. `engine` - An instance of a `PostgresEngine` engine.\n", "1. `session_id` - A unique identifier string that specifies an id for the session.\n", - "1. `table_name` : The name of the table within the Cloud SQL database to store the chat message history." + "1. `table_name` : The name of the table within the Cloud SQL database to store the chat message history.\n", + "1. `schema_name` : The name of the database schema containing the chat message history table." ] }, { @@ -315,7 +334,10 @@ "from langchain_google_cloud_sql_pg import PostgresChatMessageHistory\n", "\n", "history = PostgresChatMessageHistory.create_sync(\n", - " engine, session_id=\"test_session\", table_name=TABLE_NAME\n", + " engine,\n", + " session_id=\"test_session\",\n", + " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", ")\n", "history.add_user_message(\"hi!\")\n", "history.add_ai_message(\"whats up?\")" @@ -456,6 +478,7 @@ " engine,\n", " session_id=session_id,\n", " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", " ),\n", " input_messages_key=\"question\",\n", " history_messages_key=\"history\",\n", diff --git a/docs/document_loader.ipynb b/docs/document_loader.ipynb index 84db4ae4..7be30a9b 100644 --- a/docs/document_loader.ipynb +++ b/docs/document_loader.ipynb @@ -257,6 +257,25 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Optional Tip: 💡\n", + "You can also specify a schema name by passing `schema_name` wherever you pass `table_name`. Eg:\n", + "\n", + "```python\n", + "SCHEMA_NAME=\"my_schema\"\n", + "\n", + "await engine.ainit_document_table(\n", + " table_name=TABLE_NAME,\n", + " schema_name=SCHEMA_NAME # Default: \"public\"\n", + " \n", + " ...\n", + ")\n", + "```" + ] + }, { "cell_type": "markdown", "metadata": { @@ -277,7 +296,11 @@ "from langchain_google_cloud_sql_pg import PostgresLoader\n", "\n", "# Creating a basic PostgreSQL object\n", - "loader = await PostgresLoader.create(engine, table_name=TABLE_NAME)" + "loader = await PostgresLoader.create(\n", + " engine,\n", + " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", + ")" ] }, { @@ -304,7 +327,11 @@ "from langchain_google_cloud_sql_pg import PostgresLoader\n", "\n", "# Creating a basic PostgresLoader object\n", - "loader = await PostgresLoader.create(engine, table_name=TABLE_NAME)\n", + "loader = await PostgresLoader.create(\n", + " engine,\n", + " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", + ")\n", "\n", "docs = await loader.aload()\n", "print(docs)" @@ -328,6 +355,7 @@ "loader = await PostgresLoader.create(\n", " engine,\n", " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", " content_columns=[\"product_name\"], # Optional\n", " metadata_columns=[\"id\"], # Optional\n", ")\n", @@ -356,6 +384,7 @@ "loader = await PostgresLoader.create(\n", " engine,\n", " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", " content_columns=[\"product_name\", \"description\"],\n", " format=\"YAML\",\n", ")\n", @@ -383,6 +412,7 @@ "saver = await PostgresDocumentSaver.create(\n", " engine,\n", " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", " content_column=\"product_name\",\n", " metadata_columns=[\"description\", \"content\"],\n", " metadata_json_column=\"metadata\",\n", @@ -427,7 +457,7 @@ "metadata": {}, "source": [ "### Load the documents with PostgresLoader\n", - "PostgresLoader can be used with `TABLE_NAME` to query and load the whole table." + "PostgresLoader can be used with `TABLE_NAME` (and optionally `SCHEMA_NAME`) to query and load the whole table." ] }, { @@ -436,7 +466,11 @@ "metadata": {}, "outputs": [], "source": [ - "loader = await PostgresLoader.create(engine, table_name=TABLE_NAME)\n", + "loader = await PostgresLoader.create(\n", + " engine,\n", + " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", + ")\n", "docs = await loader.aload()\n", "\n", "print(docs)" diff --git a/docs/vector_store.ipynb b/docs/vector_store.ipynb index 60839763..a253b3a6 100644 --- a/docs/vector_store.ipynb +++ b/docs/vector_store.ipynb @@ -258,6 +258,25 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Optional Tip: 💡\n", + "You can also specify a schema name by passing `schema_name` wherever you pass `table_name`. Eg:\n", + "\n", + "```python\n", + "SCHEMA_NAME=\"my_schema\"\n", + "\n", + "await engine.ainit_vectorstore_table(\n", + " table_name=TABLE_NAME,\n", + " schema_name=SCHEMA_NAME, # Default: \"public\"\n", + " \n", + " ...\n", + ")\n", + "```" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -322,6 +341,7 @@ "store = await PostgresVectorStore.create( # Use .create() to initialize an async vector store\n", " engine=engine,\n", " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", " embedding_service=embedding,\n", ")" ] @@ -365,6 +385,7 @@ " ids=ids,\n", " engine=engine,\n", " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", " embedding_service=embedding,\n", ")" ] @@ -515,9 +536,11 @@ "\n", "# Set table name\n", "TABLE_NAME = \"vectorstore_custom\"\n", + "# SCHEMA_NAME = \"my_schema\"\n", "\n", "await engine.ainit_vectorstore_table(\n", " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", " vector_size=768, # VertexAI model: textembedding-gecko@latest\n", " metadata_columns=[Column(\"len\", \"INTEGER\")],\n", ")\n", @@ -527,6 +550,7 @@ "custom_store = await PostgresVectorStore.create(\n", " engine=engine,\n", " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", " embedding_service=embedding,\n", " metadata_columns=[\"len\"],\n", " # Connect to a existing VectorStore by customizing the table schema:\n", diff --git a/integration.cloudbuild.yaml b/integration.cloudbuild.yaml index 3bece7aa..e685c3bb 100644 --- a/integration.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -23,16 +23,30 @@ steps: entrypoint: pip args: ["install", ".[test]", "--user"] + - id: proxy-install + name: alpine:3.10 + entrypoint: sh + args: + - -c + - | + wget -O /workspace/cloud-sql-proxy https://storage.googleapis.com/cloud-sql-connectors/cloud-sql-proxy/v2.13.0/cloud-sql-proxy.linux.386 + chmod +x /workspace/cloud-sql-proxy + - id: Run integration tests name: python:${_VERSION} - entrypoint: python - args: ["-m", "pytest", "--cov=langchain_google_cloud_sql_pg", "--cov-config=.coveragerc", "tests/"] + entrypoint: /bin/bash env: - "PROJECT_ID=$PROJECT_ID" - "INSTANCE_ID=$_INSTANCE_ID" - "DATABASE_ID=$_DATABASE_ID" - "REGION=$_REGION" + - "IP_ADDRESS=$_IP_ADDRESS" secretEnv: ["DB_USER", "DB_PASSWORD", "IAM_ACCOUNT"] + args: + - "-c" + - | + /workspace/cloud-sql-proxy ${_INSTANCE_CONNECTION_NAME} --port $_DATABASE_PORT & sleep 2; + python -m pytest --cov=langchain_google_cloud_sql_pg --cov-config=.coveragerc tests/ availableSecrets: secretManager: @@ -44,9 +58,12 @@ availableSecrets: env: "IAM_ACCOUNT" substitutions: + _INSTANCE_CONNECTION_NAME: ${PROJECT_ID}:${_REGION}:${_INSTANCE_ID} + _DATABASE_PORT: "5432" _DATABASE_ID: test-database _REGION: us-central1 _VERSION: "3.8" + _IP_ADDRESS: "127.0.0.1" options: dynamicSubstitutions: true diff --git a/pyproject.toml b/pyproject.toml index b6627534..83a3bc23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ dependencies = [ "cloud-sql-python-connector[asyncpg] >= 1.10.0, <2.0.0", - "langchain-core>=0.1.1, <1.0.0 ", + "langchain-core>=0.2.36, <1.0.0 ", "numpy>=1.24.4, <2.0.0", "pgvector>=0.2.5, <1.0.0", "SQLAlchemy[asyncio]>=2.0.25, <3.0.0" diff --git a/requirements.txt b/requirements.txt index c6a36c54..ec4b6dd2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -cloud-sql-python-connector[asyncpg]==1.10.0 -langchain-core==0.2.35 +cloud-sql-python-connector[asyncpg]==1.12.0 +langchain-core==0.2.38 numpy==1.24.4; python_version<='3.8' numpy==1.26.4; python_version>'3.8' pgvector==0.3.2 -SQLAlchemy[asyncio]==2.0.32 +SQLAlchemy[asyncio]==2.0.34 diff --git a/samples/index_tuning_sample/requirements.txt b/samples/index_tuning_sample/requirements.txt index 0a7a2867..d59a3563 100644 --- a/samples/index_tuning_sample/requirements.txt +++ b/samples/index_tuning_sample/requirements.txt @@ -1,3 +1,3 @@ -langchain-google-cloud-sql-pg==0.7.0 -langchain==0.2.14 +langchain-community==0.2.16 +langchain-google-cloud-sql-pg==0.9.0 langchain-google-vertexai==1.0.10 \ No newline at end of file diff --git a/samples/langchain_on_vertexai/README.md b/samples/langchain_on_vertexai/README.md index 63974b05..3c6ea3dc 100644 --- a/samples/langchain_on_vertexai/README.md +++ b/samples/langchain_on_vertexai/README.md @@ -30,5 +30,3 @@ Build and deploy an Agent with RAG tool and Memory | [retriever_agent_with_histo 1. Use [`create_embeddings.py`](create_embeddings.py) to add data to your vector store. Learn more at [Deploying a RAG Application with Cloud SQL for Postgres to LangChain on Vertex AI](https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/reasoning-engine/tutorial_cloud_sql_pg_rag_agent.ipynb). - - diff --git a/samples/langchain_on_vertexai/prebuilt_langchain_agent_template.py b/samples/langchain_on_vertexai/prebuilt_langchain_agent_template.py index 64bbda0c..f8c6ca71 100644 --- a/samples/langchain_on_vertexai/prebuilt_langchain_agent_template.py +++ b/samples/langchain_on_vertexai/prebuilt_langchain_agent_template.py @@ -99,11 +99,7 @@ def similarity_search(query: str) -> List[Document]: "temperature": 0.1, }, ), - requirements=[ - "google-cloud-aiplatform[reasoningengine,langchain]==1.57.0", - "langchain-google-cloud-sql-pg==0.6.1", - "langchain-google-vertexai==1.0.4", - ], + requirements="requirements.txt", display_name=DISPLAY_NAME, sys_version="3.11", extra_packages=["config.py"], diff --git a/samples/langchain_on_vertexai/requirements.txt b/samples/langchain_on_vertexai/requirements.txt index 95f38d0f..5996cbf4 100644 --- a/samples/langchain_on_vertexai/requirements.txt +++ b/samples/langchain_on_vertexai/requirements.txt @@ -1,4 +1,5 @@ -google-cloud-aiplatform[reasoningengine,langchain]==1.57.0 -langchain-google-cloud-sql-pg==0.6.1 -langchain-google-vertexai==1.0.4 -google-cloud-resource-manager==1.12.3 \ No newline at end of file +google-cloud-aiplatform[reasoningengine,langchain]==1.65.0 +google-cloud-resource-manager==1.12.5 +langchain-community==0.2.16 +langchain-google-cloud-sql-pg==0.9.0 +langchain-google-vertexai==1.0.10 diff --git a/samples/langchain_on_vertexai/retriever_agent_with_history_template.py b/samples/langchain_on_vertexai/retriever_agent_with_history_template.py index 548d5e67..df8a1f8e 100644 --- a/samples/langchain_on_vertexai/retriever_agent_with_history_template.py +++ b/samples/langchain_on_vertexai/retriever_agent_with_history_template.py @@ -186,12 +186,7 @@ def query(self, input: str, session_id: str) -> str: user=USER, password=PASSWORD, ), - requirements=[ - "google-cloud-aiplatform[reasoningengine,langchain]==1.57.0", - "langchain-google-cloud-sql-pg==0.6.1", - "langchain-google-vertexai==1.0.4", - "langchainhub==0.1.20", - ], + requirements="requirements.txt", display_name=DISPLAY_NAME, sys_version="3.11", extra_packages=["config.py"], diff --git a/samples/langchain_on_vertexai/retriever_chain_template.py b/samples/langchain_on_vertexai/retriever_chain_template.py index d65842db..c032cb79 100644 --- a/samples/langchain_on_vertexai/retriever_chain_template.py +++ b/samples/langchain_on_vertexai/retriever_chain_template.py @@ -155,11 +155,7 @@ def query(self, input: str) -> str: user=USER, password=PASSWORD, ), - requirements=[ - "google-cloud-aiplatform[reasoningengine,langchain]==1.57.0", - "langchain-google-cloud-sql-pg==0.6.1", - "langchain-google-vertexai==1.0.4", - ], + requirements="requirements.txt", display_name=DISPLAY_NAME, sys_version="3.11", extra_packages=["config.py"], diff --git a/samples/requirements.txt b/samples/requirements.txt index 462950ce..13db4cc8 100644 --- a/samples/requirements.txt +++ b/samples/requirements.txt @@ -1,4 +1,5 @@ -google-cloud-aiplatform[reasoningengine,langchain] -langchain-google-vertexai -langchain-community -google-cloud-resource-manager \ No newline at end of file +google-cloud-aiplatform[reasoningengine,langchain]==1.65.0 +google-cloud-resource-manager==1.12.5 +langchain-community==0.2.16 +langchain-google-cloud-sql-pg==0.9.0 +langchain-google-vertexai==1.0.10 \ No newline at end of file diff --git a/src/langchain_google_cloud_sql_pg/async_chat_message_history.py b/src/langchain_google_cloud_sql_pg/async_chat_message_history.py new file mode 100644 index 00000000..a7873aae --- /dev/null +++ b/src/langchain_google_cloud_sql_pg/async_chat_message_history.py @@ -0,0 +1,148 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import List, Sequence + +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import BaseMessage, messages_from_dict +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import PostgresEngine + + +class AsyncPostgresChatMessageHistory(BaseChatMessageHistory): + """Chat message history stored in an Cloud SQL for PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, + key: object, + pool: AsyncEngine, + session_id: str, + table_name: str, + schema_name: str = "public", + ): + """AsyncPostgresChatMessageHistory constructor. + + Args: + key (object): Key to prevent direct constructor usage. + engine (PostgresEngine): Database connection pool. + session_id (str): Retrieve the table content with this session ID. + table_name (str): Table name that stores the chat message history. + schema_name (str, optional): Database schema name of the chat message history table. Defaults to "public". + + Raises: + Exception: If constructor is directly called by the user. + """ + if key != AsyncPostgresChatMessageHistory.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + self.pool = pool + self.session_id = session_id + self.table_name = table_name + self.schema_name = schema_name + + @classmethod + async def create( + cls, + engine: PostgresEngine, + session_id: str, + table_name: str, + schema_name: str = "public", + ) -> AsyncPostgresChatMessageHistory: + """Create a new AsyncPostgresChatMessageHistory instance. + + Args: + engine (PostgresEngine): Postgres engine to use. + session_id (str): Retrieve the table content with this session ID. + table_name (str): Table name that stores the chat message history. + schema_name (str, optional): Database schema name for the chat message history table. Defaults to "public". + + Raises: + IndexError: If the table provided does not contain required schema. + + Returns: + AsyncPostgresChatMessageHistory: A newly created instance of AsyncPostgresChatMessageHistory. + """ + table_schema = await engine._aload_table_schema(table_name, schema_name) + column_names = table_schema.columns.keys() + + required_columns = ["id", "session_id", "data", "type"] + + if not (all(x in column_names for x in required_columns)): + raise IndexError( + f"Table '{schema_name}'.'{table_name}' has incorrect schema. Got " + f"column names '{column_names}' but required column names " + f"'{required_columns}'.\nPlease create table with following schema:" + f"\nCREATE TABLE {schema_name}.{table_name} (" + "\n id INT AUTO_INCREMENT PRIMARY KEY," + "\n session_id TEXT NOT NULL," + "\n data JSON NOT NULL," + "\n type TEXT NOT NULL" + "\n);" + ) + return cls(cls.__create_key, engine._pool, session_id, table_name) + + async def aadd_message(self, message: BaseMessage) -> None: + """Append the message to the record in PostgreSQL""" + query = f"""INSERT INTO "{self.schema_name}"."{self.table_name}"(session_id, data, type) + VALUES (:session_id, :data, :type); + """ + async with self.pool.connect() as conn: + await conn.execute( + text(query), + { + "session_id": self.session_id, + "data": json.dumps(message.dict()), + "type": message.type, + }, + ) + await conn.commit() + + async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: + """Append a list of messages to the record in PostgreSQL""" + for message in messages: + await self.aadd_message(message) + + async def aclear(self) -> None: + """Clear session memory from PostgreSQL""" + query = f"""DELETE FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id;""" + async with self.pool.connect() as conn: + await conn.execute(text(query), {"session_id": self.session_id}) + await conn.commit() + + async def _aget_messages(self) -> List[BaseMessage]: + """Retrieve the messages from PostgreSQL.""" + query = f"""SELECT data, type FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id ORDER BY id;""" + async with self.pool.connect() as conn: + result = await conn.execute(text(query), {"session_id": self.session_id}) + result_map = result.mappings() + results = result_map.fetchall() + if not results: + return [] + + items = [{"data": result["data"], "type": result["type"]} for result in results] + messages = messages_from_dict(items) + return messages + + def clear(self) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatMessageHistory. Use PostgresChatMessageHistory interface instead." + ) diff --git a/src/langchain_google_cloud_sql_pg/async_loader.py b/src/langchain_google_cloud_sql_pg/async_loader.py new file mode 100644 index 00000000..90e94526 --- /dev/null +++ b/src/langchain_google_cloud_sql_pg/async_loader.py @@ -0,0 +1,450 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import Any, AsyncIterator, Callable, Dict, Iterable, List, Optional + +from langchain_core.document_loaders.base import BaseLoader +from langchain_core.documents import Document +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import PostgresEngine + +DEFAULT_CONTENT_COL = "page_content" +DEFAULT_METADATA_COL = "langchain_metadata" + + +def text_formatter(row: dict, content_columns: List[str]) -> str: + """txt document formatter.""" + return " ".join(str(row[column]) for column in content_columns if column in row) + + +def csv_formatter(row: dict, content_columns: List[str]) -> str: + """CSV document formatter.""" + return ", ".join(str(row[column]) for column in content_columns if column in row) + + +def yaml_formatter(row: dict, content_columns: List[str]) -> str: + """YAML document formatter.""" + return "\n".join( + f"{column}: {str(row[column])}" for column in content_columns if column in row + ) + + +def json_formatter(row: dict, content_columns: List[str]) -> str: + """JSON document formatter.""" + dictionary = {} + for column in content_columns: + if column in row: + dictionary[column] = row[column] + return json.dumps(dictionary) + + +def _parse_doc_from_row( + content_columns: Iterable[str], + metadata_columns: Iterable[str], + row: dict, + metadata_json_column: Optional[str] = DEFAULT_METADATA_COL, + formatter: Callable = text_formatter, +) -> Document: + """Parse row into document.""" + page_content = formatter(row, content_columns) + metadata: Dict[str, Any] = {} + # unnest metadata from langchain_metadata column + if metadata_json_column and row.get(metadata_json_column): + for k, v in row[metadata_json_column].items(): + metadata[k] = v + # load metadata from other columns + for column in metadata_columns: + if column in row and column != metadata_json_column: + metadata[column] = row[column] + + return Document(page_content=page_content, metadata=metadata) + + +def _parse_row_from_doc( + doc: Document, + column_names: Iterable[str], + content_column: str = DEFAULT_CONTENT_COL, + metadata_json_column: Optional[str] = DEFAULT_METADATA_COL, +) -> Dict: + """Parse document into a dictionary of rows.""" + doc_metadata = doc.metadata.copy() + row: Dict[str, Any] = {content_column: doc.page_content} + for entry in doc.metadata: + if entry in column_names: + row[entry] = doc_metadata[entry] + del doc_metadata[entry] + # store extra metadata in langchain_metadata column in json format + if metadata_json_column: + row[metadata_json_column] = doc_metadata + return row + + +class AsyncPostgresLoader(BaseLoader): + """Load documents from PostgreSQL`. + + Each document represents one row of the result. The `content_columns` are + written into the `content_columns`of the document. The `metadata_columns` are written + into the `metadata_columns` of the document. By default, first columns is written into + the `page_content` and everything else into the `metadata`. + """ + + __create_key = object() + + def __init__( + self, + key: object, + pool: AsyncEngine, + query: str, + content_columns: List[str], + metadata_columns: List[str], + formatter: Callable, + metadata_json_column: Optional[str] = None, + ) -> None: + """AsyncPostgresLoader constructor. + + Args: + key (object): Prevent direct constructor usage. + engine (PostgresEngine): AsyncEngine with pool connection to the postgres database + query (Optional[str], optional): SQL query. Defaults to None. + content_columns (Optional[List[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[List[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "langchain_metadata". + + + Raises: + Exception: If called directly by user. + """ + if key != AsyncPostgresLoader.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + + self.pool = pool + self.query = query + self.content_columns = content_columns + self.metadata_columns = metadata_columns + self.formatter = formatter + self.metadata_json_column = metadata_json_column + + @classmethod + async def create( + cls, + engine: PostgresEngine, + query: Optional[str] = None, + table_name: Optional[str] = None, + schema_name: str = "public", + content_columns: Optional[List[str]] = None, + metadata_columns: Optional[List[str]] = None, + metadata_json_column: Optional[str] = None, + format: Optional[str] = None, + formatter: Optional[Callable] = None, + ) -> AsyncPostgresLoader: + """Create a new AsyncPostgresLoader instance. + + Args: + engine (PostgresEngine):AsyncEngine with pool connection to the postgres database + query (Optional[str], optional): SQL query. Defaults to None. + table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Database schema name of the table. Defaults to "public". + content_columns (Optional[List[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[List[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "langchain_metadata". + format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + + Returns: + AsyncPostgresLoader + """ + if table_name and query: + raise ValueError("Only one of 'table_name' or 'query' should be specified.") + if not table_name and not query: + raise ValueError( + "At least one of the parameters 'table_name' or 'query' needs to be provided" + ) + if format and formatter: + raise ValueError("Only one of 'format' or 'formatter' should be specified.") + + if format and format not in ["csv", "text", "JSON", "YAML"]: + raise ValueError("format must be type: 'csv', 'text', 'JSON', 'YAML'") + if formatter: + formatter = formatter + elif format == "csv": + formatter = csv_formatter + elif format == "YAML": + formatter = yaml_formatter + elif format == "JSON": + formatter = json_formatter + else: + formatter = text_formatter + + if not query: + query = f'SELECT * FROM "{schema_name}"."{table_name}"' + + async with engine._pool.connect() as connection: + result_proxy = await connection.execute(text(query)) + column_names = list(result_proxy.keys()) + # Select content or default to first column + content_columns = content_columns or [column_names[0]] + # Select metadata columns + metadata_columns = metadata_columns or [ + col for col in column_names if col not in content_columns + ] + # Check validity of metadata json column + if metadata_json_column and metadata_json_column not in column_names: + raise ValueError( + f"Column {metadata_json_column} not found in query result {column_names}." + ) + # use default metadata json column if not specified + if metadata_json_column and metadata_json_column in column_names: + metadata_json_column = metadata_json_column + elif DEFAULT_METADATA_COL in column_names: + metadata_json_column = DEFAULT_METADATA_COL + else: + metadata_json_column = None + + # check validity of other column + all_names = content_columns + metadata_columns + for name in all_names: + if name not in column_names: + raise ValueError( + f"Column {name} not found in query result {column_names}." + ) + return cls( + cls.__create_key, + engine._pool, + query, + content_columns, + metadata_columns, + formatter, + metadata_json_column, + ) + + async def aload(self) -> List[Document]: + """Load PostgreSQL data into Document objects.""" + return [doc async for doc in self.alazy_load()] + + async def alazy_load(self) -> AsyncIterator[Document]: + """Load PostgreSQL data into Document objects lazily.""" + async with self.pool.connect() as connection: + result_proxy = await connection.execute(text(self.query)) + # load document one by one + while True: + row = result_proxy.fetchone() + if not row: + break + + row_data = {} + column_names = self.content_columns + self.metadata_columns + column_names += ( + [self.metadata_json_column] if self.metadata_json_column else [] + ) + for column in column_names: + value = getattr(row, column) + row_data[column] = value + + yield _parse_doc_from_row( + self.content_columns, + self.metadata_columns, + row_data, + self.metadata_json_column, + self.formatter, + ) + + +class AsyncPostgresDocumentSaver: + """A class for saving langchain documents into a PostgreSQL database table.""" + + __create_key = object() + + def __init__( + self, + key: object, + pool: AsyncEngine, + table_name: str, + content_column: str, + schema_name: str = "public", + metadata_columns: List[str] = [], + metadata_json_column: Optional[str] = None, + ): + """AsyncPostgresDocumentSaver constructor. + + Args: + key (object): Prevent direct constructor usage. + engine (PostgresEngine): AsyncEngine with pool connection to the postgres database + table_name (Optional[str], optional): Name of table to query. Defaults to None. + content_columns (Optional[List[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + schema_name (str, optional): Database schema name of the table. Defaults to "public". + metadata_columns (Optional[List[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "langchain_metadata". + + Raises: + Exception: if called directly by user. + """ + if key != AsyncPostgresDocumentSaver.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + self.pool = pool + self.table_name = table_name + self.content_column = content_column + self.schema_name = schema_name + self.metadata_columns = metadata_columns + self.metadata_json_column = metadata_json_column + + @classmethod + async def create( + cls, + engine: PostgresEngine, + table_name: str, + schema_name: str = "public", + content_column: str = DEFAULT_CONTENT_COL, + metadata_columns: List[str] = [], + metadata_json_column: Optional[str] = DEFAULT_METADATA_COL, + ) -> AsyncPostgresDocumentSaver: + """Create an AsyncPostgresDocumentSaver instance. + + Args: + engine (PostgresEngine):AsyncEngine with pool connection to the postgres database + table_name (Optional[str], optional): Name of table to query. Defaults to None. + content_columns (Optional[List[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[List[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "langchain_metadata". + + Returns: + AsyncPostgresDocumentSaver + """ + table_schema = await engine._aload_table_schema(table_name, schema_name) + column_names = table_schema.columns.keys() + if content_column not in column_names: + raise ValueError(f"Content column, {content_column}, does not exist.") + + # Set metadata columns to all columns if not set + if len(metadata_columns) == 0: + metadata_columns = [ + column + for column in column_names + if column != content_column and column != metadata_json_column + ] + + # Check and set metadata json column + for column in metadata_columns: + if column not in column_names: + raise ValueError(f"Metadata column, {column}, does not exist.") + + if ( + metadata_json_column + and metadata_json_column != DEFAULT_METADATA_COL + and metadata_json_column not in column_names + ): + raise ValueError(f"Metadata JSON column, {column}, does not exist.") + elif metadata_json_column not in column_names: + metadata_json_column = None + + return cls( + cls.__create_key, + engine._pool, + table_name, + content_column, + schema_name, + metadata_columns, + metadata_json_column, + ) + + async def aadd_documents(self, docs: List[Document]) -> None: + """ + Save documents in the DocumentSaver table. Document’s metadata is added to columns if found or + stored in langchain_metadata JSON column. + + Args: + docs (List[langchain_core.documents.Document]): a list of documents to be saved. + """ + + for doc in docs: + row = _parse_row_from_doc( + doc, + self.metadata_columns, + self.content_column, + self.metadata_json_column, + ) + for key, value in row.items(): + if isinstance(value, dict): + row[key] = json.dumps(value) + + # Create list of column names + insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"({self.content_column}' + values_stmt = f"VALUES (:{self.content_column}" + + # Add metadata + for metadata_column in self.metadata_columns: + if metadata_column in doc.metadata: + insert_stmt += f", {metadata_column}" + values_stmt += f", :{metadata_column}" + + # Add JSON column and/or close statement + insert_stmt += ( + f", {self.metadata_json_column})" if self.metadata_json_column else ")" + ) + if self.metadata_json_column: + values_stmt += f", :{self.metadata_json_column})" + else: + values_stmt += ")" + + query = insert_stmt + values_stmt + async with self.pool.connect() as conn: + await conn.execute(text(query), row) + await conn.commit() + + async def adelete(self, docs: List[Document]) -> None: + """ + Delete all instances of a document from the DocumentSaver table by matching the entire Document + object. + + Args: + docs (List[langchain_core.documents.Document]): a list of documents to be deleted. + """ + for doc in docs: + row = _parse_row_from_doc( + doc, + self.metadata_columns, + self.content_column, + self.metadata_json_column, + ) + # delete by matching all fields of document + where_conditions_list = [] + for key, value in row.items(): + if isinstance(value, dict): + where_conditions_list.append( + f"{key}::jsonb @> '{json.dumps(value)}'::jsonb" + ) + else: + # Handle simple key-value pairs + where_conditions_list.append(f"{key} = :{key}") + + where_conditions = " AND ".join(where_conditions_list) + stmt = f'DELETE FROM "{self.schema_name}"."{self.table_name}" WHERE {where_conditions};' + values = {} + for key, value in row.items(): + if type(value) is int: + values[key] = str(value) + else: + values[key] = value + async with self.pool.connect() as conn: + await conn.execute(text(stmt), values) + await conn.commit() diff --git a/src/langchain_google_cloud_sql_pg/async_vectorstore.py b/src/langchain_google_cloud_sql_pg/async_vectorstore.py new file mode 100644 index 00000000..fcf92dbf --- /dev/null +++ b/src/langchain_google_cloud_sql_pg/async_vectorstore.py @@ -0,0 +1,898 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: Remove below import when minimum supported Python version is 3.10 +from __future__ import annotations + +import json +import uuid +from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type + +import numpy as np +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.vectorstores import VectorStore, utils +from sqlalchemy import text +from sqlalchemy.engine.row import RowMapping +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import PostgresEngine +from .indexes import ( + DEFAULT_DISTANCE_STRATEGY, + DEFAULT_INDEX_NAME_SUFFIX, + BaseIndex, + DistanceStrategy, + ExactNearestNeighbor, + QueryOptions, +) + + +class AsyncPostgresVectorStore(VectorStore): + """Google Cloud SQL for PostgreSQL Vector Store class""" + + __create_key = object() + + def __init__( + self, + key: object, + pool: AsyncEngine, + embedding_service: Embeddings, + table_name: str, + schema_name: str = "public", + content_column: str = "content", + embedding_column: str = "embedding", + metadata_columns: List[str] = [], + id_column: str = "langchain_id", + metadata_json_column: Optional[str] = "langchain_metadata", + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + index_query_options: Optional[QueryOptions] = None, + ): + """AsyncPostgresVectorStore constructor. + Args: + key (object): Prevent direct constructor usage. + pool (PostgresEngine): Connection pool engine for managing connections to Postgres database. + embedding_service (Embeddings): Text embedding model to use. + table_name (str): Name of the existing table or the table to be created. + schema_name (str, optional): Database schema name of the table. Defaults to "public". + content_column (str): Column that represent a Document's page_content. Defaults to "content". + embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". + metadata_columns (List[str]): Column(s) that represent a document's metadata. + id_column (str): Column that represents the Document's id. Defaults to "langchain_id". + metadata_json_column (str): Column to store metadata as JSON. Defaults to "langchain_metadata". + distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. + k (int): Number of Documents to return from search. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. + index_query_options (QueryOptions): Index query option. + + + Raises: + Exception: If called directly by user. + """ + if key != AsyncPostgresVectorStore.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + + self.pool = pool + self.embedding_service = embedding_service + self.table_name = table_name + self.schema_name = schema_name + self.content_column = content_column + self.embedding_column = embedding_column + self.metadata_columns = metadata_columns + self.id_column = id_column + self.metadata_json_column = metadata_json_column + self.distance_strategy = distance_strategy + self.k = k + self.fetch_k = fetch_k + self.lambda_mult = lambda_mult + self.index_query_options = index_query_options + + @classmethod + async def create( + cls, + engine: PostgresEngine, + embedding_service: Embeddings, + table_name: str, + schema_name: str = "public", + content_column: str = "content", + embedding_column: str = "embedding", + metadata_columns: List[str] = [], + ignore_metadata_columns: Optional[List[str]] = None, + id_column: str = "langchain_id", + metadata_json_column: Optional[str] = "langchain_metadata", + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + index_query_options: Optional[QueryOptions] = None, + ) -> AsyncPostgresVectorStore: + """Create a new AsyncPostgresVectorStore instance. + + Args: + engine (PostgresEngine): Connection pool engine for managing connections to Cloud SQL for PostgreSQL database. + embedding_service (Embeddings): Text embedding model to use. + table_name (str): Name of an existing table or table to be created. + schema_name (str, optional): Database schema name of the table. Defaults to "public". + content_column (str): Column that represent a Document's page_content. Defaults to "content". + embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". + metadata_columns (List[str]): Column(s) that represent a document's metadata. + ignore_metadata_columns (List[str]): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None. + id_column (str): Column that represents the Document's id. Defaults to "langchain_id". + metadata_json_column (str): Column to store metadata as JSON. Defaults to "langchain_metadata". + distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. + k (int): Number of Documents to return from search. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. + index_query_options (QueryOptions): Index query option. + + Returns: + AsyncPostgresVectorStore + """ + if metadata_columns and ignore_metadata_columns: + raise ValueError( + "Can not use both metadata_columns and ignore_metadata_columns." + ) + # Get field type information + async with engine._pool.connect() as conn: + result = await conn.execute( + text( + f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}'AND table_schema = '{schema_name}'" + ) + ) + result_map = result.mappings() + results = result_map.fetchall() + + columns = {} + for field in results: + columns[field["column_name"]] = field["data_type"] + + # Check columns + if id_column not in columns: + raise ValueError(f"Id column, {id_column}, does not exist.") + if content_column not in columns: + raise ValueError(f"Content column, {content_column}, does not exist.") + content_type = columns[content_column] + if content_type != "text" and "char" not in content_type: + raise ValueError( + f"Content column, {content_column}, is type, {content_type}. It must be a type of character string." + ) + if embedding_column not in columns: + raise ValueError(f"Embedding column, {embedding_column}, does not exist.") + if columns[embedding_column] != "USER-DEFINED": + raise ValueError( + f"Embedding column, {embedding_column}, is not type Vector." + ) + + metadata_json_column = ( + None if metadata_json_column not in columns else metadata_json_column + ) + + # If using metadata_columns check to make sure column exists + for column in metadata_columns: + if column not in columns: + raise ValueError(f"Metadata column, {column}, does not exist.") + + # If using ignore_metadata_columns, filter out known columns and set known metadata columns + all_columns = columns + if ignore_metadata_columns: + for column in ignore_metadata_columns: + del all_columns[column] + + del all_columns[id_column] + del all_columns[content_column] + del all_columns[embedding_column] + metadata_columns = [k for k in all_columns.keys()] + + return cls( + cls.__create_key, + engine._pool, + embedding_service, + table_name, + schema_name, + content_column, + embedding_column, + metadata_columns, + id_column, + metadata_json_column, + distance_strategy, + k, + fetch_k, + lambda_mult, + index_query_options, + ) + + @property + def embeddings(self) -> Embeddings: + return self.embedding_service + + async def __aadd_embeddings( + self, + texts: Iterable[str], + embeddings: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List] = None, + **kwargs: Any, + ) -> List[str]: + """Add embeddings to the table. + + Raises: + :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. + """ + if not ids: + ids = [str(uuid.uuid4()) for _ in texts] + if not metadatas: + metadatas = [{} for _ in texts] + # Insert embeddings + for id, content, embedding, metadata in zip(ids, texts, embeddings, metadatas): + metadata_col_names = ( + ", " + ", ".join(self.metadata_columns) + if len(self.metadata_columns) > 0 + else "" + ) + insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"({self.id_column}, {self.content_column}, {self.embedding_column}{metadata_col_names}' + values = {"id": id, "content": content, "embedding": str(embedding)} + values_stmt = "VALUES (:id, :content, :embedding" + + # Add metadata + extra = metadata + for metadata_column in self.metadata_columns: + if metadata_column in metadata: + values_stmt += f", :{metadata_column}" + values[metadata_column] = metadata[metadata_column] + del extra[metadata_column] + else: + values_stmt += ",null" + + # Add JSON column and/or close statement + insert_stmt += ( + f", {self.metadata_json_column})" if self.metadata_json_column else ")" + ) + if self.metadata_json_column: + values_stmt += ", :extra)" + values["extra"] = json.dumps(extra) + else: + values_stmt += ")" + + query = insert_stmt + values_stmt + async with self.pool.connect() as conn: + await conn.execute(text(query), values) + await conn.commit() + + return ids + + async def aadd_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List] = None, + **kwargs: Any, + ) -> List[str]: + """Embed texts and add to the table. + + Raises: + :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. + """ + embeddings = self.embedding_service.embed_documents(list(texts)) + ids = await self.__aadd_embeddings( + texts, embeddings, metadatas=metadatas, ids=ids, **kwargs + ) + return ids + + async def aadd_documents( + self, + documents: List[Document], + ids: Optional[List] = None, + **kwargs: Any, + ) -> List[str]: + """Embed documents and add to the table. + + Raises: + :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. + """ + texts = [doc.page_content for doc in documents] + metadatas = [doc.metadata for doc in documents] + ids = await self.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs) + return ids + + async def adelete( + self, + ids: Optional[List] = None, + **kwargs: Any, + ) -> Optional[bool]: + """Delete records from the table. + + Raises: + :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. + """ + if not ids: + return False + + id_list = ", ".join([f"'{id}'" for id in ids]) + query = f'DELETE FROM "{self.schema_name}"."{self.table_name}" WHERE {self.id_column} in ({id_list})' + async with self.pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + return True + + @classmethod + async def afrom_texts( # type: ignore[override] + cls: Type[AsyncPostgresVectorStore], + texts: List[str], + embedding: Embeddings, + engine: PostgresEngine, + table_name: str, + schema_name: str = "public", + metadatas: Optional[List[dict]] = None, + ids: Optional[List] = None, + content_column: str = "content", + embedding_column: str = "embedding", + metadata_columns: List[str] = [], + ignore_metadata_columns: Optional[List[str]] = None, + id_column: str = "langchain_id", + metadata_json_column: str = "langchain_metadata", + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + index_query_options: Optional[QueryOptions] = None, + **kwargs: Any, + ) -> AsyncPostgresVectorStore: + """Create an AsyncPostgresVectorStore instance from texts. + + Args: + texts (List[str]): Texts to add to the vector store. + embedding (Embeddings): Text embedding model to use. + engine (PostgresEngine): Connection pool engine for managing connections to Postgres database. + table_name (str): Name of the existing table or the table to be created. + schema_name (str, optional): Database schema name of the table. Defaults to "public". + metadatas (Optional[List[dict]]): List of metadatas to add to table records. + ids: (Optional[List[str]]): List of IDs to add to table records. + content_column (str): Column that represent a Document’s page_content. Defaults to "content". + embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". + metadata_columns (List[str]): Column(s) that represent a document's metadata. + ignore_metadata_columns (List[str]): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None. + id_column (str): Column that represents the Document's id. Defaults to "langchain_id". + metadata_json_column (str): Column to store metadata as JSON. Defaults to "langchain_metadata". + distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. + k (int): Number of Documents to return from search. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. + index_query_options (QueryOptions): Index query option. + + Raises: + :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. + + Returns: + AsyncPostgresVectorStore + """ + vs = await cls.create( + engine, + embedding, + table_name, + schema_name, + content_column, + embedding_column, + metadata_columns, + ignore_metadata_columns, + id_column, + metadata_json_column, + distance_strategy, + k, + fetch_k, + lambda_mult, + index_query_options, + ) + await vs.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs) + return vs + + @classmethod + async def afrom_documents( # type: ignore[override] + cls: Type[AsyncPostgresVectorStore], + documents: List[Document], + embedding: Embeddings, + engine: PostgresEngine, + table_name: str, + schema_name: str = "public", + ids: Optional[List] = None, + content_column: str = "content", + embedding_column: str = "embedding", + metadata_columns: List[str] = [], + ignore_metadata_columns: Optional[List[str]] = None, + id_column: str = "langchain_id", + metadata_json_column: str = "langchain_metadata", + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + index_query_options: Optional[QueryOptions] = None, + **kwargs: Any, + ) -> AsyncPostgresVectorStore: + """Create an AsyncPostgresVectorStore instance from documents. + + Args: + documents (List[Document]): Documents to add to the vector store. + embedding (Embeddings): Text embedding model to use. + engine (PostgresEngine): Connection pool engine for managing connections to Postgres database. + table_name (str): Name of the existing table or the table to be created. + schema_name (str, optional): Database schema name of the table. Defaults to "public". + metadatas (Optional[List[dict]]): List of metadatas to add to table records. + ids: (Optional[List[str]]): List of IDs to add to table records. + content_column (str): Column that represent a Document's page_content. Defaults to "content". + embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". + metadata_columns (List[str]): Column(s) that represent a document's metadata. + ignore_metadata_columns (List[str]): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None. + id_column (str): Column that represents the Document's id. Defaults to "langchain_id". + metadata_json_column (str): Column to store metadata as JSON. Defaults to "langchain_metadata". + distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. + k (int): Number of Documents to return from search. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. + index_query_options (QueryOptions): Index query option. + + Raises: + :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. + + Returns: + AsyncPostgresVectorStore + """ + vs = await cls.create( + engine, + embedding, + table_name, + schema_name, + content_column, + embedding_column, + metadata_columns, + ignore_metadata_columns, + id_column, + metadata_json_column, + distance_strategy, + k, + fetch_k, + lambda_mult, + index_query_options, + ) + texts = [doc.page_content for doc in documents] + metadatas = [doc.metadata for doc in documents] + await vs.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs) + return vs + + async def __query_collection( + self, + embedding: List[float], + k: Optional[int] = None, + filter: Optional[str] = None, + **kwargs: Any, + ) -> Sequence[RowMapping]: + """Perform similarity search query on the vector store table.""" + k = k if k else self.k + operator = self.distance_strategy.operator + search_function = self.distance_strategy.search_function + + filter = f"WHERE {filter}" if filter else "" + stmt = f"SELECT *, {search_function}({self.embedding_column}, '{embedding}') as distance FROM \"{self.schema_name}\".\"{self.table_name}\" {filter} ORDER BY {self.embedding_column} {operator} '{embedding}' LIMIT {k};" + if self.index_query_options: + async with self.pool.connect() as conn: + await conn.execute( + text(f"SET LOCAL {self.index_query_options.to_string()};") + ) + result = await conn.execute(text(stmt)) + result_map = result.mappings() + results = result_map.fetchall() + else: + async with self.pool.connect() as conn: + result = await conn.execute(text(stmt)) + result_map = result.mappings() + results = result_map.fetchall() + return results + + async def asimilarity_search( + self, + query: str, + k: Optional[int] = None, + filter: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected by similarity search on query.""" + embedding = self.embedding_service.embed_query(text=query) + + return await self.asimilarity_search_by_vector( + embedding=embedding, k=k, filter=filter, **kwargs + ) + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + """Select a relevance function based on distance strategy.""" + # Calculate distance strategy provided in + # vectorstore constructor + if self.distance_strategy == DistanceStrategy.COSINE_DISTANCE: + return self._cosine_relevance_score_fn + if self.distance_strategy == DistanceStrategy.INNER_PRODUCT: + return self._max_inner_product_relevance_score_fn + elif self.distance_strategy == DistanceStrategy.EUCLIDEAN: + return self._euclidean_relevance_score_fn + + async def asimilarity_search_with_score( + self, + query: str, + k: Optional[int] = None, + filter: Optional[str] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs and distance scores selected by similarity search on query.""" + embedding = self.embedding_service.embed_query(query) + docs = await self.asimilarity_search_with_score_by_vector( + embedding=embedding, k=k, filter=filter, **kwargs + ) + return docs + + async def asimilarity_search_by_vector( + self, + embedding: List[float], + k: Optional[int] = None, + filter: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected by vector similarity search.""" + docs_and_scores = await self.asimilarity_search_with_score_by_vector( + embedding=embedding, k=k, filter=filter, **kwargs + ) + + return [doc for doc, _ in docs_and_scores] + + async def asimilarity_search_with_score_by_vector( + self, + embedding: List[float], + k: Optional[int] = None, + filter: Optional[str] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs and distance scores selected by vector similarity search.""" + results = await self.__query_collection( + embedding=embedding, k=k, filter=filter, **kwargs + ) + + documents_with_scores = [] + for row in results: + metadata = ( + row[self.metadata_json_column] + if self.metadata_json_column and row[self.metadata_json_column] + else {} + ) + for col in self.metadata_columns: + metadata[col] = row[col] + documents_with_scores.append( + ( + Document( + page_content=row[self.content_column], + metadata=metadata, + ), + row["distance"], + ) + ) + + return documents_with_scores + + async def amax_marginal_relevance_search( + self, + query: str, + k: Optional[int] = None, + fetch_k: Optional[int] = None, + lambda_mult: Optional[float] = None, + filter: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance.""" + embedding = self.embedding_service.embed_query(text=query) + + return await self.amax_marginal_relevance_search_by_vector( + embedding=embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, + ) + + async def amax_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: Optional[int] = None, + fetch_k: Optional[int] = None, + lambda_mult: Optional[float] = None, + filter: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance.""" + docs_and_scores = ( + await self.amax_marginal_relevance_search_with_score_by_vector( + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, + ) + ) + + return [result[0] for result in docs_and_scores] + + async def amax_marginal_relevance_search_with_score_by_vector( + self, + embedding: List[float], + k: Optional[int] = None, + fetch_k: Optional[int] = None, + lambda_mult: Optional[float] = None, + filter: Optional[str] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs and distance scores selected using the maximal marginal relevance.""" + results = await self.__query_collection( + embedding=embedding, k=fetch_k, filter=filter, **kwargs + ) + + k = k if k else self.k + fetch_k = fetch_k if fetch_k else self.fetch_k + lambda_mult = lambda_mult if lambda_mult else self.lambda_mult + embedding_list = [json.loads(row[self.embedding_column]) for row in results] + mmr_selected = utils.maximal_marginal_relevance( + np.array(embedding, dtype=np.float32), + embedding_list, + k=k, + lambda_mult=lambda_mult, + ) + + documents_with_scores = [] + for row in results: + metadata = ( + row[self.metadata_json_column] + if self.metadata_json_column and row[self.metadata_json_column] + else {} + ) + for col in self.metadata_columns: + metadata[col] = row[col] + documents_with_scores.append( + ( + Document( + page_content=row[self.content_column], + metadata=metadata, + ), + row["distance"], + ) + ) + + return [r for i, r in enumerate(documents_with_scores) if i in mmr_selected] + + async def aapply_vector_index( + self, + index: BaseIndex, + name: Optional[str] = None, + concurrently: bool = False, + ) -> None: + """Create an index on the vector store table.""" + if isinstance(index, ExactNearestNeighbor): + await self.adrop_vector_index() + return + + filter = f"WHERE ({index.partial_indexes})" if index.partial_indexes else "" + params = "WITH " + index.index_options() + function = index.distance_strategy.index_function + if name is None: + if index.name == None: + index.name = self.table_name + DEFAULT_INDEX_NAME_SUFFIX + name = index.name + stmt = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} {name} ON "{self.schema_name}"."{self.table_name}" USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};' + if concurrently: + async with self.pool.connect() as conn: + await conn.execute(text("COMMIT")) + await conn.execute(text(stmt)) + else: + async with self.pool.connect() as conn: + await conn.execute(text(stmt)) + await conn.commit() + + async def areindex(self, index_name: Optional[str] = None) -> None: + """Re-index the vector store table.""" + index_name = index_name or self.table_name + DEFAULT_INDEX_NAME_SUFFIX + query = f"REINDEX INDEX {index_name};" + async with self.pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + async def adrop_vector_index( + self, + index_name: Optional[str] = None, + ) -> None: + """Drop the vector index.""" + index_name = index_name or self.table_name + DEFAULT_INDEX_NAME_SUFFIX + query = f"DROP INDEX IF EXISTS {index_name};" + async with self.pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + async def is_valid_index( + self, + index_name: Optional[str] = None, + ) -> bool: + """Check if index exists in the table.""" + index_name = index_name or self.table_name + DEFAULT_INDEX_NAME_SUFFIX + stmt = f""" + SELECT tablename, indexname + FROM pg_indexes + WHERE tablename = '{self.table_name}' AND schemaname = '{self.schema_name}' AND indexname = '{index_name}'; + """ + async with self.pool.connect() as conn: + result = await conn.execute(text(stmt)) + result_map = result.mappings() + results = result_map.fetchall() + + return bool(len(results) == 1) + + def similarity_search( + self, + query: str, + k: Optional[int] = None, + filter: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List] = None, + **kwargs: Any, + ) -> List[str]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + + def add_documents( + self, + documents: List[Document], + ids: Optional[List] = None, + **kwargs: Any, + ) -> List[str]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + + def delete( + self, + ids: Optional[List] = None, + **kwargs: Any, + ) -> Optional[bool]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + + @classmethod + def from_texts( # type: ignore[override] + cls: Type[AsyncPostgresVectorStore], + texts: List[str], + embedding: Embeddings, + engine: PostgresEngine, + table_name: str, + metadatas: Optional[List[dict]] = None, + ids: Optional[List] = None, + content_column: str = "content", + embedding_column: str = "embedding", + metadata_columns: List[str] = [], + ignore_metadata_columns: Optional[List[str]] = None, + id_column: str = "langchain_id", + metadata_json_column: str = "langchain_metadata", + **kwargs: Any, + ) -> AsyncPostgresVectorStore: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + + @classmethod + def from_documents( # type: ignore[override] + cls: Type[AsyncPostgresVectorStore], + documents: List[Document], + embedding: Embeddings, + engine: PostgresEngine, + table_name: str, + ids: Optional[List] = None, + content_column: str = "content", + embedding_column: str = "embedding", + metadata_columns: List[str] = [], + ignore_metadata_columns: Optional[List[str]] = None, + id_column: str = "langchain_id", + metadata_json_column: str = "langchain_metadata", + **kwargs: Any, + ) -> AsyncPostgresVectorStore: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + + def similarity_search_with_score( + self, + query: str, + k: Optional[int] = None, + filter: Optional[str] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + + def similarity_search_by_vector( + self, + embedding: List[float], + k: Optional[int] = None, + filter: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + + def similarity_search_with_score_by_vector( + self, + embedding: List[float], + k: Optional[int] = None, + filter: Optional[str] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + + def max_marginal_relevance_search( + self, + query: str, + k: Optional[int] = None, + fetch_k: Optional[int] = None, + lambda_mult: Optional[float] = None, + filter: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + + def max_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: Optional[int] = None, + fetch_k: Optional[int] = None, + lambda_mult: Optional[float] = None, + filter: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + + def max_marginal_relevance_search_with_score_by_vector( + self, + embedding: List[float], + k: Optional[int] = None, + fetch_k: Optional[int] = None, + lambda_mult: Optional[float] = None, + filter: Optional[str] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) diff --git a/src/langchain_google_cloud_sql_pg/chat_message_history.py b/src/langchain_google_cloud_sql_pg/chat_message_history.py index 0150fa63..306dba15 100644 --- a/src/langchain_google_cloud_sql_pg/chat_message_history.py +++ b/src/langchain_google_cloud_sql_pg/chat_message_history.py @@ -14,32 +14,15 @@ from __future__ import annotations -import json from typing import List, Sequence from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage, messages_from_dict +from .async_chat_message_history import AsyncPostgresChatMessageHistory from .engine import PostgresEngine -async def _aget_messages( - engine: PostgresEngine, - session_id: str, - table_name: str, - schema_name: str = "public", -) -> List[BaseMessage]: - """Retrieve the messages from PostgreSQL.""" - query = f"""SELECT data, type FROM "{schema_name}"."{table_name}" WHERE session_id = :session_id ORDER BY id;""" - results = await engine._afetch(query, {"session_id": session_id}) - if not results: - return [] - - items = [{"data": result["data"], "type": result["type"]} for result in results] - messages = messages_from_dict(items) - return messages - - class PostgresChatMessageHistory(BaseChatMessageHistory): """Chat message history stored in an Cloud SQL for PostgreSQL database.""" @@ -49,21 +32,14 @@ def __init__( self, key: object, engine: PostgresEngine, - session_id: str, - table_name: str, - messages: List[BaseMessage], - schema_name: str = "public", + history: AsyncPostgresChatMessageHistory, ): """PostgresChatMessageHistory constructor. Args: key (object): Key to prevent direct constructor usage. engine (PostgresEngine): Database connection pool. - session_id (str): Retrieve the table content with this session ID. - table_name (str): Table name that stores the chat message history. - messages (List[BaseMessage]): Messages to store. - schema_name (str, optional): Database schema name of the chat message history table. Defaults to "public". - + history (AsyncPostgresChatMessageHistory): Native async implementation Raises: Exception: If constructor is directly called by the user. """ @@ -71,11 +47,8 @@ def __init__( raise Exception( "Only create class through 'create' or 'create_sync' methods!" ) - self.engine = engine - self.session_id = session_id - self.table_name = table_name - self.messages = messages - self.schema_name = schema_name + self._engine = engine + self._history = history @classmethod async def create( @@ -99,27 +72,11 @@ async def create( Returns: PostgresChatMessageHistory: A newly created instance of PostgresChatMessageHistory. """ - table_schema = await engine._aload_table_schema(table_name, schema_name) - column_names = table_schema.columns.keys() - - required_columns = ["id", "session_id", "data", "type"] - - if not (all(x in column_names for x in required_columns)): - raise IndexError( - f"Table '{schema_name}'.'{table_name}' has incorrect schema. Got " - f"column names '{column_names}' but required column names " - f"'{required_columns}'.\nPlease create table with following schema:" - f"\nCREATE TABLE {schema_name}.{table_name} (" - "\n id INT AUTO_INCREMENT PRIMARY KEY," - "\n session_id TEXT NOT NULL," - "\n data JSON NOT NULL," - "\n type TEXT NOT NULL" - "\n);" - ) - messages = await _aget_messages(engine, session_id, table_name, schema_name) - return cls( - cls.__create_key, engine, session_id, table_name, messages, schema_name + coro = AsyncPostgresChatMessageHistory.create( + engine, session_id, table_name, schema_name ) + history = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, history) @classmethod def create_sync( @@ -143,55 +100,37 @@ def create_sync( Returns: PostgresChatMessageHistory: A newly created instance of PostgresChatMessageHistory. """ - coro = cls.create(engine, session_id, table_name, schema_name) - return engine._run_as_sync(coro) + coro = AsyncPostgresChatMessageHistory.create( + engine, session_id, table_name, schema_name + ) + history = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, history) + + @property # type: ignore[override] + def messages(self) -> List[BaseMessage]: + """The abstraction required a property.""" + return self._engine._run_as_sync(self._history._aget_messages()) async def aadd_message(self, message: BaseMessage) -> None: """Append the message to the record in PostgreSQL""" - query = f"""INSERT INTO "{self.schema_name}"."{self.table_name}"(session_id, data, type) - VALUES (:session_id, :data, :type); - """ - await self.engine._aexecute( - query, - { - "session_id": self.session_id, - "data": json.dumps(message.dict()), - "type": message.type, - }, - ) - self.messages = await _aget_messages( - self.engine, self.session_id, self.table_name, self.schema_name - ) + await self._engine._run_as_async(self._history.aadd_message(message)) def add_message(self, message: BaseMessage) -> None: """Append the message to the record in PostgreSQL""" - self.engine._run_as_sync(self.aadd_message(message)) + self._engine._run_as_sync(self._history.aadd_message(message)) async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: """Append a list of messages to the record in PostgreSQL""" - for message in messages: - await self.aadd_message(message) + await self._engine._run_as_async(self._history.aadd_messages(messages)) def add_messages(self, messages: Sequence[BaseMessage]) -> None: """Append a list of messages to the record in PostgreSQL""" - self.engine._run_as_sync(self.aadd_messages(messages)) + self._engine._run_as_sync(self._history.aadd_messages(messages)) async def aclear(self) -> None: """Clear session memory from PostgreSQL""" - query = f"""DELETE FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id;""" - await self.engine._aexecute(query, {"session_id": self.session_id}) - self.messages = [] + await self._engine._run_as_async(self._history.aclear()) def clear(self) -> None: """Clear session memory from PostgreSQL""" - self.engine._run_as_sync(self.aclear()) - - async def async_messages(self) -> None: - """Retrieve the messages from Postgres.""" - self.messages = await _aget_messages( - self.engine, self.session_id, self.table_name, self.schema_name - ) - - def sync_messages(self) -> None: - """Retrieve the messages from Postgres.""" - self.engine._run_as_sync(self.async_messages()) + self._engine._run_as_sync(self._history.aclear()) diff --git a/src/langchain_google_cloud_sql_pg/engine.py b/src/langchain_google_cloud_sql_pg/engine.py index dc3137cd..33b7a83f 100644 --- a/src/langchain_google_cloud_sql_pg/engine.py +++ b/src/langchain_google_cloud_sql_pg/engine.py @@ -15,25 +15,17 @@ from __future__ import annotations import asyncio +from concurrent.futures import Future from dataclasses import dataclass from threading import Thread -from typing import ( - TYPE_CHECKING, - Awaitable, - Dict, - List, - Optional, - Sequence, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, TypeVar, Union import aiohttp import google.auth # type: ignore import google.auth.transport.requests # type: ignore from google.cloud.sql.connector import Connector, IPTypes, RefreshStrategy from sqlalchemy import MetaData, Table, text -from sqlalchemy.engine.row import RowMapping +from sqlalchemy.engine import URL from sqlalchemy.exc import InvalidRequestError from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine @@ -114,17 +106,17 @@ class PostgresEngine: def __init__( self, key: object, - engine: AsyncEngine, + pool: AsyncEngine, loop: Optional[asyncio.AbstractEventLoop], thread: Optional[Thread], ): """PostgresEngine constructor. Args: - key(object): Prevent direct constructor usage. - engine(AsyncEngine): Async engine connection pool. + key (object): Prevent direct constructor usage. + pool (AsyncEngine): Async engine connection pool. loop (Optional[asyncio.AbstractEventLoop]): Async event loop used to create the engine. - thread (Optional[Thread] = None): Thread used to create the engine async. + thread (Optional[Thread]): Thread used to create the engine async. Raises: Exception: If the constructor is called directly by the user. @@ -133,62 +125,10 @@ def __init__( raise Exception( "Only create class through 'create' or 'create_sync' methods!" ) - self._engine = engine + self._pool = pool self._loop = loop self._thread = thread - @classmethod - def from_instance( - cls, - project_id: str, - region: str, - instance: str, - database: str, - user: Optional[str] = None, - password: Optional[str] = None, - ip_type: Union[str, IPTypes] = IPTypes.PUBLIC, - quota_project: Optional[str] = None, - iam_account_email: Optional[str] = None, - ) -> PostgresEngine: - """Create a PostgresEngine from a Postgres instance. - - Args: - project_id (str): GCP project ID. - region (str): Postgres instance region. - instance (str): Postgres instance name. - database (str): Database name. - user (Optional[str], optional): Postgres user name. Defaults to None. - password (Optional[str], optional): Postgres user password. Defaults to None. - ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC. - quota_project (Optional[str]): Project that provides quota for API calls. - iam_account_email (Optional[str], optional): IAM service account email. Defaults to None. - - Returns: - PostgresEngine: A newly created PostgresEngine instance. - """ - # Running a loop in a background thread allows us to support - # async methods from non-async environments - if cls._default_loop is None: - cls._default_loop = asyncio.new_event_loop() - cls._default_thread = Thread( - target=cls._default_loop.run_forever, daemon=True - ) - cls._default_thread.start() - coro = cls._create( - project_id, - region, - instance, - database, - ip_type, - user, - password, - loop=cls._default_loop, - thread=cls._default_thread, - quota_project=quota_project, - iam_account_email=iam_account_email, - ) - return asyncio.run_coroutine_threadsafe(coro, cls._default_loop).result() - @classmethod async def _create( cls, @@ -211,13 +151,13 @@ async def _create( region (str): Postgres instance region. instance (str): Postgres instance name. database (str): Database name. - ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC. - user (Optional[str], optional): Postgres user name. Defaults to None. - password (Optional[str], optional): Postgres user password. Defaults to None. + ip_type (Union[str, IPTypes]): IP address type. Defaults to IPTypes.PUBLIC. + user (Optional[str]): Postgres user name. Defaults to None. + password (Optional[str]): Postgres user password. Defaults to None. loop (Optional[asyncio.AbstractEventLoop]): Async event loop used to create the engine. - thread (Optional[Thread] = None): Thread used to create the engine async. + thread (Optional[Thread]): Thread used to create the engine async. quota_project (Optional[str]): Project that provides quota for API calls. - iam_account_email (Optional[str], optional): IAM service account email. Defaults to None. + iam_account_email (Optional[str]): IAM service account email. Defaults to None. Raises: ValueError: If only one of `user` and `password` is specified. @@ -275,7 +215,43 @@ async def getconn() -> asyncpg.Connection: return cls(cls.__create_key, engine, loop, thread) @classmethod - async def afrom_instance( + def __start_background_loop( + cls, + project_id: str, + region: str, + instance: str, + database: str, + user: Optional[str] = None, + password: Optional[str] = None, + ip_type: Union[str, IPTypes] = IPTypes.PUBLIC, + quota_project: Optional[str] = None, + iam_account_email: Optional[str] = None, + ) -> Future: + # Running a loop in a background thread allows us to support + # async methods from non-async environments + if cls._default_loop is None: + cls._default_loop = asyncio.new_event_loop() + cls._default_thread = Thread( + target=cls._default_loop.run_forever, daemon=True + ) + cls._default_thread.start() + coro = cls._create( + project_id, + region, + instance, + database, + ip_type, + user, + password, + loop=cls._default_loop, + thread=cls._default_thread, + quota_project=quota_project, + iam_account_email=iam_account_email, + ) + return asyncio.run_coroutine_threadsafe(coro, cls._default_loop) + + @classmethod + def from_instance( cls, project_id: str, region: str, @@ -303,73 +279,129 @@ async def afrom_instance( Returns: PostgresEngine: A newly created PostgresEngine instance. """ - return await cls._create( + future = cls.__start_background_loop( project_id, region, instance, database, + user, + password, ip_type, + quota_project=quota_project, + iam_account_email=iam_account_email, + ) + return future.result() + + @classmethod + async def afrom_instance( + cls, + project_id: str, + region: str, + instance: str, + database: str, + user: Optional[str] = None, + password: Optional[str] = None, + ip_type: Union[str, IPTypes] = IPTypes.PUBLIC, + quota_project: Optional[str] = None, + iam_account_email: Optional[str] = None, + ) -> PostgresEngine: + """Create a PostgresEngine from a Postgres instance. + + Args: + project_id (str): GCP project ID. + region (str): Postgres instance region. + instance (str): Postgres instance name. + database (str): Database name. + user (Optional[str], optional): Postgres user name. Defaults to None. + password (Optional[str], optional): Postgres user password. Defaults to None. + ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC. + quota_project (Optional[str]): Project that provides quota for API calls. + iam_account_email (Optional[str], optional): IAM service account email. Defaults to None. + + Returns: + PostgresEngine: A newly created PostgresEngine instance. + """ + future = cls.__start_background_loop( + project_id, + region, + instance, + database, user, password, + ip_type, quota_project=quota_project, iam_account_email=iam_account_email, ) + return await asyncio.wrap_future(future) @classmethod - def from_engine(cls, engine: AsyncEngine) -> PostgresEngine: + def from_engine( + cls, + engine: AsyncEngine, + loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> PostgresEngine: """Create an PostgresEngine instance from an AsyncEngine.""" - return cls(cls.__create_key, engine, None, None) + return cls(cls.__create_key, engine, loop, None) - async def _aexecute(self, query: str, params: Optional[dict] = None) -> None: - """Execute a SQL query.""" - async with self._engine.connect() as conn: - await conn.execute(text(query), params) - await conn.commit() + @classmethod + def from_engine_args( + cls, + url: Union[str | URL], + **kwargs: Any, + ) -> PostgresEngine: + """Create an PostgresEngine instance from arguments. These parameters are pass directly into sqlalchemy's create_async_engine function. - async def _aexecute_outside_tx(self, query: str) -> None: - """Execute a SQL query in a new transaction.""" - async with self._engine.connect() as conn: - await conn.execute(text("COMMIT")) - await conn.execute(text(query)) + Args: + url (Union[str | URL]): the URL used to connect to a database + **kwargs (Any, optional): sqlalchemy `create_async_engine` arguments + + Raises: + ValueError: If `postgresql+asyncpg` is not specified as the PG driver + + Returns: + PostgresEngine + """ + # Running a loop in a background thread allows us to support + # async methods from non-async environments + if cls._default_loop is None: + cls._default_loop = asyncio.new_event_loop() + cls._default_thread = Thread( + target=cls._default_loop.run_forever, daemon=True + ) + cls._default_thread.start() - async def _afetch( - self, query: str, params: Optional[dict] = None - ) -> Sequence[RowMapping]: - """Fetch results from a SQL query.""" - async with self._engine.connect() as conn: - result = await conn.execute(text(query), params) - result_map = result.mappings() - result_fetch = result_map.fetchall() - - return result_fetch - - async def _afetch_with_query_options( - self, query: str, query_options: str - ) -> Sequence[RowMapping]: - """Set temporary database flags and fetch results from a SQL query.""" - async with self._engine.connect() as conn: - await conn.execute(text(query_options)) - result = await conn.execute(text(query)) - result_map = result.mappings() - result_fetch = result_map.fetchall() - - return result_fetch - - def _execute(self, query: str, params: Optional[dict] = None) -> None: - """Execute a SQL query.""" - return self._run_as_sync(self._aexecute(query, params)) - - def _fetch(self, query: str, params: Optional[dict] = None) -> Sequence[RowMapping]: - """Fetch results from a SQL query.""" - return self._run_as_sync(self._afetch(query, params)) + driver = "postgresql+asyncpg" + if (isinstance(url, str) and not url.startswith(driver)) or ( + isinstance(url, URL) and url.drivername != driver + ): + raise ValueError("Driver must be type 'postgresql+asyncpg'") + + engine = create_async_engine(url, **kwargs) + return cls(cls.__create_key, engine, cls._default_loop, cls._default_thread) + + async def _run_as_async(self, coro: Awaitable[T]) -> T: + """Run an async coroutine asynchronously""" + # If a loop has not been provided, attempt to run in current thread + if not self._loop: + return await coro + # Otherwise, run in the background thread + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, self._loop) + ) def _run_as_sync(self, coro: Awaitable[T]) -> T: """Run an async coroutine synchronously""" if not self._loop: - raise Exception("Engine was initialized async.") + raise Exception( + "Engine was initialized without a background loop and cannot call sync methods." + ) return asyncio.run_coroutine_threadsafe(coro, self._loop).result() - async def ainit_vectorstore_table( + async def close(self) -> None: + """Dispose of connection pool""" + await self._pool.dispose() + + async def _ainit_vectorstore_table( self, table_name: str, vector_size: int, @@ -378,7 +410,7 @@ async def ainit_vectorstore_table( embedding_column: str = "embedding", metadata_columns: List[Column] = [], metadata_json_column: str = "langchain_metadata", - id_column: str = "langchain_id", + id_column: Union[str, Column] = "langchain_id", overwrite_existing: bool = False, store_metadata: bool = True, ) -> None: @@ -398,22 +430,31 @@ async def ainit_vectorstore_table( metadata. Default: []. Optional. metadata_json_column (str): The column to store extra metadata in JSON format. Default: "langchain_metadata". Optional. - id_column (str): Name of the column to store ids. - Default: "langchain_id". Optional, + id_column (Union[str, Column]) : Column to store ids. + Default: "langchain_id" column name with data type UUID. Optional. overwrite_existing (bool): Whether to drop existing table. Default: False. store_metadata (bool): Whether to store metadata in the table. Default: True. - Raises: :class:`DuplicateTableError `: if table already exists and overwrite flag is not set. + :class:`UndefinedObjectError `: if the data type of the id column is not a postgreSQL data type. """ - await self._aexecute("CREATE EXTENSION IF NOT EXISTS vector") + async with self._pool.connect() as conn: + await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) + await conn.commit() if overwrite_existing: - await self._aexecute(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"') + async with self._pool.connect() as conn: + await conn.execute( + text(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"') + ) + await conn.commit() + + id_data_type = "UUID" if isinstance(id_column, str) else id_column.data_type + id_column_name = id_column if isinstance(id_column, str) else id_column.name query = f"""CREATE TABLE "{schema_name}"."{table_name}"( - "{id_column}" UUID PRIMARY KEY, + "{id_column_name}" {id_data_type} PRIMARY KEY, "{content_column}" TEXT NOT NULL, "{embedding_column}" vector({vector_size}) NOT NULL""" for column in metadata_columns: @@ -423,7 +464,59 @@ async def ainit_vectorstore_table( query += f""",\n"{metadata_json_column}" JSON""" query += "\n);" - await self._aexecute(query) + async with self._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + async def ainit_vectorstore_table( + self, + table_name: str, + vector_size: int, + schema_name: str = "public", + content_column: str = "content", + embedding_column: str = "embedding", + metadata_columns: List[Column] = [], + metadata_json_column: str = "langchain_metadata", + id_column: Union[str, Column] = "langchain_id", + overwrite_existing: bool = False, + store_metadata: bool = True, + ) -> None: + """ + Create a table for saving of vectors to be used with PostgresVectorStore. + + Args: + table_name (str): The Postgres database table name. + vector_size (int): Vector size for the embedding model to be used. + schema_name (str): The schema name to store Postgres database table. + Default: "public". + content_column (str): Name of the column to store document content. + Default: "page_content". + embedding_column (str) : Name of the column to store vector embeddings. + Default: "embedding". + metadata_columns (List[Column]): A list of Columns to create for custom + metadata. Default: []. Optional. + metadata_json_column (str): The column to store extra metadata in JSON format. + Default: "langchain_metadata". Optional. + id_column (Union[str, Column]) : Column to store ids. + Default: "langchain_id" column name with data type UUID. Optional. + overwrite_existing (bool): Whether to drop existing table. Default: False. + store_metadata (bool): Whether to store metadata in the table. + Default: True. + """ + await self._run_as_async( + self._ainit_vectorstore_table( + table_name, + vector_size, + schema_name, + content_column, + embedding_column, + metadata_columns, + metadata_json_column, + id_column, + overwrite_existing, + store_metadata, + ) + ) def init_vectorstore_table( self, @@ -434,7 +527,7 @@ def init_vectorstore_table( embedding_column: str = "embedding", metadata_columns: List[Column] = [], metadata_json_column: str = "langchain_metadata", - id_column: str = "langchain_id", + id_column: Union[str, Column] = "langchain_id", overwrite_existing: bool = False, store_metadata: bool = True, ) -> None: @@ -454,14 +547,16 @@ def init_vectorstore_table( metadata. Default: []. Optional. metadata_json_column (str): The column to store extra metadata in JSON format. Default: "langchain_metadata". Optional. - id_column (str): Name of the column to store ids. - Default: "langchain_id". Optional, + id_column (Union[str, Column]) : Column to store ids. + Default: "langchain_id" column name with data type UUID. Optional. overwrite_existing (bool): Whether to drop existing table. Default: False. store_metadata (bool): Whether to store metadata in the table. Default: True. + Raises: + :class:`UndefinedObjectError `: if the `ids` data type does not match that of the `id_column`. """ - return self._run_as_sync( - self.ainit_vectorstore_table( + self._run_as_sync( + self._ainit_vectorstore_table( table_name, vector_size, schema_name, @@ -475,7 +570,7 @@ def init_vectorstore_table( ) ) - async def ainit_chat_history_table( + async def _ainit_chat_history_table( self, table_name: str, schema_name: str = "public" ) -> None: """Create a Cloud SQL table to store chat history. @@ -494,7 +589,24 @@ async def ainit_chat_history_table( data JSONB NOT NULL, type TEXT NOT NULL );""" - await self._aexecute(create_table_query) + async with self._pool.connect() as conn: + await conn.execute(text(create_table_query)) + await conn.commit() + + async def ainit_chat_history_table( + self, table_name: str, schema_name: str = "public" + ) -> None: + """Create a Cloud SQL table to store chat history. + + Args: + table_name (str): Table name to store chat history. + + Returns: + None + """ + await self._run_as_async( + self._ainit_chat_history_table(table_name, schema_name) + ) def init_chat_history_table( self, table_name: str, schema_name: str = "public" @@ -509,13 +621,37 @@ def init_chat_history_table( Returns: None """ - return self._run_as_sync( - self.ainit_chat_history_table( + self._run_as_sync( + self._ainit_chat_history_table( table_name, schema_name, ) ) + async def _ainit_document_table( + self, + table_name: str, + schema_name: str = "public", + content_column: str = "page_content", + metadata_columns: List[Column] = [], + metadata_json_column: str = "langchain_metadata", + store_metadata: bool = True, + ) -> None: + query = f"""CREATE TABLE "{schema_name}"."{table_name}"( + {content_column} TEXT NOT NULL + """ + for column in metadata_columns: + nullable = "NOT NULL" if not column.nullable else "" + query += f',\n"{column.name}" {column.data_type} {nullable}' + metadata_json_column = metadata_json_column or "langchain_metadata" + if store_metadata: + query += f',\n"{metadata_json_column}" JSON' + query += "\n);" + + async with self._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + async def ainit_document_table( self, table_name: str, @@ -544,19 +680,16 @@ async def ainit_document_table( Raises: :class:`DuplicateTableError `: if table already exists. """ - - query = f"""CREATE TABLE "{schema_name}"."{table_name}"( - {content_column} TEXT NOT NULL - """ - for column in metadata_columns: - nullable = "NOT NULL" if not column.nullable else "" - query += f',\n"{column.name}" {column.data_type} {nullable}' - metadata_json_column = metadata_json_column or "langchain_metadata" - if store_metadata: - query += f',\n"{metadata_json_column}" JSON' - query += "\n);" - - await self._aexecute(query) + await self._run_as_async( + self._ainit_document_table( + table_name, + schema_name, + content_column, + metadata_columns, + metadata_json_column, + store_metadata, + ) + ) def init_document_table( self, @@ -575,13 +708,19 @@ def init_document_table( schema_name (str): The schema name to store PgSQL database table. Default: "public". content_column (str): Name of the column to store document content. + Default: "page_content". metadata_columns (List[sqlalchemy.Column]): A list of SQLAlchemy Columns to create for custom metadata. Optional. + metadata_json_column (str): The column to store extra metadata in JSON format. + Default: "langchain_metadata". Optional. store_metadata (bool): Whether to store extra metadata in a metadata column if not described in 'metadata' field list (Default: True). + + Raises: + :class:`DuplicateTableError `: if table already exists. """ - return self._run_as_sync( - self.ainit_document_table( + self._run_as_sync( + self._ainit_document_table( table_name, schema_name, content_column, @@ -602,7 +741,7 @@ async def _aload_table_schema( (sqlalchemy.Table): The loaded table. """ metadata = MetaData() - async with self._engine.connect() as conn: + async with self._pool.connect() as conn: try: await conn.run_sync( metadata.reflect, schema=schema_name, only=[table_name] diff --git a/src/langchain_google_cloud_sql_pg/loader.py b/src/langchain_google_cloud_sql_pg/loader.py index 39ee8935..b9f14a4b 100644 --- a/src/langchain_google_cloud_sql_pg/loader.py +++ b/src/langchain_google_cloud_sql_pg/loader.py @@ -14,95 +14,18 @@ from __future__ import annotations -import json -from typing import ( - Any, - AsyncIterator, - Callable, - Dict, - Iterable, - Iterator, - List, - Optional, -) - -import sqlalchemy +from typing import AsyncIterator, Callable, Iterator, List, Optional + from langchain_core.document_loaders.base import BaseLoader from langchain_core.documents import Document +from .async_loader import AsyncPostgresDocumentSaver, AsyncPostgresLoader from .engine import PostgresEngine DEFAULT_CONTENT_COL = "page_content" DEFAULT_METADATA_COL = "langchain_metadata" -def text_formatter(row: dict, content_columns: List[str]) -> str: - """txt document formatter.""" - return " ".join(str(row[column]) for column in content_columns if column in row) - - -def csv_formatter(row: dict, content_columns: List[str]) -> str: - """CSV document formatter.""" - return ", ".join(str(row[column]) for column in content_columns if column in row) - - -def yaml_formatter(row: dict, content_columns: List[str]) -> str: - """YAML document formatter.""" - return "\n".join( - f"{column}: {str(row[column])}" for column in content_columns if column in row - ) - - -def json_formatter(row: dict, content_columns: List[str]) -> str: - """JSON document formatter.""" - dictionary = {} - for column in content_columns: - if column in row: - dictionary[column] = row[column] - return json.dumps(dictionary) - - -def _parse_doc_from_row( - content_columns: Iterable[str], - metadata_columns: Iterable[str], - row: dict, - metadata_json_column: Optional[str] = DEFAULT_METADATA_COL, - formatter: Callable = text_formatter, -) -> Document: - """Parse row into document.""" - page_content = formatter(row, content_columns) - metadata: Dict[str, Any] = {} - # unnest metadata from langchain_metadata column - if metadata_json_column and row.get(metadata_json_column): - for k, v in row[metadata_json_column].items(): - metadata[k] = v - # load metadata from other columns - for column in metadata_columns: - if column in row and column != metadata_json_column: - metadata[column] = row[column] - - return Document(page_content=page_content, metadata=metadata) - - -def _parse_row_from_doc( - doc: Document, - column_names: Iterable[str], - content_column: str = DEFAULT_CONTENT_COL, - metadata_json_column: Optional[str] = DEFAULT_METADATA_COL, -) -> Dict: - """Parse document into a dictionary of rows.""" - doc_metadata = doc.metadata.copy() - row: Dict[str, Any] = {content_column: doc.page_content} - for entry in doc.metadata: - if entry in column_names: - row[entry] = doc_metadata[entry] - del doc_metadata[entry] - # store extra metadata in langchain_metadata column in json format - if metadata_json_column: - row[metadata_json_column] = doc_metadata - return row - - class PostgresLoader(BaseLoader): """Load documents from PostgreSQL`. @@ -115,14 +38,7 @@ class PostgresLoader(BaseLoader): __create_key = object() def __init__( - self, - key: object, - engine: PostgresEngine, - query: str, - content_columns: List[str], - metadata_columns: List[str], - formatter: Callable, - metadata_json_column: Optional[str] = None, + self, key: object, engine: PostgresEngine, loader: AsyncPostgresLoader ) -> None: """PostgresLoader constructor. @@ -144,12 +60,8 @@ def __init__( "Only create class through 'create' or 'create_sync' methods!" ) - self.engine = engine - self.query = query - self.content_columns = content_columns - self.metadata_columns = metadata_columns - self.formatter = formatter - self.metadata_json_column = metadata_json_column + self._engine = engine + self._loader = loader @classmethod async def create( @@ -180,71 +92,19 @@ async def create( Returns: PostgresLoader """ - if table_name and query: - raise ValueError("Only one of 'table_name' or 'query' should be specified.") - if not table_name and not query: - raise ValueError( - "At least one of the parameters 'table_name' or 'query' needs to be provided" - ) - if format and formatter: - raise ValueError("Only one of 'format' or 'formatter' should be specified.") - - if format and format not in ["csv", "text", "JSON", "YAML"]: - raise ValueError("format must be type: 'csv', 'text', 'JSON', 'YAML'") - if formatter: - formatter = formatter - elif format == "csv": - formatter = csv_formatter - elif format == "YAML": - formatter = yaml_formatter - elif format == "JSON": - formatter = json_formatter - else: - formatter = text_formatter - - if not query: - query = f'SELECT * FROM "{schema_name}"."{table_name}"' - stmt = sqlalchemy.text(query) - - async with engine._engine.connect() as connection: - result_proxy = await connection.execute(stmt) - - column_names = list(result_proxy.keys()) - # Select content or default to first column - content_columns = content_columns or [column_names[0]] - # Select metadata columns - metadata_columns = metadata_columns or [ - col for col in column_names if col not in content_columns - ] - # Check validity of metadata json column - if metadata_json_column and metadata_json_column not in column_names: - raise ValueError( - f"Column {metadata_json_column} not found in query result {column_names}." - ) - # use default metadata json column if not specified - if metadata_json_column and metadata_json_column in column_names: - metadata_json_column = metadata_json_column - elif DEFAULT_METADATA_COL in column_names: - metadata_json_column = DEFAULT_METADATA_COL - else: - metadata_json_column = None - - # check validity of other column - all_names = content_columns + metadata_columns - for name in all_names: - if name not in column_names: - raise ValueError( - f"Column {name} not found in query result {column_names}." - ) - return cls( - cls.__create_key, + coro = AsyncPostgresLoader.create( engine, query, + table_name, + schema_name, content_columns, metadata_columns, - formatter, metadata_json_column, + format, + formatter, ) + loader = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, loader) @classmethod def create_sync( @@ -275,7 +135,7 @@ def create_sync( Returns: PostgresLoader """ - coro = cls.create( + coro = AsyncPostgresLoader.create( engine, query, table_name, @@ -286,56 +146,36 @@ def create_sync( format, formatter, ) - return engine._run_as_sync(coro) - - async def _collect_async_items(self, docs_generator): - """Exhause document generator into a list.""" - return [doc async for doc in docs_generator] + loader = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, loader) def load(self) -> List[Document]: """Load PostgreSQL data into Document objects.""" - documents = self.engine._run_as_sync( - self._collect_async_items(self.alazy_load()) - ) - return documents + return self._engine._run_as_sync(self._loader.aload()) async def aload(self) -> List[Document]: """Load PostgreSQL data into Document objects.""" - return [doc async for doc in self.alazy_load()] + return await self._engine._run_as_async(self._loader.aload()) def lazy_load(self) -> Iterator[Document]: """Load PostgreSQL data into Document objects lazily.""" - yield from self.engine._run_as_sync( - self._collect_async_items(self.alazy_load()) - ) + iterator = self._loader.alazy_load() + while True: + try: + result = self._engine._run_as_sync(iterator.__anext__()) + yield result + except StopAsyncIteration: + break async def alazy_load(self) -> AsyncIterator[Document]: """Load PostgreSQL data into Document objects lazily.""" - stmt = sqlalchemy.text(self.query) - async with self.engine._engine.connect() as connection: - result_proxy = await connection.execute(stmt) - # load document one by one - while True: - row = result_proxy.fetchone() - if not row: - break - - row_data = {} - column_names = self.content_columns + self.metadata_columns - column_names += ( - [self.metadata_json_column] if self.metadata_json_column else [] - ) - for column in column_names: - value = getattr(row, column) - row_data[column] = value - - yield _parse_doc_from_row( - self.content_columns, - self.metadata_columns, - row_data, - self.metadata_json_column, - self.formatter, - ) + iterator = self._loader.alazy_load() + while True: + try: + result = await self._engine._run_as_async(iterator.__anext__()) + yield result + except StopAsyncIteration: + break class PostgresDocumentSaver: @@ -347,11 +187,7 @@ def __init__( self, key: object, engine: PostgresEngine, - table_name: str, - content_column: str, - schema_name: str = "public", - metadata_columns: List[str] = [], - metadata_json_column: Optional[str] = None, + saver: AsyncPostgresDocumentSaver, ): """PostgresDocumentSaver constructor. @@ -371,12 +207,8 @@ def __init__( raise Exception( "Only create class through 'create' or 'create_sync' methods!" ) - self.engine = engine - self.table_name = table_name - self.content_column = content_column - self.schema_name = schema_name - self.metadata_columns = metadata_columns - self.metadata_json_column = metadata_json_column + self._engine = engine + self._saver = saver @classmethod async def create( @@ -401,42 +233,16 @@ async def create( Returns: PostgresDocumentSaver """ - table_schema = await engine._aload_table_schema(table_name, schema_name) - column_names = table_schema.columns.keys() - if content_column not in column_names: - raise ValueError(f"Content column, {content_column}, does not exist.") - - # Set metadata columns to all columns if not set - if len(metadata_columns) == 0: - metadata_columns = [ - column - for column in column_names - if column != content_column and column != metadata_json_column - ] - - # Check and set metadata json column - for column in metadata_columns: - if column not in column_names: - raise ValueError(f"Metadata column, {column}, does not exist.") - - if ( - metadata_json_column - and metadata_json_column != DEFAULT_METADATA_COL - and metadata_json_column not in column_names - ): - raise ValueError(f"Metadata JSON column, {column}, does not exist.") - elif metadata_json_column not in column_names: - metadata_json_column = None - - return cls( - cls.__create_key, + coro = AsyncPostgresDocumentSaver.create( engine, table_name, - content_column, schema_name, + content_column, metadata_columns, metadata_json_column, ) + saver = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, saver) @classmethod def create_sync( @@ -461,7 +267,7 @@ def create_sync( Returns: PostgresDocumentSaver """ - coro = cls.create( + coro = AsyncPostgresDocumentSaver.create( engine, table_name, schema_name, @@ -469,7 +275,8 @@ def create_sync( metadata_columns, metadata_json_column, ) - return engine._run_as_sync(coro) + saver = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, saver) async def aadd_documents(self, docs: List[Document]) -> None: """ @@ -479,39 +286,7 @@ async def aadd_documents(self, docs: List[Document]) -> None: Args: docs (List[langchain_core.documents.Document]): a list of documents to be saved. """ - - for doc in docs: - row = _parse_row_from_doc( - doc, - self.metadata_columns, - self.content_column, - self.metadata_json_column, - ) - for key, value in row.items(): - if isinstance(value, dict): - row[key] = json.dumps(value) - - # Create list of column names - insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"({self.content_column}' - values_stmt = f"VALUES (:{self.content_column}" - - # Add metadata - for metadata_column in self.metadata_columns: - if metadata_column in doc.metadata: - insert_stmt += f", {metadata_column}" - values_stmt += f", :{metadata_column}" - - # Add JSON column and/or close statement - insert_stmt += ( - f", {self.metadata_json_column})" if self.metadata_json_column else ")" - ) - if self.metadata_json_column: - values_stmt += f", :{self.metadata_json_column})" - else: - values_stmt += ")" - - query = insert_stmt + values_stmt - await self.engine._aexecute(query, row) + await self._engine._run_as_async(self._saver.aadd_documents(docs)) def add_documents(self, docs: List[Document]) -> None: """ @@ -521,7 +296,7 @@ def add_documents(self, docs: List[Document]) -> None: Args: docs (List[langchain_core.documents.Document]): a list of documents to be saved. """ - self.engine._run_as_sync(self.aadd_documents(docs)) + self._engine._run_as_sync(self._saver.aadd_documents(docs)) async def adelete(self, docs: List[Document]) -> None: """ @@ -531,34 +306,7 @@ async def adelete(self, docs: List[Document]) -> None: Args: docs (List[langchain_core.documents.Document]): a list of documents to be deleted. """ - for doc in docs: - row = _parse_row_from_doc( - doc, - self.metadata_columns, - self.content_column, - self.metadata_json_column, - ) - # delete by matching all fields of document - where_conditions_list = [] - for key, value in row.items(): - if isinstance(value, dict): - where_conditions_list.append( - f"{key}::jsonb @> '{json.dumps(value)}'::jsonb" - ) - else: - # Handle simple key-value pairs - where_conditions_list.append(f"{key} = :{key}") - - where_conditions = " AND ".join(where_conditions_list) - stmt = f'DELETE FROM "{self.schema_name}"."{self.table_name}" WHERE {where_conditions};' - values = {} - for key, value in row.items(): - if type(value) is int: - values[key] = str(value) - else: - values[key] = value - - await self.engine._aexecute(stmt, values) + await self._engine._run_as_async(self._saver.adelete(docs)) def delete(self, docs: List[Document]) -> None: """ @@ -568,4 +316,4 @@ def delete(self, docs: List[Document]) -> None: Args: docs (List[langchain_core.documents.Document]): a list of documents to be deleted. """ - self.engine._run_as_sync(self.adelete(docs)) + self._engine._run_as_sync(self._saver.adelete(docs)) diff --git a/src/langchain_google_cloud_sql_pg/vectorstore.py b/src/langchain_google_cloud_sql_pg/vectorstore.py index 964e9df9..109e5d9a 100644 --- a/src/langchain_google_cloud_sql_pg/vectorstore.py +++ b/src/langchain_google_cloud_sql_pg/vectorstore.py @@ -15,23 +15,19 @@ # TODO: Remove below import when minimum supported Python version is 3.10 from __future__ import annotations -import json -import uuid -from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Iterable, List, Optional, Tuple, Type import numpy as np from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore -from sqlalchemy.engine.row import RowMapping +from .async_vectorstore import AsyncPostgresVectorStore from .engine import PostgresEngine from .indexes import ( DEFAULT_DISTANCE_STRATEGY, - DEFAULT_INDEX_NAME_SUFFIX, BaseIndex, DistanceStrategy, - ExactNearestNeighbor, QueryOptions, ) @@ -42,41 +38,13 @@ class PostgresVectorStore(VectorStore): __create_key = object() def __init__( - self, - key: object, - engine: PostgresEngine, - embedding_service: Embeddings, - table_name: str, - schema_name: str = "public", - content_column: str = "content", - embedding_column: str = "embedding", - metadata_columns: List[str] = [], - id_column: str = "langchain_id", - metadata_json_column: Optional[str] = "langchain_metadata", - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - index_query_options: Optional[QueryOptions] = None, + self, key: object, engine: PostgresEngine, vs: AsyncPostgresVectorStore ): """PostgresVectorStore constructor. Args: key (object): Prevent direct constructor usage. engine (PostgresEngine): Connection pool engine for managing connections to Postgres database. - embedding_service (Embeddings): Text embedding model to use. - table_name (str): Name of the existing table or the table to be created. - schema_name (str, optional): Database schema name of the table. Defaults to "public". - content_column (str): Column that represent a Document’s page_content. Defaults to "content". - embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". - metadata_columns (List[str]): Column(s) that represent a document's metadata. - id_column (str): Column that represents the Document's id. Defaults to "langchain_id". - metadata_json_column (str): Column to store metadata as JSON. Defaults to "langchain_metadata". - distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. - k (int): Number of Documents to return from search. Defaults to 4. - fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. - index_query_options (QueryOptions): Index query option. - + vs (AsyncPostgresVectorstore): The async only VectorStore implementation Raises: Exception: If called directly by user. @@ -86,20 +54,8 @@ def __init__( "Only create class through 'create' or 'create_sync' methods!" ) - self.engine = engine - self.embedding_service = embedding_service - self.table_name = table_name - self.schema_name = schema_name - self.content_column = content_column - self.embedding_column = embedding_column - self.metadata_columns = metadata_columns - self.id_column = id_column - self.metadata_json_column = metadata_json_column - self.distance_strategy = distance_strategy - self.k = k - self.fetch_k = fetch_k - self.lambda_mult = lambda_mult - self.index_query_options = index_query_options + self._engine = engine + self.__vs = vs @classmethod async def create( @@ -142,56 +98,7 @@ async def create( Returns: PostgresVectorStore """ - if metadata_columns and ignore_metadata_columns: - raise ValueError( - "Can not use both metadata_columns and ignore_metadata_columns." - ) - # Get field type information - stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}' AND table_schema = '{schema_name}'" - results = await engine._afetch(stmt) - columns = {} - for field in results: - columns[field["column_name"]] = field["data_type"] - - # Check columns - if id_column not in columns: - raise ValueError(f"Id column, {id_column}, does not exist.") - if content_column not in columns: - raise ValueError(f"Content column, {content_column}, does not exist.") - content_type = columns[content_column] - if content_type != "text" and "char" not in content_type: - raise ValueError( - f"Content column, {content_column}, is type, {content_type}. It must be a type of character string." - ) - if embedding_column not in columns: - raise ValueError(f"Embedding column, {embedding_column}, does not exist.") - if columns[embedding_column] != "USER-DEFINED": - raise ValueError( - f"Embedding column, {embedding_column}, is not type Vector." - ) - - metadata_json_column = ( - None if metadata_json_column not in columns else metadata_json_column - ) - - # If using metadata_columns check to make sure column exists - for column in metadata_columns: - if column not in columns: - raise ValueError(f"Metadata column, {column}, does not exist.") - - # If using ignore_metadata_columns, filter out known columns and set known metadata columns - all_columns = columns - if ignore_metadata_columns: - for column in ignore_metadata_columns: - del all_columns[column] - - del all_columns[id_column] - del all_columns[content_column] - del all_columns[embedding_column] - metadata_columns = [k for k in all_columns.keys()] - - return cls( - cls.__create_key, + coro = AsyncPostgresVectorStore.create( engine, embedding_service, table_name, @@ -199,6 +106,7 @@ async def create( content_column, embedding_column, metadata_columns, + ignore_metadata_columns, id_column, metadata_json_column, distance_strategy, @@ -207,6 +115,8 @@ async def create( lambda_mult, index_query_options, ) + vs = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, vs) @classmethod def create_sync( @@ -249,7 +159,7 @@ def create_sync( Returns: PostgresVectorStore """ - coro = cls.create( + coro = AsyncPostgresVectorStore.create( engine, embedding_service, table_name, @@ -266,129 +176,98 @@ def create_sync( lambda_mult, index_query_options, ) - return engine._run_as_sync(coro) + vs = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, vs) @property def embeddings(self) -> Embeddings: - return self.embedding_service + return self.__vs.embedding_service - async def _aadd_embeddings( + async def aadd_texts( self, texts: Iterable[str], - embeddings: List[List[float]], metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, + ids: Optional[List] = None, **kwargs: Any, ) -> List[str]: - """Add embeddings to the table.""" - if not ids: - ids = [str(uuid.uuid4()) for _ in texts] - if not metadatas: - metadatas = [{} for _ in texts] - # Insert embeddings - for id, content, embedding, metadata in zip(ids, texts, embeddings, metadatas): - metadata_col_names = ( - ", " + ", ".join(self.metadata_columns) - if len(self.metadata_columns) > 0 - else "" - ) - insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"({self.id_column}, {self.content_column}, {self.embedding_column}{metadata_col_names}' - values = {"id": id, "content": content, "embedding": str(embedding)} - values_stmt = "VALUES (:id, :content, :embedding" - - # Add metadata - extra = metadata - for metadata_column in self.metadata_columns: - if metadata_column in metadata: - values_stmt += f", :{metadata_column}" - values[metadata_column] = metadata[metadata_column] - del extra[metadata_column] - else: - values_stmt += ",null" - - # Add JSON column and/or close statement - insert_stmt += ( - f", {self.metadata_json_column})" if self.metadata_json_column else ")" - ) - if self.metadata_json_column: - values_stmt += ", :extra)" - values["extra"] = json.dumps(extra) - else: - values_stmt += ")" + """Embed texts and add to the table. - query = insert_stmt + values_stmt - await self.engine._aexecute(query, values) - - return ids + Raises: + :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. + """ + return await self._engine._run_as_async( + self.__vs.aadd_texts(texts, metadatas, ids, **kwargs) + ) - async def aadd_texts( + def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, + ids: Optional[List] = None, **kwargs: Any, ) -> List[str]: - """Embed texts and add to the table.""" - embeddings = self.embedding_service.embed_documents(list(texts)) - ids = await self._aadd_embeddings( - texts, embeddings, metadatas=metadatas, ids=ids, **kwargs + """Embed texts and add to the table. + + Raises: + :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. + """ + return self._engine._run_as_sync( + self.__vs.aadd_texts(texts, metadatas, ids, **kwargs) ) - return ids async def aadd_documents( self, documents: List[Document], - ids: Optional[List[str]] = None, + ids: Optional[List] = None, **kwargs: Any, ) -> List[str]: - """Embed documents and add to the table""" - texts = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents] - ids = await self.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs) - return ids + """Embed documents and add to the table. - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, - ) -> List[str]: - """Embed texts and add to the table.""" - return self.engine._run_as_sync( - self.aadd_texts(texts, metadatas, ids, **kwargs) + Raises: + :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. + """ + return await self._engine._run_as_async( + self.__vs.aadd_documents(documents, ids, **kwargs) ) def add_documents( self, documents: List[Document], - ids: Optional[List[str]] = None, + ids: Optional[List] = None, **kwargs: Any, ) -> List[str]: - """Embed documents and add to the table.""" - return self.engine._run_as_sync(self.aadd_documents(documents, ids, **kwargs)) + """Embed documents and add to the table. + + Raises: + :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. + """ + return self._engine._run_as_sync( + self.__vs.aadd_documents(documents, ids, **kwargs) + ) async def adelete( self, - ids: Optional[List[str]] = None, + ids: Optional[List] = None, **kwargs: Any, ) -> Optional[bool]: - """Delete records from the table.""" - if not ids: - return False + """Delete records from the table. - id_list = ", ".join([f"'{id}'" for id in ids]) - query = f'DELETE FROM "{self.schema_name}"."{self.table_name}" WHERE {self.id_column} in ({id_list})' - await self.engine._aexecute(query) - return True + Raises: + :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. + """ + return await self._engine._run_as_async(self.__vs.adelete(ids, **kwargs)) def delete( self, - ids: Optional[List[str]] = None, + ids: Optional[List] = None, **kwargs: Any, ) -> Optional[bool]: - """Delete records from the table.""" - return self.engine._run_as_sync(self.adelete(ids, **kwargs)) + """Delete records from the table. + + Raises: + :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. + """ + return self._engine._run_as_sync(self.__vs.adelete(ids, **kwargs)) @classmethod async def afrom_texts( # type: ignore[override] @@ -399,16 +278,21 @@ async def afrom_texts( # type: ignore[override] table_name: str, schema_name: str = "public", metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, + ids: Optional[List] = None, content_column: str = "content", embedding_column: str = "embedding", metadata_columns: List[str] = [], ignore_metadata_columns: Optional[List[str]] = None, id_column: str = "langchain_id", metadata_json_column: str = "langchain_metadata", - **kwargs: Any, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + index_query_options: Optional[QueryOptions] = None, ) -> PostgresVectorStore: """Create an PostgresVectorStore instance from texts. + Args: texts (List[str]): Texts to add to the vector store. embedding (Embeddings): Text embedding model to use. @@ -416,13 +300,21 @@ async def afrom_texts( # type: ignore[override] table_name (str): Name of the existing table or the table to be created. schema_name (str, optional): Database schema name of the table. Defaults to "public". metadatas (Optional[List[dict]]): List of metadatas to add to table records. - ids: (Optional[List[str]]): List of IDs to add to table records. + ids: (Optional[List]): List of IDs to add to table records. content_column (str): Column that represent a Document’s page_content. Defaults to "content". embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". metadata_columns (List[str]): Column(s) that represent a document's metadata. ignore_metadata_columns (List[str]): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None. id_column (str): Column that represents the Document's id. Defaults to "langchain_id". metadata_json_column (str): Column to store metadata as JSON. Defaults to "langchain_metadata". + distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. + k (int): Number of Documents to return from search. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. + index_query_options (QueryOptions): Index query option. + + Raises: + :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. Returns: PostgresVectorStore @@ -438,8 +330,13 @@ async def afrom_texts( # type: ignore[override] ignore_metadata_columns, id_column, metadata_json_column, + distance_strategy, + k, + fetch_k, + lambda_mult, + index_query_options, ) - await vs.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs) + await vs.aadd_texts(texts, metadatas=metadatas, ids=ids) return vs @classmethod @@ -450,14 +347,18 @@ async def afrom_documents( # type: ignore[override] engine: PostgresEngine, table_name: str, schema_name: str = "public", - ids: Optional[List[str]] = None, + ids: Optional[List] = None, content_column: str = "content", embedding_column: str = "embedding", metadata_columns: List[str] = [], ignore_metadata_columns: Optional[List[str]] = None, id_column: str = "langchain_id", metadata_json_column: str = "langchain_metadata", - **kwargs: Any, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + index_query_options: Optional[QueryOptions] = None, ) -> PostgresVectorStore: """Create an PostgresVectorStore instance from documents. @@ -468,13 +369,21 @@ async def afrom_documents( # type: ignore[override] table_name (str): Name of the existing table or the table to be created. schema_name (str, optional): Database schema name of the table. Defaults to "public". metadatas (Optional[List[dict]]): List of metadatas to add to table records. - ids: (Optional[List[str]]): List of IDs to add to table records. + ids: (Optional[List]): List of IDs to add to table records. content_column (str): Column that represent a Document’s page_content. Defaults to "content". embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". metadata_columns (List[str]): Column(s) that represent a document's metadata. ignore_metadata_columns (List[str]): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None. id_column (str): Column that represents the Document's id. Defaults to "langchain_id". metadata_json_column (str): Column to store metadata as JSON. Defaults to "langchain_metadata". + distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. + k (int): Number of Documents to return from search. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. + index_query_options (QueryOptions): Index query option. + + Raises: + :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. Returns: PostgresVectorStore @@ -490,10 +399,13 @@ async def afrom_documents( # type: ignore[override] ignore_metadata_columns, id_column, metadata_json_column, + distance_strategy, + k, + fetch_k, + lambda_mult, + index_query_options, ) - texts = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents] - await vs.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs) + await vs.aadd_documents(documents, ids=ids) return vs @classmethod @@ -505,16 +417,21 @@ def from_texts( # type: ignore[override] table_name: str, schema_name: str = "public", metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, + ids: Optional[List] = None, content_column: str = "content", embedding_column: str = "embedding", metadata_columns: List[str] = [], ignore_metadata_columns: Optional[List[str]] = None, id_column: str = "langchain_id", metadata_json_column: str = "langchain_metadata", - **kwargs: Any, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + index_query_options: Optional[QueryOptions] = None, ) -> PostgresVectorStore: """Create an PostgresVectorStore instance from texts. + Args: texts (List[str]): Texts to add to the vector store. embedding (Embeddings): Text embedding model to use. @@ -522,34 +439,44 @@ def from_texts( # type: ignore[override] table_name (str): Name of the existing table or the table to be created. schema_name (str, optional): Database schema name of the table. Defaults to "public". metadatas (Optional[List[dict]]): List of metadatas to add to table records. - ids: (Optional[List[str]]): List of IDs to add to table records. + ids: (Optional[List]): List of IDs to add to table records. content_column (str): Column that represent a Document’s page_content. Defaults to "content". embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". metadata_columns (List[str]): Column(s) that represent a document's metadata. ignore_metadata_columns (List[str]): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None. id_column (str): Column that represents the Document's id. Defaults to "langchain_id". metadata_json_column (str): Column to store metadata as JSON. Defaults to "langchain_metadata". + distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. + k (int): Number of Documents to return from search. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. + index_query_options (QueryOptions): Index query option. + + Raises: + :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. Returns: PostgresVectorStore """ - coro = cls.afrom_texts( - texts, - embedding, + vs = cls.create_sync( engine, + embedding, table_name, schema_name, - metadatas=metadatas, - content_column=content_column, - embedding_column=embedding_column, - metadata_columns=metadata_columns, - ignore_metadata_columns=ignore_metadata_columns, - metadata_json_column=metadata_json_column, - id_column=id_column, - ids=ids, - **kwargs, + content_column, + embedding_column, + metadata_columns, + ignore_metadata_columns, + id_column, + metadata_json_column, + distance_strategy, + k, + fetch_k, + lambda_mult, + index_query_options, ) - return engine._run_as_sync(coro) + vs.add_texts(texts, metadatas=metadatas, ids=ids) + return vs @classmethod def from_documents( # type: ignore[override] @@ -559,14 +486,18 @@ def from_documents( # type: ignore[override] engine: PostgresEngine, table_name: str, schema_name: str = "public", - ids: Optional[List[str]] = None, + ids: Optional[List] = None, content_column: str = "content", embedding_column: str = "embedding", metadata_columns: List[str] = [], ignore_metadata_columns: Optional[List[str]] = None, id_column: str = "langchain_id", metadata_json_column: str = "langchain_metadata", - **kwargs: Any, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + index_query_options: Optional[QueryOptions] = None, ) -> PostgresVectorStore: """Create an PostgresVectorStore instance from documents. @@ -577,58 +508,46 @@ def from_documents( # type: ignore[override] table_name (str): Name of the existing table or the table to be created. schema_name (str, optional): Database schema name of the table. Defaults to "public". metadatas (Optional[List[dict]]): List of metadatas to add to table records. - ids: (Optional[List[str]]): List of IDs to add to table records. + ids: (Optional[List]): List of IDs to add to table records. content_column (str): Column that represent a Document’s page_content. Defaults to "content". embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". metadata_columns (List[str]): Column(s) that represent a document's metadata. ignore_metadata_columns (List[str]): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None. id_column (str): Column that represents the Document's id. Defaults to "langchain_id". metadata_json_column (str): Column to store metadata as JSON. Defaults to "langchain_metadata". + distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. + k (int): Number of Documents to return from search. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. + index_query_options (QueryOptions): Index query option. + + Raises: + :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. Returns: PostgresVectorStore """ - coro = cls.afrom_documents( - documents, - embedding, + vs = cls.create_sync( engine, + embedding, table_name, schema_name, - content_column=content_column, - embedding_column=embedding_column, - metadata_columns=metadata_columns, - ignore_metadata_columns=ignore_metadata_columns, - metadata_json_column=metadata_json_column, - id_column=id_column, - ids=ids, - **kwargs, + content_column, + embedding_column, + metadata_columns, + ignore_metadata_columns, + id_column, + metadata_json_column, + distance_strategy, + k, + fetch_k, + lambda_mult, + index_query_options, ) - return engine._run_as_sync(coro) - - async def __query_collection( - self, - embedding: List[float], - k: Optional[int] = None, - filter: Optional[str] = None, - **kwargs: Any, - ) -> Sequence[RowMapping]: - """Perform similarity search query on the vector store table.""" - k = k if k else self.k - operator = self.distance_strategy.operator - search_function = self.distance_strategy.search_function - - filter = f"WHERE {filter}" if filter else "" - stmt = f"SELECT *, {search_function}({self.embedding_column}, '{embedding}') as distance FROM \"{self.schema_name}\".\"{self.table_name}\" {filter} ORDER BY {self.embedding_column} {operator} '{embedding}' LIMIT {k};" - if self.index_query_options: - query_options_stmt = f"SET LOCAL {self.index_query_options.to_string()};" - results = await self.engine._afetch_with_query_options( - stmt, query_options_stmt - ) - else: - results = await self.engine._afetch(stmt) - return results + vs.add_documents(documents, ids=ids) + return vs - def similarity_search( + async def asimilarity_search( self, query: str, k: Optional[int] = None, @@ -636,11 +555,11 @@ def similarity_search( **kwargs: Any, ) -> List[Document]: """Return docs selected by similarity search on query.""" - return self.engine._run_as_sync( - self.asimilarity_search(query, k=k, filter=filter, **kwargs) + return await self._engine._run_as_async( + self.__vs.asimilarity_search(query, k, filter, **kwargs) ) - async def asimilarity_search( + def similarity_search( self, query: str, k: Optional[int] = None, @@ -648,21 +567,19 @@ async def asimilarity_search( **kwargs: Any, ) -> List[Document]: """Return docs selected by similarity search on query.""" - embedding = self.embedding_service.embed_query(text=query) - - return await self.asimilarity_search_by_vector( - embedding=embedding, k=k, filter=filter, **kwargs + return self._engine._run_as_sync( + self.__vs.asimilarity_search(query, k, filter, **kwargs) ) + # Required for (a)similarity_search_with_relevance_scores def _select_relevance_score_fn(self) -> Callable[[float], float]: """Select a relevance function based on distance strategy.""" - # Calculate distance strategy provided in - # vectorstore constructor - if self.distance_strategy == DistanceStrategy.COSINE_DISTANCE: + # Calculate distance strategy provided in vectorstore constructor + if self.__vs.distance_strategy == DistanceStrategy.COSINE_DISTANCE: return self._cosine_relevance_score_fn - if self.distance_strategy == DistanceStrategy.INNER_PRODUCT: + if self.__vs.distance_strategy == DistanceStrategy.INNER_PRODUCT: return self._max_inner_product_relevance_score_fn - elif self.distance_strategy == DistanceStrategy.EUCLIDEAN: + elif self.__vs.distance_strategy == DistanceStrategy.EUCLIDEAN: return self._euclidean_relevance_score_fn async def asimilarity_search_with_score( @@ -673,11 +590,21 @@ async def asimilarity_search_with_score( **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs and distance scores selected by similarity search on query.""" - embedding = self.embedding_service.embed_query(query) - docs = await self.asimilarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter, **kwargs + return await self._engine._run_as_async( + self.__vs.asimilarity_search_with_score(query, k, filter, **kwargs) + ) + + def similarity_search_with_score( + self, + query: str, + k: Optional[int] = None, + filter: Optional[str] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs and distance scores selected by similarity search on query.""" + return self._engine._run_as_sync( + self.__vs.asimilarity_search_with_score(query, k, filter, **kwargs) ) - return docs async def asimilarity_search_by_vector( self, @@ -687,11 +614,21 @@ async def asimilarity_search_by_vector( **kwargs: Any, ) -> List[Document]: """Return docs selected by vector similarity search.""" - docs_and_scores = await self.asimilarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter, **kwargs + return await self._engine._run_as_async( + self.__vs.asimilarity_search_by_vector(embedding, k, filter, **kwargs) ) - return [doc for doc, _ in docs_and_scores] + def similarity_search_by_vector( + self, + embedding: List[float], + k: Optional[int] = None, + filter: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected by vector similarity search.""" + return self._engine._run_as_sync( + self.__vs.asimilarity_search_by_vector(embedding, k, filter, **kwargs) + ) async def asimilarity_search_with_score_by_vector( self, @@ -701,30 +638,25 @@ async def asimilarity_search_with_score_by_vector( **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs and distance scores selected by vector similarity search.""" - results = await self.__query_collection( - embedding=embedding, k=k, filter=filter, **kwargs + return await self._engine._run_as_async( + self.__vs.asimilarity_search_with_score_by_vector( + embedding, k, filter, **kwargs + ) ) - documents_with_scores = [] - for row in results: - metadata = ( - row[self.metadata_json_column] - if self.metadata_json_column and row[self.metadata_json_column] - else {} - ) - for col in self.metadata_columns: - metadata[col] = row[col] - documents_with_scores.append( - ( - Document( - page_content=row[self.content_column], - metadata=metadata, - ), - row["distance"], - ) + def similarity_search_with_score_by_vector( + self, + embedding: List[float], + k: Optional[int] = None, + filter: Optional[str] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs and distance scores selected by similarity search on vector.""" + return self._engine._run_as_sync( + self.__vs.asimilarity_search_with_score_by_vector( + embedding, k, filter, **kwargs ) - - return documents_with_scores + ) async def amax_marginal_relevance_search( self, @@ -736,20 +668,15 @@ async def amax_marginal_relevance_search( **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance.""" - embedding = self.embedding_service.embed_query(text=query) - - return await self.amax_marginal_relevance_search_by_vector( - embedding=embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - **kwargs, + return await self._engine._run_as_async( + self.__vs.amax_marginal_relevance_search( + query, k, fetch_k, lambda_mult, filter, **kwargs + ) ) - async def amax_marginal_relevance_search_by_vector( + def max_marginal_relevance_search( self, - embedding: List[float], + query: str, k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, @@ -757,20 +684,13 @@ async def amax_marginal_relevance_search_by_vector( **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance.""" - docs_and_scores = ( - await self.amax_marginal_relevance_search_with_score_by_vector( - embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - **kwargs, + return self._engine._run_as_sync( + self.__vs.amax_marginal_relevance_search( + query, k, fetch_k, lambda_mult, filter, **kwargs ) ) - return [result[0] for result in docs_and_scores] - - async def amax_marginal_relevance_search_with_score_by_vector( + async def amax_marginal_relevance_search_by_vector( self, embedding: List[float], k: Optional[int] = None, @@ -778,82 +698,17 @@ async def amax_marginal_relevance_search_with_score_by_vector( lambda_mult: Optional[float] = None, filter: Optional[str] = None, **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs and distance scores selected using the maximal marginal relevance.""" - results = await self.__query_collection( - embedding=embedding, k=fetch_k, filter=filter, **kwargs - ) - - k = k if k else self.k - fetch_k = fetch_k if fetch_k else self.fetch_k - lambda_mult = lambda_mult if lambda_mult else self.lambda_mult - embedding_list = [json.loads(row[self.embedding_column]) for row in results] - mmr_selected = maximal_marginal_relevance( - np.array(embedding, dtype=np.float32), - embedding_list, - k=k, - lambda_mult=lambda_mult, - ) - - documents_with_scores = [] - for row in results: - metadata = ( - row[self.metadata_json_column] - if self.metadata_json_column and row[self.metadata_json_column] - else {} - ) - for col in self.metadata_columns: - metadata[col] = row[col] - documents_with_scores.append( - ( - Document( - page_content=row[self.content_column], - metadata=metadata, - ), - row["distance"], - ) - ) - - return [r for i, r in enumerate(documents_with_scores) if i in mmr_selected] - - def similarity_search_with_score( - self, - query: str, - k: Optional[int] = None, - filter: Optional[str] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs and distance scores selected by similarity search on query.""" - coro = self.asimilarity_search_with_score(query, k, filter=filter, **kwargs) - return self.engine._run_as_sync(coro) - - def similarity_search_by_vector( - self, - embedding: List[float], - k: Optional[int] = None, - filter: Optional[str] = None, - **kwargs: Any, ) -> List[Document]: - """Return docs selected by vector similarity search.""" - coro = self.asimilarity_search_by_vector(embedding, k, filter=filter, **kwargs) - return self.engine._run_as_sync(coro) - - def similarity_search_with_score_by_vector( - self, - embedding: List[float], - k: Optional[int] = None, - filter: Optional[str] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs and distance scores selected by similarity search on vector.""" - coro = self.asimilarity_search_with_score_by_vector( - embedding, k, filter=filter, **kwargs + """Return docs selected using the maximal marginal relevance.""" + return await self._engine._run_as_async( + self.__vs.amax_marginal_relevance_search_by_vector( + embedding, k, fetch_k, lambda_mult, filter, **kwargs + ) ) - return self.engine._run_as_sync(coro) - def max_marginal_relevance_search( + def max_marginal_relevance_search_by_vector( self, - query: str, + embedding: List[float], k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, @@ -861,17 +716,13 @@ def max_marginal_relevance_search( **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance.""" - coro = self.amax_marginal_relevance_search( - query, - k, - filter=filter, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - **kwargs, + return self._engine._run_as_sync( + self.__vs.amax_marginal_relevance_search_by_vector( + embedding, k, fetch_k, lambda_mult, filter, **kwargs + ) ) - return self.engine._run_as_sync(coro) - def max_marginal_relevance_search_by_vector( + async def amax_marginal_relevance_search_with_score_by_vector( self, embedding: List[float], k: Optional[int] = None, @@ -879,17 +730,13 @@ def max_marginal_relevance_search_by_vector( lambda_mult: Optional[float] = None, filter: Optional[str] = None, **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance.""" - coro = self.amax_marginal_relevance_search_by_vector( - embedding, - k, - filter=filter, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - **kwargs, + ) -> List[Tuple[Document, float]]: + """Return docs and distance scores selected using the maximal marginal relevance.""" + return await self._engine._run_as_async( + self.__vs.amax_marginal_relevance_search_with_score_by_vector( + embedding, k, fetch_k, lambda_mult, filter, **kwargs + ) ) - return self.engine._run_as_sync(coro) def max_marginal_relevance_search_with_score_by_vector( self, @@ -901,15 +748,11 @@ def max_marginal_relevance_search_with_score_by_vector( **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs and distance scores selected using the maximal marginal relevance.""" - coro = self.amax_marginal_relevance_search_with_score_by_vector( - embedding, - k, - filter=filter, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - **kwargs, + return self._engine._run_as_sync( + self.__vs.amax_marginal_relevance_search_with_score_by_vector( + embedding, k, fetch_k, lambda_mult, filter, **kwargs + ) ) - return self.engine._run_as_sync(coro) async def aapply_vector_index( self, @@ -918,121 +761,55 @@ async def aapply_vector_index( concurrently: bool = False, ) -> None: """Create an index on the vector store table.""" - if isinstance(index, ExactNearestNeighbor): - await self.adrop_vector_index() - return - - filter = f"WHERE ({index.partial_indexes})" if index.partial_indexes else "" - params = "WITH " + index.index_options() - function = index.distance_strategy.index_function - if name is None: - if index.name == None: - index.name = self.table_name + DEFAULT_INDEX_NAME_SUFFIX - name = index.name - stmt = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} {name} ON "{self.schema_name}"."{self.table_name}" USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};' - if concurrently: - await self.engine._aexecute_outside_tx(stmt) - else: - await self.engine._aexecute(stmt) + return await self._engine._run_as_async( + self.__vs.aapply_vector_index(index, name, concurrently) + ) + + def apply_vector_index( + self, + index: BaseIndex, + name: Optional[str] = None, + concurrently: bool = False, + ) -> None: + """Create an index on the vector store table.""" + return self._engine._run_as_sync( + self.__vs.aapply_vector_index(index, name, concurrently) + ) async def areindex(self, index_name: Optional[str] = None) -> None: """Re-index the vector store table.""" - index_name = index_name or self.table_name + DEFAULT_INDEX_NAME_SUFFIX - query = f"REINDEX INDEX {index_name};" - await self.engine._aexecute(query) + return await self._engine._run_as_async(self.__vs.areindex(index_name)) + + def reindex(self, index_name: Optional[str] = None) -> None: + """Re-index the vector store table.""" + return self._engine._run_as_sync(self.__vs.areindex(index_name)) async def adrop_vector_index( self, index_name: Optional[str] = None, ) -> None: """Drop the vector index.""" - index_name = index_name or self.table_name + DEFAULT_INDEX_NAME_SUFFIX - query = f"DROP INDEX IF EXISTS {index_name};" - await self.engine._aexecute(query) + return await self._engine._run_as_async( + self.__vs.adrop_vector_index(index_name) + ) + + def drop_vector_index( + self, + index_name: Optional[str] = None, + ) -> None: + """Drop the vector index.""" + return self._engine._run_as_sync(self.__vs.adrop_vector_index(index_name)) - async def is_valid_index( + async def ais_valid_index( self, index_name: Optional[str] = None, ) -> bool: """Check if index exists in the table.""" - index_name = index_name or self.table_name + DEFAULT_INDEX_NAME_SUFFIX - query = f""" - SELECT tablename, indexname - FROM pg_indexes - WHERE tablename = '{self.table_name}' AND schemaname = '{self.schema_name}' AND indexname = '{index_name}'; - """ - results = await self.engine._afetch(query) - return bool(len(results) == 1) - - -### The following is copied from langchain-community until it's moved into core - -Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] - - -def maximal_marginal_relevance( - query_embedding: np.ndarray, - embedding_list: list, - lambda_mult: float = 0.5, - k: int = 4, -) -> List[int]: - """Calculate maximal marginal relevance.""" - if min(k, len(embedding_list)) <= 0: - return [] - if query_embedding.ndim == 1: - query_embedding = np.expand_dims(query_embedding, axis=0) - similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0] - most_similar = int(np.argmax(similarity_to_query)) - idxs = [most_similar] - selected = np.array([embedding_list[most_similar]]) - while len(idxs) < min(k, len(embedding_list)): - best_score = -np.inf - idx_to_add = -1 - similarity_to_selected = cosine_similarity(embedding_list, selected) - for i, query_score in enumerate(similarity_to_query): - if i in idxs: - continue - redundant_score = max(similarity_to_selected[i]) - equation_score = ( - lambda_mult * query_score - (1 - lambda_mult) * redundant_score - ) - if equation_score > best_score: - best_score = equation_score - idx_to_add = i - idxs.append(idx_to_add) - selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) - return idxs - - -def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: - """Row-wise cosine similarity between two equal-width matrices.""" - if len(X) == 0 or len(Y) == 0: - return np.array([]) - - X = np.array(X) - Y = np.array(Y) - if X.shape[1] != Y.shape[1]: - raise ValueError( - f"Number of columns in X and Y must be the same. X has shape {X.shape} " - f"and Y has shape {Y.shape}." - ) - try: - import simsimd as simd # type: ignore - - X = np.array(X, dtype=np.float32) - Y = np.array(Y, dtype=np.float32) - Z = 1 - simd.cdist(X, Y, metric="cosine") - if isinstance(Z, float): - return np.array([Z]) - return Z - except ImportError: - X_norm = np.linalg.norm(X, axis=1) - Y_norm = np.linalg.norm(Y, axis=1) - # Ignore divide by zero errors run time warnings as those are handled below. - with np.errstate(divide="ignore", invalid="ignore"): - similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) - similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 - return similarity - - -### End code from langchain-community + return await self._engine._run_as_async(self.__vs.is_valid_index(index_name)) + + def is_valid_index( + self, + index_name: Optional[str] = None, + ) -> bool: + """Check if index exists in the table.""" + return self._engine._run_as_sync(self.__vs.is_valid_index(index_name)) diff --git a/src/langchain_google_cloud_sql_pg/version.py b/src/langchain_google_cloud_sql_pg/version.py index ba03825a..00f17d64 100644 --- a/src/langchain_google_cloud_sql_pg/version.py +++ b/src/langchain_google_cloud_sql_pg/version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "0.9.0" +__version__ = "0.10.0" diff --git a/tests/test_async_chatmessagehistory.py b/tests/test_async_chatmessagehistory.py new file mode 100644 index 00000000..b626674b --- /dev/null +++ b/tests/test_async_chatmessagehistory.py @@ -0,0 +1,124 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import uuid +from typing import Any, Generator + +import pytest +import pytest_asyncio +from langchain_core.messages.ai import AIMessage +from langchain_core.messages.human import HumanMessage +from sqlalchemy import text + +from langchain_google_cloud_sql_pg import PostgresEngine +from langchain_google_cloud_sql_pg.async_chat_message_history import ( + AsyncPostgresChatMessageHistory, +) + +project_id = os.environ["PROJECT_ID"] +region = os.environ["REGION"] +instance_id = os.environ["INSTANCE_ID"] +db_name = os.environ["DATABASE_ID"] +table_name = "message_store" + str(uuid.uuid4()) +table_name_async = "message_store" + str(uuid.uuid4()) + + +async def aexecute(engine: PostgresEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +@pytest_asyncio.fixture +async def async_engine(): + async_engine = await PostgresEngine.afrom_instance( + project_id=project_id, + region=region, + instance=instance_id, + database=db_name, + ) + await async_engine._ainit_chat_history_table(table_name=table_name_async) + yield async_engine + # use default table for AsyncPostgresChatMessageHistory + query = f'DROP TABLE IF EXISTS "{table_name_async}"' + await aexecute(async_engine, query) + await async_engine.close() + + +@pytest.mark.asyncio +async def test_chat_message_history_async( + async_engine: PostgresEngine, +) -> None: + history = await AsyncPostgresChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=table_name_async + ) + msg1 = HumanMessage(content="hi!") + msg2 = AIMessage(content="whats up?") + await history.aadd_message(msg1) + await history.aadd_message(msg2) + messages = await history._aget_messages() + + # verify messages are correct + assert messages[0].content == "hi!" + assert type(messages[0]) is HumanMessage + assert messages[1].content == "whats up?" + assert type(messages[1]) is AIMessage + + # verify clear() clears message history + await history.aclear() + assert len(await history._aget_messages()) == 0 + + +@pytest.mark.asyncio +async def test_chat_message_history_sync_messages( + async_engine: PostgresEngine, +) -> None: + history1 = await AsyncPostgresChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=table_name_async + ) + history2 = await AsyncPostgresChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=table_name_async + ) + msg1 = HumanMessage(content="hi!") + msg2 = AIMessage(content="whats up?") + await history1.aadd_message(msg1) + await history2.aadd_message(msg2) + + assert len(await history1._aget_messages()) == 2 + assert len(await history2._aget_messages()) == 2 + + # verify clear() clears message history + await history2.aclear() + assert len(await history2._aget_messages()) == 0 + + +@pytest.mark.asyncio +async def test_chat_table_async(async_engine): + with pytest.raises(ValueError): + await AsyncPostgresChatMessageHistory.create( + engine=async_engine, session_id="test", table_name="doesnotexist" + ) + + +@pytest.mark.asyncio +async def test_chat_schema_async(async_engine): + table_name = "test_table" + str(uuid.uuid4()) + await async_engine._ainit_document_table(table_name=table_name) + with pytest.raises(IndexError): + await AsyncPostgresChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=table_name + ) + + query = f'DROP TABLE IF EXISTS "{table_name}"' + await aexecute(async_engine, query) diff --git a/tests/test_async_loader.py b/tests/test_async_loader.py new file mode 100644 index 00000000..c29a82f7 --- /dev/null +++ b/tests/test_async_loader.py @@ -0,0 +1,757 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import uuid + +import pytest +import pytest_asyncio +from langchain_core.documents import Document +from sqlalchemy import text + +from langchain_google_cloud_sql_pg import Column, PostgresEngine +from langchain_google_cloud_sql_pg.async_loader import ( + AsyncPostgresDocumentSaver, + AsyncPostgresLoader, +) + +project_id = os.environ["PROJECT_ID"] +region = os.environ["REGION"] +instance_id = os.environ["INSTANCE_ID"] +db_name = os.environ["DATABASE_ID"] +table_name = "test-table" + str(uuid.uuid4()) + + +async def aexecute(engine: PostgresEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +@pytest.mark.asyncio(scope="class") +class TestLoaderAsync: + + @pytest_asyncio.fixture(scope="class") + async def engine(self): + PostgresEngine._connector = None + engine = await PostgresEngine.afrom_instance( + project_id=project_id, + instance=instance_id, + region=region, + database=db_name, + ) + yield engine + + await engine.close() + + async def _collect_async_items(self, docs_generator): + """Collects items from an async generator.""" + docs = [] + async for doc in docs_generator: + docs.append(doc) + return docs + + async def _cleanup_table(self, engine): + await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_create_loader_with_invalid_parameters(self, engine): + with pytest.raises(ValueError): + await AsyncPostgresLoader.create( + engine=engine, + ) + with pytest.raises(ValueError): + + def fake_formatter(): + return None + + await AsyncPostgresLoader.create( + engine=engine, + table_name=table_name, + format="text", + formatter=fake_formatter, + ) + with pytest.raises(ValueError): + await AsyncPostgresLoader.create( + engine=engine, + table_name=table_name, + format="fake_format", + ) + + async def test_load_from_query_default(self, engine): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" ( + fruit_name, variety, quantity_in_stock, price_per_unit, organic + ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(engine, insert_query) + + loader = await AsyncPostgresLoader.create( + engine=engine, + table_name=table_name, + ) + + documents = await self._collect_async_items(loader.alazy_load()) + + assert documents == [ + Document( + page_content="1", + metadata={ + "fruit_name": "Apple", + "variety": "Granny Smith", + "quantity_in_stock": 150, + "price_per_unit": 1, + "organic": 1, + }, + ) + ] + await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_customized_metadata(self, engine): + await self._cleanup_table(engine) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 0.99, 1), + ('Banana', 'Cavendish', 200, 0.59, 0), + ('Orange', 'Navel', 80, 1.29, 1); + """ + await aexecute(engine, insert_query) + + loader = await AsyncPostgresLoader.create( + engine=engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ) + + documents = await self._collect_async_items(loader.alazy_load()) + + assert documents == [ + Document( + page_content="Apple Granny Smith 150 1 1", + metadata={"fruit_id": 1}, + ), + Document( + page_content="Banana Cavendish 200 1 0", + metadata={"fruit_id": 2}, + ), + Document( + page_content="Orange Navel 80 1 1", + metadata={"fruit_id": 3}, + ), + ] + + await self._cleanup_table(engine) + + async def test_load_from_query_customized_content_default_metadata(self, engine): + await self._cleanup_table(engine) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(engine, insert_query) + + loader = await AsyncPostgresLoader.create( + engine=engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ) + + documents = [] + async for docs in loader.alazy_load(): + documents.append(docs) + + assert documents == [ + Document( + page_content="Granny Smith 150 1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + loader = await AsyncPostgresLoader.create( + engine=engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="JSON", + ) + + documents = await self._collect_async_items(loader.alazy_load()) + + assert documents == [ + Document( + page_content='{"variety": "Granny Smith", "quantity_in_stock": 150, "price_per_unit": 1}', + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + await self._cleanup_table(engine) + + async def test_load_from_query_default_content_customized_metadata(self, engine): + await self._cleanup_table(engine) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" ( + fruit_name, + variety, + quantity_in_stock, + price_per_unit, + organic + ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(engine, insert_query) + + loader = await AsyncPostgresLoader.create( + engine=engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=["fruit_name", "organic"], + ) + + documents = await self._collect_async_items(loader.alazy_load()) + + assert documents == [ + Document( + page_content="1", + metadata={"fruit_name": "Apple", "organic": 1}, + ) + ] + await self._cleanup_table(engine) + + async def test_load_from_query_with_langchain_metadata(self, engine): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}"( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + langchain_metadata JSON NOT NULL + ) + """ + await aexecute(engine, query) + + metadata = json.dumps({"organic": 1}) + insert_query = f""" + INSERT INTO "{table_name}" + (fruit_name, variety, quantity_in_stock, price_per_unit, langchain_metadata) + VALUES ('Apple', 'Granny Smith', 150, 1, '{metadata}');""" + await aexecute(engine, insert_query) + + loader = await AsyncPostgresLoader.create( + engine=engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "fruit_name", + "langchain_metadata", + ], + ) + + documents = await self._collect_async_items(loader.alazy_load()) + + assert documents == [ + Document( + page_content="1", + metadata={ + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_with_json(self, engine): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}"( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety JSON NOT NULL, + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + langchain_metadata JSON NOT NULL + ) + """ + await aexecute(engine, query) + + metadata = json.dumps({"organic": 1}) + variety = json.dumps({"type": "Granny Smith"}) + insert_query = f""" + INSERT INTO "{table_name}" + (fruit_name, variety, quantity_in_stock, price_per_unit, langchain_metadata) + VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" + await aexecute(engine, insert_query) + + loader = await AsyncPostgresLoader.create( + engine=engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "variety", + ], + ) + + documents = await self._collect_async_items(loader.alazy_load()) + + assert documents == [ + Document( + page_content="1", + metadata={ + "variety": {"type": "Granny Smith"}, + "organic": 1, + }, + ) + ] + await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_formatter( + self, engine + ): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(engine, insert_query) + + def my_formatter(row, content_columns): + return "-".join( + str(row[column]) for column in content_columns if column in row + ) + + loader = await AsyncPostgresLoader.create( + engine=engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + formatter=my_formatter, + ) + + documents = await self._collect_async_items(loader.alazy_load()) + + assert documents == [ + Document( + page_content="Granny Smith-150-1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_page_content_format( + self, engine + ): + await self._cleanup_table(engine) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(engine, insert_query) + + loader = await AsyncPostgresLoader.create( + engine=engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="YAML", + ) + + documents = await self._collect_async_items(loader.alazy_load()) + + assert documents == [ + Document( + page_content="variety: Granny Smith\nquantity_in_stock: 150\nprice_per_unit: 1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + await self._cleanup_table(engine) + + async def test_save_doc_with_default_metadata(self, engine): + + await self._cleanup_table(engine) + await engine._ainit_document_table(table_name) + test_docs = [ + Document( + page_content="Apple Granny Smith 150 0.99 1", + metadata={"fruit_id": 1}, + ), + Document( + page_content="Banana Cavendish 200 0.59 0", + metadata={"fruit_id": 2}, + ), + Document( + page_content="Orange Navel 80 1.29 1", + metadata={"fruit_id": 3}, + ), + ] + saver = await AsyncPostgresDocumentSaver.create( + engine=engine, table_name=table_name + ) + loader = await AsyncPostgresLoader.create(engine=engine, table_name=table_name) + + await saver.aadd_documents(test_docs) + docs = await self._collect_async_items(loader.alazy_load()) + + assert docs == test_docs + assert (await engine._aload_table_schema(table_name)).columns.keys() == [ + "page_content", + "langchain_metadata", + ] + await self._cleanup_table(engine) + + @pytest.mark.parametrize("store_metadata", [True, False]) + async def test_save_doc_with_customized_metadata(self, engine, store_metadata): + table_name = "test-table" + str(uuid.uuid4()) + await engine._ainit_document_table( + table_name, + metadata_columns=[ + Column("fruit_name", "VARCHAR"), + Column("organic", "BOOLEAN"), + ], + store_metadata=store_metadata, + ) + test_docs = [ + Document( + page_content="Granny Smith 150 0.99", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": True, + }, + ), + ] + saver = await AsyncPostgresDocumentSaver.create( + engine=engine, table_name=table_name + ) + loader = await AsyncPostgresLoader.create( + engine=engine, + table_name=table_name, + metadata_columns=[ + "fruit_name", + "organic", + ], + ) + + await saver.aadd_documents(test_docs) + docs = await self._collect_async_items(loader.alazy_load()) + + if store_metadata: + docs == test_docs + assert (await engine._aload_table_schema(table_name)).columns.keys() == [ + "page_content", + "fruit_name", + "organic", + "langchain_metadata", + ] + else: + assert docs == [ + Document( + page_content="Granny Smith 150 0.99", + metadata={"fruit_name": "Apple", "organic": True}, + ), + ] + assert (await engine._aload_table_schema(table_name)).columns.keys() == [ + "page_content", + "fruit_name", + "organic", + ] + await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_save_doc_without_metadata(self, engine): + table_name = "test-table" + str(uuid.uuid4()) + await engine._ainit_document_table(table_name, store_metadata=False) + test_docs = [ + Document( + page_content="Granny Smith 150 0.99", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ), + ] + saver = await AsyncPostgresDocumentSaver.create( + engine=engine, table_name=table_name + ) + await saver.aadd_documents(test_docs) + + loader = await AsyncPostgresLoader.create( + engine=engine, + table_name=table_name, + ) + + docs = await self._collect_async_items(loader.alazy_load()) + + assert docs == [ + Document( + page_content="Granny Smith 150 0.99", + metadata={}, + ), + ] + assert (await engine._aload_table_schema(table_name)).columns.keys() == [ + "page_content", + ] + await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_delete_doc_with_default_metadata(self, engine): + table_name = "test-table" + str(uuid.uuid4()) + await engine._ainit_document_table(table_name) + + test_docs = [ + Document( + page_content="Apple Granny Smith 150 0.99 1", + metadata={"fruit_id": 1}, + ), + Document( + page_content="Banana Cavendish 200 0.59 0 1", + metadata={"fruit_id": 2}, + ), + ] + saver = await AsyncPostgresDocumentSaver.create( + engine=engine, table_name=table_name + ) + loader = await AsyncPostgresLoader.create(engine=engine, table_name=table_name) + + await saver.aadd_documents(test_docs) + docs = await self._collect_async_items(loader.alazy_load()) + assert docs == test_docs + + await saver.adelete(docs[:1]) + assert len(await self._collect_async_items(loader.alazy_load())) == 1 + + await saver.adelete(docs) + assert len(await self._collect_async_items(loader.alazy_load())) == 0 + await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_delete_doc_with_query(self, engine): + await self._cleanup_table(engine) + await engine._ainit_document_table( + table_name, + metadata_columns=[ + Column( + "fruit_name", + "VARCHAR", + ), + Column( + "organic", + "BOOLEAN", + ), + ], + store_metadata=True, + ) + + test_docs = [ + Document( + page_content="Granny Smith 150 0.99", + metadata={ + "fruit-id": 1, + "fruit_name": "Apple", + "organic": True, + }, + ), + Document( + page_content="Cavendish 200 0.59 0", + metadata={ + "fruit_id": 2, + "fruit_name": "Banana", + "organic": False, + }, + ), + Document( + page_content="Navel 80 1.29 1", + metadata={ + "fruit_id": 3, + "fruit_name": "Orange", + "organic": True, + }, + ), + ] + saver = await AsyncPostgresDocumentSaver.create( + engine=engine, table_name=table_name + ) + query = f"SELECT * FROM \"{table_name}\" WHERE fruit_name='Apple';" + loader = await AsyncPostgresLoader.create(engine=engine, query=query) + + await saver.aadd_documents(test_docs) + docs = await self._collect_async_items(loader.alazy_load()) + assert len(docs) == 1 + + await saver.adelete(docs) + assert len(await self._collect_async_items(loader.alazy_load())) == 0 + await self._cleanup_table(engine) + + @pytest.mark.parametrize("metadata_json_column", [None, "metadata_col_test"]) + async def test_delete_doc_with_customized_metadata( + self, engine, metadata_json_column + ): + table_name = "test-table" + str(uuid.uuid4()) + content_column = "content_col_test" + await engine._ainit_document_table( + table_name, + metadata_columns=[ + Column("fruit_name", "VARCHAR"), + Column("organic", "BOOLEAN"), + ], + content_column=content_column, + metadata_json_column=metadata_json_column, + ) + test_docs = [ + Document( + page_content="Granny Smith 150 0.99", + metadata={ + "fruit-id": 1, + "fruit_name": "Apple", + "organic": True, + }, + ), + Document( + page_content="Cavendish 200 0.59 0", + metadata={ + "fruit_id": 2, + "fruit_name": "Banana", + "organic": True, + }, + ), + ] + saver = await AsyncPostgresDocumentSaver.create( + engine=engine, + table_name=table_name, + content_column=content_column, + metadata_json_column=metadata_json_column, + ) + loader = await AsyncPostgresLoader.create( + engine=engine, + table_name=table_name, + content_columns=[content_column], + metadata_json_column=metadata_json_column, + ) + + await saver.aadd_documents(test_docs) + + docs = await loader.aload() + assert len(docs) == 2 + + await saver.adelete(docs[:1]) + assert len(await self._collect_async_items(loader.alazy_load())) == 1 + + await saver.adelete(docs) + assert len(await self._collect_async_items(loader.alazy_load())) == 0 + await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') diff --git a/tests/test_cloudsql_vectorstore.py b/tests/test_async_vectorstore.py similarity index 67% rename from tests/test_cloudsql_vectorstore.py rename to tests/test_async_vectorstore.py index 081853a8..52016149 100644 --- a/tests/test_cloudsql_vectorstore.py +++ b/tests/test_async_vectorstore.py @@ -14,13 +14,17 @@ import os import uuid +from typing import Sequence import pytest import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding +from sqlalchemy import text +from sqlalchemy.engine.row import RowMapping -from langchain_google_cloud_sql_pg import Column, PostgresEngine, PostgresVectorStore +from langchain_google_cloud_sql_pg import Column, PostgresEngine +from langchain_google_cloud_sql_pg.async_vectorstore import AsyncPostgresVectorStore DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4()) @@ -45,6 +49,20 @@ def get_env_var(key: str, desc: str) -> str: return v +async def aexecute(engine: PostgresEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + @pytest.mark.asyncio(scope="class") class TestVectorStore: @pytest.fixture(scope="module") @@ -73,45 +91,23 @@ async def engine(self, db_project, db_region, db_instance, db_name): ) yield engine - - @pytest_asyncio.fixture(scope="class") - def engine_sync(self, db_project, db_region, db_instance, db_name): - engine = PostgresEngine.from_instance( - project_id=db_project, - instance=db_instance, - region=db_region, - database=db_name, - ) - yield engine - - @pytest_asyncio.fixture(scope="class") - def vs_sync(self, engine_sync): - engine_sync.init_vectorstore_table(DEFAULT_TABLE_SYNC, VECTOR_SIZE) - - vs = PostgresVectorStore.create_sync( - engine_sync, - embedding_service=embeddings_service, - table_name=DEFAULT_TABLE_SYNC, - ) - yield vs - engine_sync._execute(f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_SYNC}"') - engine_sync._engine.dispose() + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{CUSTOM_TABLE}"') + await engine.close() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): - await engine.ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) - vs = await PostgresVectorStore.create( + await engine._ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) + vs = await AsyncPostgresVectorStore.create( engine, embedding_service=embeddings_service, table_name=DEFAULT_TABLE, ) yield vs - await engine._aexecute(f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') - await engine._engine.dispose() @pytest_asyncio.fixture(scope="class") async def vs_custom(self, engine): - await engine.ainit_vectorstore_table( + await engine._ainit_vectorstore_table( CUSTOM_TABLE, VECTOR_SIZE, id_column="myid", @@ -120,7 +116,7 @@ async def vs_custom(self, engine): metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], metadata_json_column="mymeta", ) - vs = await PostgresVectorStore.create( + vs = await AsyncPostgresVectorStore.create( engine, embedding_service=embeddings_service, table_name=CUSTOM_TABLE, @@ -131,11 +127,10 @@ async def vs_custom(self, engine): metadata_json_column="mymeta", ) yield vs - await engine._aexecute(f'DROP TABLE IF EXISTS "{CUSTOM_TABLE}"') async def test_init_with_constructor(self, engine): with pytest.raises(Exception): - PostgresVectorStore( + AsyncPostgresVectorStore( engine, embedding_service=embeddings_service, table_name=CUSTOM_TABLE, @@ -148,7 +143,7 @@ async def test_init_with_constructor(self, engine): async def test_post_init(self, engine): with pytest.raises(ValueError): - await PostgresVectorStore.create( + await AsyncPostgresVectorStore.create( engine, embedding_service=embeddings_service, table_name=CUSTOM_TABLE, @@ -162,58 +157,55 @@ async def test_post_init(self, engine): async def test_aadd_texts(self, engine, vs): ids = [str(uuid.uuid4()) for i in range(len(texts))] await vs.aadd_texts(texts, ids=ids) - results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"') + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') assert len(results) == 3 ids = [str(uuid.uuid4()) for i in range(len(texts))] await vs.aadd_texts(texts, metadatas, ids) - results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"') + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') assert len(results) == 6 - await engine._aexecute(f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') async def test_aadd_texts_edge_cases(self, engine, vs): texts = ["Taylor's", '"Swift"', "best-friend"] ids = [str(uuid.uuid4()) for i in range(len(texts))] await vs.aadd_texts(texts, ids=ids) - results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"') + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') assert len(results) == 3 - await engine._aexecute(f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') async def test_aadd_docs(self, engine, vs): ids = [str(uuid.uuid4()) for i in range(len(texts))] await vs.aadd_documents(docs, ids=ids) - results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"') - assert len(results) == 3 - await engine._aexecute(f'TRUNCATE TABLE "{DEFAULT_TABLE}"') - - async def test_aadd_embedding(self, engine, vs): - ids = [str(uuid.uuid4()) for i in range(len(texts))] - await vs._aadd_embeddings(texts, embeddings, metadatas, ids) - results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"') + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') assert len(results) == 3 - await engine._aexecute(f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') - async def test_aadd_embedding_without_id(self, engine, vs): - await vs._aadd_embeddings(texts, embeddings, metadatas) - results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"') + async def test_aadd_docs_no_ids(self, engine, vs): + await vs.aadd_documents(docs) + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') assert len(results) == 3 - assert results[0]["langchain_id"] - await engine._aexecute(f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') async def test_adelete(self, engine, vs): ids = [str(uuid.uuid4()) for i in range(len(texts))] await vs.aadd_texts(texts, ids=ids) - results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"') + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') assert len(results) == 3 # delete an ID await vs.adelete([ids[0]]) - results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"') + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') assert len(results) == 2 + # delete with no ids + result = await vs.adelete() + assert result == False + + ##### Custom Vector Store ##### async def test_aadd_texts_custom(self, engine, vs_custom): ids = [str(uuid.uuid4()) for i in range(len(texts))] await vs_custom.aadd_texts(texts, ids=ids) - results = await engine._afetch(f'SELECT * FROM "{CUSTOM_TABLE}"') + results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') assert len(results) == 3 assert results[0]["mycontent"] == "foo" assert results[0]["myembedding"] @@ -222,9 +214,9 @@ async def test_aadd_texts_custom(self, engine, vs_custom): ids = [str(uuid.uuid4()) for i in range(len(texts))] await vs_custom.aadd_texts(texts, metadatas, ids) - results = await engine._afetch(f'SELECT * FROM "{CUSTOM_TABLE}"') + results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') assert len(results) == 6 - await engine._aexecute(f'TRUNCATE TABLE "{CUSTOM_TABLE}"') + await aexecute(engine, f'TRUNCATE TABLE "{CUSTOM_TABLE}"') async def test_aadd_docs_custom(self, engine, vs_custom): ids = [str(uuid.uuid4()) for i in range(len(texts))] @@ -237,51 +229,32 @@ async def test_aadd_docs_custom(self, engine, vs_custom): ] await vs_custom.aadd_documents(docs, ids=ids) - results = await engine._afetch(f'SELECT * FROM "{CUSTOM_TABLE}"') + results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') assert len(results) == 3 assert results[0]["mycontent"] == "foo" assert results[0]["myembedding"] assert results[0]["page"] == "0" assert results[0]["source"] == "google.com" - await engine._aexecute(f'TRUNCATE TABLE "{CUSTOM_TABLE}"') - - async def test_aadd_embedding_custom(self, engine, vs_custom): - ids = [str(uuid.uuid4()) for i in range(len(texts))] - await vs_custom._aadd_embeddings(texts, embeddings, metadatas, ids) - results = await engine._afetch(f'SELECT * FROM "{CUSTOM_TABLE}"') - assert len(results) == 3 - await engine._aexecute(f'TRUNCATE TABLE "{CUSTOM_TABLE}"') + await aexecute(engine, f'TRUNCATE TABLE "{CUSTOM_TABLE}"') async def test_adelete_custom(self, engine, vs_custom): ids = [str(uuid.uuid4()) for i in range(len(texts))] await vs_custom.aadd_texts(texts, ids=ids) - results = await engine._afetch(f'SELECT * FROM "{CUSTOM_TABLE}"') + results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') content = [result["mycontent"] for result in results] assert len(results) == 3 assert "foo" in content # delete an ID await vs_custom.adelete([ids[0]]) - results = await engine._afetch(f'SELECT * FROM "{CUSTOM_TABLE}"') + results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') content = [result["mycontent"] for result in results] assert len(results) == 2 assert "foo" not in content - async def test_add_docs(self, engine_sync, vs_sync): - ids = [str(uuid.uuid4()) for i in range(len(texts))] - vs_sync.add_documents(docs, ids=ids) - results = engine_sync._fetch(f'SELECT * FROM "{DEFAULT_TABLE_SYNC}"') - assert len(results) == 3 - - async def test_add_texts(self, engine_sync, vs_sync): - ids = [str(uuid.uuid4()) for i in range(len(texts))] - vs_sync.add_texts(texts, ids=ids) - results = engine_sync._fetch(f'SELECT * FROM "{DEFAULT_TABLE_SYNC}"') - assert len(results) == 6 - - async def test_ignore_metadata_columns(self, vs_custom): + async def test_ignore_metadata_columns(self, engine): column_to_ignore = "source" - vs = await PostgresVectorStore.create( - vs_custom.engine, + vs = await AsyncPostgresVectorStore.create( + engine, embedding_service=embeddings_service, table_name=CUSTOM_TABLE, ignore_metadata_columns=[column_to_ignore], @@ -292,10 +265,10 @@ async def test_ignore_metadata_columns(self, vs_custom): ) assert column_to_ignore not in vs.metadata_columns - async def test_create_vectorstore_with_invalid_parameters(self, vs_custom): + async def test_create_vectorstore_with_invalid_parameters_1(self, engine): with pytest.raises(ValueError): - await PostgresVectorStore.create( - vs_custom.engine, + await AsyncPostgresVectorStore.create( + engine, embedding_service=embeddings_service, table_name=CUSTOM_TABLE, id_column="myid", @@ -303,9 +276,11 @@ async def test_create_vectorstore_with_invalid_parameters(self, vs_custom): embedding_column="myembedding", metadata_columns=["random_column"], # invalid metadata column ) + + async def test_create_vectorstore_with_invalid_parameters_2(self, engine): with pytest.raises(ValueError): - await PostgresVectorStore.create( - vs_custom.engine, + await AsyncPostgresVectorStore.create( + engine, embedding_service=embeddings_service, table_name=CUSTOM_TABLE, id_column="myid", @@ -313,9 +288,11 @@ async def test_create_vectorstore_with_invalid_parameters(self, vs_custom): embedding_column="myembedding", metadata_columns=["random_column"], ) + + async def test_create_vectorstore_with_invalid_parameters_3(self, engine): with pytest.raises(ValueError): - await PostgresVectorStore.create( - vs_custom.engine, + await AsyncPostgresVectorStore.create( + engine, embedding_service=embeddings_service, table_name=CUSTOM_TABLE, id_column="myid", @@ -323,9 +300,11 @@ async def test_create_vectorstore_with_invalid_parameters(self, vs_custom): embedding_column="random_column", # invalid embedding column metadata_columns=["random_column"], ) + + async def test_create_vectorstore_with_invalid_parameters_4(self, engine): with pytest.raises(ValueError): - await PostgresVectorStore.create( - vs_custom.engine, + await AsyncPostgresVectorStore.create( + engine, embedding_service=embeddings_service, table_name=CUSTOM_TABLE, id_column="myid", @@ -334,4 +313,30 @@ async def test_create_vectorstore_with_invalid_parameters(self, vs_custom): metadata_columns=["random_column"], ) - # Need tests for store metadata=False + async def test_create_vectorstore_with_invalid_parameters_5(self, engine): + with pytest.raises(ValueError): + await AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="mycontent", + embedding_column="langchain_id", + metadata_columns=["random_column"], + ignore_metadata_columns=[ + "one", + "two", + ], # invalid use of metadata_columns and ignore columns + ) + + async def test_create_vectorstore_with_init(self, engine): + with pytest.raises(Exception): + await AsyncPostgresVectorStore( + engine._pool, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["random_column"], # invalid metadata column + ) diff --git a/tests/test_cloudsql_vectorstore_from_methods.py b/tests/test_async_vectorstore_from_methods.py similarity index 62% rename from tests/test_cloudsql_vectorstore_from_methods.py rename to tests/test_async_vectorstore_from_methods.py index e7e143cb..59274f6a 100644 --- a/tests/test_cloudsql_vectorstore_from_methods.py +++ b/tests/test_async_vectorstore_from_methods.py @@ -14,17 +14,24 @@ import os import uuid +from typing import Sequence import pytest import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding +from sqlalchemy import text +from sqlalchemy.engine.row import RowMapping -from langchain_google_cloud_sql_pg import Column, PostgresEngine, PostgresVectorStore +from langchain_google_cloud_sql_pg import Column, PostgresEngine +from langchain_google_cloud_sql_pg.async_vectorstore import AsyncPostgresVectorStore DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_TABLE_WITH_INT_ID = "test_table_custom_with_int_it" + str(uuid.uuid4()).replace( + "-", "_" +) VECTOR_SIZE = 768 @@ -46,6 +53,20 @@ def get_env_var(key: str, desc: str) -> str: return v +async def aexecute(engine: PostgresEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + @pytest.mark.asyncio class TestVectorStoreFromMethods: @pytest.fixture(scope="module") @@ -72,8 +93,8 @@ async def engine(self, db_project, db_region, db_instance, db_name): region=db_region, database=db_name, ) - await engine.ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) - await engine.ainit_vectorstore_table( + await engine._ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) + await engine._ainit_vectorstore_table( CUSTOM_TABLE, VECTOR_SIZE, id_column="myid", @@ -82,29 +103,24 @@ async def engine(self, db_project, db_region, db_instance, db_name): metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], store_metadata=False, ) - yield engine - await engine._aexecute(f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") - await engine._aexecute(f"DROP TABLE IF EXISTS {CUSTOM_TABLE}") - await engine._engine.dispose() - - @pytest_asyncio.fixture - def engine_sync(self, db_project, db_region, db_instance, db_name): - engine = PostgresEngine.from_instance( - project_id=db_project, - instance=db_instance, - region=db_region, - database=db_name, + await engine._ainit_vectorstore_table( + CUSTOM_TABLE_WITH_INT_ID, + VECTOR_SIZE, + id_column=Column(name="integer_id", data_type="INTEGER", nullable="False"), + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + store_metadata=False, ) - engine.init_vectorstore_table(DEFAULT_TABLE_SYNC, VECTOR_SIZE) - yield engine - engine._execute(f"DROP TABLE IF EXISTS {DEFAULT_TABLE_SYNC}") - - engine._engine.dispose() + await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") + await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE}") + await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE_WITH_INT_ID}") + await engine.close() async def test_afrom_texts(self, engine): ids = [str(uuid.uuid4()) for i in range(len(texts))] - await PostgresVectorStore.afrom_texts( + await AsyncPostgresVectorStore.afrom_texts( texts, embeddings_service, engine, @@ -112,53 +128,26 @@ async def test_afrom_texts(self, engine): metadatas=metadatas, ids=ids, ) - results = await engine._afetch(f"SELECT * FROM {DEFAULT_TABLE}") - assert len(results) == 3 - await engine._aexecute(f"TRUNCATE TABLE {DEFAULT_TABLE}") - - async def test_from_texts(self, engine_sync): - ids = [str(uuid.uuid4()) for i in range(len(texts))] - PostgresVectorStore.from_texts( - texts, - embeddings_service, - engine_sync, - DEFAULT_TABLE_SYNC, - metadatas=metadatas, - ids=ids, - ) - results = engine_sync._fetch(f"SELECT * FROM {DEFAULT_TABLE_SYNC}") + results = await afetch(engine, f"SELECT * FROM {DEFAULT_TABLE}") assert len(results) == 3 - engine_sync._execute(f"TRUNCATE TABLE {DEFAULT_TABLE_SYNC}") + await aexecute(engine, f"TRUNCATE TABLE {DEFAULT_TABLE}") async def test_afrom_docs(self, engine): ids = [str(uuid.uuid4()) for i in range(len(texts))] - await PostgresVectorStore.afrom_documents( + await AsyncPostgresVectorStore.afrom_documents( docs, embeddings_service, engine, DEFAULT_TABLE, ids=ids, ) - results = await engine._afetch(f"SELECT * FROM {DEFAULT_TABLE}") + results = await afetch(engine, f"SELECT * FROM {DEFAULT_TABLE}") assert len(results) == 3 - await engine._aexecute(f"TRUNCATE TABLE {DEFAULT_TABLE}") - - async def test_from_docs(self, engine_sync): - ids = [str(uuid.uuid4()) for i in range(len(texts))] - PostgresVectorStore.from_documents( - docs, - embeddings_service, - engine_sync, - DEFAULT_TABLE_SYNC, - ids=ids, - ) - results = engine_sync._fetch(f"SELECT * FROM {DEFAULT_TABLE_SYNC}") - assert len(results) == 3 - engine_sync._execute(f"TRUNCATE TABLE {DEFAULT_TABLE_SYNC}") + await aexecute(engine, f"TRUNCATE TABLE {DEFAULT_TABLE}") async def test_afrom_texts_custom(self, engine): ids = [str(uuid.uuid4()) for i in range(len(texts))] - await PostgresVectorStore.afrom_texts( + await AsyncPostgresVectorStore.afrom_texts( texts, embeddings_service, engine, @@ -169,7 +158,7 @@ async def test_afrom_texts_custom(self, engine): embedding_column="myembedding", metadata_columns=["page", "source"], ) - results = await engine._afetch(f"SELECT * FROM {CUSTOM_TABLE}") + results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE}") assert len(results) == 3 assert results[0]["mycontent"] == "foo" assert results[0]["myembedding"] @@ -185,7 +174,7 @@ async def test_afrom_docs_custom(self, engine): ) for i in range(len(texts)) ] - await PostgresVectorStore.afrom_documents( + await AsyncPostgresVectorStore.afrom_documents( docs, embeddings_service, engine, @@ -197,10 +186,37 @@ async def test_afrom_docs_custom(self, engine): metadata_columns=["page", "source"], ) - results = await engine._afetch(f"SELECT * FROM {CUSTOM_TABLE}") + results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE}") assert len(results) == 3 assert results[0]["mycontent"] == "foo" assert results[0]["myembedding"] assert results[0]["page"] == "0" assert results[0]["source"] == "google.com" - await engine._aexecute(f"TRUNCATE TABLE {CUSTOM_TABLE}") + await aexecute(engine, f"TRUNCATE TABLE {CUSTOM_TABLE}") + + async def test_afrom_docs_custom_with_int_id(self, engine): + ids = [i for i in range(len(texts))] + docs = [ + Document( + page_content=texts[i], + metadata={"page": str(i), "source": "google.com"}, + ) + for i in range(len(texts)) + ] + await AsyncPostgresVectorStore.afrom_documents( + docs, + embeddings_service, + engine, + CUSTOM_TABLE_WITH_INT_ID, + ids=ids, + id_column="integer_id", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["page", "source"], + ) + + results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE_WITH_INT_ID}") + assert len(results) == 3 + for row in results: + assert isinstance(row["integer_id"], int) + await aexecute(engine, f"TRUNCATE TABLE {CUSTOM_TABLE_WITH_INT_ID}") diff --git a/tests/test_cloudsql_vectorstore_index.py b/tests/test_async_vectorstore_index.py similarity index 88% rename from tests/test_cloudsql_vectorstore_index.py rename to tests/test_async_vectorstore_index.py index 10baf13a..a3ff8c12 100644 --- a/tests/test_cloudsql_vectorstore_index.py +++ b/tests/test_async_vectorstore_index.py @@ -21,8 +21,10 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding +from sqlalchemy import text -from langchain_google_cloud_sql_pg import PostgresEngine, PostgresVectorStore +from langchain_google_cloud_sql_pg import PostgresEngine +from langchain_google_cloud_sql_pg.async_vectorstore import AsyncPostgresVectorStore from langchain_google_cloud_sql_pg.indexes import ( DEFAULT_INDEX_NAME_SUFFIX, DistanceStrategy, @@ -54,6 +56,12 @@ def get_env_var(key: str, desc: str) -> str: return v +async def aexecute(engine: PostgresEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + @pytest.mark.asyncio(scope="class") class TestIndex: @pytest.fixture(scope="module") @@ -81,11 +89,13 @@ async def engine(self, db_project, db_region, db_instance, db_name): database=db_name, ) yield engine + await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") + await engine.close() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): - await engine.ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) - vs = await PostgresVectorStore.create( + await engine._ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) + vs = await AsyncPostgresVectorStore.create( engine, embedding_service=embeddings_service, table_name=DEFAULT_TABLE, @@ -94,16 +104,12 @@ async def vs(self, engine): await vs.aadd_texts(texts, ids=ids) await vs.adrop_vector_index() yield vs - await engine._aexecute(f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") - await engine._engine.dispose() - @pytest.mark.run(order=1) async def test_aapply_vector_index(self, vs): index = HNSWIndex() await vs.aapply_vector_index(index) assert await vs.is_valid_index(DEFAULT_INDEX_NAME) - @pytest.mark.run(order=2) async def test_areindex(self, vs): if not await vs.is_valid_index(DEFAULT_INDEX_NAME): index = HNSWIndex() @@ -112,7 +118,6 @@ async def test_areindex(self, vs): await vs.areindex(DEFAULT_INDEX_NAME) assert await vs.is_valid_index(DEFAULT_INDEX_NAME) - @pytest.mark.run(order=3) async def test_dropindex(self, vs): await vs.adrop_vector_index() result = await vs.is_valid_index(DEFAULT_INDEX_NAME) diff --git a/tests/test_async_vectorstore_search.py b/tests/test_async_vectorstore_search.py new file mode 100644 index 00000000..d918415a --- /dev/null +++ b/tests/test_async_vectorstore_search.py @@ -0,0 +1,270 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid + +import pytest +import pytest_asyncio +from langchain_core.documents import Document +from langchain_core.embeddings import DeterministicFakeEmbedding +from sqlalchemy import text + +from langchain_google_cloud_sql_pg import Column, PostgresEngine +from langchain_google_cloud_sql_pg.async_vectorstore import AsyncPostgresVectorStore +from langchain_google_cloud_sql_pg.indexes import DistanceStrategy, HNSWQueryOptions + +DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") +VECTOR_SIZE = 768 + +embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) + +texts = ["foo", "bar", "baz", "boo"] +ids = [str(uuid.uuid4()) for i in range(len(texts))] +metadatas = [{"page": str(i), "source": "google.com"} for i in range(len(texts))] +docs = [ + Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) +] + +embeddings = [embeddings_service.embed_query("foo") for i in range(len(texts))] + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +async def aexecute( + engine: PostgresEngine, + query: str, +) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +@pytest.mark.asyncio(scope="class") +class TestVectorStoreSearch: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for cloud sql instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for cloud sql") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "instance for cloud sql") + + @pytest_asyncio.fixture(scope="class") + async def engine(self, db_project, db_region, db_instance, db_name): + engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + yield engine + await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") + await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE}") + await engine.close() + + @pytest_asyncio.fixture(scope="class") + async def vs(self, engine): + await engine._ainit_vectorstore_table( + DEFAULT_TABLE, VECTOR_SIZE, store_metadata=False + ) + vs = await AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE, + ) + ids = [str(uuid.uuid4()) for i in range(len(texts))] + await vs.aadd_documents(docs, ids=ids) + yield vs + + @pytest_asyncio.fixture(scope="class") + async def vs_custom(self, engine): + await engine._ainit_vectorstore_table( + CUSTOM_TABLE, + VECTOR_SIZE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[ + Column("page", "TEXT"), + Column("source", "TEXT"), + ], + store_metadata=False, + ) + + vs_custom = await AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + index_query_options=HNSWQueryOptions(ef_search=1), + ) + await vs_custom.aadd_documents(docs, ids=ids) + yield vs_custom + + async def test_asimilarity_search(self, vs): + results = await vs.asimilarity_search("foo", k=1) + assert len(results) == 1 + assert results == [Document(page_content="foo")] + results = await vs.asimilarity_search("foo", k=1, filter="content = 'bar'") + assert results == [Document(page_content="bar")] + + async def test_asimilarity_search_score(self, vs): + results = await vs.asimilarity_search_with_score("foo") + assert len(results) == 4 + assert results[0][0] == Document(page_content="foo") + assert results[0][1] == 0 + + async def test_asimilarity_search_by_vector(self, vs): + embedding = embeddings_service.embed_query("foo") + results = await vs.asimilarity_search_by_vector(embedding) + assert len(results) == 4 + assert results[0] == Document(page_content="foo") + results = await vs.asimilarity_search_with_score_by_vector(embedding) + assert results[0][0] == Document(page_content="foo") + assert results[0][1] == 0 + + async def test_similarity_search_with_relevance_scores_threshold_cosine(self, vs): + score_threshold = {"score_threshold": 0} + results = await vs.asimilarity_search_with_relevance_scores( + "foo", **score_threshold + ) + assert len(results) == 4 + + score_threshold = {"score_threshold": 0.02} + results = await vs.asimilarity_search_with_relevance_scores( + "foo", **score_threshold + ) + assert len(results) == 2 + + score_threshold = {"score_threshold": 0.9} + results = await vs.asimilarity_search_with_relevance_scores( + "foo", **score_threshold + ) + assert len(results) == 1 + assert results[0][0] == Document(page_content="foo") + + score_threshold = {"score_threshold": 0.02} + vs.distance_strategy = DistanceStrategy.EUCLIDEAN + results = await vs.asimilarity_search_with_relevance_scores( + "foo", **score_threshold + ) + assert len(results) == 1 + + async def test_similarity_search_with_relevance_scores_threshold_euclidean( + self, engine + ): + vs = await AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE, + distance_strategy=DistanceStrategy.EUCLIDEAN, + ) + + score_threshold = {"score_threshold": 0.9} + results = await vs.asimilarity_search_with_relevance_scores( + "foo", **score_threshold + ) + assert len(results) == 1 + assert results[0][0] == Document(page_content="foo") + + async def test_amax_marginal_relevance_search(self, vs): + results = await vs.amax_marginal_relevance_search("bar") + assert results[0] == Document(page_content="bar") + results = await vs.amax_marginal_relevance_search( + "bar", filter="content = 'boo'" + ) + assert results[0] == Document(page_content="boo") + + async def test_amax_marginal_relevance_search_vector(self, vs): + embedding = embeddings_service.embed_query("bar") + results = await vs.amax_marginal_relevance_search_by_vector(embedding) + assert results[0] == Document(page_content="bar") + + async def test_amax_marginal_relevance_search_vector_score(self, vs): + embedding = embeddings_service.embed_query("bar") + results = await vs.amax_marginal_relevance_search_with_score_by_vector( + embedding + ) + assert results[0][0] == Document(page_content="bar") + + results = await vs.amax_marginal_relevance_search_with_score_by_vector( + embedding, lambda_mult=0.75, fetch_k=10 + ) + assert results[0][0] == Document(page_content="bar") + + async def test_similarity_search(self, vs_custom): + results = await vs_custom.asimilarity_search("foo", k=1) + assert len(results) == 1 + assert results == [Document(page_content="foo")] + results = await vs_custom.asimilarity_search( + "foo", k=1, filter="mycontent = 'bar'" + ) + assert results == [Document(page_content="bar")] + + async def test_similarity_search_score(self, vs_custom): + results = await vs_custom.asimilarity_search_with_score("foo") + assert len(results) == 4 + assert results[0][0] == Document(page_content="foo") + assert results[0][1] == 0 + + async def test_similarity_search_by_vector(self, vs_custom): + embedding = embeddings_service.embed_query("foo") + results = await vs_custom.asimilarity_search_by_vector(embedding) + assert len(results) == 4 + assert results[0] == Document(page_content="foo") + results = await vs_custom.asimilarity_search_with_score_by_vector(embedding) + assert results[0][0] == Document(page_content="foo") + assert results[0][1] == 0 + + async def test_max_marginal_relevance_search(self, vs_custom): + results = await vs_custom.amax_marginal_relevance_search("bar") + assert results[0] == Document(page_content="bar") + results = await vs_custom.amax_marginal_relevance_search( + "bar", filter="mycontent = 'boo'" + ) + assert results[0] == Document(page_content="boo") + + async def test_max_marginal_relevance_search_vector(self, vs_custom): + embedding = embeddings_service.embed_query("bar") + results = await vs_custom.amax_marginal_relevance_search_by_vector(embedding) + assert results[0] == Document(page_content="bar") + + async def test_max_marginal_relevance_search_vector_score(self, vs_custom): + embedding = embeddings_service.embed_query("bar") + results = await vs_custom.amax_marginal_relevance_search_with_score_by_vector( + embedding + ) + assert results[0][0] == Document(page_content="bar") + + results = await vs_custom.amax_marginal_relevance_search_with_score_by_vector( + embedding, lambda_mult=0.75, fetch_k=10 + ) + assert results[0][0] == Document(page_content="bar") diff --git a/tests/test_postgresql_chatmessagehistory.py b/tests/test_chatmessagehistory.py similarity index 60% rename from tests/test_postgresql_chatmessagehistory.py rename to tests/test_chatmessagehistory.py index ea0b85ee..b0a9420a 100644 --- a/tests/test_postgresql_chatmessagehistory.py +++ b/tests/test_chatmessagehistory.py @@ -19,6 +19,7 @@ import pytest_asyncio from langchain_core.messages.ai import AIMessage from langchain_core.messages.human import HumanMessage +from sqlalchemy import text from langchain_google_cloud_sql_pg import PostgresChatMessageHistory, PostgresEngine @@ -28,10 +29,24 @@ db_name = os.environ["DATABASE_ID"] table_name = "message_store" + str(uuid.uuid4()) table_name_async = "message_store" + str(uuid.uuid4()) +user = os.environ["DB_USER"] +password = os.environ["DB_PASSWORD"] -@pytest.fixture(name="memory_engine") -def setup() -> Generator: +async def aexecute( + engine: PostgresEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +@pytest_asyncio.fixture +async def engine(): engine = PostgresEngine.from_instance( project_id=project_id, region=region, @@ -42,27 +57,29 @@ def setup() -> Generator: yield engine # use default table for PostgresChatMessageHistory query = f'DROP TABLE IF EXISTS "{table_name}"' - engine._execute(query) + await aexecute(engine, query) + await engine.close() @pytest_asyncio.fixture async def async_engine(): - engine = await PostgresEngine.afrom_instance( + async_engine = await PostgresEngine.afrom_instance( project_id=project_id, region=region, instance=instance_id, database=db_name, ) - await engine.ainit_chat_history_table(table_name=table_name_async) - yield engine + await async_engine.ainit_chat_history_table(table_name=table_name_async) + yield async_engine # use default table for PostgresChatMessageHistory - query = f'DROP TABLE IF EXISTS "{table_name}"' - await engine._aexecute(query) + query = f'DROP TABLE IF EXISTS "{table_name_async}"' + await aexecute(async_engine, query) + await async_engine.close() -def test_chat_message_history(memory_engine: PostgresEngine) -> None: +def test_chat_message_history(engine: PostgresEngine) -> None: history = PostgresChatMessageHistory.create_sync( - engine=memory_engine, session_id="test", table_name=table_name + engine=engine, session_id="test", table_name=table_name ) history.add_user_message("hi!") history.add_ai_message("whats up?") @@ -79,23 +96,24 @@ def test_chat_message_history(memory_engine: PostgresEngine) -> None: assert len(history.messages) == 0 -def test_chat_table(memory_engine: Any) -> None: +def test_chat_table(engine: Any) -> None: with pytest.raises(ValueError): PostgresChatMessageHistory.create_sync( - engine=memory_engine, session_id="test", table_name="doesnotexist" + engine=engine, session_id="test", table_name="doesnotexist" ) -def test_chat_schema(memory_engine: Any) -> None: +@pytest.mark.asyncio +async def test_chat_schema(engine: Any) -> None: doc_table_name = "test_table" + str(uuid.uuid4()) - memory_engine.init_document_table(table_name=doc_table_name) + engine.init_document_table(table_name=doc_table_name) with pytest.raises(IndexError): PostgresChatMessageHistory.create_sync( - engine=memory_engine, session_id="test", table_name=doc_table_name + engine=engine, session_id="test", table_name=doc_table_name ) query = f'DROP TABLE IF EXISTS "{doc_table_name}"' - memory_engine._execute(query) + await aexecute(engine, query) @pytest.mark.asyncio @@ -137,11 +155,8 @@ async def test_chat_message_history_sync_messages( await history1.aadd_message(msg1) await history2.aadd_message(msg2) - assert len(history1.messages) == 1 - assert len(history2.messages) == 2 - - await history1.async_messages() assert len(history1.messages) == 2 + assert len(history2.messages) == 2 # verify clear() clears message history await history2.aclear() @@ -166,4 +181,51 @@ async def test_chat_schema_async(async_engine): ) query = f'DROP TABLE IF EXISTS "{table_name}"' - await async_engine._aexecute(query) + await aexecute(async_engine, query) + + +@pytest.mark.asyncio +async def test_cross_env_chat_message_history(engine): + history = PostgresChatMessageHistory.create_sync( + engine=engine, session_id="test_cross", table_name=table_name + ) + await history.aadd_message(HumanMessage(content="hi!")) + messages = history.messages + assert messages[0].content == "hi!" + history.clear() + + history = await PostgresChatMessageHistory.create( + engine=engine, session_id="test_cross", table_name=table_name + ) + history.add_message(HumanMessage(content="hi!")) + messages = history.messages + assert messages[0].content == "hi!" + history.clear() + + +@pytest.mark.asyncio +async def test_from_engine_args_url(): + host = os.environ["IP_ADDRESS"] + port = "5432" + url = f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{db_name}" + engine = PostgresEngine.from_engine_args(url) + table_name = "test_table" + str(uuid.uuid4()).replace("-", "_") + await engine.ainit_chat_history_table(table_name) + + history = PostgresChatMessageHistory.create_sync( + engine=engine, session_id="test_cross", table_name=table_name + ) + await history.aadd_message(HumanMessage(content="hi!")) + history.add_message(HumanMessage(content="bye!")) + assert len(history.messages) == 2 + await history.aclear() + + history2 = await PostgresChatMessageHistory.create( + engine=engine, session_id="test_cross", table_name=table_name + ) + await history2.aadd_message(HumanMessage(content="hi!")) + history2.add_message(HumanMessage(content="bye!")) + assert len(history2.messages) == 2 + history2.clear() + + await aexecute(engine, f"DROP TABLE {table_name}") diff --git a/tests/test_postgresql_engine.py b/tests/test_engine.py similarity index 59% rename from tests/test_postgresql_engine.py rename to tests/test_engine.py index 9ccd43a0..5e117b0e 100644 --- a/tests/test_postgresql_engine.py +++ b/tests/test_engine.py @@ -14,22 +14,33 @@ import os import uuid +from typing import Sequence import asyncpg # type: ignore import pytest import pytest_asyncio from google.cloud.sql.connector import Connector, IPTypes from langchain_core.embeddings import DeterministicFakeEmbedding -from sqlalchemy import VARCHAR +from sqlalchemy import VARCHAR, text +from sqlalchemy.engine import URL +from sqlalchemy.engine.row import RowMapping from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.pool import NullPool from langchain_google_cloud_sql_pg import Column, PostgresEngine DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") +INT_ID_CUSTOM_TABLE = "test_table_custom_int_id" + str(uuid.uuid4()).replace("-", "_") +DEFAULT_TABLE_SYNC = "test_table" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_TABLE_SYNC = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") +INT_ID_CUSTOM_TABLE_SYNC = "test_table_custom_int_id" + str(uuid.uuid4()).replace( + "-", "_" +) VECTOR_SIZE = 768 embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) +host = os.environ["IP_ADDRESS"] def get_env_var(key: str, desc: str) -> str: @@ -39,7 +50,30 @@ def get_env_var(key: str, desc: str) -> str: return v -@pytest.mark.asyncio +async def aexecute( + engine: PostgresEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async def run(engine, query): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + return await engine._run_as_async(run(engine, query)) + + +@pytest.mark.asyncio(scope="module") class TestEngineAsync: @pytest.fixture(scope="module") def db_project(self) -> str: @@ -69,7 +103,7 @@ def password(self) -> str: def iam_account(self) -> str: return get_env_var("IAM_ACCOUNT", "Cloud SQL IAM account email") - @pytest_asyncio.fixture + @pytest_asyncio.fixture(scope="class") async def engine(self, db_project, db_region, db_instance, db_name): engine = await PostgresEngine.afrom_instance( project_id=db_project, @@ -78,10 +112,10 @@ async def engine(self, db_project, db_region, db_instance, db_name): database=db_name, ) yield engine - await engine._engine.dispose() - - async def test_execute(self, engine): - await engine._aexecute("SELECT 1") + await aexecute(engine, f'DROP TABLE "{CUSTOM_TABLE}"') + await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE}"') + await aexecute(engine, f'DROP TABLE "{INT_ID_CUSTOM_TABLE}"') + await engine.close() async def test_init_table(self, engine): await engine.ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) @@ -89,12 +123,7 @@ async def test_init_table(self, engine): content = "coffee" embedding = await embeddings_service.aembed_query(content) stmt = f"INSERT INTO {DEFAULT_TABLE} (langchain_id, content, embedding) VALUES ('{id}', '{content}','{embedding}');" - await engine._aexecute(stmt) - - async def test_fetch(self, engine): - results = await engine._afetch(f"SELECT * FROM {DEFAULT_TABLE}") - assert len(results) > 0 - await engine._aexecute(f"DROP TABLE {DEFAULT_TABLE}") + await aexecute(engine, stmt) async def test_init_table_custom(self, engine): await engine.ainit_vectorstore_table( @@ -107,7 +136,7 @@ async def test_init_table_custom(self, engine): store_metadata=True, ) stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{CUSTOM_TABLE}';" - results = await engine._afetch(stmt) + results = await afetch(engine, stmt) expected = [ {"column_name": "uuid", "data_type": "uuid"}, {"column_name": "my_embedding", "data_type": "USER-DEFINED"}, @@ -119,7 +148,28 @@ async def test_init_table_custom(self, engine): for row in results: assert row in expected - await engine._aexecute(f"DROP TABLE {CUSTOM_TABLE}") + async def test_init_table_with_int_id(self, engine): + await engine.ainit_vectorstore_table( + INT_ID_CUSTOM_TABLE, + VECTOR_SIZE, + id_column=Column(name="integer_id", data_type="INTEGER", nullable="False"), + content_column="my-content", + embedding_column="my_embedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + store_metadata=True, + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{INT_ID_CUSTOM_TABLE}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "integer_id", "data_type": "integer"}, + {"column_name": "my_embedding", "data_type": "USER-DEFINED"}, + {"column_name": "langchain_metadata", "data_type": "json"}, + {"column_name": "my-content", "data_type": "text"}, + {"column_name": "page", "data_type": "text"}, + {"column_name": "source", "data_type": "text"}, + ] + for row in results: + assert row in expected async def test_password( self, @@ -140,7 +190,7 @@ async def test_password( password=password, ) assert engine - await engine._aexecute("SELECT 1") + await aexecute(engine, "SELECT 1") PostgresEngine._connector = None async def test_from_engine( @@ -172,7 +222,49 @@ async def getconn() -> asyncpg.Connection: ) engine = PostgresEngine.from_engine(engine) - await engine._aexecute("SELECT 1") + await aexecute(engine, "SELECT 1") + await engine.close() + + async def test_from_engine_args_url( + self, + db_name, + user, + password, + ): + port = "5432" + url = f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{db_name}" + engine = PostgresEngine.from_engine_args( + url, + echo=True, + poolclass=NullPool, + ) + await aexecute(engine, "SELECT 1") + await engine.close() + + engine = PostgresEngine.from_engine_args( + URL.create("postgresql+asyncpg", user, password, host, port, db_name) + ) + await aexecute(engine, "SELECT 1") + await engine.close() + + async def test_from_engine_args_url_error( + self, + db_name, + user, + password, + ): + port = "5432" + url = f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{db_name}" + with pytest.raises(TypeError): + engine = PostgresEngine.from_engine_args(url, random=False) + with pytest.raises(ValueError): + PostgresEngine.from_engine_args( + f"postgresql+pg8000://{user}:{password}@{host}:{port}/{db_name}", + ) + with pytest.raises(ValueError): + PostgresEngine.from_engine_args( + URL.create("postgresql+pg8000", user, password, host, port, db_name) + ) async def test_column(self, engine): with pytest.raises(ValueError): @@ -187,6 +279,7 @@ async def test_iam_account_override( db_region, db_name, iam_account, + engine, ): engine = await PostgresEngine.afrom_instance( project_id=db_project, @@ -196,12 +289,11 @@ async def test_iam_account_override( iam_account_email=iam_account, ) assert engine - await engine._aexecute("SELECT 1") - await engine._connector.close_async() - await engine._engine.dispose() + await aexecute(engine, "SELECT 1") + await engine.close() -@pytest.mark.asyncio +@pytest.mark.asyncio(scope="module") class TestEngineSync: @pytest.fixture(scope="module") def db_project(self) -> str: @@ -231,8 +323,8 @@ def password(self) -> str: def iam_account(self) -> str: return get_env_var("IAM_ACCOUNT", "Cloud SQL IAM account email") - @pytest_asyncio.fixture - def engine(self, db_project, db_region, db_instance, db_name): + @pytest_asyncio.fixture(scope="class") + async def engine(self, db_project, db_region, db_instance, db_name): engine = PostgresEngine.from_instance( project_id=db_project, instance=db_instance, @@ -240,27 +332,22 @@ def engine(self, db_project, db_region, db_instance, db_name): database=db_name, ) yield engine - engine._engine.dispose() - - async def test_execute(self, engine): - engine._execute("SELECT 1") + await aexecute(engine, f'DROP TABLE "{CUSTOM_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE "{INT_ID_CUSTOM_TABLE_SYNC}"') + await engine.close() async def test_init_table(self, engine): - engine.init_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) + engine.init_vectorstore_table(DEFAULT_TABLE_SYNC, VECTOR_SIZE) id = str(uuid.uuid4()) content = "coffee" embedding = await embeddings_service.aembed_query(content) - stmt = f"INSERT INTO {DEFAULT_TABLE} (langchain_id, content, embedding) VALUES ('{id}', '{content}','{embedding}');" - engine._execute(stmt) - - async def test_fetch(self, engine): - results = engine._fetch(f"SELECT * FROM {DEFAULT_TABLE}") - assert len(results) > 0 - engine._execute(f"DROP TABLE {DEFAULT_TABLE}") + stmt = f"INSERT INTO {DEFAULT_TABLE_SYNC} (langchain_id, content, embedding) VALUES ('{id}', '{content}','{embedding}');" + await aexecute(engine, stmt) async def test_init_table_custom(self, engine): engine.init_vectorstore_table( - CUSTOM_TABLE, + CUSTOM_TABLE_SYNC, VECTOR_SIZE, id_column="uuid", content_column="my-content", @@ -268,8 +355,8 @@ async def test_init_table_custom(self, engine): metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], store_metadata=True, ) - stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{CUSTOM_TABLE}';" - results = engine._fetch(stmt) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{CUSTOM_TABLE_SYNC}';" + results = await afetch(engine, stmt) expected = [ {"column_name": "uuid", "data_type": "uuid"}, {"column_name": "my_embedding", "data_type": "USER-DEFINED"}, @@ -281,7 +368,28 @@ async def test_init_table_custom(self, engine): for row in results: assert row in expected - engine._execute(f"DROP TABLE {CUSTOM_TABLE}") + async def test_init_table_with_int_id(self, engine): + engine.init_vectorstore_table( + INT_ID_CUSTOM_TABLE_SYNC, + VECTOR_SIZE, + id_column=Column(name="integer_id", data_type="INTEGER", nullable=False), + content_column="my-content", + embedding_column="my_embedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + store_metadata=True, + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{INT_ID_CUSTOM_TABLE_SYNC}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "integer_id", "data_type": "integer"}, + {"column_name": "my_embedding", "data_type": "USER-DEFINED"}, + {"column_name": "langchain_metadata", "data_type": "json"}, + {"column_name": "my-content", "data_type": "text"}, + {"column_name": "page", "data_type": "text"}, + {"column_name": "source", "data_type": "text"}, + ] + for row in results: + assert row in expected async def test_password( self, @@ -303,7 +411,7 @@ async def test_password( quota_project=db_project, ) assert engine - engine._execute("SELECT 1") + await aexecute(engine, "SELECT 1") PostgresEngine._connector = None async def test_engine_constructor_key( @@ -314,13 +422,14 @@ async def test_engine_constructor_key( with pytest.raises(Exception): PostgresEngine(key, engine) - def test_iam_account_override( + async def test_iam_account_override( self, db_project, db_instance, db_region, db_name, iam_account, + engine, ): engine = PostgresEngine.from_instance( project_id=db_project, @@ -330,6 +439,5 @@ def test_iam_account_override( iam_account_email=iam_account, ) assert engine - engine._execute("SELECT 1") - engine._connector.close() - engine._engine.dispose() + await aexecute(engine, "SELECT 1") + await engine.close() diff --git a/tests/test_postgresql_loader.py b/tests/test_loader.py similarity index 93% rename from tests/test_postgresql_loader.py rename to tests/test_loader.py index 8f4f5e35..87ac979c 100644 --- a/tests/test_postgresql_loader.py +++ b/tests/test_loader.py @@ -19,6 +19,7 @@ import pytest import pytest_asyncio from langchain_core.documents import Document +from sqlalchemy import text from langchain_google_cloud_sql_pg import ( Column, @@ -34,6 +35,18 @@ table_name = "test-table" + str(uuid.uuid4()) +async def aexecute( + engine: PostgresEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + @pytest.mark.asyncio class TestLoaderAsync: @pytest_asyncio.fixture @@ -67,7 +80,7 @@ async def _collect_async_items(self, docs_generator): async def _cleanup_table(self, engine): query = f'DROP TABLE IF EXISTS "{table_name}"' - await engine._aexecute(query) + await aexecute(engine, query) async def test_create_loader_with_invalid_parameters(self, engine): with pytest.raises(ValueError): @@ -105,14 +118,14 @@ async def test_load_from_query_default(self, engine): organic INT NOT NULL ) """ - await engine._aexecute(query) + await aexecute(engine, query) insert_query = f""" INSERT INTO "{table_name}" ( fruit_name, variety, quantity_in_stock, price_per_unit, organic ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); """ - await engine._aexecute(insert_query) + await aexecute(engine, insert_query) loader = await PostgresLoader.create( engine=engine, @@ -152,7 +165,7 @@ async def test_load_from_query_customized_content_customized_metadata(self, engi organic INT NOT NULL ) """ - await engine._aexecute(query) + await aexecute(engine, query) insert_query = f""" INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) @@ -160,7 +173,7 @@ async def test_load_from_query_customized_content_customized_metadata(self, engi ('Banana', 'Cavendish', 200, 0.59, 0), ('Orange', 'Navel', 80, 1.29, 1); """ - await engine._aexecute(insert_query) + await aexecute(engine, insert_query) loader = await PostgresLoader.create( engine=engine, @@ -208,13 +221,13 @@ async def test_load_from_query_customized_content_default_metadata(self, engine) organic INT NOT NULL ) """ - await engine._aexecute(query) + await aexecute(engine, query) insert_query = f""" INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) VALUES ('Apple', 'Granny Smith', 150, 1, 1); """ - await engine._aexecute(insert_query) + await aexecute(engine, insert_query) loader = await PostgresLoader.create( engine=engine, @@ -226,7 +239,9 @@ async def test_load_from_query_customized_content_default_metadata(self, engine) ], ) - documents = await self._collect_async_items(loader.alazy_load()) + documents = [] + for docs in loader.lazy_load(): + documents.append(docs) assert documents == [ Document( @@ -279,7 +294,7 @@ async def test_load_from_query_default_content_customized_metadata(self, engine) organic INT NOT NULL ) """ - await engine._aexecute(query) + await aexecute(engine, query) insert_query = f""" INSERT INTO "{table_name}" ( @@ -290,7 +305,7 @@ async def test_load_from_query_default_content_customized_metadata(self, engine) organic ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); """ - await engine._aexecute(insert_query) + await aexecute(engine, insert_query) loader = await PostgresLoader.create( engine=engine, @@ -323,14 +338,14 @@ async def test_load_from_query_with_langchain_metadata(self, engine): langchain_metadata JSON NOT NULL ) """ - await engine._aexecute(query) + await aexecute(engine, query) metadata = json.dumps({"organic": 1}) insert_query = f""" INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, langchain_metadata) VALUES ('Apple', 'Granny Smith', 150, 1, '{metadata}');""" - await engine._aexecute(insert_query) + await aexecute(engine, insert_query) loader = await PostgresLoader.create( engine=engine, @@ -369,7 +384,7 @@ async def test_load_from_query_with_json(self, engine): langchain_metadata JSON NOT NULL ) """ - await engine._aexecute(query) + await aexecute(engine, query) metadata = json.dumps({"organic": 1}) variety = json.dumps({"type": "Granny Smith"}) @@ -377,7 +392,7 @@ async def test_load_from_query_with_json(self, engine): INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, langchain_metadata) VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" - await engine._aexecute(insert_query) + await aexecute(engine, insert_query) loader = await PostgresLoader.create( engine=engine, @@ -417,13 +432,13 @@ async def test_load_from_query_customized_content_default_metadata_custom_format organic INT NOT NULL ) """ - await engine._aexecute(query) + await aexecute(engine, query) insert_query = f""" INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) VALUES ('Apple', 'Granny Smith', 150, 1, 1); """ - await engine._aexecute(insert_query) + await aexecute(engine, insert_query) def my_formatter(row, content_columns): return "-".join( @@ -472,13 +487,13 @@ async def test_load_from_query_customized_content_default_metadata_custom_page_c organic INT NOT NULL ) """ - await engine._aexecute(query) + await aexecute(engine, query) insert_query = f""" INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) VALUES ('Apple', 'Granny Smith', 150, 1, 1); """ - await engine._aexecute(insert_query) + await aexecute(engine, insert_query) loader = await PostgresLoader.create( engine=engine, @@ -534,7 +549,9 @@ async def test_save_doc_with_default_metadata(self, engine): docs = await self._collect_async_items(loader.alazy_load()) assert docs == test_docs - assert (await engine._aload_table_schema(table_name)).columns.keys() == [ + assert ( + await engine._run_as_async(engine._aload_table_schema(table_name)) + ).columns.keys() == [ "page_content", "langchain_metadata", ] @@ -577,7 +594,9 @@ async def test_save_doc_with_customized_metadata(self, engine, store_metadata): if store_metadata: docs == test_docs - assert (await engine._aload_table_schema(table_name)).columns.keys() == [ + assert ( + await engine._run_as_async(engine._aload_table_schema(table_name)) + ).columns.keys() == [ "page_content", "fruit_name", "organic", @@ -590,7 +609,9 @@ async def test_save_doc_with_customized_metadata(self, engine, store_metadata): metadata={"fruit_name": "Apple", "organic": True}, ), ] - assert (await engine._aload_table_schema(table_name)).columns.keys() == [ + assert ( + await engine._run_as_async(engine._aload_table_schema(table_name)) + ).columns.keys() == [ "page_content", "fruit_name", "organic", @@ -628,7 +649,9 @@ async def test_save_doc_without_metadata(self, engine): metadata={}, ), ] - assert (await engine._aload_table_schema(table_name)).columns.keys() == [ + assert ( + await engine._run_as_async(engine._aload_table_schema(table_name)) + ).columns.keys() == [ "page_content", ] finally: @@ -782,7 +805,7 @@ async def test_delete_doc_with_customized_metadata( await saver.adelete(docs) assert len(await self._collect_async_items(loader.alazy_load())) == 0 - def test_sync_engine(self): + async def test_sync_engine(self): PostgresEngine._connector = None engine = PostgresEngine.from_instance( project_id=project_id, @@ -791,6 +814,7 @@ def test_sync_engine(self): database=db_name, ) assert engine + await engine.close() async def test_load_from_query_default_sync(self, sync_engine): try: @@ -815,7 +839,7 @@ async def test_load_from_query_default_sync(self, sync_engine): engine=sync_engine, query=f'SELECT * FROM "{table_name}";', ) - documents = loader.load() + documents = await loader.aload() assert documents == test_docs saver.delete(test_docs) diff --git a/tests/test_vectorstore.py b/tests/test_vectorstore.py new file mode 100644 index 00000000..7995cd63 --- /dev/null +++ b/tests/test_vectorstore.py @@ -0,0 +1,532 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import uuid +from threading import Thread +from typing import Sequence + +import pytest +import pytest_asyncio +from google.cloud.sql.connector import Connector, IPTypes +from langchain_core.documents import Document +from langchain_core.embeddings import DeterministicFakeEmbedding +from sqlalchemy import text +from sqlalchemy.engine.row import RowMapping +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine + +from langchain_google_cloud_sql_pg import Column, PostgresEngine, PostgresVectorStore + +DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) +DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4()) +CUSTOM_TABLE = "test-table-custom" + str(uuid.uuid4()) +VECTOR_SIZE = 768 + +embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) +host = os.environ["IP_ADDRESS"] + +texts = ["foo", "bar", "baz"] +metadatas = [{"page": str(i), "source": "google.com"} for i in range(len(texts))] +docs = [ + Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) +] + +embeddings = [embeddings_service.embed_query(texts[i]) for i in range(len(texts))] + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +async def aexecute( + engine: PostgresEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async def run(engine, query): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + return await engine._run_as_async(run(engine, query)) + + +@pytest.mark.asyncio(scope="class") +class TestVectorStore: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for cloud sql instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for cloud sql") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on cloud sql instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for cloud sql") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for cloud sql") + + @pytest_asyncio.fixture(scope="class") + async def engine(self, db_project, db_region, db_instance, db_name): + engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + + yield engine + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') + await engine.close() + + @pytest_asyncio.fixture(scope="class") + async def vs(self, engine): + await engine.ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) + vs = await PostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE, + ) + yield vs + + @pytest_asyncio.fixture(scope="class") + async def engine_sync(self, db_project, db_region, db_instance, db_name): + engine_sync = PostgresEngine.from_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + yield engine_sync + + await aexecute(engine_sync, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_SYNC}"') + await engine_sync.close() + + @pytest_asyncio.fixture(scope="class") + def vs_sync(self, engine_sync): + engine_sync.init_vectorstore_table(DEFAULT_TABLE_SYNC, VECTOR_SIZE) + + vs = PostgresVectorStore.create_sync( + engine_sync, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE_SYNC, + ) + yield vs + + @pytest_asyncio.fixture(scope="class") + async def vs_custom(self, engine): + await engine.ainit_vectorstore_table( + CUSTOM_TABLE, + VECTOR_SIZE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + metadata_json_column="mymeta", + ) + vs = await PostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["page", "source"], + metadata_json_column="mymeta", + ) + yield vs + await aexecute(engine, f'DROP TABLE IF EXISTS "{CUSTOM_TABLE}"') + + async def test_init_with_constructor(self, engine): + with pytest.raises(Exception): + PostgresVectorStore( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="noname", + embedding_column="myembedding", + metadata_columns=["page", "source"], + metadata_json_column="mymeta", + ) + + async def test_post_init(self, engine): + with pytest.raises(ValueError): + await PostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="noname", + embedding_column="myembedding", + metadata_columns=["page", "source"], + metadata_json_column="mymeta", + ) + + async def test_aadd_texts(self, engine, vs): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + await vs.aadd_texts(texts, ids=ids) + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 3 + + ids = [str(uuid.uuid4()) for i in range(len(texts))] + await vs.aadd_texts(texts, metadatas, ids) + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 6 + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + + async def test_cross_env_add_texts(self, engine, vs): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + vs.add_texts(texts, ids=ids) + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 3 + vs.delete(ids) + + async def test_aadd_texts_edge_cases(self, engine, vs): + texts = ["Taylor's", '"Swift"', "best-friend"] + ids = [str(uuid.uuid4()) for i in range(len(texts))] + await vs.aadd_texts(texts, ids=ids) + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 3 + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + + async def test_aadd_docs(self, engine, vs): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + await vs.aadd_documents(docs, ids=ids) + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 3 + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + + async def test_adelete(self, engine, vs): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + await vs.aadd_texts(texts, ids=ids) + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 3 + # delete an ID + await vs.adelete([ids[0]]) + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 2 + + async def test_aadd_texts_custom(self, engine, vs_custom): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + await vs_custom.aadd_texts(texts, ids=ids) + results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') + assert len(results) == 3 + assert results[0]["mycontent"] == "foo" + assert results[0]["myembedding"] + assert results[0]["page"] is None + assert results[0]["source"] is None + + ids = [str(uuid.uuid4()) for i in range(len(texts))] + await vs_custom.aadd_texts(texts, metadatas, ids) + results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') + assert len(results) == 6 + await aexecute(engine, f'TRUNCATE TABLE "{CUSTOM_TABLE}"') + + async def test_aadd_docs_custom(self, engine, vs_custom): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + docs = [ + Document( + page_content=texts[i], + metadata={"page": str(i), "source": "google.com"}, + ) + for i in range(len(texts)) + ] + await vs_custom.aadd_documents(docs, ids=ids) + + results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') + assert len(results) == 3 + assert results[0]["mycontent"] == "foo" + assert results[0]["myembedding"] + assert results[0]["page"] == "0" + assert results[0]["source"] == "google.com" + await aexecute(engine, f'TRUNCATE TABLE "{CUSTOM_TABLE}"') + + async def test_adelete_custom(self, engine, vs_custom): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + await vs_custom.aadd_texts(texts, ids=ids) + results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') + content = [result["mycontent"] for result in results] + assert len(results) == 3 + assert "foo" in content + # delete an ID + await vs_custom.adelete([ids[0]]) + results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') + content = [result["mycontent"] for result in results] + assert len(results) == 2 + assert "foo" not in content + + async def test_add_docs(self, engine_sync, vs_sync): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + vs_sync.add_documents(docs, ids=ids) + results = await afetch(engine_sync, f'SELECT * FROM "{DEFAULT_TABLE_SYNC}"') + assert len(results) == 3 + vs_sync.delete(ids) + + async def test_add_texts(self, engine_sync, vs_sync): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + vs_sync.add_texts(texts, ids=ids) + results = await afetch(engine_sync, f'SELECT * FROM "{DEFAULT_TABLE_SYNC}"') + assert len(results) == 3 + await vs_sync.adelete(ids) + + async def test_cross_env(self, engine_sync, vs_sync): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + await vs_sync.aadd_texts(texts, ids=ids) + results = await afetch(engine_sync, f'SELECT * FROM "{DEFAULT_TABLE_SYNC}"') + assert len(results) == 3 + await vs_sync.adelete(ids) + + async def test_create_vectorstore_with_invalid_parameters(self, engine): + with pytest.raises(ValueError): + await PostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["random_column"], # invalid metadata column + ) + with pytest.raises(ValueError): + await PostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="langchain_id", # invalid content column type + embedding_column="myembedding", + metadata_columns=["random_column"], + ) + with pytest.raises(ValueError): + await PostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="mycontent", + embedding_column="random_column", # invalid embedding column + metadata_columns=["random_column"], + ) + with pytest.raises(ValueError): + await PostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="mycontent", + embedding_column="langchain_id", # invalid embedding column data type + metadata_columns=["random_column"], + ) + + async def test_from_engine( + self, + db_project, + db_region, + db_instance, + db_name, + user, + password, + ): + async with Connector() as connector: + + async def getconn(): + conn = await connector.connect_async( # type: ignore + f"{db_project}:{db_region}:{db_instance}", + "asyncpg", + user=user, + password=password, + db=db_name, + enable_iam_auth=False, + ip_type=IPTypes.PUBLIC, + ) + return conn + + engine = create_async_engine( + "postgresql+asyncpg://", + async_creator=getconn, + ) + + engine = PostgresEngine.from_engine(engine) + table_name = "test_table" + str(uuid.uuid4()).replace("-", "_") + await engine.ainit_vectorstore_table(table_name, VECTOR_SIZE) + vs = await PostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=table_name, + ) + await vs.aadd_texts(["foo"]) + results = await afetch(engine, f"SELECT * FROM {table_name}") + assert len(results) == 1 + + await aexecute(engine, f"DROP TABLE {table_name}") + + async def test_from_engine_loop_connector( + self, + db_project, + db_region, + db_instance, + db_name, + user, + password, + ): + async def init_connection_pool(connector: Connector) -> AsyncEngine: + async def getconn(): + conn = await connector.connect_async( + f"{db_project}:{db_region}:{db_instance}", + "asyncpg", + user=user, + password=password, + db=db_name, + enable_iam_auth=False, + ip_type="PUBLIC", + ) + return conn + + pool = create_async_engine( + "postgresql+asyncpg://", + async_creator=getconn, + ) + return pool + + loop = asyncio.new_event_loop() + thread = Thread(target=loop.run_forever, daemon=True) + thread.start() + + connector = Connector(loop=loop) + coro = init_connection_pool(connector) + pool = asyncio.run_coroutine_threadsafe(coro, loop).result() + engine = PostgresEngine.from_engine(pool, loop) + table_name = "test_table" + str(uuid.uuid4()).replace("-", "_") + await engine.ainit_vectorstore_table(table_name, VECTOR_SIZE) + vs = await PostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=table_name, + ) + await vs.aadd_texts(["foo"]) + vs.add_texts(["foo"]) + results = await afetch(engine, f"SELECT * FROM {table_name}") + assert len(results) == 2 + + await aexecute(engine, f"TRUNCATE TABLE {table_name}") + + vs = PostgresVectorStore.create_sync( + engine, + embedding_service=embeddings_service, + table_name=table_name, + ) + await vs.aadd_texts(["foo"]) + vs.add_texts(["foo"]) + results = await afetch(engine, f"SELECT * FROM {table_name}") + assert len(results) == 2 + + await aexecute(engine, f"DROP TABLE {table_name}") + + async def test_from_engine_args_url( + self, + db_name, + user, + password, + ): + port = "5432" + url = f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{db_name}" + engine = PostgresEngine.from_engine_args(url) + table_name = "test_table" + str(uuid.uuid4()).replace("-", "_") + await engine.ainit_vectorstore_table(table_name, VECTOR_SIZE) + vs = await PostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=table_name, + ) + await vs.aadd_texts(["foo"]) + vs.add_texts(["foo"]) + results = await afetch(engine, f"SELECT * FROM {table_name}") + assert len(results) == 2 + + await aexecute(engine, f"TRUNCATE TABLE {table_name}") + vs = PostgresVectorStore.create_sync( + engine, + embedding_service=embeddings_service, + table_name=table_name, + ) + await vs.aadd_texts(["foo"]) + vs.add_texts(["bar"]) + results = await afetch(engine, f"SELECT * FROM {table_name}") + assert len(results) == 2 + await aexecute(engine, f"DROP TABLE {table_name}") + + async def test_from_engine_loop( + self, + db_name, + user, + password, + ): + port = "5432" + url = f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{db_name}" + + loop = asyncio.new_event_loop() + thread = Thread(target=loop.run_forever, daemon=True) + thread.start() + pool = create_async_engine(url) + engine = PostgresEngine.from_engine(pool, loop) + + table_name = "test_table" + str(uuid.uuid4()).replace("-", "_") + await engine.ainit_vectorstore_table(table_name, VECTOR_SIZE) + vs = await PostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=table_name, + ) + await vs.aadd_texts(["foo"]) + vs.add_texts(["foo"]) + results = await afetch(engine, f"SELECT * FROM {table_name}") + assert len(results) == 2 + + await aexecute(engine, f"TRUNCATE TABLE {table_name}") + vs = PostgresVectorStore.create_sync( + engine, + embedding_service=embeddings_service, + table_name=table_name, + ) + await vs.aadd_texts(["foo"]) + vs.add_texts(["bar"]) + results = await afetch(engine, f"SELECT * FROM {table_name}") + assert len(results) == 2 + await aexecute(engine, f"DROP TABLE {table_name}") diff --git a/tests/test_vectorstore_from_methods.py b/tests/test_vectorstore_from_methods.py new file mode 100644 index 00000000..fadf8fc1 --- /dev/null +++ b/tests/test_vectorstore_from_methods.py @@ -0,0 +1,325 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +from typing import Sequence + +import pytest +import pytest_asyncio +from langchain_core.documents import Document +from langchain_core.embeddings import DeterministicFakeEmbedding +from sqlalchemy import VARCHAR, text +from sqlalchemy.engine.row import RowMapping +from sqlalchemy.ext.asyncio import create_async_engine + +from langchain_google_cloud_sql_pg import Column, PostgresEngine, PostgresVectorStore + +DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") +DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_TABLE_WITH_INT_ID = "test_table_with_int_id" + str(uuid.uuid4()).replace( + "-", "_" +) +CUSTOM_TABLE_WITH_INT_ID_SYNC = "test_table_with_int_id" + str(uuid.uuid4()).replace( + "-", "_" +) +VECTOR_SIZE = 768 + + +embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) + +texts = ["foo", "bar", "baz"] +metadatas = [{"page": str(i), "source": "google.com"} for i in range(len(texts))] +docs = [ + Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) +] + +embeddings = [embeddings_service.embed_query(texts[i]) for i in range(len(texts))] + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +async def aexecute( + engine: PostgresEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async def run(engine, query): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + return await engine._run_as_async(run(engine, query)) + + +@pytest.mark.asyncio +class TestVectorStoreFromMethods: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for cloud sql instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for cloud sql") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on cloud sql instance") + + @pytest_asyncio.fixture + async def engine(self, db_project, db_region, db_instance, db_name): + engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + await engine.ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) + await engine.ainit_vectorstore_table( + CUSTOM_TABLE, + VECTOR_SIZE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + store_metadata=False, + ) + await engine.ainit_vectorstore_table( + CUSTOM_TABLE_WITH_INT_ID, + VECTOR_SIZE, + id_column=Column(name="integer_id", data_type="INTEGER", nullable="False"), + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + store_metadata=False, + ) + yield engine + await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") + await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE}") + await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE_WITH_INT_ID}") + await engine.close() + + @pytest_asyncio.fixture + async def engine_sync(self, db_project, db_region, db_instance, db_name): + engine = PostgresEngine.from_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + engine.init_vectorstore_table(DEFAULT_TABLE_SYNC, VECTOR_SIZE) + engine.init_vectorstore_table( + CUSTOM_TABLE_WITH_INT_ID_SYNC, + VECTOR_SIZE, + id_column=Column(name="integer_id", data_type="INTEGER", nullable="False"), + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + store_metadata=False, + ) + yield engine + await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_SYNC}") + await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE_WITH_INT_ID_SYNC}") + await engine.close() + + async def test_afrom_texts(self, engine): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + await PostgresVectorStore.afrom_texts( + texts, + embeddings_service, + engine, + DEFAULT_TABLE, + metadatas=metadatas, + ids=ids, + ) + results = await afetch(engine, f"SELECT * FROM {DEFAULT_TABLE}") + assert len(results) == 3 + await aexecute(engine, f"TRUNCATE TABLE {DEFAULT_TABLE}") + + async def test_from_texts(self, engine_sync): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + PostgresVectorStore.from_texts( + texts, + embeddings_service, + engine_sync, + DEFAULT_TABLE_SYNC, + metadatas=metadatas, + ids=ids, + ) + results = await afetch(engine_sync, f"SELECT * FROM {DEFAULT_TABLE_SYNC}") + assert len(results) == 3 + await aexecute(engine_sync, f"TRUNCATE TABLE {DEFAULT_TABLE_SYNC}") + + async def test_afrom_docs(self, engine): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + await PostgresVectorStore.afrom_documents( + docs, + embeddings_service, + engine, + DEFAULT_TABLE, + ids=ids, + ) + results = await afetch(engine, f"SELECT * FROM {DEFAULT_TABLE}") + assert len(results) == 3 + await aexecute(engine, f"TRUNCATE TABLE {DEFAULT_TABLE}") + + async def test_from_docs(self, engine_sync): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + PostgresVectorStore.from_documents( + docs, + embeddings_service, + engine_sync, + DEFAULT_TABLE_SYNC, + ids=ids, + ) + results = await afetch(engine_sync, f"SELECT * FROM {DEFAULT_TABLE_SYNC}") + assert len(results) == 3 + await aexecute(engine_sync, f"TRUNCATE TABLE {DEFAULT_TABLE_SYNC}") + + async def test_afrom_docs_cross_env(self, engine_sync): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + await PostgresVectorStore.afrom_documents( + docs, + embeddings_service, + engine_sync, + DEFAULT_TABLE_SYNC, + ids=ids, + ) + results = await afetch(engine_sync, f"SELECT * FROM {DEFAULT_TABLE_SYNC}") + assert len(results) == 3 + await aexecute(engine_sync, f"TRUNCATE TABLE {DEFAULT_TABLE_SYNC}") + + async def test_from_docs_cross_env(self, engine, engine_sync): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + PostgresVectorStore.from_documents( + docs, + embeddings_service, + engine, + DEFAULT_TABLE_SYNC, + ids=ids, + ) + results = await afetch(engine, f"SELECT * FROM {DEFAULT_TABLE_SYNC}") + assert len(results) == 3 + await aexecute(engine, f"TRUNCATE TABLE {DEFAULT_TABLE_SYNC}") + + async def test_afrom_texts_custom(self, engine): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + await PostgresVectorStore.afrom_texts( + texts, + embeddings_service, + engine, + CUSTOM_TABLE, + ids=ids, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["page", "source"], + ) + results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE}") + assert len(results) == 3 + assert results[0]["mycontent"] == "foo" + assert results[0]["myembedding"] + assert results[0]["page"] is None + assert results[0]["source"] is None + + async def test_afrom_docs_custom(self, engine): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + docs = [ + Document( + page_content=texts[i], + metadata={"page": str(i), "source": "google.com"}, + ) + for i in range(len(texts)) + ] + await PostgresVectorStore.afrom_documents( + docs, + embeddings_service, + engine, + CUSTOM_TABLE, + ids=ids, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["page", "source"], + ) + + results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE}") + assert len(results) == 3 + assert results[0]["mycontent"] == "foo" + assert results[0]["myembedding"] + assert results[0]["page"] == "0" + assert results[0]["source"] == "google.com" + await aexecute(engine, f"TRUNCATE TABLE {CUSTOM_TABLE}") + + async def test_afrom_texts_custom_with_int_id(self, engine): + ids = [i for i in range(len(texts))] + await PostgresVectorStore.afrom_texts( + texts, + embeddings_service, + engine, + CUSTOM_TABLE_WITH_INT_ID, + metadatas=metadatas, + ids=ids, + id_column="integer_id", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["page", "source"], + ) + results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE_WITH_INT_ID}") + assert len(results) == 3 + for row in results: + assert isinstance(row["integer_id"], int) + await aexecute(engine, f"TRUNCATE TABLE {CUSTOM_TABLE_WITH_INT_ID}") + + async def test_from_texts_custom_with_int_id(self, engine_sync): + ids = [i for i in range(len(texts))] + PostgresVectorStore.from_texts( + texts, + embeddings_service, + engine_sync, + CUSTOM_TABLE_WITH_INT_ID_SYNC, + metadatas=metadatas, + ids=ids, + id_column="integer_id", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["page", "source"], + ) + results = await afetch( + engine_sync, f"SELECT * FROM {CUSTOM_TABLE_WITH_INT_ID_SYNC}" + ) + assert len(results) == 3 + for row in results: + assert isinstance(row["integer_id"], int) + await aexecute(engine_sync, f"TRUNCATE TABLE {CUSTOM_TABLE_WITH_INT_ID_SYNC}") diff --git a/tests/test_vectorstore_index.py b/tests/test_vectorstore_index.py new file mode 100644 index 00000000..7c240061 --- /dev/null +++ b/tests/test_vectorstore_index.py @@ -0,0 +1,223 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import sys +import uuid + +import pytest +import pytest_asyncio +from langchain_core.documents import Document +from langchain_core.embeddings import DeterministicFakeEmbedding +from sqlalchemy import text + +from langchain_google_cloud_sql_pg import PostgresEngine, PostgresVectorStore +from langchain_google_cloud_sql_pg.indexes import ( + DEFAULT_INDEX_NAME_SUFFIX, + DistanceStrategy, + HNSWIndex, + IVFFlatIndex, +) + +DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") +DEFAULT_INDEX_NAME = DEFAULT_TABLE + DEFAULT_INDEX_NAME_SUFFIX +VECTOR_SIZE = 768 + +embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) + +texts = ["foo", "bar", "baz"] +ids = [str(uuid.uuid4()) for i in range(len(texts))] +metadatas = [{"page": str(i), "source": "google.com"} for i in range(len(texts))] +docs = [ + Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) +] + +embeddings = [embeddings_service.embed_query("foo") for i in range(len(texts))] + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +async def aexecute( + engine: PostgresEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +@pytest.mark.asyncio(scope="class") +class TestIndex: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for cloud sql instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for cloud sql") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "instance for cloud sql") + + @pytest_asyncio.fixture(scope="class") + async def engine(self, db_project, db_region, db_instance, db_name): + engine = PostgresEngine.from_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + yield engine + await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") + await engine.close() + + @pytest_asyncio.fixture(scope="class") + async def vs(self, engine): + engine.init_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) + vs = PostgresVectorStore.create_sync( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE, + ) + + vs.add_texts(texts, ids=ids) + vs.drop_vector_index() + yield vs + + async def test_aapply_vector_index(self, vs): + index = HNSWIndex() + vs.apply_vector_index(index) + assert vs.is_valid_index(DEFAULT_INDEX_NAME) + + async def test_areindex(self, vs): + if not vs.is_valid_index(DEFAULT_INDEX_NAME): + index = HNSWIndex() + vs.apply_vector_index(index) + vs.reindex() + vs.reindex(DEFAULT_INDEX_NAME) + assert vs.is_valid_index(DEFAULT_INDEX_NAME) + + async def test_dropindex(self, vs): + vs.drop_vector_index() + result = vs.is_valid_index(DEFAULT_INDEX_NAME) + assert not result + + async def test_aapply_vector_index_ivfflat(self, vs): + index = IVFFlatIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) + vs.apply_vector_index(index, concurrently=True) + assert vs.is_valid_index(DEFAULT_INDEX_NAME) + index = IVFFlatIndex( + name="secondindex", + distance_strategy=DistanceStrategy.INNER_PRODUCT, + ) + vs.apply_vector_index(index) + assert vs.is_valid_index("secondindex") + vs.drop_vector_index("secondindex") + + async def test_is_valid_index(self, vs): + is_valid = vs.is_valid_index("invalid_index") + assert is_valid == False + + +@pytest.mark.asyncio(scope="class") +class TestAsyncIndex: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for cloud sql instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for cloud sql") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "instance for cloud sql") + + @pytest_asyncio.fixture(scope="class") + async def engine(self, db_project, db_region, db_instance, db_name): + engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + yield engine + await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") + await engine.close() + + @pytest_asyncio.fixture(scope="class") + async def vs(self, engine): + await engine.ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) + vs = await PostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE, + ) + + await vs.aadd_texts(texts, ids=ids) + await vs.adrop_vector_index() + yield vs + + async def test_aapply_vector_index(self, vs): + index = HNSWIndex() + await vs.aapply_vector_index(index) + assert await vs.ais_valid_index(DEFAULT_INDEX_NAME) + + async def test_areindex(self, vs): + if not await vs.ais_valid_index(DEFAULT_INDEX_NAME): + index = HNSWIndex() + await vs.aapply_vector_index(index) + await vs.areindex() + await vs.areindex(DEFAULT_INDEX_NAME) + assert await vs.ais_valid_index(DEFAULT_INDEX_NAME) + + async def test_dropindex(self, vs): + await vs.adrop_vector_index() + result = await vs.ais_valid_index(DEFAULT_INDEX_NAME) + assert not result + + async def test_aapply_vector_index_ivfflat(self, vs): + index = IVFFlatIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) + await vs.aapply_vector_index(index, concurrently=True) + assert await vs.ais_valid_index(DEFAULT_INDEX_NAME) + index = IVFFlatIndex( + name="secondindex", + distance_strategy=DistanceStrategy.INNER_PRODUCT, + ) + await vs.aapply_vector_index(index) + assert await vs.ais_valid_index("secondindex") + await vs.adrop_vector_index("secondindex") + + async def test_is_valid_index(self, vs): + is_valid = await vs.ais_valid_index("invalid_index") + assert is_valid == False diff --git a/tests/test_cloudsql_vectorstore_search.py b/tests/test_vectorstore_search.py similarity index 79% rename from tests/test_cloudsql_vectorstore_search.py rename to tests/test_vectorstore_search.py index 65c6d8bc..f2c1cb17 100644 --- a/tests/test_cloudsql_vectorstore_search.py +++ b/tests/test_vectorstore_search.py @@ -19,12 +19,14 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding +from sqlalchemy import text from langchain_google_cloud_sql_pg import Column, PostgresEngine, PostgresVectorStore from langchain_google_cloud_sql_pg.indexes import DistanceStrategy, HNSWQueryOptions DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4()).replace("-", "_") VECTOR_SIZE = 768 embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) @@ -46,6 +48,18 @@ def get_env_var(key: str, desc: str) -> str: return v +async def aexecute( + engine: PostgresEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + @pytest.mark.asyncio(scope="class") class TestVectorStoreSearch: @pytest.fixture(scope="module") @@ -73,6 +87,8 @@ async def engine(self, db_project, db_region, db_instance, db_name): database=db_name, ) yield engine + await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") + await engine.close() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): @@ -87,11 +103,9 @@ async def vs(self, engine): ids = [str(uuid.uuid4()) for i in range(len(texts))] await vs.aadd_documents(docs, ids=ids) yield vs - await engine._aexecute(f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") - await engine._engine.dispose() @pytest_asyncio.fixture(scope="class") - def engine_sync(self, db_project, db_region, db_instance, db_name): + async def engine_sync(self, db_project, db_region, db_instance, db_name): engine = PostgresEngine.from_instance( project_id=db_project, instance=db_instance, @@ -99,6 +113,8 @@ def engine_sync(self, db_project, db_region, db_instance, db_name): database=db_name, ) yield engine + await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE}") + await engine.close() @pytest_asyncio.fixture(scope="class") async def vs_custom(self, engine_sync): @@ -126,8 +142,6 @@ async def vs_custom(self, engine_sync): ) vs_custom.add_documents(docs, ids=ids) yield vs_custom - engine_sync._aexecute(f"DROP TABLE IF EXISTS {CUSTOM_TABLE}") - engine_sync._engine.dispose() async def test_asimilarity_search(self, vs): results = await vs.asimilarity_search("foo", k=1) @@ -213,6 +227,63 @@ async def test_amax_marginal_relevance_search_vector_score(self, vs): ) assert results[0][0] == Document(page_content="bar") + +class TestVectorStoreSearchSync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for cloud sql instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for cloud sql") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "instance for cloud sql") + + @pytest_asyncio.fixture(scope="class") + async def engine_sync(self, db_project, db_region, db_instance, db_name): + engine = PostgresEngine.from_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + yield engine + await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE_SYNC}") + await engine.close() + + @pytest.fixture(scope="class") + def vs_custom(self, engine_sync): + engine_sync.init_vectorstore_table( + CUSTOM_TABLE_SYNC, + VECTOR_SIZE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[ + Column("page", "TEXT"), + Column("source", "TEXT"), + ], + store_metadata=False, + ) + + vs_custom = PostgresVectorStore.create_sync( + engine_sync, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE_SYNC, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + index_query_options=HNSWQueryOptions(ef_search=1), + ) + vs_custom.add_documents(docs, ids=ids) + yield vs_custom + def test_similarity_search(self, vs_custom): results = vs_custom.similarity_search("foo", k=1) assert len(results) == 1