Skip to content

Commit a1c9411

Browse files
feat: add MySQLVectorStore initialization methods (#52)
Add constructor and classmethods for MySQLVectorStore.
1 parent 3439c9d commit a1c9411

File tree

8 files changed

+713
-3
lines changed

8 files changed

+713
-3
lines changed

integration.cloudbuild.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ availableSecrets:
4343
env: "DB_PASSWORD"
4444

4545
substitutions:
46-
_INSTANCE_ID: test-instance
46+
_INSTANCE_ID: mysql-vector
4747
_REGION: us-central1
4848
_DB_NAME: test
4949
_VERSION: "3.8"

src/langchain_google_cloud_sql_mysql/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@
1313
# limitations under the License.
1414

1515
from .chat_message_history import MySQLChatMessageHistory
16-
from .engine import MySQLEngine
16+
from .engine import Column, MySQLEngine
1717
from .loader import MySQLDocumentSaver, MySQLLoader
18+
from .vectorstore import MySQLVectorStore
1819
from .version import __version__
1920

2021
__all__ = [
22+
"Column",
2123
"MySQLChatMessageHistory",
2224
"MySQLDocumentSaver",
2325
"MySQLEngine",
2426
"MySQLLoader",
27+
"MySQLVectorStore",
2528
"__version__",
2629
]

src/langchain_google_cloud_sql_mysql/chat_message_history.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class MySQLChatMessageHistory(BaseChatMessageHistory):
2525
"""Chat message history stored in a Cloud SQL MySQL database.
2626
2727
Args:
28-
engine (MySQLEngine): SQLAlchemy connection pool engine for managing
28+
engine (MySQLEngine): Connection pool engine for managing
2929
connections to Cloud SQL for MySQL.
3030
session_id (str): Arbitrary key that is used to store the messages
3131
of a single chat session.

src/langchain_google_cloud_sql_mysql/engine.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,21 @@
3131

3232
USER_AGENT = "langchain-google-cloud-sql-mysql-python/" + __version__
3333

34+
from dataclasses import dataclass
35+
36+
37+
@dataclass
38+
class Column:
39+
name: str
40+
data_type: str
41+
nullable: bool = True
42+
43+
def __post_init__(self):
44+
if not isinstance(self.name, str):
45+
raise ValueError("Column name must be type string")
46+
if not isinstance(self.data_type, str):
47+
raise ValueError("Column data_type must be type string")
48+
3449

3550
def _get_iam_principal_email(
3651
credentials: google.auth.credentials.Credentials,
@@ -206,6 +221,20 @@ def connect(self) -> sqlalchemy.engine.Connection:
206221
"""
207222
return self.engine.connect()
208223

224+
def _execute(self, query: str, params: Optional[dict] = None) -> None:
225+
"""Execute a SQL query."""
226+
with self.engine.connect() as conn:
227+
conn.execute(sqlalchemy.text(query), params)
228+
conn.commit()
229+
230+
def _fetch(self, query: str, params: Optional[dict] = None):
231+
"""Fetch results from a SQL query."""
232+
with self.engine.connect() as conn:
233+
result = conn.execute(sqlalchemy.text(query), params)
234+
result_map = result.mappings()
235+
result_fetch = result_map.fetchall()
236+
return result_fetch
237+
209238
def init_chat_history_table(self, table_name: str) -> None:
210239
"""Create table with schema required for MySQLChatMessageHistory class.
211240
@@ -293,3 +322,51 @@ def _load_document_table(self, table_name: str) -> sqlalchemy.Table:
293322
metadata = sqlalchemy.MetaData()
294323
sqlalchemy.MetaData.reflect(metadata, bind=self.engine, only=[table_name])
295324
return metadata.tables[table_name]
325+
326+
def init_vectorstore_table(
327+
self,
328+
table_name: str,
329+
vector_size: int,
330+
content_column: str = "content",
331+
embedding_column: str = "embedding",
332+
metadata_columns: List[Column] = [],
333+
metadata_json_column: str = "langchain_metadata",
334+
id_column: str = "langchain_id",
335+
overwrite_existing: bool = False,
336+
store_metadata: bool = True,
337+
) -> None:
338+
"""
339+
Create a table for saving of vectors to be used with MySQLVectorStore.
340+
341+
Args:
342+
table_name (str): The MySQL database table name.
343+
vector_size (int): Vector size for the embedding model to be used.
344+
content_column (str): Name of the column to store document content.
345+
Deafult: `page_content`.
346+
embedding_column (str) : Name of the column to store vector embeddings.
347+
Default: `embedding`.
348+
metadata_columns (List[Column]): A list of Columns to create for custom
349+
metadata. Default: []. Optional.
350+
metadata_json_column (str): The column to store extra metadata in JSON format.
351+
Default: `langchain_metadata`. Optional.
352+
id_column (str): Name of the column to store ids.
353+
Default: `langchain_id`. Optional,
354+
overwrite_existing (bool): Whether to drop existing table. Default: False.
355+
store_metadata (bool): Whether to store metadata in the table.
356+
Default: True.
357+
"""
358+
query = f"""CREATE TABLE `{table_name}`(
359+
`{id_column}` CHAR(36) PRIMARY KEY,
360+
`{content_column}` TEXT NOT NULL,
361+
`{embedding_column}` vector({vector_size}) USING VARBINARY NOT NULL"""
362+
for column in metadata_columns:
363+
nullable = "NOT NULL" if not column.nullable else ""
364+
query += f",\n`{column.name}` {column.data_type} {nullable}"
365+
if store_metadata:
366+
query += f""",\n`{metadata_json_column}` JSON"""
367+
query += "\n);"
368+
369+
with self.engine.connect() as conn:
370+
if overwrite_existing:
371+
conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS `{table_name}`"))
372+
conn.execute(sqlalchemy.text(query))
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from abc import ABC
16+
from dataclasses import dataclass
17+
18+
19+
@dataclass
20+
class QueryOptions(ABC):
21+
def to_string(self) -> str:
22+
raise NotImplementedError("to_string method must be implemented by subclass")

0 commit comments

Comments
 (0)