Skip to content

Commit f1d117e

Browse files
authored
feat: support saving with customized content column and saving/loading with non-default metadata JSON column. (#12)
* feat: support non-default content column and metadata json column * fix: fix tests
1 parent dd79a93 commit f1d117e

File tree

4 files changed

+144
-50
lines changed

4 files changed

+144
-50
lines changed

src/langchain_google_cloud_sql_mssql/mssql_engine.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@ def init_document_table(
153153
self,
154154
table_name: str,
155155
metadata_columns: List[sqlalchemy.Column] = [],
156-
store_metadata: bool = True,
156+
content_column: str = "page_content",
157+
metadata_json_column: Optional[str] = "langchain_metadata",
158+
overwrite_existing: bool = False,
157159
) -> None:
158160
"""
159161
Create a table for saving of langchain documents.
@@ -162,22 +164,29 @@ def init_document_table(
162164
table_name (str): The MSSQL database table name.
163165
metadata_columns (List[sqlalchemy.Column]): A list of SQLAlchemy Columns
164166
to create for custom metadata. Optional.
165-
store_metadata (bool): Whether to store extra metadata in a metadata column
166-
if not described in 'metadata' field list (Default: True).
167+
content_column (str): The column to store document content.
168+
Deafult: `page_content`.
169+
metadata_json_column (Optional[str]): The column to store extra metadata in JSON format.
170+
Default: `langchain_metadata`. Optional.
171+
overwrite_existing (bool): Whether to drop existing table. Default: False.
167172
"""
173+
if overwrite_existing:
174+
with self.engine.connect() as conn:
175+
conn.execute(sqlalchemy.text(f'DROP TABLE IF EXISTS "{table_name}";'))
176+
168177
columns = [
169178
sqlalchemy.Column(
170-
"page_content",
179+
content_column,
171180
sqlalchemy.UnicodeText,
172181
primary_key=False,
173182
nullable=False,
174183
)
175184
]
176185
columns += metadata_columns
177-
if store_metadata:
186+
if metadata_json_column:
178187
columns.append(
179188
sqlalchemy.Column(
180-
"langchain_metadata",
189+
metadata_json_column,
181190
sqlalchemy.JSON,
182191
primary_key=False,
183192
nullable=True,

src/langchain_google_cloud_sql_mssql/mssql_loader.py

Lines changed: 88 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,33 +26,41 @@
2626

2727

2828
def _parse_doc_from_row(
29-
content_columns: Iterable[str], metadata_columns: Iterable[str], row: Dict
29+
content_columns: Iterable[str],
30+
metadata_columns: Iterable[str],
31+
row: Dict,
32+
metadata_json_column: str = DEFAULT_METADATA_COL,
3033
) -> Document:
3134
page_content = " ".join(
3235
str(row[column]) for column in content_columns if column in row
3336
)
3437
metadata: Dict[str, Any] = {}
3538
# unnest metadata from langchain_metadata column
36-
if DEFAULT_METADATA_COL in metadata_columns and row.get(DEFAULT_METADATA_COL):
37-
for k, v in row[DEFAULT_METADATA_COL].items():
39+
if row.get(metadata_json_column):
40+
for k, v in row[metadata_json_column].items():
3841
metadata[k] = v
3942
# load metadata from other columns
4043
for column in metadata_columns:
41-
if column in row and column != DEFAULT_METADATA_COL:
44+
if column in row and column != metadata_json_column:
4245
metadata[column] = row[column]
4346
return Document(page_content=page_content, metadata=metadata)
4447

4548

46-
def _parse_row_from_doc(column_names: Iterable[str], doc: Document) -> Dict:
49+
def _parse_row_from_doc(
50+
column_names: Iterable[str],
51+
doc: Document,
52+
content_column: str = DEFAULT_CONTENT_COL,
53+
metadata_json_column: str = DEFAULT_METADATA_COL,
54+
) -> Dict:
4755
doc_metadata = doc.metadata.copy()
48-
row: Dict[str, Any] = {DEFAULT_CONTENT_COL: doc.page_content}
56+
row: Dict[str, Any] = {content_column: doc.page_content}
4957
for entry in doc.metadata:
5058
if entry in column_names:
5159
row[entry] = doc_metadata[entry]
5260
del doc_metadata[entry]
5361
# store extra metadata in langchain_metadata column in json format
54-
if DEFAULT_METADATA_COL in column_names and len(doc_metadata) > 0:
55-
row[DEFAULT_METADATA_COL] = doc_metadata
62+
if metadata_json_column in column_names and len(doc_metadata) > 0:
63+
row[metadata_json_column] = doc_metadata
5664
return row
5765

5866

@@ -66,6 +74,7 @@ def __init__(
6674
query: str = "",
6775
content_columns: Optional[List[str]] = None,
6876
metadata_columns: Optional[List[str]] = None,
77+
metadata_json_column: Optional[str] = None,
6978
):
7079
"""
7180
Document page content defaults to the first column present in the query or table and
@@ -77,19 +86,22 @@ def __init__(
7786
space-separated string concatenation.
7887
7988
Args:
80-
engine (MSSQLEngine): MSSQLEngine object to connect to the MSSQL database.
81-
table_name (str): The MSSQL database table name. (OneOf: table_name, query).
82-
query (str): The query to execute in MSSQL format. (OneOf: table_name, query).
83-
content_columns (List[str]): The columns to write into the `page_content`
84-
of the document. Optional.
85-
metadata_columns (List[str]): The columns to write into the `metadata` of the document.
86-
Optional.
89+
engine (MSSQLEngine): MSSQLEngine object to connect to the MSSQL database.
90+
table_name (str): The MSSQL database table name. (OneOf: table_name, query).
91+
query (str): The query to execute in MSSQL format. (OneOf: table_name, query).
92+
content_columns (List[str]): The columns to write into the `page_content`
93+
of the document. Optional.
94+
metadata_columns (List[str]): The columns to write into the `metadata` of the document.
95+
Optional.
96+
metadata_json_column (str): The name of the JSON column to use as the metadata’s base
97+
dictionary. Default: `langchain_metadata`. Optional.
8798
"""
8899
self.engine = engine
89100
self.table_name = table_name
90101
self.query = query
91102
self.content_columns = content_columns
92103
self.metadata_columns = metadata_columns
104+
self.metadata_json_column = metadata_json_column
93105
if not self.table_name and not self.query:
94106
raise ValueError("One of 'table_name' or 'query' must be specified.")
95107
if self.table_name and self.query:
@@ -128,6 +140,25 @@ def lazy_load(self) -> Iterator[Document]:
128140
metadata_columns = self.metadata_columns or [
129141
col for col in column_names if col not in content_columns
130142
]
143+
# check validity of metadata json column
144+
if (
145+
self.metadata_json_column
146+
and self.metadata_json_column not in column_names
147+
):
148+
raise ValueError(
149+
f"Column {self.metadata_json_column} not found in query result {column_names}."
150+
)
151+
# check validity of other column
152+
all_names = content_columns + metadata_columns
153+
for name in all_names:
154+
if name not in column_names:
155+
raise ValueError(
156+
f"Column {name} not found in query result {column_names}."
157+
)
158+
# use default metadata json column if not specified
159+
metadata_json_column = self.metadata_json_column or DEFAULT_METADATA_COL
160+
161+
# load document one by one
131162
while True:
132163
row = result_proxy.fetchone()
133164
if not row:
@@ -136,11 +167,13 @@ def lazy_load(self) -> Iterator[Document]:
136167
row_data = {}
137168
for column in column_names:
138169
value = getattr(row, column)
139-
if column == DEFAULT_METADATA_COL:
170+
if column == metadata_json_column:
140171
row_data[column] = json.loads(value)
141172
else:
142173
row_data[column] = value
143-
yield _parse_doc_from_row(content_columns, metadata_columns, row_data)
174+
yield _parse_doc_from_row(
175+
content_columns, metadata_columns, row_data, metadata_json_column
176+
)
144177

145178

146179
class MSSQLDocumentSaver:
@@ -150,6 +183,8 @@ def __init__(
150183
self,
151184
engine: MSSQLEngine,
152185
table_name: str,
186+
content_column: Optional[str] = None,
187+
metadata_json_column: Optional[str] = None,
153188
):
154189
"""
155190
MSSQLDocumentSaver allows for saving of langchain documents in a database. If the table
@@ -160,14 +195,28 @@ def __init__(
160195
Args:
161196
engine: MSSQLEngine object to connect to the MSSQL database.
162197
table_name: The name of table for saving documents.
198+
content_column (str): The column to store document content.
199+
Deafult: `page_content`. Optional.
200+
metadata_json_column (str): The name of the JSON column to use as the metadata’s base
201+
dictionary. Default: `langchain_metadata`. Optional.
163202
"""
164203
self.engine = engine
165204
self.table_name = table_name
166205
self._table = self.engine._load_document_table(table_name)
167-
if DEFAULT_CONTENT_COL not in self._table.columns.keys():
206+
self.content_column = content_column or DEFAULT_CONTENT_COL
207+
if self.content_column not in self._table.columns.keys():
208+
raise ValueError(
209+
f"Missing '{self.content_column}' field in table {table_name}."
210+
)
211+
# check metadata_json_column existence if it's provided.
212+
if (
213+
metadata_json_column
214+
and metadata_json_column not in self._table.columns.keys()
215+
):
168216
raise ValueError(
169-
f"Missing '{DEFAULT_CONTENT_COL}' field in table {table_name}."
217+
f"Cannot find '{metadata_json_column}' column in table {table_name}."
170218
)
219+
self.metadata_json_column = metadata_json_column or DEFAULT_METADATA_COL
171220

172221
def add_documents(self, docs: List[Document]) -> None:
173222
"""
@@ -179,9 +228,16 @@ def add_documents(self, docs: List[Document]) -> None:
179228
"""
180229
with self.engine.connect() as conn:
181230
for doc in docs:
182-
row = _parse_row_from_doc(self._table.columns.keys(), doc)
183-
if DEFAULT_METADATA_COL in row:
184-
row[DEFAULT_METADATA_COL] = json.dumps(row[DEFAULT_METADATA_COL])
231+
row = _parse_row_from_doc(
232+
self._table.columns.keys(),
233+
doc,
234+
self.content_column,
235+
self.metadata_json_column,
236+
)
237+
if self.metadata_json_column in row:
238+
row[self.metadata_json_column] = json.dumps(
239+
row[self.metadata_json_column]
240+
)
185241
conn.execute(sqlalchemy.insert(self._table).values(row))
186242
conn.commit()
187243

@@ -195,9 +251,16 @@ def delete(self, docs: List[Document]) -> None:
195251
"""
196252
with self.engine.connect() as conn:
197253
for doc in docs:
198-
row = _parse_row_from_doc(self._table.columns.keys(), doc)
199-
if DEFAULT_METADATA_COL in row:
200-
row[DEFAULT_METADATA_COL] = json.dumps(row[DEFAULT_METADATA_COL])
254+
row = _parse_row_from_doc(
255+
self._table.columns.keys(),
256+
doc,
257+
self.content_column,
258+
self.metadata_json_column,
259+
)
260+
if self.metadata_json_column in row:
261+
row[self.metadata_json_column] = json.dumps(
262+
row[self.metadata_json_column]
263+
)
201264
# delete by matching all fields of document
202265
where_conditions = []
203266
for col in self._table.columns:

tests/integration/test_mssql_loader.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,6 @@ def test_load_from_query_with_langchain_metadata(engine):
266266
query=query,
267267
metadata_columns=[
268268
"fruit_name",
269-
"langchain_metadata",
270269
],
271270
)
272271

@@ -311,8 +310,9 @@ def test_save_doc_with_default_metadata(engine):
311310
]
312311

313312

314-
@pytest.mark.parametrize("store_metadata", [True, False])
315-
def test_save_doc_with_customized_metadata(engine, store_metadata):
313+
@pytest.mark.parametrize("metadata_json_column", [None, "metadata_col_test"])
314+
def test_save_doc_with_customized_metadata(engine, metadata_json_column):
315+
content_column = "content_col_test"
316316
engine.init_document_table(
317317
table_name,
318318
metadata_columns=[
@@ -329,35 +329,43 @@ def test_save_doc_with_customized_metadata(engine, store_metadata):
329329
nullable=True,
330330
),
331331
],
332-
store_metadata=store_metadata,
332+
content_column=content_column,
333+
metadata_json_column=metadata_json_column,
334+
overwrite_existing=True,
333335
)
334336
test_docs = [
335337
Document(
336338
page_content="Granny Smith 150 0.99",
337339
metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1},
338340
),
339341
]
340-
saver = MSSQLDocumentSaver(engine=engine, table_name=table_name)
342+
saver = MSSQLDocumentSaver(
343+
engine=engine,
344+
table_name=table_name,
345+
content_column=content_column,
346+
metadata_json_column=metadata_json_column,
347+
)
341348
loader = MSSQLLoader(
342349
engine=engine,
343350
table_name=table_name,
351+
content_columns=[content_column],
344352
metadata_columns=[
345-
"fruit_id",
346353
"fruit_name",
347354
"organic",
348355
],
356+
metadata_json_column=metadata_json_column,
349357
)
350358

351359
saver.add_documents(test_docs)
352360
docs = loader.load()
353361

354-
if store_metadata:
362+
if metadata_json_column:
355363
docs == test_docs
356364
assert engine._load_document_table(table_name).columns.keys() == [
357-
"page_content",
365+
content_column,
358366
"fruit_name",
359367
"organic",
360-
"langchain_metadata",
368+
metadata_json_column,
361369
]
362370
else:
363371
assert docs == [
@@ -367,7 +375,7 @@ def test_save_doc_with_customized_metadata(engine, store_metadata):
367375
),
368376
]
369377
assert engine._load_document_table(table_name).columns.keys() == [
370-
"page_content",
378+
content_column,
371379
"fruit_name",
372380
"organic",
373381
]
@@ -376,7 +384,7 @@ def test_save_doc_with_customized_metadata(engine, store_metadata):
376384
def test_save_doc_without_metadata(engine):
377385
engine.init_document_table(
378386
table_name,
379-
store_metadata=False,
387+
metadata_json_column=None,
380388
)
381389
test_docs = [
382390
Document(
@@ -430,8 +438,9 @@ def test_delete_doc_with_default_metadata(engine):
430438
assert len(loader.load()) == 0
431439

432440

433-
@pytest.mark.parametrize("store_metadata", [True, False])
434-
def test_delete_doc_with_customized_metadata(engine, store_metadata):
441+
@pytest.mark.parametrize("metadata_json_column", [None, "metadata_col_test"])
442+
def test_delete_doc_with_customized_metadata(engine, metadata_json_column):
443+
content_column = "content_col_test"
435444
engine.init_document_table(
436445
table_name,
437446
metadata_columns=[
@@ -448,7 +457,9 @@ def test_delete_doc_with_customized_metadata(engine, store_metadata):
448457
nullable=True,
449458
),
450459
],
451-
store_metadata=store_metadata,
460+
content_column=content_column,
461+
metadata_json_column=metadata_json_column,
462+
overwrite_existing=True,
452463
)
453464
test_docs = [
454465
Document(
@@ -460,8 +471,18 @@ def test_delete_doc_with_customized_metadata(engine, store_metadata):
460471
metadata={"fruit_id": 2, "fruit_name": "Banana", "organic": 1},
461472
),
462473
]
463-
saver = MSSQLDocumentSaver(engine=engine, table_name=table_name)
464-
loader = MSSQLLoader(engine=engine, table_name=table_name)
474+
saver = MSSQLDocumentSaver(
475+
engine=engine,
476+
table_name=table_name,
477+
content_column=content_column,
478+
metadata_json_column=metadata_json_column,
479+
)
480+
loader = MSSQLLoader(
481+
engine=engine,
482+
table_name=table_name,
483+
content_columns=[content_column],
484+
metadata_json_column=metadata_json_column,
485+
)
465486

466487
saver.add_documents(test_docs)
467488
docs = loader.load()
@@ -491,7 +512,6 @@ def test_delete_doc_with_query(engine):
491512
nullable=True,
492513
),
493514
],
494-
store_metadata=True,
495515
)
496516
test_docs = [
497517
Document(

0 commit comments

Comments
 (0)