|
31 | 31 |
|
32 | 32 | USER_AGENT = "langchain-google-cloud-sql-mysql-python/" + __version__
|
33 | 33 |
|
| 34 | +from dataclasses import dataclass |
| 35 | + |
| 36 | + |
| 37 | +@dataclass |
| 38 | +class Column: |
| 39 | + name: str |
| 40 | + data_type: str |
| 41 | + nullable: bool = True |
| 42 | + |
| 43 | + def __post_init__(self): |
| 44 | + if not isinstance(self.name, str): |
| 45 | + raise ValueError("Column name must be type string") |
| 46 | + if not isinstance(self.data_type, str): |
| 47 | + raise ValueError("Column data_type must be type string") |
| 48 | + |
34 | 49 |
|
35 | 50 | def _get_iam_principal_email(
|
36 | 51 | credentials: google.auth.credentials.Credentials,
|
@@ -206,6 +221,20 @@ def connect(self) -> sqlalchemy.engine.Connection:
|
206 | 221 | """
|
207 | 222 | return self.engine.connect()
|
208 | 223 |
|
| 224 | + def _execute(self, query: str, params: Optional[dict] = None) -> None: |
| 225 | + """Execute a SQL query.""" |
| 226 | + with self.engine.connect() as conn: |
| 227 | + conn.execute(sqlalchemy.text(query), params) |
| 228 | + conn.commit() |
| 229 | + |
| 230 | + def _fetch(self, query: str, params: Optional[dict] = None): |
| 231 | + """Fetch results from a SQL query.""" |
| 232 | + with self.engine.connect() as conn: |
| 233 | + result = conn.execute(sqlalchemy.text(query), params) |
| 234 | + result_map = result.mappings() |
| 235 | + result_fetch = result_map.fetchall() |
| 236 | + return result_fetch |
| 237 | + |
209 | 238 | def init_chat_history_table(self, table_name: str) -> None:
|
210 | 239 | """Create table with schema required for MySQLChatMessageHistory class.
|
211 | 240 |
|
@@ -293,3 +322,51 @@ def _load_document_table(self, table_name: str) -> sqlalchemy.Table:
|
293 | 322 | metadata = sqlalchemy.MetaData()
|
294 | 323 | sqlalchemy.MetaData.reflect(metadata, bind=self.engine, only=[table_name])
|
295 | 324 | return metadata.tables[table_name]
|
| 325 | + |
| 326 | + def init_vectorstore_table( |
| 327 | + self, |
| 328 | + table_name: str, |
| 329 | + vector_size: int, |
| 330 | + content_column: str = "content", |
| 331 | + embedding_column: str = "embedding", |
| 332 | + metadata_columns: List[Column] = [], |
| 333 | + metadata_json_column: str = "langchain_metadata", |
| 334 | + id_column: str = "langchain_id", |
| 335 | + overwrite_existing: bool = False, |
| 336 | + store_metadata: bool = True, |
| 337 | + ) -> None: |
| 338 | + """ |
| 339 | + Create a table for saving of vectors to be used with MySQLVectorStore. |
| 340 | +
|
| 341 | + Args: |
| 342 | + table_name (str): The MySQL database table name. |
| 343 | + vector_size (int): Vector size for the embedding model to be used. |
| 344 | + content_column (str): Name of the column to store document content. |
| 345 | + Deafult: `page_content`. |
| 346 | + embedding_column (str) : Name of the column to store vector embeddings. |
| 347 | + Default: `embedding`. |
| 348 | + metadata_columns (List[Column]): A list of Columns to create for custom |
| 349 | + metadata. Default: []. Optional. |
| 350 | + metadata_json_column (str): The column to store extra metadata in JSON format. |
| 351 | + Default: `langchain_metadata`. Optional. |
| 352 | + id_column (str): Name of the column to store ids. |
| 353 | + Default: `langchain_id`. Optional, |
| 354 | + overwrite_existing (bool): Whether to drop existing table. Default: False. |
| 355 | + store_metadata (bool): Whether to store metadata in the table. |
| 356 | + Default: True. |
| 357 | + """ |
| 358 | + query = f"""CREATE TABLE `{table_name}`( |
| 359 | + `{id_column}` CHAR(36) PRIMARY KEY, |
| 360 | + `{content_column}` TEXT NOT NULL, |
| 361 | + `{embedding_column}` vector({vector_size}) USING VARBINARY NOT NULL""" |
| 362 | + for column in metadata_columns: |
| 363 | + nullable = "NOT NULL" if not column.nullable else "" |
| 364 | + query += f",\n`{column.name}` {column.data_type} {nullable}" |
| 365 | + if store_metadata: |
| 366 | + query += f""",\n`{metadata_json_column}` JSON""" |
| 367 | + query += "\n);" |
| 368 | + |
| 369 | + with self.engine.connect() as conn: |
| 370 | + if overwrite_existing: |
| 371 | + conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS `{table_name}`")) |
| 372 | + conn.execute(sqlalchemy.text(query)) |
0 commit comments