From ec094c8363e69bb66120cb1c964675b33ba4f389 Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Wed, 4 Dec 2024 17:41:29 -0500 Subject: [PATCH 01/39] chore: update small typos in samples notebooks (#30) * chore: update llama_index_doc_store.ipynb * chore: Update llama_index_vector_store.ipynb --- samples/llama_index_doc_store.ipynb | 2 +- samples/llama_index_vector_store.ipynb | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/samples/llama_index_doc_store.ipynb b/samples/llama_index_doc_store.ipynb index bc14192..149dd2f 100644 --- a/samples/llama_index_doc_store.ipynb +++ b/samples/llama_index_doc_store.ipynb @@ -40,7 +40,7 @@ "id": "IR54BmgvdHT_" }, "source": [ - "### πŸ¦œπŸ”— Library Installation\n", + "### πŸ¦™ Library Installation\n", "Install the integration library, `llama-index-alloydb-pg`, and the library for the embedding service, `llama-index-embeddings-vertex`." ] }, diff --git a/samples/llama_index_vector_store.ipynb b/samples/llama_index_vector_store.ipynb index 40a67f3..873665a 100644 --- a/samples/llama_index_vector_store.ipynb +++ b/samples/llama_index_vector_store.ipynb @@ -14,7 +14,7 @@ "\n", "Learn more about the package on [GitHub](https://github.com/googleapis/llama-index-alloydb-pg-python/).\n", "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/llama-index-alloydb-pg-python/blob/main/docs/vector_store.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/llama-index-alloydb-pg-python/blob/main/samples/llama_index_vector_store.ipynb)" ] }, { @@ -40,7 +40,7 @@ "id": "IR54BmgvdHT_" }, "source": [ - "### πŸ¦œπŸ”— Library Installation\n", + "### πŸ¦™ Library Installation\n", "Install the integration library, `llama-index-alloydb-pg`, and the library for the embedding service, `llama-index-embeddings-vertex`." ] }, From 31f567b8400649d77b2534924e2db8aa9f34e7e3 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Fri, 6 Dec 2024 11:44:52 +0100 Subject: [PATCH 02/39] fix(deps): update python-nonmajor (#26) Co-authored-by: Vishwaraj Anand --- pyproject.toml | 6 +++--- requirements.txt | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7cb81ae..907866a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,11 +37,11 @@ Changelog = "https://github.com/googleapis/llama-index-alloydb-pg-python/blob/ma [project.optional-dependencies] test = [ - "black[jupyter]==24.8.0", + "black[jupyter]==24.10.0", "isort==5.13.2", - "mypy==1.11.2", + "mypy==1.13.0", "pytest-asyncio==0.24.0", - "pytest==8.3.3", + "pytest==8.3.4", "pytest-cov==6.0.0", "pytest-depends==1.0.1", ] diff --git a/requirements.txt b/requirements.txt index 5ff2916..9a009eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -google-cloud-alloydb-connector[asyncpg]==1.4.0 -llama-index-core==0.12.0 +google-cloud-alloydb-connector[asyncpg]==1.5.0 +llama-index-core==0.12.2 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.36 From e12a9ab7bf3d16ced9ecdf895aa5d9fe84abce20 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Fri, 6 Dec 2024 13:21:42 +0100 Subject: [PATCH 03/39] chore(deps): update dependency llama-index-core to v0.12.3 (#33) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9a009eb..2afc7b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ google-cloud-alloydb-connector[asyncpg]==1.5.0 -llama-index-core==0.12.2 +llama-index-core==0.12.3 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.36 From 28c5752fa78dc97e0ddce08b5b19db3cb2087e7c Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Fri, 6 Dec 2024 17:47:10 -0500 Subject: [PATCH 04/39] chore: adere to PEP 585 and remove unused imports (#34) --- .../async_document_store.py | 38 +++++++++---------- .../async_index_store.py | 9 ++--- .../async_vector_store.py | 37 ++++++++---------- src/llama_index_alloydb_pg/document_store.py | 34 ++++++++--------- src/llama_index_alloydb_pg/engine.py | 34 ++++++----------- src/llama_index_alloydb_pg/index_store.py | 10 ++--- src/llama_index_alloydb_pg/indexes.py | 4 +- src/llama_index_alloydb_pg/vector_store.py | 30 +++++++-------- tests/test_async_vector_store.py | 2 +- tests/test_async_vector_store_index.py | 5 +-- tests/test_engine.py | 2 +- tests/test_vector_store.py | 2 +- tests/test_vector_store_index.py | 5 +-- 13 files changed, 95 insertions(+), 117 deletions(-) diff --git a/src/llama_index_alloydb_pg/async_document_store.py b/src/llama_index_alloydb_pg/async_document_store.py index b45b9dc..f72a9bd 100644 --- a/src/llama_index_alloydb_pg/async_document_store.py +++ b/src/llama_index_alloydb_pg/async_document_store.py @@ -16,7 +16,7 @@ import json import warnings -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Optional, Sequence from llama_index.core.constants import DATA_KEY from llama_index.core.schema import BaseNode @@ -119,13 +119,13 @@ async def __afetch_query(self, query): return results async def _put_all_doc_hashes_to_table( - self, rows: List[Tuple[str, str]], batch_size: int = int(DEFAULT_BATCH_SIZE) + self, rows: list[tuple[str, str]], batch_size: int = int(DEFAULT_BATCH_SIZE) ) -> None: """Puts a multiple rows of node ids with their doc_hash into the document table. Incase a row with the id already exists, it updates the row with the new doc_hash. Args: - rows (List[Tuple[str, str]]): List of tuples of id and doc_hash + rows (list[tuple[str, str]]): List of tuples of id and doc_hash batch_size (int): batch_size to insert the rows. Defaults to 1. Returns: @@ -173,7 +173,7 @@ async def async_add_documents( """Adds a document to the store. Args: - docs (List[BaseDocument]): documents + docs (list[BaseDocument]): documents allow_update (bool): allow update of docstore from document batch_size (int): batch_size to insert the rows. Defaults to 1. store_text (bool): allow the text content of the node to stored. @@ -225,7 +225,7 @@ async def async_add_documents( await self.__aexecute_query(query, batch) @property - async def adocs(self) -> Dict[str, BaseNode]: + async def adocs(self) -> dict[str, BaseNode]: """Get all documents. Returns: @@ -300,12 +300,12 @@ async def aget_ref_doc_info(self, ref_doc_id: str) -> Optional[RefDocInfo]: return RefDocInfo(node_ids=node_ids, metadata=merged_metadata) - async def aget_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + async def aget_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: """Get a mapping of ref_doc_id -> RefDocInfo for all ingested documents. Returns: Optional[ - Dict[ + dict[ str, #Ref_doc_id RefDocInfo, #Ref_doc_info of the id ] @@ -356,14 +356,14 @@ async def adocument_exists(self, doc_id: str) -> bool: async def _get_ref_doc_child_node_ids( self, ref_doc_id: str - ) -> Optional[Dict[str, List[str]]]: + ) -> Optional[dict[str, list[str]]]: """Helper function to find the child node mappings of a ref_doc_id. Returns: Optional[ - Dict[ + dict[ str, # Ref_doc_id - List # List of all nodes that refer to ref_doc_id + list # List of all nodes that refer to ref_doc_id ] ]""" query = f"""select id from "{self._schema_name}"."{self._table_name}" where ref_doc_id = '{ref_doc_id}';""" @@ -442,11 +442,11 @@ async def aset_document_hash(self, doc_id: str, doc_hash: str) -> None: await self._put_all_doc_hashes_to_table(rows=[(doc_id, doc_hash)]) - async def aset_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + async def aset_document_hashes(self, doc_hashes: dict[str, str]) -> None: """Set the hash for a given doc_id. Args: - doc_hashes (Dict[str, str]): Dictionary with doc_id as key and doc_hash as value. + doc_hashes (dict[str, str]): Dictionary with doc_id as key and doc_hash as value. Returns: None @@ -473,11 +473,11 @@ async def aget_document_hash(self, doc_id: str) -> Optional[str]: else: return None - async def aget_all_document_hashes(self) -> Dict[str, str]: + async def aget_all_document_hashes(self) -> dict[str, str]: """Get the stored hash for all documents. Returns: - Dict[ + dict[ str, # doc_hash str # doc_id ] @@ -498,11 +498,11 @@ async def aget_all_document_hashes(self) -> Dict[str, str]: return hashes @property - def docs(self) -> Dict[str, BaseNode]: + def docs(self) -> dict[str, BaseNode]: """Get all documents. Returns: - Dict[str, BaseDocument]: documents + dict[str, BaseDocument]: documents """ raise NotImplementedError( @@ -547,7 +547,7 @@ def set_document_hash(self, doc_id: str, doc_hash: str) -> None: "Sync methods are not implemented for AsyncAlloyDBDocumentStore. Use AlloyDBDocumentStore interface instead." ) - def set_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + def set_document_hashes(self, doc_hashes: dict[str, str]) -> None: raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBDocumentStore. Use AlloyDBDocumentStore interface instead." ) @@ -557,12 +557,12 @@ def get_document_hash(self, doc_id: str) -> Optional[str]: "Sync methods are not implemented for AsyncAlloyDBDocumentStore. Use AlloyDBDocumentStore interface instead." ) - def get_all_document_hashes(self) -> Dict[str, str]: + def get_all_document_hashes(self) -> dict[str, str]: raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBDocumentStore. Use AlloyDBDocumentStore interface instead." ) - def get_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + def get_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBDocumentStore. Use AlloyDBDocumentStore interface instead." ) diff --git a/src/llama_index_alloydb_pg/async_index_store.py b/src/llama_index_alloydb_pg/async_index_store.py index a93255e..09999f3 100644 --- a/src/llama_index_alloydb_pg/async_index_store.py +++ b/src/llama_index_alloydb_pg/async_index_store.py @@ -16,9 +16,8 @@ import json import warnings -from typing import List, Optional +from typing import Optional -from llama_index.core.constants import DATA_KEY from llama_index.core.data_structs.data_structs import IndexStruct from llama_index.core.storage.index_store.types import BaseIndexStore from llama_index.core.storage.index_store.utils import ( @@ -113,11 +112,11 @@ async def __afetch_query(self, query): await conn.commit() return results - async def aindex_structs(self) -> List[IndexStruct]: + async def aindex_structs(self) -> list[IndexStruct]: """Get all index structs. Returns: - List[IndexStruct]: index structs + list[IndexStruct]: index structs """ query = f"""SELECT * from "{self._schema_name}"."{self._table_name}";""" @@ -190,7 +189,7 @@ async def aget_index_struct( return json_to_index_struct(index_data) return None - def index_structs(self) -> List[IndexStruct]: + def index_structs(self) -> list[IndexStruct]: raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBIndexStore . Use AlloyDBIndexStore interface instead." ) diff --git a/src/llama_index_alloydb_pg/async_vector_store.py b/src/llama_index_alloydb_pg/async_vector_store.py index 8be4854..d318b11 100644 --- a/src/llama_index_alloydb_pg/async_vector_store.py +++ b/src/llama_index_alloydb_pg/async_vector_store.py @@ -15,14 +15,10 @@ # TODO: Remove below import when minimum supported Python version is 3.10 from __future__ import annotations -import base64 import json -import re -import uuid import warnings -from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type +from typing import Any, Optional, Sequence -import numpy as np from llama_index.core.schema import BaseNode, MetadataMode, NodeRelationship, TextNode from llama_index.core.vector_stores.types import ( BasePydanticVectorStore, @@ -31,7 +27,6 @@ MetadataFilter, MetadataFilters, VectorStoreQuery, - VectorStoreQueryMode, VectorStoreQueryResult, ) from llama_index.core.vector_stores.utils import ( @@ -71,7 +66,7 @@ def __init__( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -89,7 +84,7 @@ def __init__( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -121,7 +116,7 @@ def __init__( @classmethod async def create( - cls: Type[AsyncAlloyDBVectorStore], + cls: type[AsyncAlloyDBVectorStore], engine: AlloyDBEngine, table_name: str, schema_name: str = "public", @@ -129,7 +124,7 @@ async def create( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -147,7 +142,7 @@ async def create( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -234,7 +229,7 @@ def client(self) -> Any: """Get client.""" return self._engine - async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> List[str]: + async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> list[str]: """Asynchronously add nodes to the table.""" ids = [] metadata_col_names = ( @@ -293,14 +288,14 @@ async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: async def adelete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: """Asynchronously delete a set of nodes from the table matching the provided nodes and filters.""" if not node_ids and not filters: return - all_filters: List[MetadataFilter | MetadataFilters] = [] + all_filters: list[MetadataFilter | MetadataFilters] = [] if node_ids: all_filters.append( MetadataFilter( @@ -332,9 +327,9 @@ async def aclear(self) -> None: async def aget_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: """Asynchronously get nodes from the table matching the provided nodes and filters.""" query = VectorStoreQuery( node_ids=node_ids, filters=filters, similarity_top_k=-1 @@ -366,7 +361,7 @@ async def aquery( similarities.append(row["distance"]) return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids) - def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> List[str]: + def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]: raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." ) @@ -378,7 +373,7 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: def delete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: @@ -393,9 +388,9 @@ def clear(self) -> None: def get_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." ) @@ -495,7 +490,7 @@ async def __query_columns( **kwargs: Any, ) -> Sequence[RowMapping]: """Perform search query on database.""" - filters: List[MetadataFilter | MetadataFilters] = [] + filters: list[MetadataFilter | MetadataFilters] = [] if query.doc_ids: filters.append( MetadataFilter( diff --git a/src/llama_index_alloydb_pg/document_store.py b/src/llama_index_alloydb_pg/document_store.py index cc86297..033c052 100644 --- a/src/llama_index_alloydb_pg/document_store.py +++ b/src/llama_index_alloydb_pg/document_store.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Dict, List, Optional, Sequence, Type +from typing import Optional, Sequence from llama_index.core.schema import BaseNode from llama_index.core.storage.docstore import BaseDocumentStore @@ -55,7 +55,7 @@ def __init__( @classmethod async def create( - cls: Type[AlloyDBDocumentStore], + cls: type[AlloyDBDocumentStore], engine: AlloyDBEngine, table_name: str, schema_name: str = "public", @@ -83,7 +83,7 @@ async def create( @classmethod def create_sync( - cls: Type[AlloyDBDocumentStore], + cls: type[AlloyDBDocumentStore], engine: AlloyDBEngine, table_name: str, schema_name: str = "public", @@ -110,11 +110,11 @@ def create_sync( return cls(cls.__create_key, engine, document_store) @property - def docs(self) -> Dict[str, BaseNode]: + def docs(self) -> dict[str, BaseNode]: """Get all documents. Returns: - Dict[str, BaseDocument]: documents + dict[str, BaseDocument]: documents """ return self._engine._run_as_sync(self.__document_store.adocs) @@ -291,11 +291,11 @@ def set_document_hash(self, doc_id: str, doc_hash: str) -> None: self.__document_store.aset_document_hash(doc_id, doc_hash) ) - async def aset_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + async def aset_document_hashes(self, doc_hashes: dict[str, str]) -> None: """Set the hash for a given doc_id. Args: - doc_hashes (Dict[str, str]): Dictionary with doc_id as key and doc_hash as value. + doc_hashes (dict[str, str]): Dictionary with doc_id as key and doc_hash as value. Returns: None @@ -304,11 +304,11 @@ async def aset_document_hashes(self, doc_hashes: Dict[str, str]) -> None: self.__document_store.aset_document_hashes(doc_hashes) ) - def set_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + def set_document_hashes(self, doc_hashes: dict[str, str]) -> None: """Set the hash for a given doc_id. Args: - doc_hashes (Dict[str, str]): Dictionary with doc_id as key and doc_hash as value. + doc_hashes (dict[str, str]): Dictionary with doc_id as key and doc_hash as value. Returns: None @@ -343,11 +343,11 @@ def get_document_hash(self, doc_id: str) -> Optional[str]: self.__document_store.aget_document_hash(doc_id) ) - async def aget_all_document_hashes(self) -> Dict[str, str]: + async def aget_all_document_hashes(self) -> dict[str, str]: """Get the stored hash for all documents. Returns: - Dict[ + dict[ str, # doc_hash str # doc_id ] @@ -356,11 +356,11 @@ async def aget_all_document_hashes(self) -> Dict[str, str]: self.__document_store.aget_all_document_hashes() ) - def get_all_document_hashes(self) -> Dict[str, str]: + def get_all_document_hashes(self) -> dict[str, str]: """Get the stored hash for all documents. Returns: - Dict[ + dict[ str, # doc_hash str # doc_id ] @@ -369,12 +369,12 @@ def get_all_document_hashes(self) -> Dict[str, str]: self.__document_store.aget_all_document_hashes() ) - async def aget_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + async def aget_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: """Get a mapping of ref_doc_id -> RefDocInfo for all ingested documents. Returns: Optional[ - Dict[ + dict[ str, #Ref_doc_id RefDocInfo, #Ref_doc_info of the id ] @@ -384,12 +384,12 @@ async def aget_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: self.__document_store.aget_all_ref_doc_info() ) - def get_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + def get_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: """Get a mapping of ref_doc_id -> RefDocInfo for all ingested documents. Returns: Optional[ - Dict[ + dict[ str, #Ref_doc_id RefDocInfo, #Ref_doc_info of the id ] diff --git a/src/llama_index_alloydb_pg/engine.py b/src/llama_index_alloydb_pg/engine.py index f9984d0..bcb531b 100644 --- a/src/llama_index_alloydb_pg/engine.py +++ b/src/llama_index_alloydb_pg/engine.py @@ -17,17 +17,7 @@ from concurrent.futures import Future from dataclasses import dataclass from threading import Thread -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Dict, - List, - Optional, - Type, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, Awaitable, Optional, TypeVar, Union import aiohttp import google.auth # type: ignore @@ -76,7 +66,7 @@ async def _get_iam_principal_email( url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}" async with aiohttp.ClientSession() as client: response = await client.get(url, raise_for_status=True) - response_json: Dict = await response.json() + response_json: dict = await response.json() email = response_json.get("email") if email is None: raise ValueError( @@ -179,7 +169,7 @@ def __start_background_loop( @classmethod def from_instance( - cls: Type[AlloyDBEngine], + cls: type[AlloyDBEngine], project_id: str, region: str, cluster: str, @@ -221,7 +211,7 @@ def from_instance( @classmethod async def _create( - cls: Type[AlloyDBEngine], + cls: type[AlloyDBEngine], project_id: str, region: str, cluster: str, @@ -305,7 +295,7 @@ async def getconn() -> asyncpg.Connection: @classmethod async def afrom_instance( - cls: Type[AlloyDBEngine], + cls: type[AlloyDBEngine], project_id: str, region: str, cluster: str, @@ -347,7 +337,7 @@ async def afrom_instance( @classmethod def from_engine( - cls: Type[AlloyDBEngine], + cls: type[AlloyDBEngine], engine: AsyncEngine, loop: Optional[asyncio.AbstractEventLoop] = None, ) -> AlloyDBEngine: @@ -512,7 +502,7 @@ async def _ainit_vector_store_table( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[Column] = [], + metadata_columns: list[Column] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -528,7 +518,7 @@ async def _ainit_vector_store_table( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -585,7 +575,7 @@ async def ainit_vector_store_table( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[Column] = [], + metadata_columns: list[Column] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -601,7 +591,7 @@ async def ainit_vector_store_table( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -636,7 +626,7 @@ def init_vector_store_table( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[Column] = [], + metadata_columns: list[Column] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -652,7 +642,7 @@ def init_vector_store_table( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". diff --git a/src/llama_index_alloydb_pg/index_store.py b/src/llama_index_alloydb_pg/index_store.py index f99e054..a90f0cc 100644 --- a/src/llama_index_alloydb_pg/index_store.py +++ b/src/llama_index_alloydb_pg/index_store.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import List, Optional +from typing import Optional from llama_index.core.data_structs.data_structs import IndexStruct from llama_index.core.storage.index_store.types import BaseIndexStore @@ -96,20 +96,20 @@ def create_sync( index_store = engine._run_as_sync(coro) return cls(cls.__create_key, engine, index_store) - async def aindex_structs(self) -> List[IndexStruct]: + async def aindex_structs(self) -> list[IndexStruct]: """Get all index structs. Returns: - List[IndexStruct]: index structs + list[IndexStruct]: index structs """ return await self._engine._run_as_async(self.__index_store.aindex_structs()) - def index_structs(self) -> List[IndexStruct]: + def index_structs(self) -> list[IndexStruct]: """Get all index structs. Returns: - List[IndexStruct]: index structs + list[IndexStruct]: index structs """ return self._engine._run_as_sync(self.__index_store.aindex_structs()) diff --git a/src/llama_index_alloydb_pg/indexes.py b/src/llama_index_alloydb_pg/indexes.py index 6c69ffa..5793c53 100644 --- a/src/llama_index_alloydb_pg/indexes.py +++ b/src/llama_index_alloydb_pg/indexes.py @@ -15,7 +15,7 @@ import enum from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import List, Optional +from typing import Optional @dataclass @@ -45,7 +45,7 @@ class BaseIndex(ABC): distance_strategy: DistanceStrategy = field( default_factory=lambda: DistanceStrategy.COSINE_DISTANCE ) - partial_indexes: Optional[List[str]] = None + partial_indexes: Optional[list[str]] = None @abstractmethod def index_options(self) -> str: diff --git a/src/llama_index_alloydb_pg/vector_store.py b/src/llama_index_alloydb_pg/vector_store.py index 852b6b6..d4ef2f6 100644 --- a/src/llama_index_alloydb_pg/vector_store.py +++ b/src/llama_index_alloydb_pg/vector_store.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, List, Optional, Sequence, Type +from typing import Any, Optional, Sequence from llama_index.core.schema import BaseNode from llama_index.core.vector_stores.types import ( @@ -71,7 +71,7 @@ def __init__( @classmethod async def create( - cls: Type[AlloyDBVectorStore], + cls: type[AlloyDBVectorStore], engine: AlloyDBEngine, table_name: str, schema_name: str = "public", @@ -79,7 +79,7 @@ async def create( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -97,7 +97,7 @@ async def create( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -138,7 +138,7 @@ async def create( @classmethod def create_sync( - cls: Type[AlloyDBVectorStore], + cls: type[AlloyDBVectorStore], engine: AlloyDBEngine, table_name: str, schema_name: str = "public", @@ -146,7 +146,7 @@ def create_sync( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -164,7 +164,7 @@ def create_sync( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -212,11 +212,11 @@ def client(self) -> Any: """Get client.""" return self._engine - async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> List[str]: + async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> list[str]: """Asynchronously add nodes to the table.""" return await self._engine._run_as_async(self.__vs.async_add(nodes, **kwargs)) - def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> List[str]: + def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]: """Synchronously add nodes to the table.""" return self._engine._run_as_sync(self.__vs.async_add(nodes, **add_kwargs)) @@ -230,7 +230,7 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: async def adelete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: @@ -241,7 +241,7 @@ async def adelete_nodes( def delete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: @@ -260,17 +260,17 @@ def clear(self) -> None: async def aget_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: """Asynchronously get nodes from the table matching the provided nodes and filters.""" return await self._engine._run_as_async(self.__vs.aget_nodes(node_ids, filters)) def get_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: """Asynchronously get nodes from the table matching the provided nodes and filters.""" return self._engine._run_as_sync(self.__vs.aget_nodes(node_ids, filters)) diff --git a/tests/test_async_vector_store.py b/tests/test_async_vector_store.py index f3200b5..459b28d 100644 --- a/tests/test_async_vector_store.py +++ b/tests/test_async_vector_store.py @@ -14,7 +14,7 @@ import os import uuid -from typing import List, Sequence +from typing import Sequence import pytest import pytest_asyncio diff --git a/tests/test_async_vector_store_index.py b/tests/test_async_vector_store_index.py index 1b78e34..f7373c7 100644 --- a/tests/test_async_vector_store_index.py +++ b/tests/test_async_vector_store_index.py @@ -15,13 +15,11 @@ import os import uuid -from typing import List, Sequence import pytest import pytest_asyncio -from llama_index.core.schema import MetadataMode, NodeRelationship, TextNode +from llama_index.core.schema import TextNode from sqlalchemy import text -from sqlalchemy.engine.row import RowMapping from llama_index_alloydb_pg import AlloyDBEngine from llama_index_alloydb_pg.async_vector_store import AsyncAlloyDBVectorStore @@ -31,7 +29,6 @@ HNSWIndex, IVFFlatIndex, ) -from llama_index_alloydb_pg.vector_store import AlloyDBVectorStore DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") DEFAULT_INDEX_NAME = DEFAULT_TABLE + DEFAULT_INDEX_NAME_SUFFIX diff --git a/tests/test_engine.py b/tests/test_engine.py index 3d30bc0..744ba84 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -14,7 +14,7 @@ import os import uuid -from typing import Dict, Sequence +from typing import Sequence import asyncpg # type: ignore import pytest diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py index b55e582..fc2d63d 100644 --- a/tests/test_vector_store.py +++ b/tests/test_vector_store.py @@ -14,7 +14,7 @@ import os import uuid -from typing import List, Sequence +from typing import Sequence import pytest import pytest_asyncio diff --git a/tests/test_vector_store_index.py b/tests/test_vector_store_index.py index 405f762..4a2f298 100644 --- a/tests/test_vector_store_index.py +++ b/tests/test_vector_store_index.py @@ -14,15 +14,12 @@ import os -import sys import uuid -from typing import List, Sequence import pytest import pytest_asyncio -from llama_index.core.schema import MetadataMode, NodeRelationship, TextNode +from llama_index.core.schema import TextNode from sqlalchemy import text -from sqlalchemy.engine.row import RowMapping from sqlalchemy.ext.asyncio import create_async_engine from llama_index_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore From 0c50a7d0551913297f2aa5a1466da6e8534453ba Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Tue, 10 Dec 2024 04:14:35 +0530 Subject: [PATCH 05/39] chore: add tests for omni (#25) * chore: add tests for omni * chore: uncomment alloy db scann test * fix: tests for scann index on omni and alloy db hosted * chore: add drop index statements to avoid conflict --- tests/test_vector_store_index.py | 117 +++++++++++++++++++++++++++---- 1 file changed, 102 insertions(+), 15 deletions(-) diff --git a/tests/test_vector_store_index.py b/tests/test_vector_store_index.py index 4a2f298..ae863f1 100644 --- a/tests/test_vector_store_index.py +++ b/tests/test_vector_store_index.py @@ -35,11 +35,8 @@ DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE_ASYNC = "test_table" + str(uuid.uuid4()).replace("-", "_") -DEFAULT_TABLE_OMNI = "test_table" + str(uuid.uuid4()).replace("-", "_") -CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") DEFAULT_INDEX_NAME = DEFAULT_TABLE + DEFAULT_INDEX_NAME_SUFFIX DEFAULT_INDEX_NAME_ASYNC = DEFAULT_TABLE_ASYNC + DEFAULT_INDEX_NAME_SUFFIX -DEFAULT_INDEX_NAME_OMNI = DEFAULT_TABLE_OMNI + DEFAULT_INDEX_NAME_SUFFIX VECTOR_SIZE = 5 @@ -122,14 +119,54 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): @pytest_asyncio.fixture(scope="class") async def vs(self, engine): - engine.init_vector_store_table(DEFAULT_TABLE, VECTOR_SIZE) + engine.init_vector_store_table( + DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True + ) vs = AlloyDBVectorStore.create_sync( engine, table_name=DEFAULT_TABLE, ) - await vs.async_add(nodes) + vs.add(nodes) + vs.drop_vector_index() + yield vs + + @pytest.fixture(scope="module") + def omni_host(self) -> str: + return get_env_var("OMNI_HOST", "AlloyDB Omni host address") + + @pytest.fixture(scope="module") + def omni_user(self) -> str: + return get_env_var("OMNI_USER", "AlloyDB Omni user name") + + @pytest.fixture(scope="module") + def omni_password(self) -> str: + return get_env_var("OMNI_PASSWORD", "AlloyDB Omni password") + + @pytest.fixture(scope="module") + def omni_database_name(self) -> str: + return get_env_var("OMNI_DATABASE_ID", "AlloyDB Omni database name") + @pytest_asyncio.fixture(scope="class") + async def omni_engine( + self, omni_host, omni_user, omni_password, omni_database_name + ): + connstring = f"postgresql+asyncpg://{omni_user}:{omni_password}@{omni_host}:5432/{omni_database_name}" + omni_engine = AlloyDBEngine.from_connection_string(connstring) + yield omni_engine + await aexecute(omni_engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") + await omni_engine.close() + + @pytest_asyncio.fixture(scope="class") + async def omni_vs(self, omni_engine): + omni_engine.init_vector_store_table( + DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True + ) + vs = AlloyDBVectorStore.create_sync( + omni_engine, + table_name=DEFAULT_TABLE, + ) + vs.add(nodes) vs.drop_vector_index() yield vs @@ -137,6 +174,7 @@ async def test_aapply_vector_index(self, vs): index = HNSWIndex() vs.apply_vector_index(index) assert vs.is_valid_index(DEFAULT_INDEX_NAME) + vs.drop_vector_index(DEFAULT_INDEX_NAME) async def test_areindex(self, vs): if not vs.is_valid_index(DEFAULT_INDEX_NAME): @@ -145,6 +183,7 @@ async def test_areindex(self, vs): vs.reindex() vs.reindex(DEFAULT_INDEX_NAME) assert vs.is_valid_index(DEFAULT_INDEX_NAME) + vs.drop_vector_index(DEFAULT_INDEX_NAME) async def test_dropindex(self, vs): vs.drop_vector_index() @@ -162,11 +201,40 @@ async def test_aapply_vector_index_ivfflat(self, vs): vs.apply_vector_index(index) assert vs.is_valid_index("secondindex") vs.drop_vector_index("secondindex") + vs.drop_vector_index(DEFAULT_INDEX_NAME) async def test_is_valid_index(self, vs): is_valid = vs.is_valid_index("invalid_index") assert is_valid == False + async def test_aapply_vector_index_scann(self, vs): + index = ScaNNIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) + await vs.aset_maintenance_work_mem(index.num_leaves, VECTOR_SIZE) + await vs.aapply_vector_index(index, concurrently=True) + assert await vs.ais_valid_index(DEFAULT_INDEX_NAME) + index = ScaNNIndex( + name="secondindex", + distance_strategy=DistanceStrategy.COSINE_DISTANCE, + ) + await vs.aapply_vector_index(index) + assert await vs.ais_valid_index("secondindex") + await vs.adrop_vector_index("secondindex") + await vs.adrop_vector_index() + + async def test_apply_vector_index_scann_omni(self, omni_vs): + index = ScaNNIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) + omni_vs.set_maintenance_work_mem(index.num_leaves, VECTOR_SIZE) + omni_vs.apply_vector_index(index, concurrently=True) + assert omni_vs.is_valid_index(DEFAULT_INDEX_NAME) + index = ScaNNIndex( + name="secondindex", + distance_strategy=DistanceStrategy.COSINE_DISTANCE, + ) + omni_vs.apply_vector_index(index) + assert omni_vs.is_valid_index("secondindex") + omni_vs.drop_vector_index("secondindex") + omni_vs.drop_vector_index() + @pytest.mark.asyncio(loop_scope="class") class TestAsyncIndex: @@ -213,7 +281,9 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): @pytest_asyncio.fixture(scope="class") async def vs(self, engine): - await engine.ainit_vector_store_table(DEFAULT_TABLE_ASYNC, VECTOR_SIZE) + await engine.ainit_vector_store_table( + DEFAULT_TABLE_ASYNC, VECTOR_SIZE, overwrite_existing=True + ) vs = await AlloyDBVectorStore.create( engine, table_name=DEFAULT_TABLE_ASYNC, @@ -244,28 +314,30 @@ async def omni_engine( self, omni_host, omni_user, omni_password, omni_database_name ): connstring = f"postgresql+asyncpg://{omni_user}:{omni_password}@{omni_host}:5432/{omni_database_name}" - print(f"Connecting to AlloyDB Omni with {connstring}") - async_engine = create_async_engine(connstring) omni_engine = AlloyDBEngine.from_engine(async_engine) yield omni_engine - await aexecute(omni_engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_OMNI}") + await aexecute(omni_engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_ASYNC}") await omni_engine.close() @pytest_asyncio.fixture(scope="class") - async def omni_vs(self, engine): - await engine.ainit_vector_store_table(DEFAULT_TABLE_OMNI, VECTOR_SIZE) + async def omni_vs(self, omni_engine): + await omni_engine.ainit_vector_store_table( + DEFAULT_TABLE_ASYNC, VECTOR_SIZE, overwrite_existing=True + ) vs = await AlloyDBVectorStore.create( - engine, - table_name=DEFAULT_TABLE_OMNI, + omni_engine, + table_name=DEFAULT_TABLE_ASYNC, ) await vs.async_add(nodes) + await vs.adrop_vector_index() yield vs async def test_aapply_vector_index(self, vs): index = HNSWIndex() await vs.aapply_vector_index(index) assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) + await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC) async def test_areindex(self, vs): if not await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC): @@ -274,6 +346,7 @@ async def test_areindex(self, vs): await vs.areindex() await vs.areindex(DEFAULT_INDEX_NAME_ASYNC) assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) + await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC) async def test_dropindex(self, vs): await vs.adrop_vector_index() @@ -310,11 +383,11 @@ async def test_aapply_vector_index_ivf(self, vs): await vs.adrop_vector_index("secondindex") await vs.adrop_vector_index() - async def test_aapply_alloydb_scann_index_ScaNN(self, omni_vs): + async def test_aapply_vector_index_scann_omni(self, omni_vs): index = ScaNNIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) await omni_vs.aset_maintenance_work_mem(index.num_leaves, VECTOR_SIZE) await omni_vs.aapply_vector_index(index, concurrently=True) - assert await omni_vs.ais_valid_index(DEFAULT_INDEX_NAME_OMNI) + assert await omni_vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) index = ScaNNIndex( name="secondindex", distance_strategy=DistanceStrategy.COSINE_DISTANCE, @@ -323,3 +396,17 @@ async def test_aapply_alloydb_scann_index_ScaNN(self, omni_vs): assert await omni_vs.ais_valid_index("secondindex") await omni_vs.adrop_vector_index("secondindex") await omni_vs.adrop_vector_index() + + async def test_apply_vector_index_scann(self, vs): + index = ScaNNIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) + vs.set_maintenance_work_mem(index.num_leaves, VECTOR_SIZE) + vs.apply_vector_index(index, concurrently=True) + assert vs.is_valid_index(DEFAULT_INDEX_NAME_ASYNC) + index = ScaNNIndex( + name="secondindex", + distance_strategy=DistanceStrategy.COSINE_DISTANCE, + ) + vs.apply_vector_index(index) + assert vs.is_valid_index("secondindex") + vs.drop_vector_index("secondindex") + vs.drop_vector_index() From de53006d00fe1edd5b3e5c1349613e82f0c94794 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Mon, 9 Dec 2024 22:55:16 +0000 Subject: [PATCH 06/39] feat: Adding chat store init methods. (#29) * feat: Adding chat store init methods. * Add index and constraints to chat store table. * Adjusted to changed schema * Removed unique constraint on key * Fix tests --------- Co-authored-by: Averi Kitsch --- src/llama_index_alloydb_pg/engine.py | 91 ++++++++++++++++++++++++++++ tests/test_engine.py | 36 +++++++++++ 2 files changed, 127 insertions(+) diff --git a/src/llama_index_alloydb_pg/engine.py b/src/llama_index_alloydb_pg/engine.py index bcb531b..f004335 100644 --- a/src/llama_index_alloydb_pg/engine.py +++ b/src/llama_index_alloydb_pg/engine.py @@ -757,6 +757,97 @@ def init_index_store_table( ) ) + async def _ainit_chat_store_table( + self, + table_name: str, + schema_name: str = "public", + overwrite_existing: bool = False, + ) -> None: + """ + Create an AlloyDB table to save chat store. + + Args: + table_name (str): The table name to store chat history. + schema_name (str): The schema name to store the chat store table. + Default: "public". + overwrite_existing (bool): Whether to drop existing table. + Default: False. + + Returns: + None + """ + if overwrite_existing: + async with self._pool.connect() as conn: + await conn.execute( + text(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"') + ) + await conn.commit() + + create_table_query = f"""CREATE TABLE "{schema_name}"."{table_name}"( + id SERIAL PRIMARY KEY, + key VARCHAR NOT NULL, + message JSON NOT NULL + );""" + create_index_query = f"""CREATE INDEX "{table_name}_idx_key" ON "{schema_name}"."{table_name}" (key);""" + async with self._pool.connect() as conn: + await conn.execute(text(create_table_query)) + await conn.execute(text(create_index_query)) + await conn.commit() + + async def ainit_chat_store_table( + self, + table_name: str, + schema_name: str = "public", + overwrite_existing: bool = False, + ) -> None: + """ + Create an AlloyDB table to save chat store. + + Args: + table_name (str): The table name to store chat store. + schema_name (str): The schema name to store the chat store table. + Default: "public". + overwrite_existing (bool): Whether to drop existing table. + Default: False. + + Returns: + None + """ + await self._run_as_async( + self._ainit_chat_store_table( + table_name, + schema_name, + overwrite_existing, + ) + ) + + def init_chat_store_table( + self, + table_name: str, + schema_name: str = "public", + overwrite_existing: bool = False, + ) -> None: + """ + Create an AlloyDB table to save chat store. + + Args: + table_name (str): The table name to store chat store. + schema_name (str): The schema name to store the chat store table. + Default: "public". + overwrite_existing (bool): Whether to drop existing table. + Default: False. + + Returns: + None + """ + self._run_as_sync( + self._ainit_chat_store_table( + table_name, + schema_name, + overwrite_existing, + ) + ) + async def _aload_table_schema( self, table_name: str, schema_name: str = "public" ) -> Table: diff --git a/tests/test_engine.py b/tests/test_engine.py index 744ba84..16cab72 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -34,6 +34,8 @@ DEFAULT_IS_TABLE_SYNC = "index_store_" + str(uuid.uuid4()) DEFAULT_VS_TABLE = "vector_store_" + str(uuid.uuid4()) DEFAULT_VS_TABLE_SYNC = "vector_store_" + str(uuid.uuid4()) +DEFAULT_CS_TABLE = "chat_store_" + str(uuid.uuid4()) +DEFAULT_CS_TABLE_SYNC = "chat_store_" + str(uuid.uuid4()) VECTOR_SIZE = 768 @@ -118,6 +120,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE}"') + await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE}"') await engine.close() async def test_password( @@ -307,6 +310,22 @@ async def test_init_index_store(self, engine): for row in results: assert row in expected + async def test_init_chat_store(self, engine): + await engine.ainit_chat_store_table( + table_name=DEFAULT_CS_TABLE, + schema_name="public", + overwrite_existing=True, + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{DEFAULT_CS_TABLE}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "id", "data_type": "integer"}, + {"column_name": "key", "data_type": "character varying"}, + {"column_name": "message", "data_type": "json"}, + ] + for row in results: + assert row in expected + @pytest.mark.asyncio class TestEngineSync: @@ -359,6 +378,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE_SYNC}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE_SYNC}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE_SYNC}"') await engine.close() async def test_password( @@ -481,3 +501,19 @@ async def test_init_index_store(self, engine): ] for row in results: assert row in expected + + async def test_init_chat_store(self, engine): + engine.init_chat_store_table( + table_name=DEFAULT_CS_TABLE_SYNC, + schema_name="public", + overwrite_existing=True, + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{DEFAULT_CS_TABLE_SYNC}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "id", "data_type": "integer"}, + {"column_name": "key", "data_type": "character varying"}, + {"column_name": "message", "data_type": "json"}, + ] + for row in results: + assert row in expected From ed1f15ab7c3bdb4fe3b38b56557bc93922057d73 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Tue, 10 Dec 2024 20:24:54 +0100 Subject: [PATCH 07/39] chore(deps): update dependency llama-index-core to v0.12.5 (#36) Co-authored-by: Averi Kitsch --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2afc7b6..d34f5e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ google-cloud-alloydb-connector[asyncpg]==1.5.0 -llama-index-core==0.12.3 +llama-index-core==0.12.5 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.36 From 1abee55bc8bd074ac55c7636412de04525432385 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Thu, 12 Dec 2024 02:26:16 +0530 Subject: [PATCH 08/39] chore: add code coverage (#38) --- .coveragerc | 8 +++ integration.cloudbuild.yaml | 2 +- tests/test_async_document_store.py | 82 ++++++++++++++++++++++++++++-- tests/test_async_index_store.py | 36 +++++++++++-- tests/test_async_vector_store.py | 72 ++++++++++++++++++++++++-- tests/test_document_store.py | 8 ++- tests/test_engine.py | 60 ++++++++++++++++++++++ tests/test_index_store.py | 10 +++- tests/test_vector_store.py | 6 ++- tests/test_vector_store_index.py | 14 +++++ 10 files changed, 282 insertions(+), 16 deletions(-) create mode 100644 .coveragerc diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..b21412b --- /dev/null +++ b/.coveragerc @@ -0,0 +1,8 @@ +[run] +branch = true +omit = + */__init__.py + +[report] +show_missing = true +fail_under = 90 diff --git a/integration.cloudbuild.yaml b/integration.cloudbuild.yaml index 383aa79..3749777 100644 --- a/integration.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -39,7 +39,7 @@ steps: - "-c" - | /workspace/alloydb-auth-proxy --port ${_DATABASE_PORT} ${_INSTANCE_CONNECTION_NAME} & sleep 2; - python -m pytest tests/ + python -m pytest --cov=llama_index_alloydb_pg --cov-config=.coveragerc tests/ env: - "PROJECT_ID=$PROJECT_ID" - "INSTANCE_ID=$_INSTANCE_ID" diff --git a/tests/test_async_document_store.py b/tests/test_async_document_store.py index f045152..c978f94 100644 --- a/tests/test_async_document_store.py +++ b/tests/test_async_document_store.py @@ -28,6 +28,7 @@ default_table_name_async = "document_store_" + str(uuid.uuid4()) custom_table_name = "document_store_" + str(uuid.uuid4()) +sync_method_exception_str = "Sync methods are not implemented for AsyncAlloyDBDocumentStore. Use AlloyDBDocumentStore interface instead." async def aexecute(engine: AlloyDBEngine, query: str) -> None: @@ -124,9 +125,16 @@ async def custom_doc_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): AsyncAlloyDBDocumentStore( - engine=async_engine, table_name=default_table_name_async + key, engine=async_engine, table_name=default_table_name_async + ) + + async def test_create_without_table(self, async_engine): + with pytest.raises(ValueError): + await AsyncAlloyDBDocumentStore.create( + engine=async_engine, table_name="non-existent-table" ) async def test_warning(self, custom_doc_store): @@ -187,7 +195,7 @@ async def test_add_hash_before_data(self, async_engine, doc_store): result = results[0] assert result["node_data"][DATA_KEY]["text"] == document_text - async def test_ref_doc_exists(self, doc_store): + async def test_aref_doc_exists(self, doc_store): # Create a ref_doc & a doc and add them to the store. ref_doc = Document( text="first doc", id_="doc_exists_doc_1", metadata={"doc": "info"} @@ -244,6 +252,8 @@ async def test_adelete_ref_doc(self, doc_store): assert ( await doc_store.aget_document(doc_id=doc.doc_id, raise_error=False) is None ) + # Confirm deleting an non-existent reference doc returns None. + assert await doc_store.adelete_ref_doc(ref_doc_id=ref_doc.doc_id) is None async def test_set_and_get_document_hash(self, doc_store): # Set a doc hash for a document @@ -254,6 +264,9 @@ async def test_set_and_get_document_hash(self, doc_store): # Assert with get that the hash is same as the one set. assert await doc_store.aget_document_hash(doc_id=doc_id) == doc_hash + async def test_aget_document_hash(self, doc_store): + assert await doc_store.aget_document_hash(doc_id="non-existent-doc") is None + async def test_set_and_get_document_hashes(self, doc_store): # Create a dictionary of doc_id -> doc_hash mappings and add it to the table. document_dict = { @@ -288,7 +301,7 @@ async def test_doc_store_basic(self, doc_store): retrieved_node = await doc_store.aget_document(doc_id=node.node_id) assert retrieved_node == node - async def test_delete_document(self, async_engine, doc_store): + async def test_adelete_document(self, async_engine, doc_store): # Create a doc and add it to the store. doc = Document(text="document_2", id_="doc_id_2", metadata={"doc": "info"}) await doc_store.async_add_documents([doc]) @@ -301,6 +314,11 @@ async def test_delete_document(self, async_engine, doc_store): result = await afetch(async_engine, query) assert len(result) == 0 + async def test_delete_non_existent_document(self, async_engine, doc_store): + await doc_store.adelete_document(doc_id="non-existent-doc", raise_error=False) + with pytest.raises(ValueError): + await doc_store.adelete_document(doc_id="non-existent-doc") + async def test_doc_store_ref_doc_not_added(self, async_engine, doc_store): # Create a ref_doc & doc. ref_doc = Document( @@ -376,3 +394,61 @@ async def test_doc_store_delete_all_ref_doc_nodes(self, async_engine, doc_store) query = f"""select * from "public"."{default_table_name_async}" where id = '{ref_doc.doc_id}';""" result = await afetch(async_engine, query) assert len(result) == 0 + + async def test_docs(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.docs() + + async def test_add_documents(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.add_documents([]) + + async def test_get_document(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_document("test_doc_id", raise_error=True) + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_document("test_doc_id", raise_error=False) + + async def test_delete_document(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_document("test_doc_id", raise_error=True) + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_document("test_doc_id", raise_error=False) + + async def test_document_exists(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.document_exists("test_doc_id") + + async def test_ref_doc_exists(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.ref_doc_exists(ref_doc_id="test_ref_doc_id") + + async def test_set_document_hash(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.set_document_hash("test_doc_id", "test_doc_hash") + + async def test_set_document_hashes(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.set_document_hashes({"test_doc_id": "test_doc_hash"}) + + async def test_get_document_hash(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_document_hash(doc_id="test_doc_id") + + async def test_get_all_document_hashes(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_all_document_hashes() + + async def test_get_all_ref_doc_info(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_all_ref_doc_info() + + async def test_get_ref_doc_info(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_ref_doc_info(ref_doc_id="test_doc_id") + + async def test_delete_ref_doc(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_ref_doc(ref_doc_id="test_doc_id", raise_error=False) + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_ref_doc(ref_doc_id="test_doc_id", raise_error=True) diff --git a/tests/test_async_index_store.py b/tests/test_async_index_store.py index 44a32cb..09f532d 100644 --- a/tests/test_async_index_store.py +++ b/tests/test_async_index_store.py @@ -19,13 +19,19 @@ import pytest import pytest_asyncio -from llama_index.core.data_structs.data_structs import IndexDict, IndexGraph, IndexList +from llama_index.core.data_structs.data_structs import ( + IndexDict, + IndexGraph, + IndexList, + IndexStruct, +) from sqlalchemy import RowMapping, text from llama_index_alloydb_pg import AlloyDBEngine from llama_index_alloydb_pg.async_index_store import AsyncAlloyDBIndexStore default_table_name_async = "index_store_" + str(uuid.uuid4()) +sync_method_exception_str = "Sync methods are not implemented for AsyncAlloyDBIndexStore . Use AlloyDBIndexStore interface instead." async def aexecute(engine: AlloyDBEngine, query: str) -> None: @@ -81,7 +87,7 @@ def password(self) -> str: @pytest_asyncio.fixture(scope="class") async def async_engine( - self, db_project, db_region, db_cluster, db_instance, db_name, user, password + self, db_project, db_region, db_cluster, db_instance, db_name ): async_engine = await AlloyDBEngine.afrom_instance( project_id=db_project, @@ -109,9 +115,16 @@ async def index_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): AsyncAlloyDBIndexStore( - engine=async_engine, table_name=default_table_name_async + key, engine=async_engine, table_name=default_table_name_async + ) + + async def test_create_without_table(self, async_engine): + with pytest.raises(ValueError): + await AsyncAlloyDBIndexStore.create( + engine=async_engine, table_name="non-existent-table" ) async def test_add_and_delete_index(self, index_store, async_engine): @@ -169,3 +182,20 @@ async def test_warning(self, index_store): assert "No struct_id specified and more than one struct exists." in str( w[-1].message ) + + async def test_index_structs(self, index_store): + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.index_structs() + + async def test_add_index_struct(self, index_store): + index_struct = IndexGraph() + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.add_index_struct(index_struct) + + async def test_delete_index_struct(self, index_store): + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.delete_index_struct("non_existent_key") + + async def test_get_index_struct(self, index_store): + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.get_index_struct(struct_id="non_existent_id") diff --git a/tests/test_async_vector_store.py b/tests/test_async_vector_store.py index 459b28d..e99ac51 100644 --- a/tests/test_async_vector_store.py +++ b/tests/test_async_vector_store.py @@ -14,6 +14,7 @@ import os import uuid +import warnings from typing import Sequence import pytest @@ -114,8 +115,8 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): ) yield engine - await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE_CUSTOM_VS}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_CUSTOM_VS}"') await engine.close() @pytest_asyncio.fixture(scope="class") @@ -158,8 +159,9 @@ async def custom_vs(self, engine): yield vs async def test_init_with_constructor(self, engine): + key = object() with pytest.raises(Exception): - AsyncAlloyDBVectorStore(engine, table_name=DEFAULT_TABLE) + AsyncAlloyDBVectorStore(key, engine, table_name=DEFAULT_TABLE) async def test_validate_id_column_create(self, engine, vs): test_id_column = "test_id_column" @@ -318,6 +320,70 @@ async def test_aquery(self, engine, vs): assert len(results.nodes) == 3 assert results.nodes[0].get_content(metadata_mode=MetadataMode.NONE) == "foo" + async def test_aquery_filters(self, engine, custom_vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE_CUSTOM_VS}"') + # setting extra metadata to be indexed in separate column + for node in nodes: + node.metadata["len"] = len(node.text) + + await custom_vs.async_add(nodes) + + filters = MetadataFilters( + filters=[ + MetadataFilter( + key="some_test_column", + value=["value_should_be_ignored"], + operator=FilterOperator.CONTAINS, + ), + MetadataFilter( + key="len", + value=3, + operator=FilterOperator.LTE, + ), + MetadataFilter( + key="len", + value=3, + operator=FilterOperator.GTE, + ), + MetadataFilter( + key="len", + value=2, + operator=FilterOperator.GT, + ), + MetadataFilter( + key="len", + value=4, + operator=FilterOperator.LT, + ), + MetadataFilters( + filters=[ + MetadataFilter( + key="len", + value=6.0, + operator=FilterOperator.NE, + ), + ], + condition=FilterCondition.OR, + ), + ], + condition=FilterCondition.AND, + ) + query = VectorStoreQuery( + query_embedding=[1.0] * VECTOR_SIZE, filters=filters, similarity_top_k=-1 + ) + with warnings.catch_warnings(record=True) as w: + results = await custom_vs.aquery(query) + + assert len(w) == 1 + assert "Expecting a scalar in the filter value" in str(w[-1].message) + + assert results.nodes is not None + assert results.ids is not None + assert results.similarities is not None + assert len(results.nodes) == 3 + async def test_aclear(self, engine, vs): # Note: To be migrated to a pytest dependency on test_adelete # Blocked due to unexpected fixtures reloads while running integration test suite diff --git a/tests/test_document_store.py b/tests/test_document_store.py index 7a9d391..44533e4 100644 --- a/tests/test_document_store.py +++ b/tests/test_document_store.py @@ -123,9 +123,10 @@ async def doc_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): AlloyDBDocumentStore( - engine=async_engine, table_name=default_table_name_async + key, engine=async_engine, table_name=default_table_name_async ) async def test_async_add_document(self, async_engine, doc_store): @@ -414,8 +415,11 @@ async def sync_doc_store(self, sync_engine): await aexecute(sync_engine, query) async def test_init_with_constructor(self, sync_engine): + key = object() with pytest.raises(Exception): - AlloyDBDocumentStore(engine=sync_engine, table_name=default_table_name_sync) + AlloyDBDocumentStore( + key, engine=sync_engine, table_name=default_table_name_sync + ) async def test_docs(self, sync_doc_store): # Create and add document into the doc store. diff --git a/tests/test_engine.py b/tests/test_engine.py index 16cab72..fd168e1 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -123,6 +123,37 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE}"') await engine.close() + async def test_init_with_constructor( + self, + db_project, + db_region, + db_cluster, + db_instance, + db_name, + user, + password, + ): + async def getconn() -> asyncpg.Connection: + conn = await connector.connect( # type: ignore + f"projects/{db_project}/locations/{db_region}/clusters/{db_cluster}/instances/{db_instance}", + "asyncpg", + user=user, + password=password, + db=db_name, + enable_iam_auth=False, + ip_type=IPTypes.PUBLIC, + ) + return conn + + engine = create_async_engine( + "postgresql+asyncpg://", + async_creator=getconn, + ) + + key = object() + with pytest.raises(Exception): + AlloyDBEngine(key, engine) + async def test_password( self, db_project, @@ -148,6 +179,35 @@ async def test_password( AlloyDBEngine._connector = None await engine.close() + async def test_missing_user_or_password( + self, + db_project, + db_region, + db_cluster, + db_instance, + db_name, + user, + password, + ): + with pytest.raises(ValueError): + await AlloyDBEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + cluster=db_cluster, + database=db_name, + user=user, + ) + with pytest.raises(ValueError): + await AlloyDBEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + cluster=db_cluster, + database=db_name, + password=password, + ) + async def test_from_engine( self, db_project, diff --git a/tests/test_index_store.py b/tests/test_index_store.py index 03a2eb9..a5bc09e 100644 --- a/tests/test_index_store.py +++ b/tests/test_index_store.py @@ -123,8 +123,11 @@ async def index_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): - AlloyDBIndexStore(engine=async_engine, table_name=default_table_name_async) + AlloyDBIndexStore( + key, engine=async_engine, table_name=default_table_name_async + ) async def test_add_and_delete_index(self, index_store, async_engine): index_struct = IndexGraph() @@ -248,8 +251,11 @@ async def index_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): - AlloyDBIndexStore(engine=async_engine, table_name=default_table_name_sync) + AlloyDBIndexStore( + key, engine=async_engine, table_name=default_table_name_sync + ) async def test_add_and_delete_index(self, index_store, async_engine): index_struct = IndexGraph() diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py index fc2d63d..96a7bca 100644 --- a/tests/test_vector_store.py +++ b/tests/test_vector_store.py @@ -134,8 +134,9 @@ async def vs(self, engine): yield vs async def test_init_with_constructor(self, engine): + key = object() with pytest.raises(Exception): - AlloyDBVectorStore(engine, table_name=DEFAULT_TABLE) + AlloyDBVectorStore(key, engine, table_name=DEFAULT_TABLE) async def test_validate_id_column_create(self, engine, vs): test_id_column = "test_id_column" @@ -513,8 +514,9 @@ async def vs(self, engine): yield vs async def test_init_with_constructor(self, engine): + key = object() with pytest.raises(Exception): - AlloyDBVectorStore(engine, table_name=DEFAULT_TABLE) + AlloyDBVectorStore(key, engine, table_name=DEFAULT_TABLE) async def test_validate_id_column_create(self, engine, vs): test_id_column = "test_id_column" diff --git a/tests/test_vector_store_index.py b/tests/test_vector_store_index.py index ae863f1..3920929 100644 --- a/tests/test_vector_store_index.py +++ b/tests/test_vector_store_index.py @@ -203,6 +203,20 @@ async def test_aapply_vector_index_ivfflat(self, vs): vs.drop_vector_index("secondindex") vs.drop_vector_index(DEFAULT_INDEX_NAME) + async def test_apply_vector_index_scann(self, vs): + index = ScaNNIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) + vs.set_maintenance_work_mem(index.num_leaves, VECTOR_SIZE) + vs.apply_vector_index(index, concurrently=True) + assert vs.is_valid_index(DEFAULT_INDEX_NAME) + index = ScaNNIndex( + name="secondindex", + distance_strategy=DistanceStrategy.COSINE_DISTANCE, + ) + vs.apply_vector_index(index) + assert vs.is_valid_index("secondindex") + vs.drop_vector_index("secondindex") + vs.drop_vector_index(DEFAULT_INDEX_NAME) + async def test_is_valid_index(self, vs): is_valid = vs.is_valid_index("invalid_index") assert is_valid == False From 6c9ad8510b780d44e0255fbeab36ad5bfab917f3 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Thu, 12 Dec 2024 19:04:19 +0100 Subject: [PATCH 09/39] chore(deps): update dependency google-cloud-alloydb-connector to v1.6.0 (#39) Co-authored-by: Averi Kitsch --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index d34f5e3..5cdb1c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -google-cloud-alloydb-connector[asyncpg]==1.5.0 +google-cloud-alloydb-connector[asyncpg]==1.6.0 llama-index-core==0.12.5 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.36 From 228c50c5f9b7e25043e475b67a0dd2f152650949 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Thu, 19 Dec 2024 00:59:51 +0530 Subject: [PATCH 10/39] chore(ci): add Cloud Build failure reporter (#41) * chore(ci): add Cloud Build failure reporter * chore: refer to langchain alloy db workflow --- .github/workflows/schedule_reporter.yml | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 .github/workflows/schedule_reporter.yml diff --git a/.github/workflows/schedule_reporter.yml b/.github/workflows/schedule_reporter.yml new file mode 100644 index 0000000..ab846ef --- /dev/null +++ b/.github/workflows/schedule_reporter.yml @@ -0,0 +1,25 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Schedule Reporter + +on: + schedule: + - cron: '0 6 * * *' # Runs at 6 AM every morning + +jobs: + run_reporter: + uses: googleapis/langchain-google-alloydb-pg-python/.github/workflows/cloud_build_failure_reporter.yml@main + with: + trigger_names: "integration-test-nightly,continuous-test-on-merge" From dcba62642557779db3b8421f93562b20a9d2a754 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 6 Jan 2025 13:19:08 -0800 Subject: [PATCH 11/39] chore(deps): bump jinja2 from 3.1.4 to 3.1.5 in /.kokoro (#42) Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.4 to 3.1.5. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/jinja/compare/3.1.4...3.1.5) --- updated-dependencies: - dependency-name: jinja2 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .kokoro/requirements.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.kokoro/requirements.txt b/.kokoro/requirements.txt index 88fb726..23e61f6 100644 --- a/.kokoro/requirements.txt +++ b/.kokoro/requirements.txt @@ -277,9 +277,9 @@ jeepney==0.8.0 \ # via # keyring # secretstorage -jinja2==3.1.4 \ - --hash=sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369 \ - --hash=sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d +jinja2==3.1.5 \ + --hash=sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb \ + --hash=sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb # via gcp-releasetool keyring==24.3.1 \ --hash=sha256:c3327b6ffafc0e8befbdb597cacdb4928ffe5c1212f7645f186e6d9957a898db \ @@ -525,4 +525,4 @@ zipp==3.19.1 \ # WARNING: The following packages were not pinned, but pip requires them to be # pinned when the requirements file includes hashes and the requirement is not # satisfied by a package already installed. Consider using the --allow-unsafe flag. -# setuptools \ No newline at end of file +# setuptools From dd987718f0482177d03c84eee6334703613461d0 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Tue, 7 Jan 2025 07:30:11 +0000 Subject: [PATCH 12/39] feat: Adding Async Chat Store (#35) * feat: Adding Async Chat Store * Removed user and password from tests * Linter fix * Added docstrings. * chore(deps): update dependency google-cloud-alloydb-connector to v1.6.0 (#39) Co-authored-by: Averi Kitsch * chore(ci): add Cloud Build failure reporter (#41) * chore(ci): add Cloud Build failure reporter * chore: refer to langchain alloy db workflow * chore(deps): bump jinja2 from 3.1.4 to 3.1.5 in /.kokoro (#42) Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.4 to 3.1.5. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/jinja/compare/3.1.4...3.1.5) --- updated-dependencies: - dependency-name: jinja2 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Made review changes * Changed header to 2025. * Changed header to 2025. --------- Signed-off-by: dependabot[bot] Co-authored-by: Averi Kitsch Co-authored-by: Vishwaraj Anand Co-authored-by: Mend Renovate Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../async_chat_store.py | 295 ++++++++++++++++++ tests/test_async_chat_store.py | 224 +++++++++++++ 2 files changed, 519 insertions(+) create mode 100644 src/llama_index_alloydb_pg/async_chat_store.py create mode 100644 tests/test_async_chat_store.py diff --git a/src/llama_index_alloydb_pg/async_chat_store.py b/src/llama_index_alloydb_pg/async_chat_store.py new file mode 100644 index 0000000..d967349 --- /dev/null +++ b/src/llama_index_alloydb_pg/async_chat_store.py @@ -0,0 +1,295 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import List, Optional + +from llama_index.core.llms import ChatMessage +from llama_index.core.storage.chat_store.base import BaseChatStore +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import AlloyDBEngine + + +class AsyncAlloyDBChatStore(BaseChatStore): + """Chat Store Table stored in an AlloyDB for PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, + key: object, + engine: AsyncEngine, + table_name: str, + schema_name: str = "public", + ): + """AsyncAlloyDBChatStore constructor. + + Args: + key (object): Key to prevent direct constructor usage. + engine (AlloyDBEngine): Database connection pool. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. + Defaults to "public" + + Raises: + Exception: If constructor is directly called by the user. + """ + if key != AsyncAlloyDBChatStore.__create_key: + raise Exception("Only create class through 'create' method!") + + # Delegate to Pydantic's __init__ + super().__init__() + self._engine = engine + self._table_name = table_name + self._schema_name = schema_name + + @classmethod + async def create( + cls, + engine: AlloyDBEngine, + table_name: str, + schema_name: str = "public", + ) -> AsyncAlloyDBChatStore: + """Create a new AsyncAlloyDBChatStore instance. + + Args: + engine (AlloyDBEngine): AlloyDB engine to use. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. + Defaults to "public" + + Raises: + ValueError: If the table provided does not contain required schema. + + Returns: + AsyncAlloyDBChatStore: A newly created instance of AsyncAlloyDBChatStore. + """ + table_schema = await engine._aload_table_schema(table_name, schema_name) + column_names = table_schema.columns.keys() + + required_columns = ["id", "key", "message"] + + if not (all(x in column_names for x in required_columns)): + raise ValueError( + f"Table '{schema_name}'.'{table_name}' has an incorrect schema.\n" + f"Expected column names: {required_columns}\n" + f"Provided column names: {column_names}\n" + "Please create the table with the following schema:\n" + f"CREATE TABLE {schema_name}.{table_name} (\n" + " id SERIAL PRIMARY KEY,\n" + " key VARCHAR NOT NULL,\n" + " message JSON NOT NULL\n" + ");" + ) + + return cls(cls.__create_key, engine._pool, table_name, schema_name) + + async def __aexecute_query(self, query, params=None): + async with self._engine.connect() as conn: + await conn.execute(text(query), params) + await conn.commit() + + async def __afetch_query(self, query): + async with self._engine.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + results = result_map.fetchall() + await conn.commit() + return results + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "AsyncAlloyDBChatStore" + + async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None: + """Asynchronously sets the chat messages for a specific key. + + Args: + key (str): A unique identifier for the chat. + messages (List[ChatMessage]): A list of `ChatMessage` objects to upsert. + + Returns: + None + + """ + query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE key = '{key}'; """ + await self.__aexecute_query(query) + insert_query = f""" + INSERT INTO "{self._schema_name}"."{self._table_name}" (key, message) + VALUES (:key, :message);""" + + params = [ + { + "key": key, + "message": json.dumps(message.dict()), + } + for message in messages + ] + + await self.__aexecute_query(insert_query, params) + + async def aget_messages(self, key: str) -> List[ChatMessage]: + """Asynchronously retrieves the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for which the messages are to be retrieved. + + Returns: + List[ChatMessage]: A list of `ChatMessage` objects associated with the provided key. + If no messages are found, an empty list is returned. + """ + query = f"""SELECT message from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id;""" + results = await self.__afetch_query(query) + if results: + return [ + ChatMessage.model_validate(result.get("message")) for result in results + ] + return [] + + async def async_add_message(self, key: str, message: ChatMessage) -> None: + """Asynchronously adds a new chat message to the specified key. + + Args: + key (str): A unique identifierfor the chat to which the message is added. + message (ChatMessage): The `ChatMessage` object that is to be added. + + Returns: + None + """ + insert_query = f""" + INSERT INTO "{self._schema_name}"."{self._table_name}" (key, message) + VALUES (:key, :message);""" + params = {"key": key, "message": json.dumps(message.dict())} + + await self.__aexecute_query(insert_query, params) + + async def adelete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """Asynchronously deletes the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + + Returns: + Optional[List[ChatMessage]]: A list of `ChatMessage` objects that were deleted, or `None` if no messages + were associated with the key or could be deleted. + """ + query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' RETURNING *; """ + results = await self.__afetch_query(query) + if results: + return [ + ChatMessage.model_validate(result.get("message")) for result in results + ] + return None + + async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """Asynchronously deletes a specific chat message by index from the messages associated with a given key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + idx (int): The index of the `ChatMessage` to be deleted from the list of messages. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + query = f"""SELECT * from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id;""" + results = await self.__afetch_query(query) + if results: + if idx >= len(results): + return None + id_to_be_deleted = results[idx].get("id") + delete_query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE id = '{id_to_be_deleted}' RETURNING *;""" + result = await self.__afetch_query(delete_query) + result = result[0] + if result: + return ChatMessage.model_validate(result.get("message")) + return None + return None + + async def adelete_last_message(self, key: str) -> Optional[ChatMessage]: + """Asynchronously deletes the last chat message associated with a given key. + + Args: + key (str): A unique identifier for the chat whose message is to be deleted. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + query = f"""SELECT * from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id DESC LIMIT 1;""" + results = await self.__afetch_query(query) + if results: + id_to_be_deleted = results[0].get("id") + delete_query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE id = '{id_to_be_deleted}' RETURNING *;""" + result = await self.__afetch_query(delete_query) + result = result[0] + if result: + return ChatMessage.model_validate(result.get("message")) + return None + return None + + async def aget_keys(self) -> List[str]: + """Asynchronously retrieves a list of all keys. + + Returns: + Optional[str]: A list of strings representing the keys. If no keys are found, an empty list is returned. + """ + query = ( + f"""SELECT distinct key from "{self._schema_name}"."{self._table_name}";""" + ) + results = await self.__afetch_query(query) + keys = [] + if results: + keys = [row.get("key") for row in results] + return keys + + def set_messages(self, key: str, messages: List[ChatMessage]) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBChatStore . Use AlloyDBChatStore interface instead." + ) + + def get_messages(self, key: str) -> List[ChatMessage]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBChatStore . Use AlloyDBChatStore interface instead." + ) + + def add_message(self, key: str, message: ChatMessage) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBChatStore . Use AlloyDBChatStore interface instead." + ) + + def delete_messages(self, key: str) -> Optional[List[ChatMessage]]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBChatStore . Use AlloyDBChatStore interface instead." + ) + + def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBChatStore . Use AlloyDBChatStore interface instead." + ) + + def delete_last_message(self, key: str) -> Optional[ChatMessage]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBChatStore . Use AlloyDBChatStore interface instead." + ) + + def get_keys(self) -> List[str]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBChatStore . Use AlloyDBChatStore interface instead." + ) diff --git a/tests/test_async_chat_store.py b/tests/test_async_chat_store.py new file mode 100644 index 0000000..70397e4 --- /dev/null +++ b/tests/test_async_chat_store.py @@ -0,0 +1,224 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.llms import ChatMessage +from sqlalchemy import RowMapping, text + +from llama_index_alloydb_pg import AlloyDBEngine +from llama_index_alloydb_pg.async_chat_store import AsyncAlloyDBChatStore + +default_table_name_async = "chat_store_" + str(uuid.uuid4()) + + +async def aexecute(engine: AlloyDBEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +async def afetch(engine: AlloyDBEngine, query: str) -> Sequence[RowMapping]: + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestAsyncAlloyDBChatStore: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on AlloyDB instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for AlloyDB") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for AlloyDB") + + @pytest_asyncio.fixture(scope="class") + async def async_engine( + self, + db_project, + db_region, + db_cluster, + db_instance, + db_name, + ): + async_engine = await AlloyDBEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + cluster=db_cluster, + region=db_region, + database=db_name, + ) + + yield async_engine + + await async_engine.close() + + @pytest_asyncio.fixture(scope="class") + async def chat_store(self, async_engine): + await async_engine._ainit_chat_store_table(table_name=default_table_name_async) + + chat_store = await AsyncAlloyDBChatStore.create( + engine=async_engine, table_name=default_table_name_async + ) + + yield chat_store + + query = f'DROP TABLE IF EXISTS "{default_table_name_async}"' + await aexecute(async_engine, query) + + async def test_init_with_constructor(self, async_engine): + with pytest.raises(Exception): + AsyncAlloyDBChatStore( + engine=async_engine, table_name=default_table_name_async + ) + + async def test_async_add_message(self, async_engine, chat_store): + key = "test_add_key" + + message = ChatMessage(content="add_message_test", role="user") + await chat_store.async_add_message(key, message=message) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + result = results[0] + assert result["message"] == message.dict() + + async def test_aset_and_aget_messages(self, chat_store): + message_1 = ChatMessage(content="First message", role="user") + message_2 = ChatMessage(content="Second message", role="user") + messages = [message_1, message_2] + key = "test_set_and_get_key" + await chat_store.aset_messages(key, messages) + + results = await chat_store.aget_messages(key) + + assert len(results) == 2 + assert results[0].content == message_1.content + assert results[1].content == message_2.content + + async def test_adelete_messages(self, async_engine, chat_store): + messages = [ChatMessage(content="Message to delete", role="user")] + key = "test_delete_key" + await chat_store.aset_messages(key, messages) + + await chat_store.adelete_messages(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 0 + + async def test_adelete_message(self, async_engine, chat_store): + message_1 = ChatMessage(content="Keep me", role="user") + message_2 = ChatMessage(content="Delete me", role="user") + messages = [message_1, message_2] + key = "test_delete_message_key" + await chat_store.aset_messages(key, messages) + + await chat_store.adelete_message(key, 1) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + assert results[0]["message"] == message_1.dict() + + async def test_adelete_last_message(self, async_engine, chat_store): + message_1 = ChatMessage(content="Message 1", role="user") + message_2 = ChatMessage(content="Message 2", role="user") + message_3 = ChatMessage(content="Message 3", role="user") + messages = [message_1, message_2, message_3] + key = "test_delete_last_message_key" + await chat_store.aset_messages(key, messages) + + await chat_store.adelete_last_message(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 2 + assert results[0]["message"] == message_1.dict() + assert results[1]["message"] == message_2.dict() + + async def test_aget_keys(self, async_engine, chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + message_2 = [ChatMessage(content="Second message", role="user")] + key_1 = "key1" + key_2 = "key2" + await chat_store.aset_messages(key_1, message_1) + await chat_store.aset_messages(key_2, message_2) + + keys = await chat_store.aget_keys() + + assert key_1 in keys + assert key_2 in keys + + async def test_set_exisiting_key(self, async_engine, chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + key = "test_set_exisiting_key" + await chat_store.aset_messages(key, message_1) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + result = results[0] + assert result["message"] == message_1[0].dict() + + message_2 = ChatMessage(content="Second message", role="user") + message_3 = ChatMessage(content="Third message", role="user") + messages = [message_2, message_3] + + await chat_store.aset_messages(key, messages) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + # Assert the previous messages are deleted and only the newest ones exist. + assert len(results) == 2 + + assert results[0]["message"] == message_2.dict() + assert results[1]["message"] == message_3.dict() From 320b448fc60b2a41c4b3e1b90084d799319260eb Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Tue, 7 Jan 2025 07:47:22 +0000 Subject: [PATCH 13/39] feat: Adding AlloyDB Chat Store (#37) * feat: Adding AlloyDB Chat Store * Changed header to 2025 --------- Co-authored-by: Vishwaraj Anand --- src/llama_index_alloydb_pg/__init__.py | 2 + src/llama_index_alloydb_pg/chat_store.py | 289 +++++++++++++++++ tests/test_chat_store.py | 396 +++++++++++++++++++++++ 3 files changed, 687 insertions(+) create mode 100644 src/llama_index_alloydb_pg/chat_store.py create mode 100644 tests/test_chat_store.py diff --git a/src/llama_index_alloydb_pg/__init__.py b/src/llama_index_alloydb_pg/__init__.py index b695ab3..baea6ec 100644 --- a/src/llama_index_alloydb_pg/__init__.py +++ b/src/llama_index_alloydb_pg/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .chat_store import AlloyDBChatStore from .document_store import AlloyDBDocumentStore from .engine import AlloyDBEngine, Column from .index_store import AlloyDBIndexStore @@ -19,6 +20,7 @@ from .version import __version__ _all = [ + "AlloyDBChatStore", "AlloyDBDocumentStore", "AlloyDBEngine", "AlloyDBIndexStore", diff --git a/src/llama_index_alloydb_pg/chat_store.py b/src/llama_index_alloydb_pg/chat_store.py new file mode 100644 index 0000000..3e5b5ee --- /dev/null +++ b/src/llama_index_alloydb_pg/chat_store.py @@ -0,0 +1,289 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import List, Optional + +from llama_index.core.llms import ChatMessage +from llama_index.core.storage.chat_store.base import BaseChatStore + +from .async_chat_store import AsyncAlloyDBChatStore +from .engine import AlloyDBEngine + + +class AlloyDBChatStore(BaseChatStore): + """Chat Store Table stored in an AlloyDB for PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, key: object, engine: AlloyDBEngine, chat_store: AsyncAlloyDBChatStore + ): + """AlloyDBChatStore constructor. + + Args: + key (object): Key to prevent direct constructor usage. + engine (AlloyDBEngine): Database connection pool. + chat_store (AsyncAlloyDBChatStore): The async only IndexStore implementation + + Raises: + Exception: If constructor is directly called by the user. + """ + if key != AlloyDBChatStore.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + + # Delegate to Pydantic's __init__ + super().__init__() + self._engine = engine + self.__chat_store = chat_store + + @classmethod + async def create( + cls, + engine: AlloyDBEngine, + table_name: str, + schema_name: str = "public", + ) -> AlloyDBChatStore: + """Create a new AlloyDBChatStore instance. + + Args: + engine (AlloyDBEngine): AlloyDB engine to use. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. Defaults to "public" + + Raises: + ValueError: If the table provided does not contain required schema. + + Returns: + AlloyDBChatStore: A newly created instance of AlloyDBChatStore. + """ + coro = AsyncAlloyDBChatStore.create(engine, table_name, schema_name) + chat_store = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, chat_store) + + @classmethod + def create_sync( + cls, + engine: AlloyDBEngine, + table_name: str, + schema_name: str = "public", + ) -> AlloyDBChatStore: + """Create a new AlloyDBChatStore sync instance. + + Args: + engine (AlloyDBEngine): AlloyDB engine to use. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. Defaults to "public" + + Raises: + ValueError: If the table provided does not contain required schema. + + Returns: + AlloyDBChatStore: A newly created instance of AlloyDBChatStore. + """ + coro = AsyncAlloyDBChatStore.create(engine, table_name, schema_name) + chat_store = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, chat_store) + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "AlloyDBChatStore" + + async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None: + """Asynchronously sets the chat messages for a specific key. + + Args: + key (str): A unique identifier for the chat. + messages (List[ChatMessage]): A list of `ChatMessage` objects to upsert. + + Returns: + None + + """ + return await self._engine._run_as_async( + self.__chat_store.aset_messages(key=key, messages=messages) + ) + + async def aget_messages(self, key: str) -> List[ChatMessage]: + """Asynchronously retrieves the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for which the messages are to be retrieved. + + Returns: + List[ChatMessage]: A list of `ChatMessage` objects associated with the provided key. + If no messages are found, an empty list is returned. + """ + return await self._engine._run_as_async( + self.__chat_store.aget_messages(key=key) + ) + + async def async_add_message(self, key: str, message: ChatMessage) -> None: + """Asynchronously adds a new chat message to the specified key. + + Args: + key (str): A unique identifierfor the chat to which the message is added. + message (ChatMessage): The `ChatMessage` object that is to be added. + + Returns: + None + """ + return await self._engine._run_as_async( + self.__chat_store.async_add_message(key=key, message=message) + ) + + async def adelete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """Asynchronously deletes the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + + Returns: + Optional[List[ChatMessage]]: A list of `ChatMessage` objects that were deleted, or `None` if no messages + were associated with the key or could be deleted. + """ + return await self._engine._run_as_async( + self.__chat_store.adelete_messages(key=key) + ) + + async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """Asynchronously deletes a specific chat message by index from the messages associated with a given key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + idx (int): The index of the `ChatMessage` to be deleted from the list of messages. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return await self._engine._run_as_async( + self.__chat_store.adelete_message(key=key, idx=idx) + ) + + async def adelete_last_message(self, key: str) -> Optional[ChatMessage]: + """Asynchronously deletes the last chat message associated with a given key. + + Args: + key (str): A unique identifier for the chat whose message is to be deleted. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return await self._engine._run_as_async( + self.__chat_store.adelete_last_message(key=key) + ) + + async def aget_keys(self) -> List[str]: + """Asynchronously retrieves a list of all keys. + + Returns: + Optional[str]: A list of strings representing the keys. If no keys are found, an empty list is returned. + """ + return await self._engine._run_as_async(self.__chat_store.aget_keys()) + + def set_messages(self, key: str, messages: List[ChatMessage]) -> None: + """Synchronously sets the chat messages for a specific key. + + Args: + key (str): A unique identifier for the chat. + messages (List[ChatMessage]): A list of `ChatMessage` objects to upsert. + + Returns: + None + + """ + return self._engine._run_as_sync( + self.__chat_store.aset_messages(key=key, messages=messages) + ) + + def get_messages(self, key: str) -> List[ChatMessage]: + """Synchronously retrieves the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for which the messages are to be retrieved. + + Returns: + List[ChatMessage]: A list of `ChatMessage` objects associated with the provided key. + If no messages are found, an empty list is returned. + """ + return self._engine._run_as_sync(self.__chat_store.aget_messages(key=key)) + + def add_message(self, key: str, message: ChatMessage) -> None: + """Synchronously adds a new chat message to the specified key. + + Args: + key (str): A unique identifierfor the chat to which the message is added. + message (ChatMessage): The `ChatMessage` object that is to be added. + + Returns: + None + """ + return self._engine._run_as_sync( + self.__chat_store.async_add_message(key=key, message=message) + ) + + def delete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """Synchronously deletes the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + + Returns: + Optional[List[ChatMessage]]: A list of `ChatMessage` objects that were deleted, or `None` if no messages + were associated with the key or could be deleted. + """ + return self._engine._run_as_sync(self.__chat_store.adelete_messages(key=key)) + + def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """Synchronously deletes a specific chat message by index from the messages associated with a given key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + idx (int): The index of the `ChatMessage` to be deleted from the list of messages. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return self._engine._run_as_sync( + self.__chat_store.adelete_message(key=key, idx=idx) + ) + + def delete_last_message(self, key: str) -> Optional[ChatMessage]: + """Synchronously deletes the last chat message associated with a given key. + + Args: + key (str): A unique identifier for the chat whose message is to be deleted. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return self._engine._run_as_sync( + self.__chat_store.adelete_last_message(key=key) + ) + + def get_keys(self) -> List[str]: + """Synchronously retrieves a list of all keys. + + Returns: + Optional[str]: A list of strings representing the keys. If no keys are found, an empty list is returned. + """ + return self._engine._run_as_sync(self.__chat_store.aget_keys()) diff --git a/tests/test_chat_store.py b/tests/test_chat_store.py new file mode 100644 index 0000000..087c880 --- /dev/null +++ b/tests/test_chat_store.py @@ -0,0 +1,396 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +import warnings +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.llms import ChatMessage +from sqlalchemy import RowMapping, text + +from llama_index_alloydb_pg import AlloyDBChatStore, AlloyDBEngine + +default_table_name_async = "chat_store_" + str(uuid.uuid4()) +default_table_name_sync = "chat_store_" + str(uuid.uuid4()) + + +async def aexecute( + engine: AlloyDBEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +async def afetch(engine: AlloyDBEngine, query: str) -> Sequence[RowMapping]: + async def run(engine, query): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + return await engine._run_as_async(run(engine, query)) + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestAlloyDBChatStoreAsync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on AlloyDB instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for AlloyDB") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for AlloyDB") + + @pytest_asyncio.fixture(scope="class") + async def async_engine( + self, db_project, db_region, db_cluster, db_instance, db_name + ): + async_engine = await AlloyDBEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + cluster=db_cluster, + region=db_region, + database=db_name, + ) + + yield async_engine + + await async_engine.close() + + @pytest_asyncio.fixture(scope="class") + async def async_chat_store(self, async_engine): + await async_engine.ainit_chat_store_table(table_name=default_table_name_async) + + async_chat_store = await AlloyDBChatStore.create( + engine=async_engine, table_name=default_table_name_async + ) + + yield async_chat_store + + query = f'DROP TABLE IF EXISTS "{default_table_name_async}"' + await aexecute(async_engine, query) + + async def test_init_with_constructor(self, async_engine): + with pytest.raises(Exception): + AlloyDBChatStore(engine=async_engine, table_name=default_table_name_async) + + async def test_async_add_message(self, async_engine, async_chat_store): + key = "test_add_key" + + message = ChatMessage(content="add_message_test", role="user") + await async_chat_store.async_add_message(key, message=message) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + result = results[0] + assert result["message"] == message.dict() + + async def test_aset_and_aget_messages(self, async_chat_store): + message_1 = ChatMessage(content="First message", role="user") + message_2 = ChatMessage(content="Second message", role="user") + messages = [message_1, message_2] + key = "test_set_and_get_key" + await async_chat_store.aset_messages(key, messages) + + results = await async_chat_store.aget_messages(key) + + assert len(results) == 2 + assert results[0].content == message_1.content + assert results[1].content == message_2.content + + async def test_adelete_messages(self, async_engine, async_chat_store): + messages = [ChatMessage(content="Message to delete", role="user")] + key = "test_delete_key" + await async_chat_store.aset_messages(key, messages) + + await async_chat_store.adelete_messages(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 0 + + async def test_adelete_message(self, async_engine, async_chat_store): + message_1 = ChatMessage(content="Keep me", role="user") + message_2 = ChatMessage(content="Delete me", role="user") + messages = [message_1, message_2] + key = "test_delete_message_key" + await async_chat_store.aset_messages(key, messages) + + await async_chat_store.adelete_message(key, 1) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + assert results[0]["message"] == message_1.dict() + + async def test_adelete_last_message(self, async_engine, async_chat_store): + message_1 = ChatMessage(content="Message 1", role="user") + message_2 = ChatMessage(content="Message 2", role="user") + message_3 = ChatMessage(content="Message 3", role="user") + messages = [message_1, message_2, message_3] + key = "test_delete_last_message_key" + await async_chat_store.aset_messages(key, messages) + + await async_chat_store.adelete_last_message(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 2 + assert results[0]["message"] == message_1.dict() + assert results[1]["message"] == message_2.dict() + + async def test_aget_keys(self, async_engine, async_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + message_2 = [ChatMessage(content="Second message", role="user")] + key_1 = "key1" + key_2 = "key2" + await async_chat_store.aset_messages(key_1, message_1) + await async_chat_store.aset_messages(key_2, message_2) + + keys = await async_chat_store.aget_keys() + + assert key_1 in keys + assert key_2 in keys + + async def test_set_exisiting_key(self, async_engine, async_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + key = "test_set_exisiting_key" + await async_chat_store.aset_messages(key, message_1) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + result = results[0] + assert result["message"] == message_1[0].dict() + + message_2 = ChatMessage(content="Second message", role="user") + message_3 = ChatMessage(content="Third message", role="user") + messages = [message_2, message_3] + + await async_chat_store.aset_messages(key, messages) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + # Assert the previous messages are deleted and only the newest ones exist. + assert len(results) == 2 + + assert results[0]["message"] == message_2.dict() + assert results[1]["message"] == message_3.dict() + + +@pytest.mark.asyncio(loop_scope="class") +class TestAlloyDBChatStoreSync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on AlloyDB instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for AlloyDB") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for AlloyDB") + + @pytest_asyncio.fixture(scope="class") + async def sync_engine( + self, db_project, db_region, db_cluster, db_instance, db_name + ): + sync_engine = AlloyDBEngine.from_instance( + project_id=db_project, + instance=db_instance, + cluster=db_cluster, + region=db_region, + database=db_name, + ) + + yield sync_engine + + await sync_engine.close() + + @pytest_asyncio.fixture(scope="class") + async def sync_chat_store(self, sync_engine): + sync_engine.init_chat_store_table(table_name=default_table_name_sync) + + sync_chat_store = AlloyDBChatStore.create_sync( + engine=sync_engine, table_name=default_table_name_sync + ) + + yield sync_chat_store + + query = f'DROP TABLE IF EXISTS "{default_table_name_sync}"' + await aexecute(sync_engine, query) + + async def test_init_with_constructor(self, sync_engine): + with pytest.raises(Exception): + AlloyDBChatStore(engine=sync_engine, table_name=default_table_name_sync) + + async def test_async_add_message(self, sync_engine, sync_chat_store): + key = "test_add_key" + + message = ChatMessage(content="add_message_test", role="user") + sync_chat_store.add_message(key, message=message) + + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}';""" + results = await afetch(sync_engine, query) + result = results[0] + assert result["message"] == message.dict() + + async def test_aset_and_aget_messages(self, sync_chat_store): + message_1 = ChatMessage(content="First message", role="user") + message_2 = ChatMessage(content="Second message", role="user") + messages = [message_1, message_2] + key = "test_set_and_get_key" + sync_chat_store.set_messages(key, messages) + + results = sync_chat_store.get_messages(key) + + assert len(results) == 2 + assert results[0].content == message_1.content + assert results[1].content == message_2.content + + async def test_adelete_messages(self, sync_engine, sync_chat_store): + messages = [ChatMessage(content="Message to delete", role="user")] + key = "test_delete_key" + sync_chat_store.set_messages(key, messages) + + sync_chat_store.delete_messages(key) + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}' ORDER BY id;""" + results = await afetch(sync_engine, query) + + assert len(results) == 0 + + async def test_adelete_message(self, sync_engine, sync_chat_store): + message_1 = ChatMessage(content="Keep me", role="user") + message_2 = ChatMessage(content="Delete me", role="user") + messages = [message_1, message_2] + key = "test_delete_message_key" + sync_chat_store.set_messages(key, messages) + + sync_chat_store.delete_message(key, 1) + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}' ORDER BY id;""" + results = await afetch(sync_engine, query) + + assert len(results) == 1 + assert results[0]["message"] == message_1.dict() + + async def test_adelete_last_message(self, sync_engine, sync_chat_store): + message_1 = ChatMessage(content="Message 1", role="user") + message_2 = ChatMessage(content="Message 2", role="user") + message_3 = ChatMessage(content="Message 3", role="user") + messages = [message_1, message_2, message_3] + key = "test_delete_last_message_key" + sync_chat_store.set_messages(key, messages) + + sync_chat_store.delete_last_message(key) + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}' ORDER BY id;""" + results = await afetch(sync_engine, query) + + assert len(results) == 2 + assert results[0]["message"] == message_1.dict() + assert results[1]["message"] == message_2.dict() + + async def test_aget_keys(self, sync_engine, sync_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + message_2 = [ChatMessage(content="Second message", role="user")] + key_1 = "key1" + key_2 = "key2" + sync_chat_store.set_messages(key_1, message_1) + sync_chat_store.set_messages(key_2, message_2) + + keys = sync_chat_store.get_keys() + + assert key_1 in keys + assert key_2 in keys + + async def test_set_exisiting_key(self, sync_engine, sync_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + key = "test_set_exisiting_key" + sync_chat_store.set_messages(key, message_1) + + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}';""" + results = await afetch(sync_engine, query) + + assert len(results) == 1 + result = results[0] + assert result["message"] == message_1[0].dict() + + message_2 = ChatMessage(content="Second message", role="user") + message_3 = ChatMessage(content="Third message", role="user") + messages = [message_2, message_3] + + sync_chat_store.set_messages(key, messages) + + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}';""" + results = await afetch(sync_engine, query) + + # Assert the previous messages are deleted and only the newest ones exist. + assert len(results) == 2 + + assert results[0]["message"] == message_2.dict() + assert results[1]["message"] == message_3.dict() From 27e28eef7540dfecb16374c1453b4cac1824144b Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Tue, 7 Jan 2025 23:12:07 +0530 Subject: [PATCH 14/39] ci: Add blunderbuss config (#43) --- .github/blunderbuss.yml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 .github/blunderbuss.yml diff --git a/.github/blunderbuss.yml b/.github/blunderbuss.yml new file mode 100644 index 0000000..d922933 --- /dev/null +++ b/.github/blunderbuss.yml @@ -0,0 +1,4 @@ +assign_issues: + - googleapis/llama-index-alloydb +assign_prs: + - googleapis/llama-index-alloydb From 4b40a9996c86d6c93a6674a59617bc541a839600 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Tue, 7 Jan 2025 21:16:38 +0100 Subject: [PATCH 15/39] chore(deps): update python-nonmajor (#40) --- pyproject.toml | 4 ++-- requirements.txt | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 907866a..2fdea84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,8 +39,8 @@ Changelog = "https://github.com/googleapis/llama-index-alloydb-pg-python/blob/ma test = [ "black[jupyter]==24.10.0", "isort==5.13.2", - "mypy==1.13.0", - "pytest-asyncio==0.24.0", + "mypy==1.14.1", + "pytest-asyncio==0.25.1", "pytest==8.3.4", "pytest-cov==6.0.0", "pytest-depends==1.0.1", diff --git a/requirements.txt b/requirements.txt index 5cdb1c0..5490657 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ google-cloud-alloydb-connector[asyncpg]==1.6.0 -llama-index-core==0.12.5 +llama-index-core==0.12.10.post1 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.36 From 63714917888cbd147241832aca506fbb68641b72 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Wed, 8 Jan 2025 18:45:17 +0100 Subject: [PATCH 16/39] fix(deps): update dependency pytest-asyncio to v0.25.2 (#44) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2fdea84..e22d455 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ test = [ "black[jupyter]==24.10.0", "isort==5.13.2", "mypy==1.14.1", - "pytest-asyncio==0.25.1", + "pytest-asyncio==0.25.2", "pytest==8.3.4", "pytest-cov==6.0.0", "pytest-depends==1.0.1", From fb39e08728e41add2e7b5c67f6c7f58e541261b4 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Tue, 14 Jan 2025 06:18:28 +0530 Subject: [PATCH 17/39] chore: add drop index statements to avoid conflict (#48) --- tests/test_async_vector_store_index.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_async_vector_store_index.py b/tests/test_async_vector_store_index.py index f7373c7..c66dd7f 100644 --- a/tests/test_async_vector_store_index.py +++ b/tests/test_async_vector_store_index.py @@ -123,6 +123,7 @@ async def test_aapply_vector_index(self, vs): index = HNSWIndex() await vs.aapply_vector_index(index) assert await vs.is_valid_index(DEFAULT_INDEX_NAME) + await vs.adrop_vector_index(DEFAULT_INDEX_NAME) async def test_areindex(self, vs): if not await vs.is_valid_index(DEFAULT_INDEX_NAME): @@ -131,6 +132,7 @@ async def test_areindex(self, vs): await vs.areindex() await vs.areindex(DEFAULT_INDEX_NAME) assert await vs.is_valid_index(DEFAULT_INDEX_NAME) + await vs.adrop_vector_index(DEFAULT_INDEX_NAME) async def test_dropindex(self, vs): await vs.adrop_vector_index() @@ -147,6 +149,7 @@ async def test_aapply_vector_index_ivfflat(self, vs): ) await vs.aapply_vector_index(index) assert await vs.is_valid_index("secondindex") + await vs.adrop_vector_index() await vs.adrop_vector_index("secondindex") async def test_is_valid_index(self, vs): From ee0a5b5d0a1706a66909d0147b351f5999b9afe6 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Tue, 14 Jan 2025 18:56:38 +0100 Subject: [PATCH 18/39] chore(deps): update dependency sqlalchemy to v2.0.37 (#45) Co-authored-by: Averi Kitsch --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5490657..5fc85af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ google-cloud-alloydb-connector[asyncpg]==1.6.0 llama-index-core==0.12.10.post1 pgvector==0.3.6 -SQLAlchemy[asyncio]==2.0.36 +SQLAlchemy[asyncio]==2.0.37 From 134873516250f3137c60482869846e9b5ea96a2d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 14 Jan 2025 09:57:57 -0800 Subject: [PATCH 19/39] chore(deps): bump virtualenv from 20.25.1 to 20.26.6 in /.kokoro (#49) Bumps [virtualenv](https://github.com/pypa/virtualenv) from 20.25.1 to 20.26.6. - [Release notes](https://github.com/pypa/virtualenv/releases) - [Changelog](https://github.com/pypa/virtualenv/blob/main/docs/changelog.rst) - [Commits](https://github.com/pypa/virtualenv/compare/20.25.1...20.26.6) --- updated-dependencies: - dependency-name: virtualenv dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Averi Kitsch --- .kokoro/requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.kokoro/requirements.txt b/.kokoro/requirements.txt index 23e61f6..b5a1d9a 100644 --- a/.kokoro/requirements.txt +++ b/.kokoro/requirements.txt @@ -509,9 +509,9 @@ urllib3==2.2.2 \ # via # requests # twine -virtualenv==20.25.1 \ - --hash=sha256:961c026ac520bac5f69acb8ea063e8a4f071bcc9457b9c1f28f6b085c511583a \ - --hash=sha256:e08e13ecdca7a0bd53798f356d5831434afa5b07b93f0abdf0797b7a06ffe197 +virtualenv==20.26.6 \ + --hash=sha256:280aede09a2a5c317e409a00102e7077c6432c5a38f0ef938e643805a7ad2c48 \ + --hash=sha256:7345cc5b25405607a624d8418154577459c3e0277f5466dd79c49d5e492995f2 # via nox wheel==0.43.0 \ --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ From 5f1405ed7ba7941c9c9a4370a428c720d857e6af Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Fri, 17 Jan 2025 02:32:17 +0530 Subject: [PATCH 20/39] fix: programming error while setting multiple query option (#47) * fix: programming error while setting multiple query option * chore: address pr comments * chore: pr comments --- .../async_vector_store.py | 8 +- src/llama_index_alloydb_pg/indexes.py | 41 ++++++ tests/test_async_vector_store.py | 37 ++++++ tests/test_indexes.py | 122 ++++++++++++++++++ 4 files changed, 204 insertions(+), 4 deletions(-) create mode 100644 tests/test_indexes.py diff --git a/src/llama_index_alloydb_pg/async_vector_store.py b/src/llama_index_alloydb_pg/async_vector_store.py index d318b11..5dc3640 100644 --- a/src/llama_index_alloydb_pg/async_vector_store.py +++ b/src/llama_index_alloydb_pg/async_vector_store.py @@ -543,10 +543,10 @@ async def __query_columns( query_stmt = f'SELECT * {scoring_stmt} FROM "{self._schema_name}"."{self._table_name}" {filters_stmt} {order_stmt} {limit_stmt}' async with self._engine.connect() as conn: if self._index_query_options: - query_options_stmt = ( - f"SET LOCAL {self._index_query_options.to_string()};" - ) - await conn.execute(text(query_options_stmt)) + # Set each query option individually + for query_option in self._index_query_options.to_parameter(): + query_options_stmt = f"SET LOCAL {query_option};" + await conn.execute(text(query_options_stmt)) result = await conn.execute(text(query_stmt)) result_map = result.mappings() results = result_map.fetchall() diff --git a/src/llama_index_alloydb_pg/indexes.py b/src/llama_index_alloydb_pg/indexes.py index 5793c53..20bdfec 100644 --- a/src/llama_index_alloydb_pg/indexes.py +++ b/src/llama_index_alloydb_pg/indexes.py @@ -13,6 +13,7 @@ # limitations under the License. import enum +import warnings from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Optional @@ -62,6 +63,11 @@ class ExactNearestNeighbor(BaseIndex): @dataclass class QueryOptions(ABC): + @abstractmethod + def to_parameter(self) -> list[str]: + """Convert index attributes to list of configurations.""" + raise NotImplementedError("to_parameter method must be implemented by subclass") + @abstractmethod def to_string(self) -> str: """Convert index attributes to string.""" @@ -83,8 +89,16 @@ def index_options(self) -> str: class HNSWQueryOptions(QueryOptions): ef_search: int = 40 + def to_parameter(self) -> list[str]: + """Convert index attributes to list of configurations.""" + return [f"hnsw.ef_search = {self.ef_search}"] + def to_string(self) -> str: """Convert index attributes to string.""" + warnings.warn( + "to_string is deprecated, use to_parameter instead.", + DeprecationWarning, + ) return f"hnsw.ef_search = {self.ef_search}" @@ -102,8 +116,16 @@ def index_options(self) -> str: class IVFFlatQueryOptions(QueryOptions): probes: int = 1 + def to_parameter(self) -> list[str]: + """Convert index attributes to list of configurations.""" + return [f"ivfflat.probes = {self.probes}"] + def to_string(self) -> str: """Convert index attributes to string.""" + warnings.warn( + "to_string is deprecated, use to_parameter instead.", + DeprecationWarning, + ) return f"ivfflat.probes = {self.probes}" @@ -124,8 +146,16 @@ def index_options(self) -> str: class IVFQueryOptions(QueryOptions): probes: int = 1 + def to_parameter(self) -> list[str]: + """Convert index attributes to list of configurations.""" + return [f"ivf.probes = {self.probes}"] + def to_string(self) -> str: """Convert index attributes to string.""" + warnings.warn( + "to_string is deprecated, use to_parameter instead.", + DeprecationWarning, + ) return f"ivf.probes = {self.probes}" @@ -147,6 +177,17 @@ class ScaNNQueryOptions(QueryOptions): num_leaves_to_search: int = 1 pre_reordering_num_neighbors: int = -1 + def to_parameter(self) -> list[str]: + """Convert index attributes to list of configurations.""" + return [ + f"scann.num_leaves_to_search = {self.num_leaves_to_search}", + f"scann.pre_reordering_num_neighbors = {self.pre_reordering_num_neighbors}", + ] + def to_string(self) -> str: """Convert index attributes to string.""" + warnings.warn( + "to_string is deprecated, use to_parameter instead.", + DeprecationWarning, + ) return f"scann.num_leaves_to_search = {self.num_leaves_to_search}, scann.pre_reordering_num_neighbors = {self.pre_reordering_num_neighbors}" diff --git a/tests/test_async_vector_store.py b/tests/test_async_vector_store.py index e99ac51..4088eed 100644 --- a/tests/test_async_vector_store.py +++ b/tests/test_async_vector_store.py @@ -32,6 +32,7 @@ from llama_index_alloydb_pg import AlloyDBEngine, Column from llama_index_alloydb_pg.async_vector_store import AsyncAlloyDBVectorStore +from llama_index_alloydb_pg.indexes import HNSWQueryOptions, ScaNNQueryOptions DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) DEFAULT_TABLE_CUSTOM_VS = "test_table" + str(uuid.uuid4()) @@ -155,6 +156,23 @@ async def custom_vs(self, engine): "nullable_int_field", "nullable_str_field", ], + index_query_options=HNSWQueryOptions(ef_search=1), + ) + yield vs + + @pytest_asyncio.fixture(scope="class") + async def custom_vs_scann(self, engine, custom_vs): + vs = await AsyncAlloyDBVectorStore.create( + engine, + table_name=DEFAULT_TABLE_CUSTOM_VS, + metadata_columns=[ + "len", + "nullable_int_field", + "nullable_str_field", + ], + index_query_options=ScaNNQueryOptions( + num_leaves_to_search=1, pre_reordering_num_neighbors=2 + ), ) yield vs @@ -320,6 +338,25 @@ async def test_aquery(self, engine, vs): assert len(results.nodes) == 3 assert results.nodes[0].get_content(metadata_mode=MetadataMode.NONE) == "foo" + async def test_aquery_scann(self, engine, custom_vs_scann): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE_CUSTOM_VS}"') + # setting extra metadata to be indexed in separate column + for node in nodes: + node.metadata["len"] = len(node.text) + + await custom_vs_scann.async_add(nodes) + query = VectorStoreQuery( + query_embedding=[1.0] * VECTOR_SIZE, similarity_top_k=3 + ) + results = await custom_vs_scann.aquery(query) + + assert results.nodes is not None + assert results.ids is not None + assert results.similarities is not None + assert len(results.nodes) == 3 + async def test_aquery_filters(self, engine, custom_vs): # Note: To be migrated to a pytest dependency on test_async_add # Blocked due to unexpected fixtures reloads while running integration test suite diff --git a/tests/test_indexes.py b/tests/test_indexes.py new file mode 100644 index 0000000..c2a781c --- /dev/null +++ b/tests/test_indexes.py @@ -0,0 +1,122 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +from llama_index_alloydb_pg.indexes import ( + DistanceStrategy, + HNSWIndex, + HNSWQueryOptions, + IVFFlatIndex, + IVFFlatQueryOptions, + IVFIndex, + IVFQueryOptions, + ScaNNIndex, + ScaNNQueryOptions, +) + + +class TestAlloyDBIndex: + def test_distance_strategy(self): + assert DistanceStrategy.EUCLIDEAN.operator == "<->" + assert DistanceStrategy.EUCLIDEAN.search_function == "l2_distance" + assert DistanceStrategy.EUCLIDEAN.index_function == "vector_l2_ops" + assert DistanceStrategy.EUCLIDEAN.scann_index_function == "l2" + + assert DistanceStrategy.COSINE_DISTANCE.operator == "<=>" + assert DistanceStrategy.COSINE_DISTANCE.search_function == "cosine_distance" + assert DistanceStrategy.COSINE_DISTANCE.index_function == "vector_cosine_ops" + assert DistanceStrategy.COSINE_DISTANCE.scann_index_function == "cosine" + + assert DistanceStrategy.INNER_PRODUCT.operator == "<#>" + assert DistanceStrategy.INNER_PRODUCT.search_function == "inner_product" + assert DistanceStrategy.INNER_PRODUCT.index_function == "vector_ip_ops" + assert DistanceStrategy.INNER_PRODUCT.scann_index_function == "dot_product" + + def test_hnsw_index(self): + index = HNSWIndex(name="test_index", m=32, ef_construction=128) + assert index.index_type == "hnsw" + assert index.m == 32 + assert index.ef_construction == 128 + assert index.index_options() == "(m = 32, ef_construction = 128)" + + def test_hnsw_query_options(self): + options = HNSWQueryOptions(ef_search=80) + assert options.to_parameter() == ["hnsw.ef_search = 80"] + + with warnings.catch_warnings(record=True) as w: + options.to_string() + + assert len(w) == 1 + assert "to_string is deprecated, use to_parameter instead." in str( + w[-1].message + ) + + def test_ivfflat_index(self): + index = IVFFlatIndex(name="test_index", lists=200) + assert index.index_type == "ivfflat" + assert index.lists == 200 + assert index.index_options() == "(lists = 200)" + + def test_ivfflat_query_options(self): + options = IVFFlatQueryOptions(probes=2) + assert options.to_parameter() == ["ivfflat.probes = 2"] + + with warnings.catch_warnings(record=True) as w: + options.to_string() + assert len(w) == 1 + assert "to_string is deprecated, use to_parameter instead." in str( + w[-1].message + ) + + def test_ivf_index(self): + index = IVFIndex(name="test_index", lists=200) + assert index.index_type == "ivf" + assert index.lists == 200 + assert index.quantizer == "sq8" # Check default value + assert index.index_options() == "(lists = 200, quantizer = sq8)" + + def test_ivf_query_options(self): + options = IVFQueryOptions(probes=2) + assert options.to_parameter() == ["ivf.probes = 2"] + + with warnings.catch_warnings(record=True) as w: + options.to_string() + assert len(w) == 1 + assert "to_string is deprecated, use to_parameter instead." in str( + w[-1].message + ) + + def test_scann_index(self): + index = ScaNNIndex(name="test_index", num_leaves=10) + assert index.index_type == "ScaNN" + assert index.num_leaves == 10 + assert index.quantizer == "sq8" # Check default value + assert index.index_options() == "(num_leaves = 10, quantizer = sq8)" + + def test_scann_query_options(self): + options = ScaNNQueryOptions( + num_leaves_to_search=2, pre_reordering_num_neighbors=10 + ) + assert options.to_parameter() == [ + "scann.num_leaves_to_search = 2", + "scann.pre_reordering_num_neighbors = 10", + ] + + with warnings.catch_warnings(record=True) as w: + options.to_string() + assert len(w) == 1 + assert "to_string is deprecated, use to_parameter instead." in str( + w[-1].message + ) From bcac2d7bc0477a9e4e24d919ec8c37517f146521 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Fri, 17 Jan 2025 18:17:13 +0100 Subject: [PATCH 21/39] chore(deps): update python-nonmajor (#51) Co-authored-by: Averi Kitsch --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 5fc85af..4c67617 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -google-cloud-alloydb-connector[asyncpg]==1.6.0 -llama-index-core==0.12.10.post1 +google-cloud-alloydb-connector[asyncpg]==1.7.0 +llama-index-core==0.12.11 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.37 From dff623bf8d340811ed88271e59b11d0f996cc811 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Fri, 17 Jan 2025 17:53:24 +0000 Subject: [PATCH 22/39] fix: query and return only selected metadata columns (#52) * fix: query and return only selected metadata columns * Review changes * Linter fix --------- Co-authored-by: Averi Kitsch --- src/llama_index_alloydb_pg/async_vector_store.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/llama_index_alloydb_pg/async_vector_store.py b/src/llama_index_alloydb_pg/async_vector_store.py index 5dc3640..7d2c16d 100644 --- a/src/llama_index_alloydb_pg/async_vector_store.py +++ b/src/llama_index_alloydb_pg/async_vector_store.py @@ -540,7 +540,19 @@ async def __query_columns( f" LIMIT {query.similarity_top_k} " if query.similarity_top_k >= 1 else "" ) - query_stmt = f'SELECT * {scoring_stmt} FROM "{self._schema_name}"."{self._table_name}" {filters_stmt} {order_stmt} {limit_stmt}' + columns = self._metadata_columns + [ + self._id_column, + self._text_column, + self._embedding_column, + self._ref_doc_id_column, + self._node_column, + ] + if self._metadata_json_column: + columns.append(self._metadata_json_column) + + column_names = ", ".join(f'"{col}"' for col in columns) + + query_stmt = f'SELECT {column_names} {scoring_stmt} FROM "{self._schema_name}"."{self._table_name}" {filters_stmt} {order_stmt} {limit_stmt}' async with self._engine.connect() as conn: if self._index_query_options: # Set each query option individually From a03afdbac3d8690478bb42d6236a0f8f42e1d98e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 17 Jan 2025 09:59:36 -0800 Subject: [PATCH 23/39] chore(deps): bump virtualenv in /.kokoro/docker/docs (#53) Bumps [virtualenv](https://github.com/pypa/virtualenv) from 20.26.0 to 20.26.6. - [Release notes](https://github.com/pypa/virtualenv/releases) - [Changelog](https://github.com/pypa/virtualenv/blob/main/docs/changelog.rst) - [Commits](https://github.com/pypa/virtualenv/compare/20.26.0...20.26.6) --- updated-dependencies: - dependency-name: virtualenv dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Averi Kitsch --- .kokoro/docker/docs/requirements.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.kokoro/docker/docs/requirements.txt b/.kokoro/docker/docs/requirements.txt index 56381b8..43b2594 100644 --- a/.kokoro/docker/docs/requirements.txt +++ b/.kokoro/docker/docs/requirements.txt @@ -32,7 +32,7 @@ platformdirs==4.2.1 \ --hash=sha256:031cd18d4ec63ec53e82dceaac0417d218a6863f7745dfcc9efe7793b7039bdf \ --hash=sha256:17d5a1161b3fd67b390023cb2d3b026bbd40abde6fdb052dfbd3a29c3ba22ee1 # via virtualenv -virtualenv==20.26.0 \ - --hash=sha256:0846377ea76e818daaa3e00a4365c018bc3ac9760cbb3544de542885aad61fb3 \ - --hash=sha256:ec25a9671a5102c8d2657f62792a27b48f016664c6873f6beed3800008577210 - # via nox \ No newline at end of file +virtualenv==20.26.6 \ + --hash=sha256:280aede09a2a5c317e409a00102e7077c6432c5a38f0ef938e643805a7ad2c48 \ + --hash=sha256:7345cc5b25405607a624d8418154577459c3e0277f5466dd79c49d5e492995f2 + # via nox From 7ae4fed37d68306b45e3e34132bab35e17045b12 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Tue, 21 Jan 2025 18:20:44 +0100 Subject: [PATCH 24/39] chore(deps): update dependency llama-index-core to v0.12.12 (#56) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 4c67617..afe5e9b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ google-cloud-alloydb-connector[asyncpg]==1.7.0 -llama-index-core==0.12.11 +llama-index-core==0.12.12 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.37 From e904dabfd426b4dfc6c4addc5011bf51905f5a0d Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Tue, 21 Jan 2025 23:27:56 +0530 Subject: [PATCH 25/39] chore: test cleanup (#50) * chore: remove pytest warning * chore: better engine cleanup * update connector close * fix connector * remove .dict() warning --------- Co-authored-by: Averi Kitsch --- pyproject.toml | 3 ++ .../async_chat_store.py | 4 +-- tests/test_async_chat_store.py | 15 ++++----- tests/test_async_document_store.py | 1 + tests/test_async_index_store.py | 1 + tests/test_async_vector_store.py | 1 + tests/test_async_vector_store_index.py | 1 + tests/test_chat_store.py | 30 +++++++++--------- tests/test_document_store.py | 2 ++ tests/test_engine.py | 31 ++++++++++++++++--- tests/test_index_store.py | 2 ++ tests/test_vector_store.py | 2 ++ tests/test_vector_store_index.py | 2 ++ 13 files changed, 67 insertions(+), 28 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e22d455..94a41ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,9 @@ test = [ requires = ["setuptools"] build-backend = "setuptools.build_meta" +[tool.pytest.ini_options] +asyncio_default_fixture_loop_scope = "class" + [tool.black] target-version = ['py39'] diff --git a/src/llama_index_alloydb_pg/async_chat_store.py b/src/llama_index_alloydb_pg/async_chat_store.py index d967349..da24402 100644 --- a/src/llama_index_alloydb_pg/async_chat_store.py +++ b/src/llama_index_alloydb_pg/async_chat_store.py @@ -137,7 +137,7 @@ async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None: params = [ { "key": key, - "message": json.dumps(message.dict()), + "message": json.dumps(message.model_dump()), } for message in messages ] @@ -175,7 +175,7 @@ async def async_add_message(self, key: str, message: ChatMessage) -> None: insert_query = f""" INSERT INTO "{self._schema_name}"."{self._table_name}" (key, message) VALUES (:key, :message);""" - params = {"key": key, "message": json.dumps(message.dict())} + params = {"key": key, "message": json.dumps(message.model_dump())} await self.__aexecute_query(insert_query, params) diff --git a/tests/test_async_chat_store.py b/tests/test_async_chat_store.py index 70397e4..f64f824 100644 --- a/tests/test_async_chat_store.py +++ b/tests/test_async_chat_store.py @@ -98,6 +98,7 @@ async def async_engine( yield async_engine await async_engine.close() + await async_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def chat_store(self, async_engine): @@ -127,7 +128,7 @@ async def test_async_add_message(self, async_engine, chat_store): query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" results = await afetch(async_engine, query) result = results[0] - assert result["message"] == message.dict() + assert result["message"] == message.model_dump() async def test_aset_and_aget_messages(self, chat_store): message_1 = ChatMessage(content="First message", role="user") @@ -165,7 +166,7 @@ async def test_adelete_message(self, async_engine, chat_store): results = await afetch(async_engine, query) assert len(results) == 1 - assert results[0]["message"] == message_1.dict() + assert results[0]["message"] == message_1.model_dump() async def test_adelete_last_message(self, async_engine, chat_store): message_1 = ChatMessage(content="Message 1", role="user") @@ -180,8 +181,8 @@ async def test_adelete_last_message(self, async_engine, chat_store): results = await afetch(async_engine, query) assert len(results) == 2 - assert results[0]["message"] == message_1.dict() - assert results[1]["message"] == message_2.dict() + assert results[0]["message"] == message_1.model_dump() + assert results[1]["message"] == message_2.model_dump() async def test_aget_keys(self, async_engine, chat_store): message_1 = [ChatMessage(content="First message", role="user")] @@ -206,7 +207,7 @@ async def test_set_exisiting_key(self, async_engine, chat_store): assert len(results) == 1 result = results[0] - assert result["message"] == message_1[0].dict() + assert result["message"] == message_1[0].model_dump() message_2 = ChatMessage(content="Second message", role="user") message_3 = ChatMessage(content="Third message", role="user") @@ -220,5 +221,5 @@ async def test_set_exisiting_key(self, async_engine, chat_store): # Assert the previous messages are deleted and only the newest ones exist. assert len(results) == 2 - assert results[0]["message"] == message_2.dict() - assert results[1]["message"] == message_3.dict() + assert results[0]["message"] == message_2.model_dump() + assert results[1]["message"] == message_3.model_dump() diff --git a/tests/test_async_document_store.py b/tests/test_async_document_store.py index c978f94..99ab5d1 100644 --- a/tests/test_async_document_store.py +++ b/tests/test_async_document_store.py @@ -97,6 +97,7 @@ async def async_engine( yield async_engine await async_engine.close() + await async_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def doc_store(self, async_engine): diff --git a/tests/test_async_index_store.py b/tests/test_async_index_store.py index 09f532d..10297cb 100644 --- a/tests/test_async_index_store.py +++ b/tests/test_async_index_store.py @@ -100,6 +100,7 @@ async def async_engine( yield async_engine await async_engine.close() + await async_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def index_store(self, async_engine): diff --git a/tests/test_async_vector_store.py b/tests/test_async_vector_store.py index 4088eed..36f05d1 100644 --- a/tests/test_async_vector_store.py +++ b/tests/test_async_vector_store.py @@ -119,6 +119,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_CUSTOM_VS}"') await engine.close() + await engine._connector.close() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): diff --git a/tests/test_async_vector_store_index.py b/tests/test_async_vector_store_index.py index c66dd7f..954edbb 100644 --- a/tests/test_async_vector_store_index.py +++ b/tests/test_async_vector_store_index.py @@ -104,6 +104,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") await engine.close() + await engine._connector.close() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): diff --git a/tests/test_chat_store.py b/tests/test_chat_store.py index 087c880..d10e0cb 100644 --- a/tests/test_chat_store.py +++ b/tests/test_chat_store.py @@ -103,6 +103,7 @@ async def async_engine( yield async_engine await async_engine.close() + await async_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def async_chat_store(self, async_engine): @@ -130,7 +131,7 @@ async def test_async_add_message(self, async_engine, async_chat_store): query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" results = await afetch(async_engine, query) result = results[0] - assert result["message"] == message.dict() + assert result["message"] == message.model_dump() async def test_aset_and_aget_messages(self, async_chat_store): message_1 = ChatMessage(content="First message", role="user") @@ -168,7 +169,7 @@ async def test_adelete_message(self, async_engine, async_chat_store): results = await afetch(async_engine, query) assert len(results) == 1 - assert results[0]["message"] == message_1.dict() + assert results[0]["message"] == message_1.model_dump() async def test_adelete_last_message(self, async_engine, async_chat_store): message_1 = ChatMessage(content="Message 1", role="user") @@ -183,8 +184,8 @@ async def test_adelete_last_message(self, async_engine, async_chat_store): results = await afetch(async_engine, query) assert len(results) == 2 - assert results[0]["message"] == message_1.dict() - assert results[1]["message"] == message_2.dict() + assert results[0]["message"] == message_1.model_dump() + assert results[1]["message"] == message_2.model_dump() async def test_aget_keys(self, async_engine, async_chat_store): message_1 = [ChatMessage(content="First message", role="user")] @@ -209,7 +210,7 @@ async def test_set_exisiting_key(self, async_engine, async_chat_store): assert len(results) == 1 result = results[0] - assert result["message"] == message_1[0].dict() + assert result["message"] == message_1[0].model_dump() message_2 = ChatMessage(content="Second message", role="user") message_3 = ChatMessage(content="Third message", role="user") @@ -223,8 +224,8 @@ async def test_set_exisiting_key(self, async_engine, async_chat_store): # Assert the previous messages are deleted and only the newest ones exist. assert len(results) == 2 - assert results[0]["message"] == message_2.dict() - assert results[1]["message"] == message_3.dict() + assert results[0]["message"] == message_2.model_dump() + assert results[1]["message"] == message_3.model_dump() @pytest.mark.asyncio(loop_scope="class") @@ -272,6 +273,7 @@ async def sync_engine( yield sync_engine await sync_engine.close() + await sync_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def sync_chat_store(self, sync_engine): @@ -299,7 +301,7 @@ async def test_async_add_message(self, sync_engine, sync_chat_store): query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}';""" results = await afetch(sync_engine, query) result = results[0] - assert result["message"] == message.dict() + assert result["message"] == message.model_dump() async def test_aset_and_aget_messages(self, sync_chat_store): message_1 = ChatMessage(content="First message", role="user") @@ -337,7 +339,7 @@ async def test_adelete_message(self, sync_engine, sync_chat_store): results = await afetch(sync_engine, query) assert len(results) == 1 - assert results[0]["message"] == message_1.dict() + assert results[0]["message"] == message_1.model_dump() async def test_adelete_last_message(self, sync_engine, sync_chat_store): message_1 = ChatMessage(content="Message 1", role="user") @@ -352,8 +354,8 @@ async def test_adelete_last_message(self, sync_engine, sync_chat_store): results = await afetch(sync_engine, query) assert len(results) == 2 - assert results[0]["message"] == message_1.dict() - assert results[1]["message"] == message_2.dict() + assert results[0]["message"] == message_1.model_dump() + assert results[1]["message"] == message_2.model_dump() async def test_aget_keys(self, sync_engine, sync_chat_store): message_1 = [ChatMessage(content="First message", role="user")] @@ -378,7 +380,7 @@ async def test_set_exisiting_key(self, sync_engine, sync_chat_store): assert len(results) == 1 result = results[0] - assert result["message"] == message_1[0].dict() + assert result["message"] == message_1[0].model_dump() message_2 = ChatMessage(content="Second message", role="user") message_3 = ChatMessage(content="Third message", role="user") @@ -392,5 +394,5 @@ async def test_set_exisiting_key(self, sync_engine, sync_chat_store): # Assert the previous messages are deleted and only the newest ones exist. assert len(results) == 2 - assert results[0]["message"] == message_2.dict() - assert results[1]["message"] == message_3.dict() + assert results[0]["message"] == message_2.model_dump() + assert results[1]["message"] == message_3.model_dump() diff --git a/tests/test_document_store.py b/tests/test_document_store.py index 44533e4..6a7094e 100644 --- a/tests/test_document_store.py +++ b/tests/test_document_store.py @@ -108,6 +108,7 @@ async def async_engine( yield async_engine await async_engine.close() + await async_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def doc_store(self, async_engine): @@ -400,6 +401,7 @@ async def sync_engine( yield sync_engine await sync_engine.close() + await sync_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def sync_doc_store(self, sync_engine): diff --git a/tests/test_engine.py b/tests/test_engine.py index fd168e1..c1cf94e 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -122,6 +122,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE}"') await engine.close() + await engine._connector.close() async def test_init_with_constructor( self, @@ -307,7 +308,9 @@ async def test_iam_account_override( async def test_init_document_store(self, engine): await engine.ainit_doc_store_table( - table_name=DEFAULT_DS_TABLE, schema_name="public", overwrite_existing=True + table_name=DEFAULT_DS_TABLE, + schema_name="public", + overwrite_existing=True, ) stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{DEFAULT_DS_TABLE}';" results = await afetch(engine, stmt) @@ -335,13 +338,21 @@ async def test_init_vector_store(self, engine): "data_type": "character varying", "is_nullable": "NO", }, - {"column_name": "li_metadata", "data_type": "jsonb", "is_nullable": "NO"}, + { + "column_name": "li_metadata", + "data_type": "jsonb", + "is_nullable": "NO", + }, { "column_name": "embedding", "data_type": "USER-DEFINED", "is_nullable": "YES", }, - {"column_name": "node_data", "data_type": "json", "is_nullable": "NO"}, + { + "column_name": "node_data", + "data_type": "json", + "is_nullable": "NO", + }, { "column_name": "ref_doc_id", "data_type": "character varying", @@ -440,6 +451,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE_SYNC}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE_SYNC}"') await engine.close() + await engine._connector.close() async def test_password( self, @@ -494,6 +506,7 @@ async def test_iam_account_override( assert engine await aexecute(engine, "SELECT 1") await engine.close() + await engine._connector.close() async def test_init_document_store(self, engine): engine.init_doc_store_table( @@ -527,13 +540,21 @@ async def test_init_vector_store(self, engine): "data_type": "character varying", "is_nullable": "NO", }, - {"column_name": "li_metadata", "data_type": "jsonb", "is_nullable": "NO"}, + { + "column_name": "li_metadata", + "data_type": "jsonb", + "is_nullable": "NO", + }, { "column_name": "embedding", "data_type": "USER-DEFINED", "is_nullable": "YES", }, - {"column_name": "node_data", "data_type": "json", "is_nullable": "NO"}, + { + "column_name": "node_data", + "data_type": "json", + "is_nullable": "NO", + }, { "column_name": "ref_doc_id", "data_type": "character varying", diff --git a/tests/test_index_store.py b/tests/test_index_store.py index a5bc09e..c9a4bce 100644 --- a/tests/test_index_store.py +++ b/tests/test_index_store.py @@ -108,6 +108,7 @@ async def async_engine( yield async_engine await async_engine.close() + await async_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def index_store(self, async_engine): @@ -236,6 +237,7 @@ async def async_engine( yield async_engine await async_engine.close() + await async_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def index_store(self, async_engine): diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py index 96a7bca..2e7361a 100644 --- a/tests/test_vector_store.py +++ b/tests/test_vector_store.py @@ -124,6 +124,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): await aexecute(sync_engine, f'DROP TABLE "{DEFAULT_TABLE}"') await sync_engine.close() + await sync_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): @@ -504,6 +505,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): await aexecute(sync_engine, f'DROP TABLE "{DEFAULT_TABLE}"') await sync_engine.close() + await sync_engine._connector.close() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): diff --git a/tests/test_vector_store_index.py b/tests/test_vector_store_index.py index 3920929..94bd002 100644 --- a/tests/test_vector_store_index.py +++ b/tests/test_vector_store_index.py @@ -116,6 +116,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") await engine.close() + await engine._connector.close() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): @@ -292,6 +293,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_ASYNC}") await engine.close() + await engine._connector.close() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): From 94522b1212ed840b99553cbf0d5868b4c15673ce Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Fri, 24 Jan 2025 18:48:54 +0100 Subject: [PATCH 26/39] chore(deps): update dependency llama-index-core to v0.12.13 (#58) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index afe5e9b..7fe38ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ google-cloud-alloydb-connector[asyncpg]==1.7.0 -llama-index-core==0.12.12 +llama-index-core==0.12.13 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.37 From 56e64790c8eb85979d60b87366adb46596232e24 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Fri, 24 Jan 2025 23:18:58 +0000 Subject: [PATCH 27/39] feat: Adding Async AlloyDB Reader (#55) * feat: Adding Async AlloyDB Reader * Linter fix * Fix pydantic super class init * Linter fix * Minor change * Minor change * Minor docstring changes * Minor docstring changes * Minor docstring changes * Linter fix * Variable name change in tests * Make docstrings more informative * Fix class variables * lint * fix pool error --------- Co-authored-by: Averi Kitsch --- src/llama_index_alloydb_pg/async_reader.py | 270 ++++++++++++ tests/test_async_reader.py | 472 +++++++++++++++++++++ 2 files changed, 742 insertions(+) create mode 100644 src/llama_index_alloydb_pg/async_reader.py create mode 100644 tests/test_async_reader.py diff --git a/src/llama_index_alloydb_pg/async_reader.py b/src/llama_index_alloydb_pg/async_reader.py new file mode 100644 index 0000000..8f6b910 --- /dev/null +++ b/src/llama_index_alloydb_pg/async_reader.py @@ -0,0 +1,270 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import Any, AsyncIterable, Callable, Iterable, Iterator, List, Optional + +from llama_index.core.bridge.pydantic import ConfigDict +from llama_index.core.readers.base import BasePydanticReader +from llama_index.core.schema import Document +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import AlloyDBEngine + +DEFAULT_METADATA_COL = "llamaindex_metadata" + + +def text_formatter(row: dict, content_columns: list[str]) -> str: + """txt document formatter.""" + return " ".join(str(row[column]) for column in content_columns if column in row) + + +def csv_formatter(row: dict, content_columns: list[str]) -> str: + """CSV document formatter.""" + return ", ".join(str(row[column]) for column in content_columns if column in row) + + +def yaml_formatter(row: dict, content_columns: list[str]) -> str: + """YAML document formatter.""" + return "\n".join( + f"{column}: {str(row[column])}" for column in content_columns if column in row + ) + + +def json_formatter(row: dict, content_columns: list[str]) -> str: + """JSON document formatter.""" + dictionary = {} + for column in content_columns: + if column in row: + dictionary[column] = row[column] + return json.dumps(dictionary) + + +def _parse_doc_from_row( + content_columns: Iterable[str], + metadata_columns: Iterable[str], + row: dict, + formatter: Callable = text_formatter, + metadata_json_column: Optional[str] = DEFAULT_METADATA_COL, +) -> Document: + """Parse row into document.""" + text = formatter(row, content_columns) + metadata: dict[str, Any] = {} + # unnest metadata from llamaindex_metadata column + if metadata_json_column and row.get(metadata_json_column): + for k, v in row[metadata_json_column].items(): + metadata[k] = v + # load metadata from other columns + for column in metadata_columns: + if column in row and column != metadata_json_column: + metadata[column] = row[column] + + return Document(text=text, extra_info=metadata) + + +class AsyncAlloyDBReader(BasePydanticReader): + """Load documents from AlloyDB. + + Each document represents one row of the result. The `content_columns` are + written into the `text` of the document. The `metadata_columns` are written + into the `metadata` of the document. By default, first columns is written into + the `text` and everything else into the `metadata`. + """ + + __create_key = object() + is_remote: bool = True + + def __init__( + self, + key: object, + pool: AsyncEngine, + query: str, + content_columns: list[str], + metadata_columns: list[str], + formatter: Callable, + metadata_json_column: Optional[str] = None, + is_remote: bool = True, + ) -> None: + """AsyncAlloyDBReader constructor. + + Args: + key (object): Prevent direct constructor usage. + engine (AlloyDBEngine): AsyncEngine with pool connection to the alloydb database + query (Optional[str], optional): SQL query. Defaults to None. + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "llamaindex_metadata". + is_remote (bool): Whether the data is loaded from a remote API or a local file. + + Raises: + Exception: If called directly by user. + """ + if key != AsyncAlloyDBReader.__create_key: + raise Exception("Only create class through 'create' method!") + + super().__init__(is_remote=is_remote) + + self._pool = pool + self._query = query + self._content_columns = content_columns + self._metadata_columns = metadata_columns + self._formatter = formatter + self._metadata_json_column = metadata_json_column + + @classmethod + async def create( + cls: type[AsyncAlloyDBReader], + engine: AlloyDBEngine, + query: Optional[str] = None, + table_name: Optional[str] = None, + schema_name: str = "public", + content_columns: Optional[list[str]] = None, + metadata_columns: Optional[list[str]] = None, + metadata_json_column: Optional[str] = None, + format: Optional[str] = None, + formatter: Optional[Callable] = None, + is_remote: bool = True, + ) -> AsyncAlloyDBReader: + """Create an AsyncAlloyDBReader instance. + + Args: + engine (AlloyDBEngine):AsyncEngine with pool connection to the alloydb database + query (Optional[str], optional): SQL query. Defaults to None. + table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Name of the schema where table is located. Defaults to "public". + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "llamaindex_metadata". + format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + is_remote (bool): Whether the data is loaded from a remote API or a local file. + + + Returns: + AsyncAlloyDBReader: A newly created instance of AsyncAlloyDBReader. + """ + if table_name and query: + raise ValueError("Only one of 'table_name' or 'query' should be specified.") + if not table_name and not query: + raise ValueError( + "At least one of the parameters 'table_name' or 'query' needs to be provided" + ) + if format and formatter: + raise ValueError("Only one of 'format' or 'formatter' should be specified.") + + if format and format not in ["csv", "text", "JSON", "YAML"]: + raise ValueError("format must be type: 'csv', 'text', 'JSON', 'YAML'") + if formatter: + formatter = formatter + elif format == "csv": + formatter = csv_formatter + elif format == "YAML": + formatter = yaml_formatter + elif format == "JSON": + formatter = json_formatter + else: + formatter = text_formatter + + if not query: + query = f'SELECT * FROM "{schema_name}"."{table_name}"' + + async with engine._pool.connect() as connection: + result_proxy = await connection.execute(text(query)) + column_names = list(result_proxy.keys()) + # Select content or default to first column + content_columns = content_columns or [column_names[0]] + # Select metadata columns + metadata_columns = metadata_columns or [ + col for col in column_names if col not in content_columns + ] + + # Check validity of metadata json column + if metadata_json_column and metadata_json_column not in column_names: + raise ValueError( + f"Column {metadata_json_column} not found in query result {column_names}." + ) + + if metadata_json_column and metadata_json_column in column_names: + metadata_json_column = metadata_json_column + elif DEFAULT_METADATA_COL in column_names: + metadata_json_column = DEFAULT_METADATA_COL + else: + metadata_json_column = None + + # check validity of other column + all_names = content_columns + metadata_columns + for name in all_names: + if name not in column_names: + raise ValueError( + f"Column {name} not found in query result {column_names}." + ) + return cls( + key=cls.__create_key, + pool=engine._pool, + query=query, + content_columns=content_columns, + metadata_columns=metadata_columns, + formatter=formatter, + metadata_json_column=metadata_json_column, + is_remote=is_remote, + ) + + @classmethod + def class_name(cls) -> str: + return "AsyncAlloyDBReader" + + async def aload_data(self) -> list[Document]: + """Asynchronously load AlloyDB data into Document objects.""" + return [doc async for doc in self.alazy_load_data()] + + async def alazy_load_data(self) -> AsyncIterable[Document]: # type: ignore + """Asynchronously load AlloyDB data into Document objects lazily.""" + async with self._pool.connect() as connection: + result_proxy = await connection.execute(text(self._query)) + # load document one by one + while True: + row = result_proxy.fetchone() + if not row: + break + + row_data = {} + column_names = self._content_columns + self._metadata_columns + column_names += ( + [self._metadata_json_column] if self._metadata_json_column else [] + ) + for column in column_names: + value = getattr(row, column) + row_data[column] = value + + yield _parse_doc_from_row( + self._content_columns, + self._metadata_columns, + row_data, + self._formatter, + self._metadata_json_column, + ) + + def lazy_load_data(self) -> Iterator[Document]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBReader. Use AlloyDBReader interface instead." + ) + + def load_data(self) -> List[Document]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBReader. Use AlloyDBReader interface instead." + ) diff --git a/tests/test_async_reader.py b/tests/test_async_reader.py new file mode 100644 index 0000000..74129d8 --- /dev/null +++ b/tests/test_async_reader.py @@ -0,0 +1,472 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import uuid +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.schema import Document +from sqlalchemy import RowMapping, text + +from llama_index_alloydb_pg import AlloyDBEngine +from llama_index_alloydb_pg.async_reader import AsyncAlloyDBReader + +default_table_name_async = "reader_test_" + str(uuid.uuid4()) + + +async def aexecute(engine: AlloyDBEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +async def afetch(engine: AlloyDBEngine, query: str) -> Sequence[RowMapping]: + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestAsyncAlloyDBReader: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on AlloyDB instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for AlloyDB") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for AlloyDB") + + @pytest_asyncio.fixture(scope="class") + async def async_engine( + self, db_project, db_region, db_cluster, db_instance, db_name + ): + async_engine = await AlloyDBEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + cluster=db_cluster, + region=db_region, + database=db_name, + ) + + yield async_engine + + await aexecute( + async_engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"' + ) + + await async_engine.close() + await async_engine._connector.close() + + async def _cleanup_table(self, engine): + await aexecute(engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"') + + async def _collect_async_items(self, docs_generator): + """Collects items from an async generator.""" + docs = [] + async for doc in docs_generator: + docs.append(doc) + return docs + + async def test_create_reader_with_invalid_parameters(self, async_engine): + with pytest.raises(ValueError): + await AsyncAlloyDBReader.create( + engine=async_engine, + ) + with pytest.raises(ValueError): + + def fake_formatter(): + return None + + await AsyncAlloyDBReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="text", + formatter=fake_formatter, + ) + with pytest.raises(ValueError): + await AsyncAlloyDBReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="fake_format", + ) + + async def test_load_from_query_default(self, async_engine): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" ( + fruit_name, variety, quantity_in_stock, price_per_unit, organic + ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncAlloyDBReader.create( + engine=async_engine, + table_name=table_name, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_document = Document( + text="1", + metadata={ + "fruit_name": "Apple", + "variety": "Granny Smith", + "quantity_in_stock": 150, + "price_per_unit": 1, + "organic": 1, + }, + ) + + assert documents[0].text == expected_document.text + assert documents[0].metadata == expected_document.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_customized_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + expected_docs = [ + Document( + text="Apple Smith 150 1 1", + metadata={"fruit_id": 1}, + ), + Document( + text="Banana Cavendish 200 1 0", + metadata={"fruit_id": 2}, + ), + Document( + text="Orange Navel 80 1 1", + metadata={"fruit_id": 3}, + ), + ] + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Smith', 150, 0.99, 1), + ('Banana', 'Cavendish', 200, 0.59, 0), + ('Orange', 'Navel', 80, 1.29, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncAlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + # Compare the full list of documents to make sure all are in sync. + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncAlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_text_docs = [ + Document( + text="Granny Smith 150 1", + metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1}, + ) + ] + + for expected, actual in zip(expected_text_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + reader = await AsyncAlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="JSON", + ) + + actual_documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text='{"variety": "Granny Smith", "quantity_in_stock": 150, "price_per_unit": 1}', + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, actual_documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_with_json(self, async_engine): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}"( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety JSON NOT NULL, + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + llamaindex_metadata JSON NOT NULL + ) + """ + await aexecute(async_engine, query) + + metadata = json.dumps({"organic": 1}) + variety = json.dumps({"type": "Granny Smith"}) + insert_query = f""" + INSERT INTO "{table_name}" + (fruit_name, variety, quantity_in_stock, price_per_unit, llamaindex_metadata) + VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" + await aexecute(async_engine, insert_query) + + reader = await AsyncAlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "variety", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="1", + metadata={ + "variety": {"type": "Granny Smith"}, + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_formatter( + self, async_engine + ): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + def my_formatter(row, content_columns): + return "-".join( + str(row[column]) for column in content_columns if column in row + ) + + reader = await AsyncAlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + formatter=my_formatter, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_documents = [ + Document( + text="Granny Smith-150-1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_documents, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_page_content_format( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncAlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="YAML", + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="variety: Granny Smith\nquantity_in_stock: 150\nprice_per_unit: 1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') From 7314d835e62ccd7e8fe59b35f37dccaaee6aed36 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Sat, 25 Jan 2025 00:41:38 +0000 Subject: [PATCH 28/39] feat: Adding AlloyDB Reader (#57) * feat: Adding AlloyDB Reader * Linter fix * Change pydantic field to class variable * Use __aiter__ to convert to AsyncIterator * Linter fix --- src/llama_index_alloydb_pg/__init__.py | 2 + src/llama_index_alloydb_pg/reader.py | 186 +++++ tests/test_reader.py | 912 +++++++++++++++++++++++++ 3 files changed, 1100 insertions(+) create mode 100644 src/llama_index_alloydb_pg/reader.py create mode 100644 tests/test_reader.py diff --git a/src/llama_index_alloydb_pg/__init__.py b/src/llama_index_alloydb_pg/__init__.py index baea6ec..91c28b8 100644 --- a/src/llama_index_alloydb_pg/__init__.py +++ b/src/llama_index_alloydb_pg/__init__.py @@ -16,6 +16,7 @@ from .document_store import AlloyDBDocumentStore from .engine import AlloyDBEngine, Column from .index_store import AlloyDBIndexStore +from .reader import AlloyDBReader from .vector_store import AlloyDBVectorStore from .version import __version__ @@ -24,6 +25,7 @@ "AlloyDBDocumentStore", "AlloyDBEngine", "AlloyDBIndexStore", + "AlloyDBReader", "AlloyDBVectorStore", "Column", "__version__", diff --git a/src/llama_index_alloydb_pg/reader.py b/src/llama_index_alloydb_pg/reader.py new file mode 100644 index 0000000..aae2019 --- /dev/null +++ b/src/llama_index_alloydb_pg/reader.py @@ -0,0 +1,186 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import AsyncIterable, Callable, Iterable, List, Optional + +from llama_index.core.bridge.pydantic import ConfigDict +from llama_index.core.readers.base import BasePydanticReader +from llama_index.core.schema import Document + +from .async_reader import AsyncAlloyDBReader +from .engine import AlloyDBEngine + +DEFAULT_METADATA_COL = "llamaindex_metadata" + + +class AlloyDBReader(BasePydanticReader): + """Chat Store Table stored in an AlloyDB for PostgreSQL database.""" + + __create_key = object() + is_remote: bool = True + + def __init__( + self, + key: object, + engine: AlloyDBEngine, + reader: AsyncAlloyDBReader, + is_remote: bool = True, + ) -> None: + """AlloyDBReader constructor. + + Args: + key (object): Prevent direct constructor usage. + engine (AlloyDBEngine): AlloyDB with pool connection to the alloydb database + reader (AsyncAlloyDBReader): The async only AlloyDBReader implementation + is_remote (Optional[bool]): Whether the data is loaded from a remote API or a local file. + + Raises: + Exception: If called directly by user. + """ + if key != AlloyDBReader.__create_key: + raise Exception("Only create class through 'create' method!") + + super().__init__(is_remote=is_remote) + + self._engine = engine + self.__reader = reader + + @classmethod + async def create( + cls: type[AlloyDBReader], + engine: AlloyDBEngine, + query: Optional[str] = None, + table_name: Optional[str] = None, + schema_name: str = "public", + content_columns: Optional[list[str]] = None, + metadata_columns: Optional[list[str]] = None, + metadata_json_column: Optional[str] = None, + format: Optional[str] = None, + formatter: Optional[Callable] = None, + is_remote: bool = True, + ) -> AlloyDBReader: + """Asynchronously create an AlloyDBReader instance. + + Args: + engine (AlloyDBEngine): AlloyDBEngine with pool connection to the alloydb database + query (Optional[str], optional): SQL query. Defaults to None. + table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Name of the schema where table is located. Defaults to "public". + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "llamaindex_metadata". + format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + is_remote (Optional[bool]): Whether the data is loaded from a remote API or a local file. + + + Returns: + AlloyDBReader: A newly created instance of AlloyDBReader. + """ + coro = AsyncAlloyDBReader.create( + engine=engine, + query=query, + table_name=table_name, + schema_name=schema_name, + content_columns=content_columns, + metadata_columns=metadata_columns, + metadata_json_column=metadata_json_column, + format=format, + formatter=formatter, + is_remote=is_remote, + ) + reader = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, reader, is_remote) + + @classmethod + def create_sync( + cls: type[AlloyDBReader], + engine: AlloyDBEngine, + query: Optional[str] = None, + table_name: Optional[str] = None, + schema_name: str = "public", + content_columns: Optional[list[str]] = None, + metadata_columns: Optional[list[str]] = None, + metadata_json_column: Optional[str] = None, + format: Optional[str] = None, + formatter: Optional[Callable] = None, + is_remote: bool = True, + ) -> AlloyDBReader: + """Synchronously create an AlloyDBReader instance. + + Args: + engine (AlloyDBEngine):AsyncEngine with pool connection to the alloydb database + query (Optional[str], optional): SQL query. Defaults to None. + table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Name of the schema where table is located. Defaults to "public". + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "llamaindex_metadata". + format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + is_remote (Optional[bool]): Whether the data is loaded from a remote API or a local file. + + + Returns: + AlloyDBReader: A newly created instance of AlloyDBReader. + """ + coro = AsyncAlloyDBReader.create( + engine=engine, + query=query, + table_name=table_name, + schema_name=schema_name, + content_columns=content_columns, + metadata_columns=metadata_columns, + metadata_json_column=metadata_json_column, + format=format, + formatter=formatter, + is_remote=is_remote, + ) + reader = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, reader, is_remote) + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "AlloyDBReader" + + async def aload_data(self) -> list[Document]: + """Asynchronously load AlloyDB data into Document objects.""" + return await self._engine._run_as_async(self.__reader.aload_data()) + + def load_data(self) -> list[Document]: + """Synchronously load AlloyDB data into Document objects.""" + return self._engine._run_as_sync(self.__reader.aload_data()) + + async def alazy_load_data(self) -> AsyncIterable[Document]: # type: ignore + """Asynchronously load AlloyDB data into Document objects lazily.""" + iterator = self.__reader.alazy_load_data().__aiter__() + while True: + try: + result = await self._engine._run_as_async(iterator.__anext__()) + yield result + except StopAsyncIteration: + break + + def lazy_load_data(self) -> Iterable[Document]: # type: ignore + """Synchronously aoad AlloyDB data into Document objects lazily.""" + iterator = self.__reader.alazy_load_data().__aiter__() + while True: + try: + result = self._engine._run_as_sync(iterator.__anext__()) + yield result + except StopAsyncIteration: + break diff --git a/tests/test_reader.py b/tests/test_reader.py new file mode 100644 index 0000000..e5ad30a --- /dev/null +++ b/tests/test_reader.py @@ -0,0 +1,912 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import uuid +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.schema import Document +from sqlalchemy import RowMapping, text + +from llama_index_alloydb_pg import AlloyDBEngine, AlloyDBReader + +default_table_name_async = "async_reader_test_" + str(uuid.uuid4()) +default_table_name_sync = "sync_reader_test_" + str(uuid.uuid4()) + + +async def aexecute( + engine: AlloyDBEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +async def afetch(engine: AlloyDBEngine, query: str) -> Sequence[RowMapping]: + async def run(engine, query): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + return await engine._run_as_async(run(engine, query)) + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestAlloyDBReaderAsync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on AlloyDB instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for AlloyDB") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for AlloyDB") + + @pytest_asyncio.fixture(scope="class") + async def async_engine( + self, + db_project, + db_region, + db_cluster, + db_instance, + db_name, + ): + async_engine = await AlloyDBEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + cluster=db_cluster, + region=db_region, + database=db_name, + ) + + yield async_engine + + await aexecute( + async_engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"' + ) + + await async_engine.close() + + async def _cleanup_table(self, engine): + await aexecute(engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"') + + async def _collect_async_items(self, docs_generator): + """Collects items from an async generator.""" + docs = [] + async for doc in docs_generator: + docs.append(doc) + return docs + + async def test_create_reader_with_invalid_parameters(self, async_engine): + with pytest.raises(ValueError): + await AlloyDBReader.create( + engine=async_engine, + ) + with pytest.raises(ValueError): + + def fake_formatter(): + return None + + await AlloyDBReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="text", + formatter=fake_formatter, + ) + with pytest.raises(ValueError): + await AlloyDBReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="fake_format", + ) + + async def test_load_from_query_default(self, async_engine): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" ( + fruit_name, variety, quantity_in_stock, price_per_unit, organic + ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AlloyDBReader.create( + engine=async_engine, + table_name=table_name, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_document = Document( + text="1", + metadata={ + "fruit_name": "Apple", + "variety": "Granny Smith", + "quantity_in_stock": 150, + "price_per_unit": 1, + "organic": 1, + }, + ) + + assert documents[0].text == expected_document.text + assert documents[0].metadata == expected_document.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_customized_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + expected_docs = [ + Document( + text="Apple Smith 150 1 1", + metadata={"fruit_id": 1}, + ), + Document( + text="Banana Cavendish 200 1 0", + metadata={"fruit_id": 2}, + ), + Document( + text="Orange Navel 80 1 1", + metadata={"fruit_id": 3}, + ), + ] + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Smith', 150, 0.99, 1), + ('Banana', 'Cavendish', 200, 0.59, 0), + ('Orange', 'Navel', 80, 1.29, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + # Compare the full list of documents to make sure all are in sync. + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_text_docs = [ + Document( + text="Granny Smith 150 1", + metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1}, + ) + ] + + for expected, actual in zip(expected_text_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + reader = await AlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="JSON", + ) + + actual_documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text='{"variety": "Granny Smith", "quantity_in_stock": 150, "price_per_unit": 1}', + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, actual_documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_with_json(self, async_engine): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}"( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety JSON NOT NULL, + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + llamaindex_metadata JSON NOT NULL + ) + """ + await aexecute(async_engine, query) + + metadata = json.dumps({"organic": 1}) + variety = json.dumps({"type": "Granny Smith"}) + insert_query = f""" + INSERT INTO "{table_name}" + (fruit_name, variety, quantity_in_stock, price_per_unit, llamaindex_metadata) + VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" + await aexecute(async_engine, insert_query) + + reader = await AlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "variety", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="1", + metadata={ + "variety": {"type": "Granny Smith"}, + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_formatter( + self, async_engine + ): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + def my_formatter(row, content_columns): + return "-".join( + str(row[column]) for column in content_columns if column in row + ) + + reader = await AlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + formatter=my_formatter, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_documents = [ + Document( + text="Granny Smith-150-1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_documents, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_page_content_format( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AlloyDBReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="YAML", + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="variety: Granny Smith\nquantity_in_stock: 150\nprice_per_unit: 1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + +@pytest.mark.asyncio(loop_scope="class") +class TestAlloyDBReaderSync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on AlloyDB instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for AlloyDB") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for AlloyDB") + + @pytest_asyncio.fixture(scope="class") + async def sync_engine( + self, + db_project, + db_region, + db_cluster, + db_instance, + db_name, + ): + sync_engine = await AlloyDBEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + cluster=db_cluster, + region=db_region, + database=db_name, + ) + + yield sync_engine + + await aexecute( + sync_engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"' + ) + + await sync_engine.close() + + async def _cleanup_table(self, engine): + await aexecute(engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"') + + def _collect_items(self, docs_generator): + """Collects items from a generator.""" + docs = [] + for doc in docs_generator: + docs.append(doc) + return docs + + async def test_create_reader_with_invalid_parameters(self, sync_engine): + with pytest.raises(ValueError): + AlloyDBReader.create_sync( + engine=sync_engine, + ) + with pytest.raises(ValueError): + + def fake_formatter(): + return None + + AlloyDBReader.create_sync( + engine=sync_engine, + table_name=default_table_name_async, + format="text", + formatter=fake_formatter, + ) + with pytest.raises(ValueError): + AlloyDBReader.create_sync( + engine=sync_engine, + table_name=default_table_name_async, + format="fake_format", + ) + + async def test_load_from_query_default(self, sync_engine): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" ( + fruit_name, variety, quantity_in_stock, price_per_unit, organic + ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = AlloyDBReader.create_sync( + engine=sync_engine, + table_name=table_name, + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_document = Document( + text="1", + metadata={ + "fruit_name": "Apple", + "variety": "Granny Smith", + "quantity_in_stock": 150, + "price_per_unit": 1, + "organic": 1, + }, + ) + + assert documents[0].text == expected_document.text + assert documents[0].metadata == expected_document.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_customized_metadata( + self, sync_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + expected_docs = [ + Document( + text="Apple Smith 150 1 1", + metadata={"fruit_id": 1}, + ), + Document( + text="Banana Cavendish 200 1 0", + metadata={"fruit_id": 2}, + ), + Document( + text="Orange Navel 80 1 1", + metadata={"fruit_id": 3}, + ), + ] + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Smith', 150, 0.99, 1), + ('Banana', 'Cavendish', 200, 0.59, 0), + ('Orange', 'Navel', 80, 1.29, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = AlloyDBReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ) + + documents = self._collect_items(reader.lazy_load_data()) + + # Compare the full list of documents to make sure all are in sync. + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata( + self, sync_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = AlloyDBReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_text_docs = [ + Document( + text="Granny Smith 150 1", + metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1}, + ) + ] + + for expected, actual in zip(expected_text_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + reader = AlloyDBReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="JSON", + ) + + actual_documents = self._collect_items(reader.lazy_load_data()) + + expected_docs = [ + Document( + text='{"variety": "Granny Smith", "quantity_in_stock": 150, "price_per_unit": 1}', + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, actual_documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_with_json(self, sync_engine): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}"( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety JSON NOT NULL, + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + llamaindex_metadata JSON NOT NULL + ) + """ + await aexecute(sync_engine, query) + + metadata = json.dumps({"organic": 1}) + variety = json.dumps({"type": "Granny Smith"}) + insert_query = f""" + INSERT INTO "{table_name}" + (fruit_name, variety, quantity_in_stock, price_per_unit, llamaindex_metadata) + VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" + await aexecute(sync_engine, insert_query) + + reader = AlloyDBReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "variety", + ], + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_docs = [ + Document( + text="1", + metadata={ + "variety": {"type": "Granny Smith"}, + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_formatter( + self, sync_engine + ): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + def my_formatter(row, content_columns): + return "-".join( + str(row[column]) for column in content_columns if column in row + ) + + reader = AlloyDBReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + formatter=my_formatter, + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_documents = [ + Document( + text="Granny Smith-150-1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_documents, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_page_content_format( + self, sync_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = AlloyDBReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="YAML", + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_docs = [ + Document( + text="variety: Granny Smith\nquantity_in_stock: 150\nprice_per_unit: 1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') From 98f6c65fc77cbb7f25b22d7118bbb89f3c674b2f Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Sat, 25 Jan 2025 00:55:44 +0000 Subject: [PATCH 29/39] fix: Update lazy_load_data return type to Iterable. (#61) Co-authored-by: Averi Kitsch --- src/llama_index_alloydb_pg/async_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama_index_alloydb_pg/async_reader.py b/src/llama_index_alloydb_pg/async_reader.py index 8f6b910..a096f62 100644 --- a/src/llama_index_alloydb_pg/async_reader.py +++ b/src/llama_index_alloydb_pg/async_reader.py @@ -259,7 +259,7 @@ async def alazy_load_data(self) -> AsyncIterable[Document]: # type: ignore self._metadata_json_column, ) - def lazy_load_data(self) -> Iterator[Document]: + def lazy_load_data(self) -> Iterable[Document]: raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBReader. Use AlloyDBReader interface instead." ) From 22bb4b9e4f69773ccc0ef526a3065bcf7944ba3a Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Mon, 27 Jan 2025 17:24:47 +0000 Subject: [PATCH 30/39] chore(docs): Update docstring (#62) docs: Update docstring --- src/llama_index_alloydb_pg/chat_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama_index_alloydb_pg/chat_store.py b/src/llama_index_alloydb_pg/chat_store.py index 3e5b5ee..1615ca4 100644 --- a/src/llama_index_alloydb_pg/chat_store.py +++ b/src/llama_index_alloydb_pg/chat_store.py @@ -36,7 +36,7 @@ def __init__( Args: key (object): Key to prevent direct constructor usage. engine (AlloyDBEngine): Database connection pool. - chat_store (AsyncAlloyDBChatStore): The async only IndexStore implementation + chat_store (AsyncAlloyDBChatStore): The async only ChatStore implementation Raises: Exception: If constructor is directly called by the user. From ecb53c80d311deb9232f0f8844761a816fc01bc0 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Mon, 27 Jan 2025 17:58:29 +0000 Subject: [PATCH 31/39] fix: Change default metadata_json_column default value (#66) Co-authored-by: Averi Kitsch --- src/llama_index_alloydb_pg/async_reader.py | 8 ++++---- src/llama_index_alloydb_pg/reader.py | 6 +++--- tests/test_async_reader.py | 4 ++-- tests/test_reader.py | 8 ++++---- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/llama_index_alloydb_pg/async_reader.py b/src/llama_index_alloydb_pg/async_reader.py index a096f62..b233ef2 100644 --- a/src/llama_index_alloydb_pg/async_reader.py +++ b/src/llama_index_alloydb_pg/async_reader.py @@ -25,7 +25,7 @@ from .engine import AlloyDBEngine -DEFAULT_METADATA_COL = "llamaindex_metadata" +DEFAULT_METADATA_COL = "li_metadata" def text_formatter(row: dict, content_columns: list[str]) -> str: @@ -64,7 +64,7 @@ def _parse_doc_from_row( """Parse row into document.""" text = formatter(row, content_columns) metadata: dict[str, Any] = {} - # unnest metadata from llamaindex_metadata column + # unnest metadata from li_metadata column if metadata_json_column and row.get(metadata_json_column): for k, v in row[metadata_json_column].items(): metadata[k] = v @@ -108,7 +108,7 @@ def __init__( content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. - metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "llamaindex_metadata". + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". is_remote (bool): Whether the data is loaded from a remote API or a local file. Raises: @@ -149,7 +149,7 @@ async def create( schema_name (str, optional): Name of the schema where table is located. Defaults to "public". content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. - metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "llamaindex_metadata". + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. is_remote (bool): Whether the data is loaded from a remote API or a local file. diff --git a/src/llama_index_alloydb_pg/reader.py b/src/llama_index_alloydb_pg/reader.py index aae2019..5b78491 100644 --- a/src/llama_index_alloydb_pg/reader.py +++ b/src/llama_index_alloydb_pg/reader.py @@ -23,7 +23,7 @@ from .async_reader import AsyncAlloyDBReader from .engine import AlloyDBEngine -DEFAULT_METADATA_COL = "llamaindex_metadata" +DEFAULT_METADATA_COL = "li_metadata" class AlloyDBReader(BasePydanticReader): @@ -81,7 +81,7 @@ async def create( schema_name (str, optional): Name of the schema where table is located. Defaults to "public". content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. - metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "llamaindex_metadata". + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. is_remote (Optional[bool]): Whether the data is loaded from a remote API or a local file. @@ -128,7 +128,7 @@ def create_sync( schema_name (str, optional): Name of the schema where table is located. Defaults to "public". content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. - metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "llamaindex_metadata". + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. is_remote (Optional[bool]): Whether the data is loaded from a remote API or a local file. diff --git a/tests/test_async_reader.py b/tests/test_async_reader.py index 74129d8..658cc37 100644 --- a/tests/test_async_reader.py +++ b/tests/test_async_reader.py @@ -322,7 +322,7 @@ async def test_load_from_query_with_json(self, async_engine): variety JSON NOT NULL, quantity_in_stock INT NOT NULL, price_per_unit INT NOT NULL, - llamaindex_metadata JSON NOT NULL + li_metadata JSON NOT NULL ) """ await aexecute(async_engine, query) @@ -331,7 +331,7 @@ async def test_load_from_query_with_json(self, async_engine): variety = json.dumps({"type": "Granny Smith"}) insert_query = f""" INSERT INTO "{table_name}" - (fruit_name, variety, quantity_in_stock, price_per_unit, llamaindex_metadata) + (fruit_name, variety, quantity_in_stock, price_per_unit, li_metadata) VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" await aexecute(async_engine, insert_query) diff --git a/tests/test_reader.py b/tests/test_reader.py index e5ad30a..2f9df85 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -335,7 +335,7 @@ async def test_load_from_query_with_json(self, async_engine): variety JSON NOT NULL, quantity_in_stock INT NOT NULL, price_per_unit INT NOT NULL, - llamaindex_metadata JSON NOT NULL + li_metadata JSON NOT NULL ) """ await aexecute(async_engine, query) @@ -344,7 +344,7 @@ async def test_load_from_query_with_json(self, async_engine): variety = json.dumps({"type": "Granny Smith"}) insert_query = f""" INSERT INTO "{table_name}" - (fruit_name, variety, quantity_in_stock, price_per_unit, llamaindex_metadata) + (fruit_name, variety, quantity_in_stock, price_per_unit, li_metadata) VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" await aexecute(async_engine, insert_query) @@ -762,7 +762,7 @@ async def test_load_from_query_with_json(self, sync_engine): variety JSON NOT NULL, quantity_in_stock INT NOT NULL, price_per_unit INT NOT NULL, - llamaindex_metadata JSON NOT NULL + li_metadata JSON NOT NULL ) """ await aexecute(sync_engine, query) @@ -771,7 +771,7 @@ async def test_load_from_query_with_json(self, sync_engine): variety = json.dumps({"type": "Granny Smith"}) insert_query = f""" INSERT INTO "{table_name}" - (fruit_name, variety, quantity_in_stock, price_per_unit, llamaindex_metadata) + (fruit_name, variety, quantity_in_stock, price_per_unit, li_metadata) VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" await aexecute(sync_engine, insert_query) From 67c416970371d14ddcdebd22bd7276c436078934 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Wed, 29 Jan 2025 12:12:08 +0000 Subject: [PATCH 32/39] docs: Add AlloyDB Reader How-to (#65) --- samples/llama_index_reader.ipynb | 378 +++++++++++++++++++++++++++++++ 1 file changed, 378 insertions(+) create mode 100644 samples/llama_index_reader.ipynb diff --git a/samples/llama_index_reader.ipynb b/samples/llama_index_reader.ipynb new file mode 100644 index 0000000..7a4f1e8 --- /dev/null +++ b/samples/llama_index_reader.ipynb @@ -0,0 +1,378 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Google AlloyDB for PostgreSQL - `AlloyDBReader`\n", + "\n", + "> [AlloyDB](https://cloud.google.com/alloydb) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. AlloyDB is 100% compatible with PostgreSQL. Extend your database application to build AI-powered experiences leveraging AlloyDB's LlamaIndex integrations.\n", + "\n", + "This notebook goes over how to use `AlloyDB for PostgreSQL` to retrieve data as documents with the `AlloyDBReader` class.\n", + "\n", + "Learn more about the package on [GitHub](https://github.com/googleapis/llama-index-alloydb-pg-python/).\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/llama-index-alloydb-pg-python/blob/main/samples/llama_index_reader.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Before you begin\n", + "\n", + "To run this notebook, you will need to do the following:\n", + "\n", + " * [Create a Google Cloud Project](https://developers.google.com/workspace/guides/create-project)\n", + " * [Enable the AlloyDB API](https://console.cloud.google.com/flows/enableapi?apiid=alloydb.googleapis.com)\n", + " * [Create a AlloyDB cluster and instance.](https://cloud.google.com/alloydb/docs/cluster-create)\n", + " * [Create a AlloyDB database.](https://cloud.google.com/alloydb/docs/quickstart/create-and-connect)\n", + " * [Add a User to the database.](https://cloud.google.com/alloydb/docs/database-users/about)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### πŸ¦™ Library Installation\n", + "Install the integration library, `llama-index-alloydb-pg`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Colab only:** Uncomment the following cell to restart the kernel or use the button to restart the kernel. For Vertex AI Workbench you can restart the terminal using the button on top." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# # Automatically restart kernel after installs so that your environment can access the new packages\n", + "# import IPython\n", + "\n", + "# app = IPython.Application.instance()\n", + "# app.kernel.do_shutdown(True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### πŸ” Authentication\n", + "Authenticate to Google Cloud as the IAM user logged into this notebook in order to access your Google Cloud Project.\n", + "\n", + "* If you are using Colab to run this notebook, use the cell below and continue.\n", + "* If you are using Vertex AI Workbench, check out the setup instructions [here](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/setup-env)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from google.colab import auth\n", + "\n", + "auth.authenticate_user()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ☁ Set Your Google Cloud Project\n", + "Set your Google Cloud project so that you can leverage Google Cloud resources within this notebook.\n", + "\n", + "If you don't know your project ID, try the following:\n", + "\n", + "* Run `gcloud config list`.\n", + "* Run `gcloud projects list`.\n", + "* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @markdown Please fill in the value below with your Google Cloud project ID and then run the cell.\n", + "\n", + "PROJECT_ID = \"my-project-id\" # @param {type:\"string\"}\n", + "\n", + "# Set the project id\n", + "!gcloud config set project {PROJECT_ID}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Usage" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Set AlloyDB database values\n", + "Find your database values, in the [AlloyDB Instances page](https://console.cloud.google.com/alloydb/clusters)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Set Your Values Here { display-mode: \"form\" }\n", + "REGION = \"us-central1\" # @param {type: \"string\"}\n", + "CLUSTER = \"my-cluster\" # @param {type: \"string\"}\n", + "INSTANCE = \"my-primary\" # @param {type: \"string\"}\n", + "DATABASE = \"my-database\" # @param {type: \"string\"}\n", + "TABLE_NAME = \"document_store\" # @param {type: \"string\"}\n", + "USER = \"postgres\" # @param {type: \"string\"}\n", + "PASSWORD = \"my-password\" # @param {type: \"string\"}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### AlloyDBEngine Connection Pool\n", + "\n", + "One of the requirements and arguments to establish AlloyDB Reader is a `AlloyDBEngine` object. The `AlloyDBEngine` configures a connection pool to your AlloyDB database, enabling successful connections from your application and following industry best practices.\n", + "\n", + "To create a `AlloyDBEngine` using `AlloyDBEngine.from_instance()` you need to provide only 5 things:\n", + "\n", + "1. `project_id` : Project ID of the Google Cloud Project where the AlloyDB instance is located.\n", + "1. `region` : Region where the AlloyDB instance is located.\n", + "1. `cluster`: The name of the AlloyDB cluster.\n", + "1. `instance` : The name of the AlloyDB instance.\n", + "1. `database` : The name of the database to connect to on the AlloyDB instance.\n", + "\n", + "By default, [IAM database authentication](https://cloud.google.com/alloydb/docs/connect-iam) will be used as the method of database authentication. This library uses the IAM principal belonging to the [Application Default Credentials (ADC)](https://cloud.google.com/docs/authentication/application-default-credentials) sourced from the environment.\n", + "\n", + "Optionally, [built-in database authentication](https://cloud.google.com/alloydb/docs/database-users/about) using a username and password to access the AlloyDB database can also be used. Just provide the optional `user` and `password` arguments to `AlloyDBEngine.from_instance()`:\n", + "\n", + "* `user` : Database user to use for built-in database authentication and login\n", + "* `password` : Database password to use for built-in database authentication and login.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note:** This tutorial demonstrates the async interface. All async methods have corresponding sync methods." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index_alloydb_pg import AlloyDBEngine\n", + "\n", + "engine = await AlloyDBEngine.afrom_instance(\n", + " project_id=PROJECT_ID,\n", + " region=REGION,\n", + " cluster=CLUSTER,\n", + " instance=INSTANCE,\n", + " database=DATABASE,\n", + " user=USER,\n", + " password=PASSWORD,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create AlloyDBReader" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When creating an `AlloyDBReader` for fetching data from AlloyDB, you have two main options to specify the data you want to load:\n", + "* using the table_name argument - When you specify the table_name argument, you're telling the reader to fetch all the data from the given table.\n", + "* using the query argument - When you specify the query argument, you can provide a custom SQL query to fetch the data. This allows you to have full control over the SQL query, including selecting specific columns, applying filters, sorting, joining tables, etc.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load Documents using the `table_name` argument" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Load Documents via default table\n", + "The reader returns a list of Documents from the table using the first column as text and all other columns as metadata. The default table will have the first column as\n", + "text and the second column as metadata (JSON). Each row becomes a document." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index_alloydb_pg import AlloyDBReader\n", + "\n", + "# Creating a basic AlloyDBReader object\n", + "reader = await AlloyDBReader.create(\n", + " engine,\n", + " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Load documents via custom table/metadata or custom page content columns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reader = await AlloyDBReader.create(\n", + " engine,\n", + " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", + " content_columns=[\"product_name\"], # Optional\n", + " metadata_columns=[\"id\"], # Optional\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load Documents using a SQL query\n", + "The query parameter allows users to specify a custom SQL query which can include filters to load specific documents from a database." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "table_name = \"products\"\n", + "content_columns = [\"product_name\", \"description\"]\n", + "metadata_columns = [\"id\", \"content\"]\n", + "\n", + "reader = AlloyDBReader.create(\n", + " engine=engine,\n", + " query=f\"SELECT * FROM {table_name};\",\n", + " content_columns=content_columns,\n", + " metadata_columns=metadata_columns,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note**: If the `content_columns` and `metadata_columns` are not specified, the reader will automatically treat the first returned column as the document’s `text` and all subsequent columns as `metadata`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Set page content format\n", + "The reader returns a list of Documents, with one document per row, with page content in specified string format, i.e. text (space separated concatenation), JSON, YAML, CSV, etc. JSON and YAML formats include headers, while text and CSV do not include field headers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reader = await AlloyDBReader.create(\n", + " engine,\n", + " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", + " content_columns=[\"product_name\", \"description\"],\n", + " format=\"YAML\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load the documents\n", + "\n", + "You can choose to load the documents in two ways:\n", + "* Load all the data at once\n", + "* Lazy load data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Load data all at once" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs = await reader.aload_data()\n", + "\n", + "print(docs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Lazy Load the data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs_iterable = reader.alazy_load_data()\n", + "\n", + "docs = []\n", + "async for doc in docs_iterable:\n", + " docs.append(doc)\n", + "\n", + "print(docs)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 2b19c0f901f440f2ada6a76f5e3e4f4b3ec9e527 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Wed, 29 Jan 2025 15:05:34 +0000 Subject: [PATCH 33/39] chore(docs): Add Chat Store How-to (#64) * docs: Add Chat Store How-to * Change markdown font * Change project id * Reduced duplication for AlloyDBVectorStore setup. --------- Co-authored-by: Averi Kitsch --- samples/llama_index_chat_store.ipynb | 419 +++++++++++++++++++++++++++ 1 file changed, 419 insertions(+) create mode 100644 samples/llama_index_chat_store.ipynb diff --git a/samples/llama_index_chat_store.ipynb b/samples/llama_index_chat_store.ipynb new file mode 100644 index 0000000..157ed60 --- /dev/null +++ b/samples/llama_index_chat_store.ipynb @@ -0,0 +1,419 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Google AlloyDB for PostgreSQL - `AlloyDBChatStore`\n", + "\n", + "> [AlloyDB](https://cloud.google.com/alloydb) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. AlloyDB is 100% compatible with PostgreSQL. Extend your database application to build AI-powered experiences leveraging AlloyDB's LlamaIndex integrations.\n", + "\n", + "This notebook goes over how to use `AlloyDB for PostgreSQL` to store chat history with `AlloyDBChatStore` class.\n", + "\n", + "Learn more about the package on [GitHub](https://github.com/googleapis/llama-index-alloydb-pg-python/).\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/llama-index-alloydb-pg-python/blob/main/samples/llama_index_chat_store.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Before you begin\n", + "\n", + "To run this notebook, you will need to do the following:\n", + "\n", + " * [Create a Google Cloud Project](https://developers.google.com/workspace/guides/create-project)\n", + " * [Enable the AlloyDB API](https://console.cloud.google.com/flows/enableapi?apiid=alloydb.googleapis.com)\n", + " * [Create a AlloyDB cluster and instance.](https://cloud.google.com/alloydb/docs/cluster-create)\n", + " * [Create a AlloyDB database.](https://cloud.google.com/alloydb/docs/quickstart/create-and-connect)\n", + " * [Add a User to the database.](https://cloud.google.com/alloydb/docs/database-users/about)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### πŸ¦™ Library Installation\n", + "Install the integration library, `llama-index-alloydb-pg`, and the library for the embedding service, `llama-index-embeddings-vertex`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet llama-index-alloydb-pg llama-index-llms-vertex llama-index" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Colab only:** Uncomment the following cell to restart the kernel or use the button to restart the kernel. For Vertex AI Workbench you can restart the terminal using the button on top." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# # Automatically restart kernel after installs so that your environment can access the new packages\n", + "# import IPython\n", + "\n", + "# app = IPython.Application.instance()\n", + "# app.kernel.do_shutdown(True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### πŸ” Authentication\n", + "Authenticate to Google Cloud as the IAM user logged into this notebook in order to access your Google Cloud Project.\n", + "\n", + "* If you are using Colab to run this notebook, use the cell below and continue.\n", + "* If you are using Vertex AI Workbench, check out the setup instructions [here](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/setup-env)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from google.colab import auth\n", + "\n", + "auth.authenticate_user()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ☁ Set Your Google Cloud Project\n", + "Set your Google Cloud project so that you can leverage Google Cloud resources within this notebook.\n", + "\n", + "If you don't know your project ID, try the following:\n", + "\n", + "* Run `gcloud config list`.\n", + "* Run `gcloud projects list`.\n", + "* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @markdown Please fill in the value below with your Google Cloud project ID and then run the cell.\n", + "\n", + "PROJECT_ID = \"my-project-id\" # @param {type:\"string\"}\n", + "\n", + "# Set the project id\n", + "!gcloud config set project {PROJECT_ID}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Usage" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Set AlloyDB database values\n", + "Find your database values, in the [AlloyDB Instances page](https://console.cloud.google.com/alloydb/clusters)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Set Your Values Here { display-mode: \"form\" }\n", + "REGION = \"us-central1\" # @param {type: \"string\"}\n", + "CLUSTER = \"my-cluster\" # @param {type: \"string\"}\n", + "INSTANCE = \"my-primary\" # @param {type: \"string\"}\n", + "DATABASE = \"my-database\" # @param {type: \"string\"}\n", + "TABLE_NAME = \"chat_store\" # @param {type: \"string\"}\n", + "USER = \"postgres\" # @param {type: \"string\"}\n", + "PASSWORD = \"my-password\" # @param {type: \"string\"}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### AlloyDBEngine Connection Pool\n", + "\n", + "One of the requirements and arguments to establish AlloyDB as a chat store is a `AlloyDBEngine` object. The `AlloyDBEngine` configures a connection pool to your AlloyDB database, enabling successful connections from your application and following industry best practices.\n", + "\n", + "To create a `AlloyDBEngine` using `AlloyDBEngine.from_instance()` you need to provide only 5 things:\n", + "\n", + "1. `project_id` : Project ID of the Google Cloud Project where the AlloyDB instance is located.\n", + "1. `region` : Region where the AlloyDB instance is located.\n", + "1. `cluster`: The name of the AlloyDB cluster.\n", + "1. `instance` : The name of the AlloyDB instance.\n", + "1. `database` : The name of the database to connect to on the AlloyDB instance.\n", + "\n", + "By default, [IAM database authentication](https://cloud.google.com/alloydb/docs/connect-iam) will be used as the method of database authentication. This library uses the IAM principal belonging to the [Application Default Credentials (ADC)](https://cloud.google.com/docs/authentication/application-default-credentials) sourced from the environment.\n", + "\n", + "Optionally, [built-in database authentication](https://cloud.google.com/alloydb/docs/database-users/about) using a username and password to access the AlloyDB database can also be used. Just provide the optional `user` and `password` arguments to `AlloyDBEngine.from_instance()`:\n", + "\n", + "* `user` : Database user to use for built-in database authentication and login\n", + "* `password` : Database password to use for built-in database authentication and login.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note:** This tutorial demonstrates the async interface. All async methods have corresponding sync methods." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index_alloydb_pg import AlloyDBEngine\n", + "\n", + "engine = await AlloyDBEngine.afrom_instance(\n", + " project_id=PROJECT_ID,\n", + " region=REGION,\n", + " cluster=CLUSTER,\n", + " instance=INSTANCE,\n", + " database=DATABASE,\n", + " user=USER,\n", + " password=PASSWORD,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### AlloyDBEngine for AlloyDB Omni\n", + "To create an `AlloyDBEngine` for AlloyDB Omni, you will need a connection url. `AlloyDBEngine.from_connection_string` first creates an async engine and then turns it into an `AlloyDBEngine`. Here is an example connection with the `asyncpg` driver:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Replace with your own AlloyDB Omni info\n", + "OMNI_USER = \"my-omni-user\"\n", + "OMNI_PASSWORD = \"\"\n", + "OMNI_HOST = \"127.0.0.1\"\n", + "OMNI_PORT = \"5432\"\n", + "OMNI_DATABASE = \"my-omni-db\"\n", + "\n", + "connstring = f\"postgresql+asyncpg://{OMNI_USER}:{OMNI_PASSWORD}@{OMNI_HOST}:{OMNI_PORT}/{OMNI_DATABASE}\"\n", + "engine = AlloyDBEngine.from_connection_string(connstring)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize a table\n", + "The `AlloyDBChatStore` class requires a database table. The `AlloyDBEngine` engine has a helper method `ainit_chat_store_table()` that can be used to create a table with the proper schema for you." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "await engine.ainit_chat_store_table(table_name=TABLE_NAME)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Optional Tip: πŸ’‘\n", + "You can also specify a schema name by passing `schema_name` wherever you pass `table_name`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "SCHEMA_NAME = \"my_schema\"\n", + "\n", + "await engine.ainit_chat_store_table(\n", + " table_name=TABLE_NAME,\n", + " schema_name=SCHEMA_NAME,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize a default AlloyDBChatStore" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index_alloydb_pg import AlloyDBChatStore\n", + "\n", + "chat_store = await AlloyDBChatStore.create(\n", + " engine=engine,\n", + " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create a ChatMemoryBuffer\n", + "The `ChatMemoryBuffer` stores a history of recent chat messages, enabling the LLM to access relevant context from prior interactions.\n", + "\n", + "By passing our chat store into the `ChatMemoryBuffer`, it can automatically retrieve and update messages associated with a specific session ID or `chat_store_key`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.memory import ChatMemoryBuffer\n", + "\n", + "memory = ChatMemoryBuffer.from_defaults(\n", + " token_limit=3000,\n", + " chat_store=chat_store,\n", + " chat_store_key=\"user1\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create an LLM class instance\n", + "\n", + "You can use any of the [LLMs compatible with LlamaIndex](https://docs.llamaindex.ai/en/stable/module_guides/models/llms/modules/).\n", + "You may need to enable Vertex AI API to use `Vertex`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.llms.vertex import Vertex\n", + "\n", + "llm = Vertex(model=\"gemini-1.5-flash-002\", project=PROJECT_ID)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Use the AlloyDBChatStore without a storage context" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Create a Simple Chat Engine" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.chat_engine import SimpleChatEngine\n", + "\n", + "chat_engine = SimpleChatEngine(memory=memory, llm=llm, prefix_messages=[])\n", + "\n", + "response = chat_engine.chat(\"Hello\")\n", + "\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Use the AlloyDBChatStore with a storage context" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Create a LlamaIndex `Index`\n", + "\n", + "An `Index` is allows us to quickly retrieve relevant context for a user query.\n", + "They are used to build `QueryEngines` and `ChatEngines`.\n", + "For a list of indexes that can be built in LlamaIndex, see [Index Guide](https://docs.llamaindex.ai/en/stable/module_guides/indexing/index_guide/).\n", + "\n", + "A `VectorStoreIndex`, can be built using the `AlloyDBVectorStore`. See the detailed guide on how to use the `AlloyDBVectorStore` to build an index [here](https://github.com/googleapis/llama-index-alloydb-pg-python/blob/main/samples/llama_index_vector_store.ipynb).\n", + "\n", + "You can also use the `AlloyDBDocumentStore` and `AlloyDBIndexStore` to persist documents and index metadata.\n", + "These modules can be used to build other `Indexes`.\n", + "For a detailed python notebook on this, see [LlamaIndex Doc Store Guide](https://github.com/googleapis/llama-index-alloydb-pg-python/blob/main/samples/llama_index_doc_store.ipynb)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Create and use the Chat Engine" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create an `index` here\n", + "\n", + "chat_engine = index.as_chat_engine(llm=llm, chat_mode=\"context\", memory=memory) # type: ignore\n", + "response = chat_engine.chat(\"What did the author do?\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "senseAIenv", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From d8a6445cb733e5bf9e69d6f74b2786b2ef2e2608 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Wed, 29 Jan 2025 17:15:19 +0000 Subject: [PATCH 34/39] =?UTF-8?q?Chore:=20Add=20additional=20tests=20for?= =?UTF-8?q?=20AsyncReader=20and=20type=20mismatch=20method=20=E2=80=A6=20(?= =?UTF-8?q?#70)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Chore: Add additional tests for AsyncReader and type mismatch method comment. * Add await to function call --- src/llama_index_alloydb_pg/reader.py | 1 + tests/test_async_reader.py | 37 +++++++++++++++++++++++++--- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/src/llama_index_alloydb_pg/reader.py b/src/llama_index_alloydb_pg/reader.py index 5b78491..aac1582 100644 --- a/src/llama_index_alloydb_pg/reader.py +++ b/src/llama_index_alloydb_pg/reader.py @@ -167,6 +167,7 @@ def load_data(self) -> list[Document]: async def alazy_load_data(self) -> AsyncIterable[Document]: # type: ignore """Asynchronously load AlloyDB data into Document objects lazily.""" + # The return type in the underlying base class is an Iterable which we are overriding to an AsyncIterable in this implementation. iterator = self.__reader.alazy_load_data().__aiter__() while True: try: diff --git a/tests/test_async_reader.py b/tests/test_async_reader.py index 658cc37..50a1be0 100644 --- a/tests/test_async_reader.py +++ b/tests/test_async_reader.py @@ -26,6 +26,7 @@ from llama_index_alloydb_pg.async_reader import AsyncAlloyDBReader default_table_name_async = "reader_test_" + str(uuid.uuid4()) +sync_method_exception_str = "Sync methods are not implemented for AsyncAlloyDBReader. Use AlloyDBReader interface instead." async def aexecute(engine: AlloyDBEngine, query: str) -> None: @@ -90,12 +91,11 @@ async def async_engine( region=db_region, database=db_name, ) + await self._create_default_table(async_engine) yield async_engine - await aexecute( - async_engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"' - ) + await self._cleanup_table(async_engine) await async_engine.close() await async_engine._connector.close() @@ -103,6 +103,19 @@ async def async_engine( async def _cleanup_table(self, engine): await aexecute(engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"') + async def _create_default_table(self, engine): + create_query = f""" + CREATE TABLE IF NOT EXISTS "{default_table_name_async}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(engine, create_query) + async def _collect_async_items(self, docs_generator): """Collects items from an async generator.""" docs = [] @@ -133,6 +146,24 @@ def fake_formatter(): format="fake_format", ) + async def test_lazy_load_data(self, async_engine): + with pytest.raises(Exception, match=sync_method_exception_str): + reader = await AsyncAlloyDBReader.create( + engine=async_engine, + table_name=default_table_name_async, + ) + + reader.lazy_load_data() + + async def test_load_data(self, async_engine): + with pytest.raises(Exception, match=sync_method_exception_str): + reader = await AsyncAlloyDBReader.create( + engine=async_engine, + table_name=default_table_name_async, + ) + + reader.load_data() + async def test_load_from_query_default(self, async_engine): table_name = "test-table" + str(uuid.uuid4()) query = f""" From 0ebc4c896496717646c4a771c9e58fc243004097 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Thu, 30 Jan 2025 20:23:10 +0100 Subject: [PATCH 35/39] chore(deps): update actions/setup-python action to v5.4.0 (#68) --- .github/workflows/lint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 28d3312..fbd4535 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -34,7 +34,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 - name: Setup Python - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 with: python-version: "3.11" From 063811c420ea9f02f7fe244063d41bab61b9f07d Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Thu, 30 Jan 2025 20:37:47 +0100 Subject: [PATCH 36/39] chore(deps): update dependency black to v25 (#69) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 94a41ad..d0b14aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ Changelog = "https://github.com/googleapis/llama-index-alloydb-pg-python/blob/ma [project.optional-dependencies] test = [ - "black[jupyter]==24.10.0", + "black[jupyter]==25.1.0", "isort==5.13.2", "mypy==1.14.1", "pytest-asyncio==0.25.2", From 8d51750b5d05afaa042cee5fc0fb1be2402d85e5 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Thu, 30 Jan 2025 20:42:02 +0100 Subject: [PATCH 37/39] chore(deps): update dependency isort to v6 (#67) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d0b14aa..5b669cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ Changelog = "https://github.com/googleapis/llama-index-alloydb-pg-python/blob/ma [project.optional-dependencies] test = [ "black[jupyter]==25.1.0", - "isort==5.13.2", + "isort==6.0.0", "mypy==1.14.1", "pytest-asyncio==0.25.2", "pytest==8.3.4", From 1ac6d03d8d55aebe9ca4908c95593245aa57980b Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Thu, 30 Jan 2025 20:56:40 +0100 Subject: [PATCH 38/39] chore(deps): update python-nonmajor (#63) --- pyproject.toml | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5b669cc..04d7c1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ test = [ "black[jupyter]==25.1.0", "isort==6.0.0", "mypy==1.14.1", - "pytest-asyncio==0.25.2", + "pytest-asyncio==0.25.3", "pytest==8.3.4", "pytest-cov==6.0.0", "pytest-depends==1.0.1", diff --git a/requirements.txt b/requirements.txt index 7fe38ac..ee5ab04 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ google-cloud-alloydb-connector[asyncpg]==1.7.0 -llama-index-core==0.12.13 +llama-index-core==0.12.14 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.37 From 61d1c065a28cbba1e053ea4a5632c057d2cfd7b8 Mon Sep 17 00:00:00 2001 From: "release-please[bot]" <55107282+release-please[bot]@users.noreply.github.com> Date: Thu, 30 Jan 2025 13:55:17 -0800 Subject: [PATCH 39/39] chore(main): release 0.2.0 (#32) Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> --- CHANGELOG.md | 19 +++++++++++++++++++ src/llama_index_alloydb_pg/version.py | 2 +- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a224203..657d2b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,24 @@ # Changelog +## [0.2.0](https://github.com/googleapis/llama-index-alloydb-pg-python/compare/v0.1.0...v0.2.0) (2025-01-30) + + +### Features + +* Adding AlloyDB Chat Store ([#37](https://github.com/googleapis/llama-index-alloydb-pg-python/issues/37)) ([320b448](https://github.com/googleapis/llama-index-alloydb-pg-python/commit/320b448fc60b2a41c4b3e1b90084d799319260eb)) +* Adding AlloyDB Reader ([#57](https://github.com/googleapis/llama-index-alloydb-pg-python/issues/57)) ([7314d83](https://github.com/googleapis/llama-index-alloydb-pg-python/commit/7314d835e62ccd7e8fe59b35f37dccaaee6aed36)) +* Adding Async AlloyDB Reader ([#55](https://github.com/googleapis/llama-index-alloydb-pg-python/issues/55)) ([56e6479](https://github.com/googleapis/llama-index-alloydb-pg-python/commit/56e64790c8eb85979d60b87366adb46596232e24)) +* Adding Async Chat Store ([#35](https://github.com/googleapis/llama-index-alloydb-pg-python/issues/35)) ([dd98771](https://github.com/googleapis/llama-index-alloydb-pg-python/commit/dd987718f0482177d03c84eee6334703613461d0)) +* Adding chat store init methods. ([#29](https://github.com/googleapis/llama-index-alloydb-pg-python/issues/29)) ([de53006](https://github.com/googleapis/llama-index-alloydb-pg-python/commit/de53006d00fe1edd5b3e5c1349613e82f0c94794)) + + +### Bug Fixes + +* Change default metadata_json_column default value ([#66](https://github.com/googleapis/llama-index-alloydb-pg-python/issues/66)) ([ecb53c8](https://github.com/googleapis/llama-index-alloydb-pg-python/commit/ecb53c80d311deb9232f0f8844761a816fc01bc0)) +* Programming error while setting multiple query option ([#47](https://github.com/googleapis/llama-index-alloydb-pg-python/issues/47)) ([5f1405e](https://github.com/googleapis/llama-index-alloydb-pg-python/commit/5f1405ed7ba7941c9c9a4370a428c720d857e6af)) +* Query and return only selected metadata columns ([#52](https://github.com/googleapis/llama-index-alloydb-pg-python/issues/52)) ([dff623b](https://github.com/googleapis/llama-index-alloydb-pg-python/commit/dff623bf8d340811ed88271e59b11d0f996cc811)) +* Update lazy_load_data return type to Iterable. ([#61](https://github.com/googleapis/llama-index-alloydb-pg-python/issues/61)) ([98f6c65](https://github.com/googleapis/llama-index-alloydb-pg-python/commit/98f6c65fc77cbb7f25b22d7118bbb89f3c674b2f)) + ## 0.1.0 (2024-12-03) diff --git a/src/llama_index_alloydb_pg/version.py b/src/llama_index_alloydb_pg/version.py index c1c8212..20c5861 100644 --- a/src/llama_index_alloydb_pg/version.py +++ b/src/llama_index_alloydb_pg/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.1.0" +__version__ = "0.2.0"