diff --git a/.github/renovate.json5 b/.github/renovate.json5 index 868655c3..995819bb 100644 --- a/.github/renovate.json5 +++ b/.github/renovate.json5 @@ -1,7 +1,7 @@ { "extends": [ "config:base", // https://docs.renovatebot.com/presets-config/#configbase - ":semanticCommits", // https://docs.renovatebot.com/presets-default/#semanticcommits + ":semanticCommitTypeAll(chore)", // https://docs.renovatebot.com/presets-default/#semanticcommittypeallarg0 ":ignoreUnstable", // https://docs.renovatebot.com/presets-default/#ignoreunstable "group:allNonMajor", // https://docs.renovatebot.com/presets-group/#groupallnonmajor ":separateMajorReleases", // https://docs.renovatebot.com/presets-default/#separatemajorreleases diff --git a/.github/workflows/schedule_reporter.yml b/.github/workflows/schedule_reporter.yml index 3549f373..8a974d66 100644 --- a/.github/workflows/schedule_reporter.yml +++ b/.github/workflows/schedule_reporter.yml @@ -20,6 +20,10 @@ on: jobs: run_reporter: - uses: googleapis/langchain-google-alloydb-pg-python/.github/workflows/cloud_build_failure_reporter.yml@main + permissions: + issues: 'write' + checks: 'read' + contents: 'read' + uses: googleapis/langchain-google-alloydb-pg-python/.github/workflows/cloud_build_failure_reporter.yml@074f9932a8099256ff210771473badbd2156713b with: trigger_names: "pg-integration-test-nightly,pg-continuous-test-on-merge" diff --git a/.kokoro/requirements.txt b/.kokoro/requirements.txt index cd183f45..01481136 100644 --- a/.kokoro/requirements.txt +++ b/.kokoro/requirements.txt @@ -276,9 +276,9 @@ jeepney==0.8.0 \ # via # keyring # secretstorage -jinja2==3.1.5 \ - --hash=sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb \ - --hash=sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb +jinja2==3.1.6 \ + --hash=sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d \ + --hash=sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67 # via gcp-releasetool keyring==24.3.1 \ --hash=sha256:c3327b6ffafc0e8befbdb597cacdb4928ffe5c1212f7645f186e6d9957a898db \ diff --git a/CHANGELOG.md b/CHANGELOG.md index d62a32e6..b05ddb37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## [0.13.0](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/compare/v0.12.1...v0.13.0) (2025-03-17) + + +### Features + +* **langgraph:** Add Langgraph Checkpointer ([#284](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/issues/284)) ([14a4240](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/14a4240e1e5769b203e9463d016d08ac2e6f603e)) + + +### Bug Fixes + +* **deps:** Update dependency numpy to v2 ([#251](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/issues/251)) ([a164aa2](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/a164aa2d54575461e993594a1c98b8fac0e06ea2)) +* **engine:** Loop error on close ([#285](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/issues/285)) ([e8bd4ae](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/e8bd4ae9f03a3e60af0a1335d976423b0ae6e41a)) + ## [0.12.1](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/compare/v0.12.0...v0.12.1) (2025-02-12) diff --git a/README.rst b/README.rst index 9f76d6b5..c5da0942 100644 --- a/README.rst +++ b/README.rst @@ -158,6 +158,22 @@ See the full `Chat Message History`_ tutorial. .. _`Chat Message History`: https://github.com/googleapis/langchain-google-cloud-sql-pg-python/tree/main/docs/chat_message_history.ipynb +Langgraph Checkpoint Usage +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Use ``PostgresSaver`` to save snapshots of the graph state at a given point in time. + +.. code:: python + + from langchain_google_cloud_sql_pg import PostgresSaver, PostgresEngine + + engine = PostgresEngine.from_instance("project-id", "region", "my-instance", "my-database") + checkpoint = PostgresSaver.create_sync(engine) + +See the full `Checkpoint`_ tutorial. + +.. _`Checkpoint`: https://github.com/googleapis/langchain-google-cloud-sql-pg-python/tree/main/docs/langgraph_checkpointer.ipynb + Contributions ~~~~~~~~~~~~~ diff --git a/docs/chat_message_history.ipynb b/docs/chat_message_history.ipynb index 02e6f04f..c7f37933 100644 --- a/docs/chat_message_history.ipynb +++ b/docs/chat_message_history.ipynb @@ -1,581 +1,587 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "f22eab3f84cbeb37", - "metadata": { - "id": "f22eab3f84cbeb37" - }, - "source": [ - "# Google Cloud SQL for PostgreSQL\n", - "\n", - "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers MySQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's Langchain integrations.\n", - "\n", - "This notebook goes over how to use `Cloud SQL for PostgreSQL` to store chat message history with the `PostgresChatMessageHistory` class.\n", - "\n", - "Learn more about the package on [GitHub](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/).\n", - "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/langchain-google-cloud-sql-pg-python/blob/main/docs/chat_message_history.ipynb)" - ] - }, - { - "cell_type": "markdown", - "id": "da400c79-a360-43e2-be60-401fd02b2819", - "metadata": { - "id": "da400c79-a360-43e2-be60-401fd02b2819" - }, - "source": [ - "## Before You Begin\n", - "\n", - "To run this notebook, you will need to do the following:\n", - "\n", - " * [Create a Google Cloud Project](https://developers.google.com/workspace/guides/create-project)\n", - " * [Enable the Cloud SQL Admin API.](https://console.cloud.google.com/marketplace/product/google/sqladmin.googleapis.com)\n", - " * [Create a Cloud SQL for PostgreSQL instance](https://cloud.google.com/sql/docs/postgres/create-instance)\n", - " * [Create a Cloud SQL database](https://cloud.google.com/sql/docs/mysql/create-manage-databases)\n", - " * [Add an IAM database user to the database](https://cloud.google.com/sql/docs/postgres/add-manage-iam-users#creating-a-database-user) (Optional)" - ] - }, - { - "cell_type": "markdown", - "id": "Mm7-fG_LltD7", - "metadata": { - "id": "Mm7-fG_LltD7" - }, - "source": [ - "### 🦜🔗 Library Installation\n", - "The integration lives in its own `langchain-google-cloud-sql-pg` package, so we need to install it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1VELXvcj8AId", - "metadata": { - "id": "1VELXvcj8AId" - }, - "outputs": [], - "source": [ - "%pip install --upgrade --quiet langchain-google-cloud-sql-pg langchain-google-vertexai" - ] - }, - { - "cell_type": "markdown", - "id": "98TVoM3MNDHu", - "metadata": { - "id": "98TVoM3MNDHu" - }, - "source": [ - "**Colab only:** Uncomment the following cell to restart the kernel or use the button to restart the kernel. For Vertex AI Workbench you can restart the terminal using the button on top." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "v6jBDnYnNM08", - "metadata": { - "id": "v6jBDnYnNM08" - }, - "outputs": [], - "source": [ - "# # Automatically restart kernel after installs so that your environment can access the new packages\n", - "# import IPython\n", - "\n", - "# app = IPython.Application.instance()\n", - "# app.kernel.do_shutdown(True)" - ] - }, - { - "cell_type": "markdown", - "id": "yygMe6rPWxHS", - "metadata": { - "id": "yygMe6rPWxHS" - }, - "source": [ - "### 🔐 Authentication\n", - "Authenticate to Google Cloud as the IAM user logged into this notebook in order to access your Google Cloud Project.\n", - "\n", - "* If you are using Colab to run this notebook, use the cell below and continue.\n", - "* If you are using Vertex AI Workbench, check out the setup instructions [here](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/setup-env)." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "PTXN1_DSXj2b", - "metadata": { - "id": "PTXN1_DSXj2b" - }, - "outputs": [], - "source": [ - "from google.colab import auth\n", - "\n", - "auth.authenticate_user()" - ] - }, - { - "cell_type": "markdown", - "id": "NEvB9BoLEulY", - "metadata": { - "id": "NEvB9BoLEulY" - }, - "source": [ - "### ☁ Set Your Google Cloud Project\n", - "Set your Google Cloud project so that you can leverage Google Cloud resources within this notebook.\n", - "\n", - "If you don't know your project ID, try the following:\n", - "\n", - "* Run `gcloud config list`.\n", - "* Run `gcloud projects list`.\n", - "* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "gfkS3yVRE4_W", - "metadata": { - "cellView": "form", - "id": "gfkS3yVRE4_W" - }, - "outputs": [], - "source": [ - "# @markdown Please fill in the value below with your Google Cloud project ID and then run the cell.\n", - "\n", - "PROJECT_ID = \"my-project-id\" # @param {type:\"string\"}\n", - "\n", - "# Set the project id\n", - "!gcloud config set project {PROJECT_ID}" - ] - }, - { - "cell_type": "markdown", - "id": "rEWWNoNnKOgq", - "metadata": { - "id": "rEWWNoNnKOgq" - }, - "source": [ - "### 💡 API Enablement\n", - "The `langchain-google-cloud-sql-pg` package requires that you [enable the Cloud SQL Admin API](https://console.cloud.google.com/flows/enableapi?apiid=sqladmin.googleapis.com) in your Google Cloud Project." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "5utKIdq7KYi5", - "metadata": { - "id": "5utKIdq7KYi5" - }, - "outputs": [], - "source": [ - "# enable Cloud SQL Admin API\n", - "!gcloud services enable sqladmin.googleapis.com" - ] - }, - { - "cell_type": "markdown", - "id": "f8f2830ee9ca1e01", - "metadata": { - "id": "f8f2830ee9ca1e01" - }, - "source": [ - "## Basic Usage" - ] - }, - { - "cell_type": "markdown", - "id": "OMvzMWRrR6n7", - "metadata": { - "id": "OMvzMWRrR6n7" - }, - "source": [ - "### Set Cloud SQL database values\n", - "Find your database values, in the [Cloud SQL Instances page](https://console.cloud.google.com/sql?_ga=2.223735448.2062268965.1707700487-2088871159.1707257687)." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "irl7eMFnSPZr", - "metadata": { - "id": "irl7eMFnSPZr" - }, - "outputs": [], - "source": [ - "# @title Set Your Values Here { display-mode: \"form\" }\n", - "REGION = \"us-central1\" # @param {type: \"string\"}\n", - "INSTANCE = \"my-postgresql-instance\" # @param {type: \"string\"}\n", - "DATABASE = \"my-database\" # @param {type: \"string\"}\n", - "TABLE_NAME = \"message_store\" # @param {type: \"string\"}" - ] - }, - { - "cell_type": "markdown", - "id": "QuQigs4UoFQ2", - "metadata": { - "id": "QuQigs4UoFQ2" - }, - "source": [ - "### PostgresEngine Connection Pool\n", - "\n", - "One of the requirements and arguments to establish Cloud SQL as a ChatMessageHistory memory store is a `PostgresEngine` object. The `PostgresEngine` configures a connection pool to your Cloud SQL database, enabling successful connections from your application and following industry best practices.\n", - "\n", - "To create a `PostgresEngine` using `PostgresEngine.from_instance()` you need to provide only 4 things:\n", - "\n", - "1. `project_id` : Project ID of the Google Cloud Project where the Cloud SQL instance is located.\n", - "1. `region` : Region where the Cloud SQL instance is located.\n", - "1. `instance` : The name of the Cloud SQL instance.\n", - "1. `database` : The name of the database to connect to on the Cloud SQL instance.\n", - "\n", - "By default, [IAM database authentication](https://cloud.google.com/sql/docs/postgres/iam-authentication#iam-db-auth) will be used as the method of database authentication. This library uses the IAM principal belonging to the [Application Default Credentials (ADC)](https://cloud.google.com/docs/authentication/application-default-credentials) sourced from the envionment.\n", - "\n", - "For more informatin on IAM database authentication please see:\n", - "* [Configure an instance for IAM database authentication](https://cloud.google.com/sql/docs/postgres/create-edit-iam-instances)\n", - "* [Manage users with IAM database authentication](https://cloud.google.com/sql/docs/postgres/add-manage-iam-users)\n", - "\n", - "Optionally, [built-in database authentication](https://cloud.google.com/sql/docs/postgres/built-in-authentication) using a username and password to access the Cloud SQL database can also be used. Just provide the optional `user` and `password` arguments to `PostgresEngine.from_instance()`:\n", - "\n", - "* `user` : Database user to use for built-in database authentication and login\n", - "* `password` : Database password to use for built-in database authentication and login.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "4576e914a866fb40", - "metadata": { - "ExecuteTime": { - "end_time": "2023-08-28T10:04:38.077748Z", - "start_time": "2023-08-28T10:04:36.105894Z" - }, - "id": "4576e914a866fb40", - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [], - "source": [ - "from langchain_google_cloud_sql_pg import PostgresEngine\n", - "\n", - "engine = PostgresEngine.from_instance(\n", - " project_id=PROJECT_ID, region=REGION, instance=INSTANCE, database=DATABASE\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "qPV8WfWr7O54", - "metadata": { - "id": "qPV8WfWr7O54" - }, - "source": [ - "### Initialize a table\n", - "The `PostgresChatMessageHistory` class requires a database table with a specific schema in order to store the chat message history.\n", - "\n", - "The `PostgresEngine` engine has a helper method `init_chat_history_table()` that can be used to create a table with the proper schema for you." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "TEu4VHArRttE", - "metadata": { - "id": "TEu4VHArRttE" - }, - "outputs": [], - "source": [ - "engine.init_chat_history_table(table_name=TABLE_NAME)" - ] - }, - { - "cell_type": "markdown", - "id": "345b76b8", - "metadata": {}, - "source": [ - "#### Optional Tip: 💡\n", - "You can also specify a schema name by passing `schema_name` wherever you pass `table_name`. Eg:\n", - "\n", - "```python\n", - "SCHEMA_NAME=\"my_schema\"\n", - "\n", - "engine.init_chat_history_table(\n", - " table_name=TABLE_NAME,\n", - " schema_name=SCHEMA_NAME # Default: \"public\"\n", - ")\n", - "```" - ] - }, - { - "cell_type": "markdown", - "id": "zSYQTYf3UfOi", - "metadata": { - "id": "zSYQTYf3UfOi" - }, - "source": [ - "### PostgresChatMessageHistory\n", - "\n", - "To initialize the `PostgresChatMessageHistory` class you need to provide only 3 things:\n", - "\n", - "1. `engine` - An instance of a `PostgresEngine` engine.\n", - "1. `session_id` - A unique identifier string that specifies an id for the session.\n", - "1. `table_name` : The name of the table within the Cloud SQL database to store the chat message history.\n", - "1. `schema_name` : The name of the database schema containing the chat message history table." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "Kq7RLtfOq0wi", - "metadata": { - "id": "Kq7RLtfOq0wi" - }, - "outputs": [], - "source": [ - "from langchain_google_cloud_sql_pg import PostgresChatMessageHistory\n", - "\n", - "history = PostgresChatMessageHistory.create_sync(\n", - " engine,\n", - " session_id=\"test_session\",\n", - " table_name=TABLE_NAME,\n", - " # schema_name=SCHEMA_NAME,\n", - ")\n", - "history.add_user_message(\"hi!\")\n", - "history.add_ai_message(\"whats up?\")" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "b476688cbb32ba90", - "metadata": { - "ExecuteTime": { - "end_time": "2023-08-28T10:04:38.929396Z", - "start_time": "2023-08-28T10:04:38.915727Z" - }, - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "b476688cbb32ba90", - "jupyter": { - "outputs_hidden": false - }, - "outputId": "a19e5cd8-4225-476a-d28d-e870c6b838bb" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[HumanMessage(content='hi!'), AIMessage(content='whats up?')]" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "history.messages" - ] - }, - { - "cell_type": "markdown", - "id": "ss6CbqcTTedr", - "metadata": { - "id": "ss6CbqcTTedr" - }, - "source": [ - "#### Cleaning up\n", - "When the history of a specific session is obsolete and can be deleted, it can be done the following way.\n", - "\n", - "**Note:** Once deleted, the data is no longer stored in Cloud SQL and is gone forever." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "3khxzFxYO7x6", - "metadata": { - "id": "3khxzFxYO7x6" - }, - "outputs": [], - "source": [ - "history.clear()" - ] - }, - { - "cell_type": "markdown", - "id": "2e5337719d5614fd", - "metadata": { - "id": "2e5337719d5614fd" - }, - "source": [ - "## 🔗 Chaining\n", - "\n", - "We can easily combine this message history class with [LCEL Runnables](/docs/expression_language/how_to/message_history)\n", - "\n", - "To do this we will use one of [Google's Vertex AI chat models](https://python.langchain.com/docs/integrations/chat/google_vertex_ai_palm) which requires that you [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com) in your Google Cloud Project.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "hYtHM3-TOMCe", - "metadata": { - "id": "hYtHM3-TOMCe" - }, - "outputs": [], - "source": [ - "# enable Vertex AI API\n", - "!gcloud services enable aiplatform.googleapis.com" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "6558418b-0ece-4d01-9661-56d562d78f7a", - "metadata": { - "id": "6558418b-0ece-4d01-9661-56d562d78f7a" - }, - "outputs": [], - "source": [ - "from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n", - "from langchain_core.runnables.history import RunnableWithMessageHistory\n", - "from langchain_google_vertexai import ChatVertexAI" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "82149122-61d3-490d-9bdb-bb98606e8ba1", - "metadata": { - "id": "82149122-61d3-490d-9bdb-bb98606e8ba1" - }, - "outputs": [], - "source": [ - "prompt = ChatPromptTemplate.from_messages(\n", - " [\n", - " (\"system\", \"You are a helpful assistant.\"),\n", - " MessagesPlaceholder(variable_name=\"history\"),\n", - " (\"human\", \"{question}\"),\n", - " ]\n", - ")\n", - "\n", - "chain = prompt | ChatVertexAI(project=PROJECT_ID)" - ] + "cells": [ + { + "cell_type": "markdown", + "id": "f22eab3f84cbeb37", + "metadata": { + "id": "f22eab3f84cbeb37" + }, + "source": [ + "# Google Cloud SQL for PostgreSQL\n", + "\n", + "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers MySQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's Langchain integrations.\n", + "\n", + "This notebook goes over how to use `Cloud SQL for PostgreSQL` to store chat message history with the `PostgresChatMessageHistory` class.\n", + "\n", + "Learn more about the package on [GitHub](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/).\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/langchain-google-cloud-sql-pg-python/blob/main/docs/chat_message_history.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "da400c79-a360-43e2-be60-401fd02b2819", + "metadata": { + "id": "da400c79-a360-43e2-be60-401fd02b2819" + }, + "source": [ + "## Before You Begin\n", + "\n", + "To run this notebook, you will need to do the following:\n", + "\n", + " * [Create a Google Cloud Project](https://developers.google.com/workspace/guides/create-project)\n", + " * [Enable the Cloud SQL Admin API.](https://console.cloud.google.com/marketplace/product/google/sqladmin.googleapis.com)\n", + " * [Create a Cloud SQL for PostgreSQL instance](https://cloud.google.com/sql/docs/postgres/create-instance)\n", + " * [Create a Cloud SQL database](https://cloud.google.com/sql/docs/mysql/create-manage-databases)\n", + " * [Add an IAM database user to the database](https://cloud.google.com/sql/docs/postgres/add-manage-iam-users#creating-a-database-user) (Optional)" + ] + }, + { + "cell_type": "markdown", + "id": "Mm7-fG_LltD7", + "metadata": { + "id": "Mm7-fG_LltD7" + }, + "source": [ + "### 🦜🔗 Library Installation\n", + "The integration lives in its own `langchain-google-cloud-sql-pg` package, so we need to install it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1VELXvcj8AId", + "metadata": { + "id": "1VELXvcj8AId" + }, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet langchain-google-cloud-sql-pg langchain-google-vertexai" + ] + }, + { + "cell_type": "markdown", + "id": "98TVoM3MNDHu", + "metadata": { + "id": "98TVoM3MNDHu" + }, + "source": [ + "**Colab only:** Uncomment the following cell to restart the kernel or use the button to restart the kernel. For Vertex AI Workbench you can restart the terminal using the button on top." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "v6jBDnYnNM08", + "metadata": { + "id": "v6jBDnYnNM08" + }, + "outputs": [], + "source": [ + "# # Automatically restart kernel after installs so that your environment can access the new packages\n", + "# import IPython\n", + "\n", + "# app = IPython.Application.instance()\n", + "# app.kernel.do_shutdown(True)" + ] + }, + { + "cell_type": "markdown", + "id": "yygMe6rPWxHS", + "metadata": { + "id": "yygMe6rPWxHS" + }, + "source": [ + "### 🔐 Authentication\n", + "Authenticate to Google Cloud as the IAM user logged into this notebook in order to access your Google Cloud Project.\n", + "\n", + "* If you are using Colab to run this notebook, use the cell below and continue.\n", + "* If you are using Vertex AI Workbench, check out the setup instructions [here](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/setup-env)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "PTXN1_DSXj2b", + "metadata": { + "id": "PTXN1_DSXj2b" + }, + "outputs": [], + "source": [ + "from google.colab import auth\n", + "\n", + "auth.authenticate_user()" + ] + }, + { + "cell_type": "markdown", + "id": "NEvB9BoLEulY", + "metadata": { + "id": "NEvB9BoLEulY" + }, + "source": [ + "### ☁ Set Your Google Cloud Project\n", + "Set your Google Cloud project so that you can leverage Google Cloud resources within this notebook.\n", + "\n", + "If you don't know your project ID, try the following:\n", + "\n", + "* Run `gcloud config list`.\n", + "* Run `gcloud projects list`.\n", + "* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "gfkS3yVRE4_W", + "metadata": { + "cellView": "form", + "id": "gfkS3yVRE4_W" + }, + "outputs": [], + "source": [ + "# @markdown Please fill in the value below with your Google Cloud project ID and then run the cell.\n", + "\n", + "PROJECT_ID = \"my-project-id\" # @param {type:\"string\"}\n", + "\n", + "# Set the project id\n", + "!gcloud config set project {PROJECT_ID}" + ] + }, + { + "cell_type": "markdown", + "id": "rEWWNoNnKOgq", + "metadata": { + "id": "rEWWNoNnKOgq" + }, + "source": [ + "### 💡 API Enablement\n", + "The `langchain-google-cloud-sql-pg` package requires that you [enable the Cloud SQL Admin API](https://console.cloud.google.com/flows/enableapi?apiid=sqladmin.googleapis.com) in your Google Cloud Project." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5utKIdq7KYi5", + "metadata": { + "id": "5utKIdq7KYi5" + }, + "outputs": [], + "source": [ + "# enable Cloud SQL Admin API\n", + "!gcloud services enable sqladmin.googleapis.com" + ] + }, + { + "cell_type": "markdown", + "id": "f8f2830ee9ca1e01", + "metadata": { + "id": "f8f2830ee9ca1e01" + }, + "source": [ + "## Basic Usage" + ] + }, + { + "cell_type": "markdown", + "id": "OMvzMWRrR6n7", + "metadata": { + "id": "OMvzMWRrR6n7" + }, + "source": [ + "### Set Cloud SQL database values\n", + "Find your database values, in the [Cloud SQL Instances page](https://console.cloud.google.com/sql?_ga=2.223735448.2062268965.1707700487-2088871159.1707257687)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "irl7eMFnSPZr", + "metadata": { + "id": "irl7eMFnSPZr" + }, + "outputs": [], + "source": [ + "# @title Set Your Values Here { display-mode: \"form\" }\n", + "REGION = \"us-central1\" # @param {type: \"string\"}\n", + "INSTANCE = \"my-postgresql-instance\" # @param {type: \"string\"}\n", + "DATABASE = \"my-database\" # @param {type: \"string\"}\n", + "TABLE_NAME = \"message_store\" # @param {type: \"string\"}" + ] + }, + { + "cell_type": "markdown", + "id": "QuQigs4UoFQ2", + "metadata": { + "id": "QuQigs4UoFQ2" + }, + "source": [ + "### PostgresEngine Connection Pool\n", + "\n", + "One of the requirements and arguments to establish Cloud SQL as a ChatMessageHistory memory store is a `PostgresEngine` object. The `PostgresEngine` configures a connection pool to your Cloud SQL database, enabling successful connections from your application and following industry best practices.\n", + "\n", + "To create a `PostgresEngine` using `PostgresEngine.from_instance()` you need to provide only 4 things:\n", + "\n", + "1. `project_id` : Project ID of the Google Cloud Project where the Cloud SQL instance is located.\n", + "1. `region` : Region where the Cloud SQL instance is located.\n", + "1. `instance` : The name of the Cloud SQL instance.\n", + "1. `database` : The name of the database to connect to on the Cloud SQL instance.\n", + "\n", + "By default, [IAM database authentication](https://cloud.google.com/sql/docs/postgres/iam-authentication#iam-db-auth) will be used as the method of database authentication. This library uses the IAM principal belonging to the [Application Default Credentials (ADC)](https://cloud.google.com/docs/authentication/application-default-credentials) sourced from the envionment.\n", + "\n", + "For more informatin on IAM database authentication please see:\n", + "* [Configure an instance for IAM database authentication](https://cloud.google.com/sql/docs/postgres/create-edit-iam-instances)\n", + "* [Manage users with IAM database authentication](https://cloud.google.com/sql/docs/postgres/add-manage-iam-users)\n", + "\n", + "Optionally, [built-in database authentication](https://cloud.google.com/sql/docs/postgres/built-in-authentication) using a username and password to access the Cloud SQL database can also be used. Just provide the optional `user` and `password` arguments to `PostgresEngine.from_instance()`:\n", + "\n", + "* `user` : Database user to use for built-in database authentication and login\n", + "* `password` : Database password to use for built-in database authentication and login.\n", + "\n", + "To connect to your Cloud SQL instance from this notebook, you will need to enable public IP on your instance. Alternatively, you can follow [these instructions](https://cloud.google.com/sql/docs/postgres/connect-to-instance-from-outside-vpc) to connect to an Cloud SQL for PostgreSQL instance with Private IP from outside your VPC. Learn more about [specifying IP types](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector?tab=readme-ov-file#specifying-ip-address-type).\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "4576e914a866fb40", + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-28T10:04:38.077748Z", + "start_time": "2023-08-28T10:04:36.105894Z" + }, + "id": "4576e914a866fb40", + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "from langchain_google_cloud_sql_pg import PostgresEngine\n", + "\n", + "engine = PostgresEngine.from_instance(\n", + " project_id=PROJECT_ID,\n", + " region=REGION,\n", + " instance=INSTANCE,\n", + " database=DATABASE,\n", + " ip_type=\"public\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "qPV8WfWr7O54", + "metadata": { + "id": "qPV8WfWr7O54" + }, + "source": [ + "### Initialize a table\n", + "The `PostgresChatMessageHistory` class requires a database table with a specific schema in order to store the chat message history.\n", + "\n", + "The `PostgresEngine` engine has a helper method `init_chat_history_table()` that can be used to create a table with the proper schema for you." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "TEu4VHArRttE", + "metadata": { + "id": "TEu4VHArRttE" + }, + "outputs": [], + "source": [ + "engine.init_chat_history_table(table_name=TABLE_NAME)" + ] + }, + { + "cell_type": "markdown", + "id": "345b76b8", + "metadata": {}, + "source": [ + "#### Optional Tip: 💡\n", + "You can also specify a schema name by passing `schema_name` wherever you pass `table_name`. Eg:\n", + "\n", + "```python\n", + "SCHEMA_NAME=\"my_schema\"\n", + "\n", + "engine.init_chat_history_table(\n", + " table_name=TABLE_NAME,\n", + " schema_name=SCHEMA_NAME # Default: \"public\"\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "zSYQTYf3UfOi", + "metadata": { + "id": "zSYQTYf3UfOi" + }, + "source": [ + "### PostgresChatMessageHistory\n", + "\n", + "To initialize the `PostgresChatMessageHistory` class you need to provide only 3 things:\n", + "\n", + "1. `engine` - An instance of a `PostgresEngine` engine.\n", + "1. `session_id` - A unique identifier string that specifies an id for the session.\n", + "1. `table_name` : The name of the table within the Cloud SQL database to store the chat message history.\n", + "1. `schema_name` : The name of the database schema containing the chat message history table." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "Kq7RLtfOq0wi", + "metadata": { + "id": "Kq7RLtfOq0wi" + }, + "outputs": [], + "source": [ + "from langchain_google_cloud_sql_pg import PostgresChatMessageHistory\n", + "\n", + "history = PostgresChatMessageHistory.create_sync(\n", + " engine,\n", + " session_id=\"test_session\",\n", + " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", + ")\n", + "history.add_user_message(\"hi!\")\n", + "history.add_ai_message(\"whats up?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "b476688cbb32ba90", + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-28T10:04:38.929396Z", + "start_time": "2023-08-28T10:04:38.915727Z" }, - { - "cell_type": "code", - "execution_count": 16, - "id": "2df90853-b67c-490f-b7f8-b69d69270b9c", - "metadata": { - "id": "2df90853-b67c-490f-b7f8-b69d69270b9c" - }, - "outputs": [], - "source": [ - "chain_with_history = RunnableWithMessageHistory(\n", - " chain,\n", - " lambda session_id: PostgresChatMessageHistory.create_sync(\n", - " engine,\n", - " session_id=session_id,\n", - " table_name=TABLE_NAME,\n", - " # schema_name=SCHEMA_NAME,\n", - " ),\n", - " input_messages_key=\"question\",\n", - " history_messages_key=\"history\",\n", - ")" - ] + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "code", - "execution_count": 17, - "id": "0ce596b8-3b78-48fd-9f92-46dccbbfd58b", - "metadata": { - "id": "0ce596b8-3b78-48fd-9f92-46dccbbfd58b" - }, - "outputs": [], - "source": [ - "# This is where we configure the session id\n", - "config = {\"configurable\": {\"session_id\": \"test_session\"}}" - ] + "id": "b476688cbb32ba90", + "jupyter": { + "outputs_hidden": false }, + "outputId": "a19e5cd8-4225-476a-d28d-e870c6b838bb" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 18, - "id": "38e1423b-ba86-4496-9151-25932fab1a8b", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "38e1423b-ba86-4496-9151-25932fab1a8b", - "outputId": "d5c93570-4b0b-4fe8-d19c-4b361fe74291" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "AIMessage(content=' Hello Bob, how can I help you today?')" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "chain_with_history.invoke({\"question\": \"Hi! I'm bob\"}, config=config)" + "data": { + "text/plain": [ + "[HumanMessage(content='hi!'), AIMessage(content='whats up?')]" ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "history.messages" + ] + }, + { + "cell_type": "markdown", + "id": "ss6CbqcTTedr", + "metadata": { + "id": "ss6CbqcTTedr" + }, + "source": [ + "#### Cleaning up\n", + "When the history of a specific session is obsolete and can be deleted, it can be done the following way.\n", + "\n", + "**Note:** Once deleted, the data is no longer stored in Cloud SQL and is gone forever." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "3khxzFxYO7x6", + "metadata": { + "id": "3khxzFxYO7x6" + }, + "outputs": [], + "source": [ + "history.clear()" + ] + }, + { + "cell_type": "markdown", + "id": "2e5337719d5614fd", + "metadata": { + "id": "2e5337719d5614fd" + }, + "source": [ + "## 🔗 Chaining\n", + "\n", + "We can easily combine this message history class with [LCEL Runnables](/docs/expression_language/how_to/message_history)\n", + "\n", + "To do this we will use one of [Google's Vertex AI chat models](https://python.langchain.com/docs/integrations/chat/google_vertex_ai_palm) which requires that you [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com) in your Google Cloud Project.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "hYtHM3-TOMCe", + "metadata": { + "id": "hYtHM3-TOMCe" + }, + "outputs": [], + "source": [ + "# enable Vertex AI API\n", + "!gcloud services enable aiplatform.googleapis.com" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "6558418b-0ece-4d01-9661-56d562d78f7a", + "metadata": { + "id": "6558418b-0ece-4d01-9661-56d562d78f7a" + }, + "outputs": [], + "source": [ + "from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n", + "from langchain_core.runnables.history import RunnableWithMessageHistory\n", + "from langchain_google_vertexai import ChatVertexAI" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "82149122-61d3-490d-9bdb-bb98606e8ba1", + "metadata": { + "id": "82149122-61d3-490d-9bdb-bb98606e8ba1" + }, + "outputs": [], + "source": [ + "prompt = ChatPromptTemplate.from_messages(\n", + " [\n", + " (\"system\", \"You are a helpful assistant.\"),\n", + " MessagesPlaceholder(variable_name=\"history\"),\n", + " (\"human\", \"{question}\"),\n", + " ]\n", + ")\n", + "\n", + "chain = prompt | ChatVertexAI(project=PROJECT_ID)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "2df90853-b67c-490f-b7f8-b69d69270b9c", + "metadata": { + "id": "2df90853-b67c-490f-b7f8-b69d69270b9c" + }, + "outputs": [], + "source": [ + "chain_with_history = RunnableWithMessageHistory(\n", + " chain,\n", + " lambda session_id: PostgresChatMessageHistory.create_sync(\n", + " engine,\n", + " session_id=session_id,\n", + " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", + " ),\n", + " input_messages_key=\"question\",\n", + " history_messages_key=\"history\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "0ce596b8-3b78-48fd-9f92-46dccbbfd58b", + "metadata": { + "id": "0ce596b8-3b78-48fd-9f92-46dccbbfd58b" + }, + "outputs": [], + "source": [ + "# This is where we configure the session id\n", + "config = {\"configurable\": {\"session_id\": \"test_session\"}}" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "38e1423b-ba86-4496-9151-25932fab1a8b", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "38e1423b-ba86-4496-9151-25932fab1a8b", + "outputId": "d5c93570-4b0b-4fe8-d19c-4b361fe74291" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 19, - "id": "2ee4ee62-a216-4fb1-bf33-57476a84cf16", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "2ee4ee62-a216-4fb1-bf33-57476a84cf16", - "outputId": "288fe388-3f60-41b8-8edb-37cfbec18981" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "AIMessage(content=' Your name is Bob.')" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "chain_with_history.invoke({\"question\": \"Whats my name\"}, config=config)" + "data": { + "text/plain": [ + "AIMessage(content=' Hello Bob, how can I help you today?')" ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" } - ], - "metadata": { + ], + "source": [ + "chain_with_history.invoke({\"question\": \"Hi! I'm bob\"}, config=config)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "2ee4ee62-a216-4fb1-bf33-57476a84cf16", + "metadata": { "colab": { - "provenance": [], - "toc_visible": true + "base_uri": "https://localhost:8080/" }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.8" + "id": "2ee4ee62-a216-4fb1-bf33-57476a84cf16", + "outputId": "288fe388-3f60-41b8-8edb-37cfbec18981" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=' Your name is Bob.')" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" } + ], + "source": [ + "chain_with_history.invoke({\"question\": \"Whats my name\"}, config=config)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/document_loader.ipynb b/docs/document_loader.ipynb index c0b59a09..c98482d3 100644 --- a/docs/document_loader.ipynb +++ b/docs/document_loader.ipynb @@ -203,7 +203,10 @@ "Optionally, [built-in database authentication](https://cloud.google.com/sql/docs/postgres/users) using a username and password to access the Cloud SQL database can also be used. Just provide the optional `user` and `password` arguments to `PostgresEngine.from_instance()`:\n", "\n", "* `user` : Database user to use for built-in database authentication and login\n", - "* `password` : Database password to use for built-in database authentication and login.\n" + "* `password` : Database password to use for built-in database authentication and login.\n", + "\n", + "\n", + "To connect to your Cloud SQL instance from this notebook, you will need to enable public IP on your instance. Alternatively, you can follow [these instructions](https://cloud.google.com/sql/docs/postgres/connect-to-instance-from-outside-vpc) to connect to an Cloud SQL for PostgreSQL instance with Private IP from outside your VPC. Learn more about [specifying IP types](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector?tab=readme-ov-file#specifying-ip-address-type)." ] }, { @@ -226,6 +229,7 @@ " region=REGION,\n", " instance=INSTANCE,\n", " database=DATABASE,\n", + " ip_type=\"public\",\n", ")" ] }, diff --git a/docs/langgraph_checkpoint.ipynb b/docs/langgraph_checkpoint.ipynb new file mode 100644 index 00000000..c663f95d --- /dev/null +++ b/docs/langgraph_checkpoint.ipynb @@ -0,0 +1,424 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3 (ipykernel)", + "language": "python" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Google Cloud SQL for PostgreSQL\n", + "\n", + "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers MySQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's Langchain integrations.\n", + "\n", + "This notebook goes over how to use `Cloud SQL for PostgreSQL` to store checkpoints with the PostgresSaver class.\n", + "\n", + "Learn more about the package on [GitHub](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/).\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/langchain-google-cloud-sql-pg-python/blob/main/docs/langgraph_checkpoint.ipynb)" + ], + "metadata": { + "id": "xHq_1Zh8TfCz" + } + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Prerequisites\n", + "\n", + "This guide assumes familiarity with the following:\n", + "\n", + "- [LangGraph Persistence](https://langchain-ai.github.io/langgraph/concepts/persistence/)\n", + "- [Postgresql](https://www.postgresql.org/about/)\n", + "\n", + "When creating LangGraph agents, you can also set them up so that they persist their state. This allows you to do things like interact with an agent multiple times and have it remember previous interactions." + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Before You Begin\n", + "\n", + "To run this notebook, you will need to do the following:\n", + "\n", + " * [Create a Google Cloud Project](https://developers.google.com/workspace/guides/create-project)\n", + " * [Enable the Cloud SQL Admin API.](https://console.cloud.google.com/marketplace/product/google/sqladmin.googleapis.com)\n", + " * [Create a Cloud SQL for PostgreSQL instance](https://cloud.google.com/sql/docs/postgres/create-instance)\n", + " * [Create a Cloud SQL database](https://cloud.google.com/sql/docs/mysql/create-manage-databases)\n", + " * [Add an IAM database user to the database](https://cloud.google.com/sql/docs/postgres/add-manage-iam-users#creating-a-database-user) (Optional)\n" + ], + "metadata": { + "id": "5NW0xqHRToAM" + } + }, + { + "cell_type": "markdown", + "source": [ + "### 🦜🔗 Library Installation\n", + "The integration lives in its own `langchain-google-cloud-sql-pg` package, so we need to install it." + ], + "metadata": { + "id": "Owc-OZpOT0Jy" + } + }, + { + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "TpNTWLbZT1Dw", + "outputId": "a54015ba-0281-4955-e9e0-36d3125c4887" + }, + "cell_type": "code", + "source": "%pip install --upgrade --quiet langchain-google-cloud-sql-pg[langgraph] langchain-google-vertexai langgraph", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "**Colab only:** Uncomment the following cell to restart the kernel or use the button to restart the kernel. For Vertex AI Workbench you can restart the terminal using the button on top." + }, + { + "cell_type": "code", + "source": [ + "# # Automatically restart kernel after installs so that your environment can access the new packages\n", + "# import IPython\n", + "#\n", + "# app = IPython.Application.instance()\n", + "# app.kernel.do_shutdown(True)" + ], + "metadata": { + "id": "OQrizsZGT4xx" + }, + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "source": [ + "### 🔐 Authentication\n", + "Authenticate to Google Cloud as the IAM user logged into this notebook in order to access your Google Cloud Project.\n", + "\n", + "* If you are using Colab to run this notebook, use the cell below and continue.\n", + "* If you are using Vertex AI Workbench, check out the setup instructions [here](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/setup-env)." + ], + "metadata": { + "id": "XpCaEQgvUPBQ" + } + }, + { + "cell_type": "code", + "source": [ + "from google.colab import auth\n", + "\n", + "auth.authenticate_user()" + ], + "metadata": { + "id": "WvydoWCKUOlM" + }, + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "source": [ + "### ☁ Set Your Google Cloud Project\n", + "Set your Google Cloud project so that you can leverage Google Cloud resources within this notebook.\n", + "\n", + "If you don't know your project ID, try the following:\n", + "\n", + "* Run `gcloud config list`.\n", + "* Run `gcloud projects list`.\n", + "* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)." + ], + "metadata": { + "id": "nNK4c-n-Umq9" + } + }, + { + "cell_type": "code", + "source": [ + "# @markdown Please fill in the value below with your Google Cloud project ID and then run the cell.\n", + "\n", + "PROJECT_ID = \"my-project-id\" # @param {type:\"string\"}\n", + "\n", + "# Set the project id\n", + "!gcloud config set project {PROJECT_ID}" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5pCxbzAGUP9p", + "outputId": "625f4fea-75a7-41ad-fef0-96467df9f41b" + }, + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "source": [ + "### 💡 API Enablement\n", + "The `langchain-google-cloud-sql-pg` package requires that you [enable the Cloud SQL Admin API](https://console.cloud.google.com/flows/enableapi?apiid=sqladmin.googleapis.com) in your Google Cloud Project." + ], + "metadata": { + "id": "sfN2P8EkUydF" + } + }, + { + "cell_type": "code", + "source": [ + "# enable Cloud SQL Admin API\n", + "!gcloud services enable sqladmin.googleapis.com" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HS5OmGusUnP4", + "outputId": "952fdfd9-279f-4246-83a7-bc35793869f2" + }, + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "source": [ + "## Basic Usage" + ], + "metadata": { + "id": "j6Sd_BIHU_cV" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Set Cloud SQL database values\n", + "Find your database values, in the [Cloud SQL Instances page](https://console.cloud.google.com/sql?_ga=2.223735448.2062268965.1707700487-2088871159.1707257687)." + ], + "metadata": { + "id": "Kd3WmpRRVGJH" + } + }, + { + "cell_type": "code", + "source": [ + "# @title Set Your Values Here { display-mode: \"form\" }\n", + "REGION = \"us-central1\" # @param {type: \"string\"}\n", + "INSTANCE = \"my-postgresql-instance\" # @param {type: \"string\"}\n", + "DATABASE = \"my-database\" # @param {type: \"string\"}" + ], + "metadata": { + "id": "ASWTYxvGUzfR" + }, + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "source": [ + "### PostgresEngine Connection Pool\n", + "\n", + "One of the requirements and arguments to establish Cloud SQL as a PostgresSaver is a `PostgresEngine` object. The `PostgresEngine` configures a connection pool to your Cloud SQL database, enabling successful connections from your application and following industry best practices.\n", + "\n", + "To create a `PostgresEngine` using `PostgresEngine.from_instance()` you need to provide only 4 things:\n", + "\n", + "1. `project_id` : Project ID of the Google Cloud Project where the Cloud SQL instance is located.\n", + "1. `region` : Region where the Cloud SQL instance is located.\n", + "1. `instance` : The name of the Cloud SQL instance.\n", + "1. `database` : The name of the database to connect to on the Cloud SQL instance.\n", + "\n", + "By default, [IAM database authentication](https://cloud.google.com/sql/docs/postgres/iam-authentication#iam-db-auth) will be used as the method of database authentication. This library uses the IAM principal belonging to the [Application Default Credentials (ADC)](https://cloud.google.com/docs/authentication/application-default-credentials) sourced from the envionment.\n", + "\n", + "For more informatin on IAM database authentication please see:\n", + "* [Configure an instance for IAM database authentication](https://cloud.google.com/sql/docs/postgres/create-edit-iam-instances)\n", + "* [Manage users with IAM database authentication](https://cloud.google.com/sql/docs/postgres/add-manage-iam-users)\n", + "\n", + "Optionally, [built-in database authentication](https://cloud.google.com/sql/docs/postgres/built-in-authentication) using a username and password to access the Cloud SQL database can also be used. Just provide the optional `user` and `password` arguments to `PostgresEngine.from_instance()`:\n", + "\n", + "* `user` : Database user to use for built-in database authentication and login\n", + "* `password` : Database password to use for built-in database authentication and login." + ], + "metadata": { + "id": "5noUSRjIVtlO" + } + }, + { + "cell_type": "code", + "source": [ + "from langchain_google_cloud_sql_pg import PostgresEngine\n", + "\n", + "engine = PostgresEngine.from_instance(\n", + " project_id=PROJECT_ID, region=REGION, instance=INSTANCE, database=DATABASE\n", + ")" + ], + "metadata": { + "id": "uYh8ulGNVAus" + }, + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "### Initialize a table\n", + "The `PostgresSaver` class requires a database table with a specific schema in order to store the persist LangGraph agents state.\n", + "The `PostgresEngine` engine has a helper method `init_checkpoint_table()` that can be used to create a table with the proper schema for you." + ] + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "engine.init_checkpoint_table() # Use table_name to customise the table name" + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "#### Optional Tip: 💡\n", + "You can also specify a schema name by passing `schema_name` wherever you pass `table_name`. Eg:\n", + "\n", + "```python\n", + "SCHEMA_NAME=\"my_schema\"\n", + "\n", + "engine.init_chat_history_table(\n", + " table_name=TABLE_NAME,\n", + " schema_name=SCHEMA_NAME # Default: \"public\"\n", + ")\n", + "```" + ] + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "### PostgresSaver\n", + "\n", + "To initialize the `PostgresSaver` class you need to provide only 3 things:\n", + "\n", + "1. `engine` - An instance of a `PostgresEngine` engine.\n", + "2. `table_name` : The name of the table within the Cloud SQL database to store the checkpoints (Default: \"checkpoints\").\n", + "3. `schema_name` : The name of the database schema containing the checkpoints table (Default: \"public\")\n", + "4. `serde`: Serializer for encoding/decoding checkpoints (Default: None)\n" + ] + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "### Example of Checkpointer methods" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "from langchain_google_cloud_sql_pg import PostgresSaver\n", + "\n", + "checkpointer = PostgresSaver.create_sync(\n", + " engine,\n", + " # table_name = TABLE_NAME,\n", + " # schema_name = SCHEMA_NAME,\n", + " # serde = None,\n", + ")" + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "write_config = {\"configurable\": {\"thread_id\": \"1\", \"checkpoint_ns\": \"\"}}\n", + "read_config = {\"configurable\": {\"thread_id\": \"1\"}}\n", + "\n", + "checkpoint = {\n", + " \"v\": 1,\n", + " \"ts\": \"2024-07-31T20:14:19.804150+00:00\",\n", + " \"id\": \"1ef4f797-8335-6428-8001-8a1503f9b875\",\n", + " \"channel_values\": {\"my_key\": \"meow\", \"node\": \"node\"},\n", + " \"channel_versions\": {\"__start__\": 2, \"my_key\": 3, \"start:node\": 3, \"node\": 3},\n", + " \"versions_seen\": {\n", + " \"__input__\": {},\n", + " \"__start__\": {\"__start__\": 1},\n", + " \"node\": {\"start:node\": 2},\n", + " },\n", + " \"pending_sends\": [],\n", + "}\n", + "\n", + "# store checkpoint\n", + "checkpointer.put(write_config, checkpoint, {}, {})\n", + "# load checkpoint\n", + "checkpointer.get(read_config)" + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "## 🔗 Adding persistence to the pre-built create react agent" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "from typing import Literal\n", + "\n", + "from langchain_core.tools import tool\n", + "from langchain_google_vertexai import ChatVertexAI\n", + "from langgraph.prebuilt import create_react_agent\n", + "import vertexai\n", + "\n", + "vertexai.init(project=PROJECT_ID, location=REGION)\n", + "\n", + "\n", + "@tool\n", + "def get_weather(city: Literal[\"nyc\", \"sf\"]):\n", + " if city == \"nyc\":\n", + " return \"It might be cloudy in nyc\"\n", + " elif city == \"sf\":\n", + " return \"It's always sunny in sf\"\n", + " else:\n", + " raise AssertionError(\"Unknown city\")\n", + "\n", + "\n", + "tools = [get_weather]\n", + "model = ChatVertexAI(project_name=PROJECT_ID, model_name=\"gemini-2.0-flash-exp\")\n", + "\n", + "graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)\n", + "config = {\"configurable\": {\"thread_id\": \"2\"}}\n", + "res = graph.invoke({\"messages\": [(\"human\", \"what's the weather in sf\")]}, config)\n", + "print(res)" + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "# Example of resulting checkpoint config\n", + "checkpoint = checkpointer.get(config)" + ], + "outputs": [], + "execution_count": null + } + ] +} diff --git a/docs/vector_store.ipynb b/docs/vector_store.ipynb index 4071dbdf..cf2814fe 100644 --- a/docs/vector_store.ipynb +++ b/docs/vector_store.ipynb @@ -209,7 +209,9 @@ "Optionally, [built-in database authentication](https://cloud.google.com/sql/docs/postgres/built-in-authentication) using a username and password to access the Cloud SQL database can also be used. Just provide the optional `user` and `password` arguments to `PostgresEngine.from_instance()`:\n", "\n", "* `user` : Database user to use for built-in database authentication and login\n", - "* `password` : Database password to use for built-in database authentication and login.\n" + "* `password` : Database password to use for built-in database authentication and login.\n", + "\n", + "To connect to your Cloud SQL instance from this notebook, you will need to enable public IP on your instance. Alternatively, you can follow [these instructions](https://cloud.google.com/sql/docs/postgres/connect-to-instance-from-outside-vpc) to connect to an Cloud SQL for PostgreSQL instance with Private IP from outside your VPC. Learn more about [specifying IP types](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector?tab=readme-ov-file#specifying-ip-address-type).\n" ] }, { @@ -228,7 +230,11 @@ "from langchain_google_cloud_sql_pg import PostgresEngine\n", "\n", "engine = await PostgresEngine.afrom_instance(\n", - " project_id=PROJECT_ID, region=REGION, instance=INSTANCE, database=DATABASE\n", + " project_id=PROJECT_ID,\n", + " region=REGION,\n", + " instance=INSTANCE,\n", + " database=DATABASE,\n", + " ip_type=\"public\",\n", ")" ] }, diff --git a/noxfile.py b/noxfile.py index 8fb4d61b..75eae944 100644 --- a/noxfile.py +++ b/noxfile.py @@ -41,7 +41,7 @@ def docs(session): """Build the docs for this library.""" - session.install("-e", ".") + session.install("-e", ".[test]") session.install( # We need to pin to specific versions of the `sphinxcontrib-*` packages # which still support sphinx 4.x. diff --git a/pyproject.toml b/pyproject.toml index 5bce61ce..63bb192d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,8 @@ authors = [ dependencies = [ "cloud-sql-python-connector[asyncpg] >= 1.10.0, <2.0.0", "langchain-core>=0.2.36, <1.0.0 ", - "numpy>=1.24.4, <2.0.0", + "numpy>=1.24.4, <3.0.0; python_version > '3.9'", + "numpy>=1.24.4, <=2.0.2; python_version <= '3.9'", "pgvector>=0.2.5, <1.0.0", "SQLAlchemy[asyncio]>=2.0.25, <3.0.0" ] @@ -38,13 +39,17 @@ Repository = "https://github.com/googleapis/langchain-google-cloud-sql-pg-python Changelog = "https://github.com/googleapis/langchain-google-cloud-sql-pg-python/blob/main/CHANGELOG.md" [project.optional-dependencies] +langgraph = [ + "langgraph-checkpoint>=2.0.9, <3.0.0" +] test = [ "black[jupyter]==25.1.0", - "isort==6.0.0", - "mypy==1.13.0", - "pytest-asyncio==0.24.0", + "isort==6.0.1", + "mypy==1.15.0", + "pytest-asyncio==0.25.3", "pytest==8.3.4", - "pytest-cov==6.0.0" + "pytest-cov==6.0.0", + "langgraph==0.2.74" ] [build-system] diff --git a/requirements.txt b/requirements.txt index 79e8db67..c52a1b3d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ -cloud-sql-python-connector[asyncpg]==1.14.0 -langchain-core==0.3.22 -numpy==1.26.4 +cloud-sql-python-connector[asyncpg]==1.17.0 +langchain-core==0.3.40 +numpy==2.2.3; python_version > "3.9" +numpy== 2.0.2; python_version <= "3.9" pgvector==0.3.6 -SQLAlchemy[asyncio]==2.0.36 +SQLAlchemy[asyncio]==2.0.38 +langgraph-checkpoint==2.0.10 \ No newline at end of file diff --git a/samples/index_tuning_sample/requirements.txt b/samples/index_tuning_sample/requirements.txt index 9351541e..19489979 100644 --- a/samples/index_tuning_sample/requirements.txt +++ b/samples/index_tuning_sample/requirements.txt @@ -1,3 +1,3 @@ -langchain-community==0.3.16 -langchain-google-cloud-sql-pg==0.11.1 -langchain-google-vertexai==2.0.12 +langchain-community==0.3.18 +langchain-google-cloud-sql-pg==0.12.1 +langchain-google-vertexai==2.0.14 diff --git a/samples/langchain_on_vertexai/requirements.txt b/samples/langchain_on_vertexai/requirements.txt index 090f6c2e..153755af 100644 --- a/samples/langchain_on_vertexai/requirements.txt +++ b/samples/langchain_on_vertexai/requirements.txt @@ -1,5 +1,5 @@ -google-cloud-aiplatform[reasoningengine,langchain]==1.79.0 -google-cloud-resource-manager==1.14.0 -langchain-community==0.3.16 -langchain-google-cloud-sql-pg==0.11.1 -langchain-google-vertexai==2.0.12 +google-cloud-aiplatform[reasoningengine,langchain]==1.81.0 +google-cloud-resource-manager==1.14.1 +langchain-community==0.3.18 +langchain-google-cloud-sql-pg==0.12.1 +langchain-google-vertexai==2.0.14 diff --git a/samples/langchain_quick_start.ipynb b/samples/langchain_quick_start.ipynb index c6361566..ff2c5828 100644 --- a/samples/langchain_quick_start.ipynb +++ b/samples/langchain_quick_start.ipynb @@ -1,1019 +1,1022 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "6zvUr-Qev6lL" - }, - "outputs": [], - "source": [ - "# Copyright 2024 Google LLC\n", - "#\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ob11AkrStrRI" - }, - "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/langchain-google-cloud-sql-pg-python/blob/main/samples/langchain_quick_start.ipynb)\n", - "\n", - "---\n", - "# **Introduction**\n", - "\n", - "In this codelab, you'll learn how to create a powerful interactive generative AI application using Retrieval Augmented Generation powered by [Cloud SQL for PostgreSQL](https://cloud.google.com/sql/docs/postgres) and [LangChain](https://www.langchain.com/). We will be creating an application grounded in a [Netflix Movie dataset](https://www.kaggle.com/datasets/shivamb/netflix-shows), allowing you to interact with movie data in exciting new ways." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ma6pEng3ypbA" - }, - "source": [ - "## Prerequisites\n", - "\n", - "* A basic understanding of the Google Cloud Console\n", - "* Basic skills in command line interface and Google Cloud shell\n", - "* Basic python knowledge" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DzDgqJHgysy1" - }, - "source": [ - "## What you'll learn\n", - "\n", - "* How to deploy a Cloud SQL for PostgreSQL instance\n", - "* How to use Cloud SQL for PostgreSQL as a DocumentLoader\n", - "* How to use Cloud SQL for PostgreSQL as a VectorStore\n", - "* How to use Cloud SQL for PostgreSQL for ChatMessageHistory storage" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FbcZUjT1yvTq" - }, - "source": [ - "## What you'll need\n", - "* A Google Cloud Account and Google Cloud Project\n", - "* A web browser such as [Chrome](https://www.google.com/chrome/)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vHdR4fF3vLWA" - }, - "source": [ - "# **Setup and Requirements**\n", - "\n", - "In the following instructions you will learn to:\n", - "1. Install required dependencies for our application\n", - "2. Set up authentication for our project\n", - "3. Set up a Cloud SQL for PostgreSQL Instance\n", - "4. Import the data used by our application" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uy9KqgPQ4GBi" - }, - "source": [ - "## Install dependencies\n", - "First you will need to install the dependencies needed to run this demo app." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "M_ppDxYf4Gqs" - }, - "outputs": [], - "source": [ - "%pip install --upgrade --quiet langchain-google-cloud-sql-pg langchain-google-vertexai langchain" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Colab only:** Uncomment the following cell to restart the kernel or use the button to restart the kernel. For Vertex AI Workbench you can restart the terminal using the button on top." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Automatically restart kernel after installs so that your environment can access the new packages\n", - "import IPython\n", - "\n", - "app = IPython.Application.instance()\n", - "app.kernel.do_shutdown(True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DeUbHclxw7_l" - }, - "source": [ - "## Authenticate to Google Cloud within Colab\n", - "In order to access your Google Cloud Project from this notebook, you will need to Authenticate as an IAM user." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "a168rJE1xDHO" - }, - "outputs": [], - "source": [ - "from google.colab import auth\n", - "\n", - "auth.authenticate_user()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UCiNGP1Qxd6x" - }, - "source": [ - "## Connect Your Google Cloud Project\n", - "Time to connect your Google Cloud Project to this notebook so that you can leverage Google Cloud from within Colab." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "qjFuhRhVxlWP" - }, - "outputs": [], - "source": [ - "# @markdown Please fill in the value below with your GCP project ID and then run the cell.\n", - "\n", - "# Please fill in these values.\n", - "project_id = \"\" # @param {type:\"string\"}\n", - "\n", - "# Quick input validations.\n", - "assert project_id, \"⚠️ Please provide a Google Cloud project ID\"\n", - "\n", - "# Configure gcloud.\n", - "!gcloud config set project {project_id}" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "O-oqMC5Ox-ZM" - }, - "source": [ - "## Configure Your Google Cloud Project\n", - "\n", - "Configure the following in your Google Cloud Project." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "1. IAM principal (user, service account, etc.) with the [Cloud SQL Client][client-role] role. The user logged into this notebook will be used as the IAM principal and will be granted the Cloud SQL Client role.\n", - "\n", - "[client-role]: https://cloud.google.com/sql/docs/mysql/roles-and-permissions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "current_user = !gcloud auth list --filter=status:ACTIVE --format=\"value(account)\"\n", - "!gcloud projects add-iam-policy-binding {project_id} \\\n", - " --member=user:{current_user[0]} \\\n", - " --role=\"roles/cloudsql.client\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "2. Enable the APIs for Cloud SQL and Vertex AI within your project." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "CKWrwyfzyTwH" - }, - "outputs": [], - "source": [ - "# Enable GCP services\n", - "!gcloud services enable sqladmin.googleapis.com aiplatform.googleapis.com" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Gn8g7-wCyZU6" - }, - "source": [ - "## Set up Cloud SQL\n", - "You will need a **Postgres** Cloud SQL instance for the following stages of this notebook." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "T616pEOUygYQ" - }, - "source": [ - "### Create a Postgres Instance\n", - "Running the below cell will verify the existence of the Cloud SQL instance and or create a new instance and database if one does not exist.\n", - "\n", - "> ⏳ - Creating a Cloud SQL instance may take a few minutes." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "XXI1uUu3y8gc" - }, - "outputs": [], - "source": [ - "#@markdown Please fill in the both the Google Cloud region and name of your Cloud SQL instance. Once filled in, run the cell.\n", - "\n", - "# Please fill in these values.\n", - "region = \"us-central1\" #@param {type:\"string\"}\n", - "instance_name = \"langchain-quickstart-instance\" #@param {type:\"string\"}\n", - "database_name = \"langchain-quickstart-db\" #@param {type:\"string\"}\n", - "password = input(\"Please provide a password to be used for 'postgres' database user: \")\n", - "\n", - "# Quick input validations.\n", - "assert region, \"⚠️ Please provide a Google Cloud region\"\n", - "assert instance_name, \"⚠️ Please provide the name of your instance\"\n", - "assert database_name, \"⚠️ Please provide the name of your database_name\"\n", - "\n", - "# check if Cloud SQL instance exists in the provided region\n", - "database_version = !gcloud sql instances describe {instance_name} --format=\"value(databaseVersion)\"\n", - "if database_version[0].startswith(\"POSTGRES\"):\n", - " print(\"Found existing Postgres Cloud SQL Instance!\")\n", - "else:\n", - " print(\"Creating new Cloud SQL instance...\")\n", - " !gcloud sql instances create {instance_name} --database-version=POSTGRES_15 \\\n", - " --region={region} --cpu=1 --memory=4GB --root-password={password} \\\n", - " --database-flags=cloudsql.iam_authentication=On\n", - " !gcloud sql databases create {database_name} --instance={instance_name}\n", - "\n", - "\n", - "databases = !gcloud sql databases list --instance={instance_name} --format=\"value(name)\"\n", - "if database_name not in databases:\n", - " !gcloud sql databases create {database_name} --instance={instance_name}" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HdolCWyatZmG" - }, - "source": [ - "## Import data to your database\n", - "\n", - "Now that you have your database, you will need to import data! We will be using a [Netflix Dataset from Kaggle](https://www.kaggle.com/datasets/shivamb/netflix-shows). Here is what the data looks like:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "36-FBKzJ-tLa" - }, - "source": [ - "| show_id | type | title | director | cast | country | date_added | release_year | rating | duration | listed_in | description |\n", - "|---------|---------|----------------------|------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------|---------------|-------------------|--------------|--------|-----------|----------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", - "| s1 | Movie | Dick Johnson Is Dead | Kirsten Johnson | | United States | September 25, 2021 | 2020 | PG-13 | 90 min | Documentaries | As her father nears the end of his life, filmmaker Kirsten Johnson stages his death in inventive and comical ways to help them both face the inevitable. |\n", - "| s2 | TV Show | Blood & Water | | Ama Qamata, Khosi Ngema, Gail Mabalane, Thabang Molaba, Dillon Windvogel, Natasha Thahane, Arno Greeff, Xolile Tshabalala, Getmore Sithole, Cindy Mahlangu, Ryle De Morny, Greteli Fincham, Sello Maake Ka-Ncube, Odwa Gwanya, Mekaila Mathys, Sandi Schultz, Duane Williams, Shamilla Miller, Patrick Mofokeng | South Africa | September 24, 2021 | 2021 | TV-MA | 2 Seasons | International TV Shows, TV Dramas, TV Mysteries | After crossing paths at a party, a Cape Town teen sets out to prove whether a private-school swimming star is her sister who was abducted at birth. |\n", - "| s3 | TV Show | Ganglands | Julien Leclercq | Sami Bouajila, Tracy Gotoas, Samuel Jouy, Nabiha Akkari, Sofia Lesaffre, Salim Kechiouche, Noureddine Farihi, Geert Van Rampelberg, Bakary Diombera | | September 24, 2021 | 2021 | TV-MA | 1 Season | Crime TV Shows, International TV Shows, TV Action & Adventure | To protect his family from a powerful drug lord, skilled thief Mehdi and his expert team of robbers are pulled into a violent and deadly turf war. |\n", - "| s4 | TV Show | Jailbirds New Orleans | | | | September 24, 2021 | 2021 | TV-MA | 1 Season | Docuseries, Reality TV | Feuds, flirtations and toilet talk go down among the incarcerated women at the Orleans Justice Center in New Orleans on this gritty reality series. |\n", - "| s5 | TV Show | Kota Factory | | Mayur More, Jitendra Kumar, Ranjan Raj, Alam Khan, Ahsaas Channa, Revathi Pillai, Urvi Singh, Arun Kumar | India | September 24, 2021 | 2021 | TV-MA | 2 Seasons | International TV Shows, Romantic TV Shows, TV Comedies | In a city of coaching centers known to train India’s finest collegiate minds, an earnest but unexceptional student and his friends navigate campus life. |\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kQ2KWsYI_Msa" - }, - "source": [ - "You won't need to directly load the csv data into your database. Instead we prepared a table, \"netflix_titles\", in the format of a `.sql` file for Postgres. You can easily import the table into your database with one `gcloud` command." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "qbLhv9jgD8nm" - }, - "outputs": [], - "source": [ - "# Import the Netflix titles table using gcloud command\n", - "import_command_output = !gcloud sql import sql {instance_name} gs://cloud-samples-data/langchain/cloud-sql/postgres/first_five_netflix_titles.sql --database={database_name} --quiet\n", - "\n", - "if \"Imported data\" in str(import_command_output):\n", - " print(import_command_output)\n", - "elif \"already exists\" in str(import_command_output):\n", - " print(\"Did not import because the table already existed.\")\n", - "else:\n", - " raise Exception(f\"The import seems failed:\\n{import_command_output}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SsGS80H04bDN" - }, - "source": [ - "# **Use case 1: Cloud SQL for PostgreSQL as a document loader**\n", - "\n", - "Now that you have data in your database, you are ready to use Cloud SQL for PostgreSQL as a [document loader](https://python.langchain.com/docs/modules/data_connection/document_loaders/). This means we will pull data from the database and load it into memory as documents. We can then feed these documents into the vector store." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-CQgPON8dwSK" - }, - "source": [ - "Next let's connect to our CloudSQL PostgreSQL instance using the PostgresEngine class." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "id": "zrwTsWHMkQ_v" - }, - "outputs": [], - "source": [ - "from langchain_google_cloud_sql_pg import PostgresLoader, PostgresEngine, Column\n", - "\n", - "pg_engine = PostgresEngine.from_instance(\n", - " project_id=project_id,\n", - " instance=instance_name,\n", - " region=region,\n", - " database=database_name,\n", - " user=\"postgres\",\n", - " password=password,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8s-C0P-Oee69" - }, - "source": [ - "Once we initialize a PostgresEngine object, we can pass it into the PostgresLoader to connect to a specific database. As you can see we also pass in a query, table_name and a list of columns. The query tells the loader what query to use to pull data. The \"content_columns\" argument refers to the columns that will be used as \"content\" in the document object we will later construct. The rest of the columns in that table will become the \"metadata\" associated with the documents." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "2SdFJT6Vece1" - }, - "outputs": [], - "source": [ - "table_name = \"netflix_titles\"\n", - "content_columns = [\"title\", \"director\", \"cast\", \"description\"]\n", - "loader = await PostgresLoader.create(\n", - " engine=pg_engine,\n", - " query=f\"SELECT * FROM {table_name};\",\n", - " content_columns=content_columns,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dsL-KFrmfuS1" - }, - "source": [ - "Then let's run the function to pull our documents from out database using our document loader. You can see the first 5 documents from the database here. Nice, you just used CloudSQL for PostgreSQL as a document loader!" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "t4zTx-HLfwmW", - "outputId": "da7239e4-710d-43ce-c004-520a0af9c79f" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loaded 8802 from the database. 5 Examples:\n", - "page_content='Midnight Mass Mike Flanagan Kate Siegel, Zach Gilford, Hamish Linklater, Henry Thomas, Kristin Lehman, Samantha Sloyan, Igby Rigney, Rahul Kohli, Annarah Cymone, Annabeth Gish, Alex Essoe, Rahul Abburi, Matt Biedel, Michael Trucco, Crystal Balint, Louis Oliver The arrival of a charismatic young priest brings glorious miracles, ominous mysteries and renewed religious fervor to a dying town desperate to believe.' metadata={'show_id': 's6', 'type': 'TV Show', 'country': None, 'date_added': 'September 24, 2021', 'release_year': 2021, 'rating': 'TV-MA', 'duration': '1 Season', 'listed_in': 'TV Dramas, TV Horror, TV Mysteries'}\n", - "page_content=\"My Little Pony: A New Generation Robert Cullen, José Luis Ucha Vanessa Hudgens, Kimiko Glenn, James Marsden, Sofia Carson, Liza Koshy, Ken Jeong, Elizabeth Perkins, Jane Krakowski, Michael McKean, Phil LaMarr Equestria's divided. But a bright-eyed hero believes Earth Ponies, Pegasi and Unicorns should be pals — and, hoof to heart, she’s determined to prove it.\" metadata={'show_id': 's7', 'type': 'Movie', 'country': None, 'date_added': 'September 24, 2021', 'release_year': 2021, 'rating': 'PG', 'duration': '91 min', 'listed_in': 'Children & Family Movies'}\n", - "page_content='Sankofa Haile Gerima Kofi Ghanaba, Oyafunmike Ogunlano, Alexandra Duah, Nick Medley, Mutabaruka, Afemo Omilami, Reggie Carter, Mzuri On a photo shoot in Ghana, an American model slips back in time, becomes enslaved on a plantation and bears witness to the agony of her ancestral past.' metadata={'show_id': 's8', 'type': 'Movie', 'country': 'United States, Ghana, Burkina Faso, United Kingdom, Germany, Ethiopia', 'date_added': 'September 24, 2021', 'release_year': 1993, 'rating': 'TV-MA', 'duration': '125 min', 'listed_in': 'Dramas, Independent Movies, International Movies'}\n", - "page_content=\"The Great British Baking Show Andy Devonshire Mel Giedroyc, Sue Perkins, Mary Berry, Paul Hollywood A talented batch of amateur bakers face off in a 10-week competition, whipping up their best dishes in the hopes of being named the U.K.'s best.\" metadata={'show_id': 's9', 'type': 'TV Show', 'country': 'United Kingdom', 'date_added': 'September 24, 2021', 'release_year': 2021, 'rating': 'TV-14', 'duration': '9 Seasons', 'listed_in': 'British TV Shows, Reality TV'}\n", - "page_content=\"The Starling Theodore Melfi Melissa McCarthy, Chris O'Dowd, Kevin Kline, Timothy Olyphant, Daveed Diggs, Skyler Gisondo, Laura Harrier, Rosalind Chao, Kimberly Quinn, Loretta Devine, Ravi Kapoor A woman adjusting to life after a loss contends with a feisty bird that's taken over her garden — and a husband who's struggling to find a way forward.\" metadata={'show_id': 's10', 'type': 'Movie', 'country': 'United States', 'date_added': 'September 24, 2021', 'release_year': 2021, 'rating': 'PG-13', 'duration': '104 min', 'listed_in': 'Comedies, Dramas'}\n" - ] - } - ], - "source": [ - "documents = await loader.aload()\n", - "print(f\"Loaded {len(documents)} from the database. 5 Examples:\")\n", - "for doc in documents[:5]:\n", - " print(doc)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "z9uLV3bs4noo" - }, - "source": [ - "# **Use case 2: Cloud SQL for PostgreSQL as Vector Store**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "duVsSeMcgEWl" - }, - "source": [ - "Now, let's learn how to put all of the documents we just loaded into a [vector store](https://python.langchain.com/docs/modules/data_connection/vectorstores/) so that we can use vector search to answer our user's questions!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jfH8oQJ945Ko" - }, - "source": [ - "### Create Your Vector Store table\n", - "\n", - "Based on the documents that we loaded before, we want to create a table with a vector column as our vector store. We will start it by intializing a vector table by calling the `init_vectorstore_table` function from our `engine`. As you can see we list all of the columns for our metadata. We also specify a vector size, 768, that corresponds with the length of the vectors computed by the model our embeddings service uses, Vertex AI's textembedding-gecko.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "id": "e_rmjywG47pv" - }, - "outputs": [], - "source": [ - "from langchain_google_cloud_sql_pg import PostgresEngine, Column\n", - "\n", - "sample_vector_table_name = \"movie_vector_table_samples\"\n", - "\n", - "pg_engine = PostgresEngine.from_instance(\n", - " project_id=project_id,\n", - " instance=instance_name,\n", - " region=region,\n", - " database=database_name,\n", - " user=\"postgres\",\n", - " password=password,\n", - ")\n", - "\n", - "pg_engine.init_vectorstore_table(\n", - " sample_vector_table_name,\n", - " vector_size=768,\n", - " metadata_columns=[\n", - " Column(\"show_id\", \"VARCHAR\", nullable=True),\n", - " Column(\"type\", \"VARCHAR\", nullable=True),\n", - " Column(\"country\", \"VARCHAR\", nullable=True),\n", - " Column(\"date_added\", \"VARCHAR\", nullable=True),\n", - " Column(\"release_year\", \"INTEGER\", nullable=True),\n", - " Column(\"rating\", \"VARCHAR\", nullable=True),\n", - " Column(\"duration\", \"VARCHAR\", nullable=True),\n", - " Column(\"listed_in\", \"VARCHAR\", nullable=True),\n", - " ],\n", - " overwrite_existing=True, # Enabling this will recreate the table if exists.\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KG6rwEuJLNIo" - }, - "source": [ - "### Try inserting the documents into the vector table\n", - "\n", - "Now we will create a vector_store object backed by our vector table in the Cloud SQL database. Let's load the data from the documents to the vector table. Note that for each row, the embedding service will be called to compute the embeddings to store in the vector table. Pricing details can be found [here](https://cloud.google.com/vertex-ai/pricing)." - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "id": "Wo4-7EYCIFF9" - }, - "outputs": [], - "source": [ - "from langchain_google_vertexai import VertexAIEmbeddings\n", - "from langchain_google_cloud_sql_pg import PostgresVectorStore\n", - "\n", - "# Initialize the embedding service. In this case we are using version 003 of Vertex AI's textembedding-gecko model. In general, it is good practice to specify the model version used.\n", - "embeddings_service = VertexAIEmbeddings(\n", - " model_name=\"textembedding-gecko@003\", project=project_id\n", - ")\n", - "\n", - "vector_store = PostgresVectorStore.create_sync(\n", - " engine=pg_engine,\n", - " embedding_service=embeddings_service,\n", - " table_name=sample_vector_table_name,\n", - " metadata_columns=[\n", - " \"show_id\",\n", - " \"type\",\n", - " \"country\",\n", - " \"date_added\",\n", - " \"release_year\",\n", - " \"duration\",\n", - " \"listed_in\",\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fr1rP6KQ-8ag" - }, - "source": [ - "Now let's try to put the documents data into the vector table. Here is a code example to load the first 5 documents in the list." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "CTks8Cy--93B" - }, - "outputs": [], - "source": [ - "import uuid\n", - "\n", - "docs_to_load = documents[:5]\n", - "\n", - "# ! Uncomment the following line to load all 8,800+ documents to the database vector table with calling the embedding service.\n", - "# docs_to_load = documents\n", - "\n", - "ids = [str(uuid.uuid4()) for i in range(len(docs_to_load))]\n", - "vector_store.add_documents(docs_to_load, ids)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "29iztdvfL2BN" - }, - "source": [ - "### Import the rest of your data into your vector table\n", - "\n", - "You don't have to call the embedding service 8,800 times to load all the documents for the demo. Instead, we have prepared a table with the all 8,800+ rows with pre-computed embeddings in a `.sql` file. Again, let's import to our DB using `gcloud` command.\n", - "\n", - "It will restore the `.sql` file to a table with vectors called `movie_vector_table`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "FEe9El7QMjHi" - }, - "outputs": [], - "source": [ - "# Import the netflix titles with vector table using gcloud command\n", - "import_command_output = !gcloud sql import sql {instance_name} gs://cloud-samples-data/langchain/cloud-sql/postgres/netflix_titles_vector_table.sql --database={database_name} --quiet\n", - "\n", - "if \"Imported data\" in str(import_command_output):\n", - " print(import_command_output)\n", - "elif \"already exists\" in str(import_command_output):\n", - " print(\"Did not import because the table already existed.\")\n", - "else:\n", - " raise Exception(f\"The import seems failed:\\n{import_command_output}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZM_OFzZrQEPs" - }, - "source": [ - "# **Use case 3: Cloud SQL for PostgreSQL as Chat Memory**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dxqIPQtjDquk" - }, - "source": [ - "Next we will add chat history (called [“memory” in the context of LangChain](https://python.langchain.com/docs/modules/memory/)) to our application so the LLM can retain context and information across multiple interactions, leading to more coherent and sophisticated conversations or text generation. We can use Cloud SQL for PostgreSQL as “memory” storage in our application so that the LLM can use context from prior conversations to better answer the user’s prompts. First let's initialize Cloud SQL for PostgreSQL as memory storage." - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "id": "vyYQILyoEAqg" - }, - "outputs": [], - "source": [ - "from langchain_google_cloud_sql_pg import PostgresChatMessageHistory, PostgresEngine\n", - "\n", - "message_table_name = \"message_store\"\n", - "\n", - "pg_engine.init_chat_history_table(table_name=message_table_name)\n", - "\n", - "chat_history = PostgresChatMessageHistory.create_sync(\n", - " pg_engine,\n", - " session_id=\"my-test-session\",\n", - " table_name=message_table_name,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2yuXYLTCl2K1" - }, - "source": [ - "Here is an example of how you would add a user message and how you would add an ai message." - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "qDVoTWZal0ZF", - "outputId": "aeb8c338-9f0d-4143-c09d-9c49478940e0" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[HumanMessage(content='What movie was Brad Pitt in?'),\n", - " AIMessage(content='Brad Pitt was in Inglourious Basterds, By the Sea, Killing Them Softly and Babel according to the given data.'),\n", - " HumanMessage(content='How about Jonny Depp?'),\n", - " AIMessage(content=\"Jonny Depp was in Charlie and the Chocolate Factory, The Rum Diary, The Imaginarium of Doctor Parnassus, and What's Eating Gilbert Grape according to the given data.\"),\n", - " HumanMessage(content='Are there movies about animals?'),\n", - " AIMessage(content='Yes, Rango features animals.'),\n", - " HumanMessage(content='Hi!'),\n", - " AIMessage(content=\"Hello there. I'm a model and am happy to help!\")]" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "chat_history.add_user_message(\"Hi!\")\n", - "chat_history.add_ai_message(\"Hello there. I'm a model and am happy to help!\")\n", - "\n", - "chat_history.messages" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "k0O9mta8RQ0v" - }, - "source": [ - "# **Conversational RAG Chain backed by Cloud SQL Postgres**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "j2OxF3JoNA7J" - }, - "source": [ - "So far we've tested with using Cloud SQL for PostgreSQL as document loader, Vector Store and Chat Memory. Now let's use it in the `ConversationalRetrievalChain`.\n", - "\n", - "We will build a chat bot that can answer movie related questions based on the vector search results.\n", - "\n", - "First let's initialize all of our PostgresSQLEngine object to use as a connection in our vector store and chat_history." - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": { - "id": "9ukjOO-sNQ8_" - }, - "outputs": [], - "source": [ - "from langchain_google_vertexai import VertexAIEmbeddings, VertexAI\n", - "from langchain.chains import ConversationalRetrievalChain\n", - "from langchain.memory import ConversationSummaryBufferMemory\n", - "from langchain_core.prompts import PromptTemplate\n", - "from langchain_google_cloud_sql_pg import (\n", - " PostgresEngine,\n", - " PostgresVectorStore,\n", - " PostgresChatMessageHistory,\n", - ")\n", - "\n", - "# Initialize the embedding service\n", - "embeddings_service = VertexAIEmbeddings(\n", - " model_name=\"textembedding-gecko@latest\", project=project_id\n", - ")\n", - "\n", - "# Initialize the engine\n", - "pg_engine = PostgresEngine.from_instance(\n", - " project_id=project_id,\n", - " instance=instance_name,\n", - " region=region,\n", - " database=database_name,\n", - " user=\"postgres\",\n", - " password=password,\n", - ")\n", - "\n", - "# Initialize the Vector Store\n", - "vector_table_name = \"movie_vector_table\"\n", - "vector_store = PostgresVectorStore.create_sync(\n", - " engine=pg_engine,\n", - " embedding_service=embeddings_service,\n", - " table_name=vector_table_name,\n", - " metadata_columns=[\n", - " \"show_id\",\n", - " \"type\",\n", - " \"country\",\n", - " \"date_added\",\n", - " \"release_year\",\n", - " \"duration\",\n", - " \"listed_in\",\n", - " ],\n", - ")\n", - "\n", - "# Initialize the PostgresChatMessageHistory\n", - "chat_history = PostgresChatMessageHistory.create_sync(\n", - " pg_engine,\n", - " session_id=\"my-test-session\",\n", - " table_name=\"message_store\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ytlz9D3LmcU7" - }, - "source": [ - "Let's create a prompt for the LLM. Here we can add instructions specific to our application, such as \"Don't make things up\"." - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": { - "id": "LoAHNdrWmW9W" - }, - "outputs": [], - "source": [ - "# Prepare some prompt templates for the ConversationalRetrievalChain\n", - "prompt = PromptTemplate(\n", - " template=\"\"\"Use all the information from the context and the conversation history to answer new question. If you see the answer in previous conversation history or the context. \\\n", - "Answer it with clarifying the source information. If you don't see it in the context or the chat history, just say you \\\n", - "didn't find the answer in the given data. Don't make things up.\n", - "\n", - "Previous conversation history from the questioner. \"Human\" was the user who's asking the new question. \"Assistant\" was you as the assistant:\n", - "```{chat_history}\n", - "```\n", - "\n", - "Vector search result of the new question:\n", - "```{context}\n", - "```\n", - "\n", - "New Question:\n", - "```{question}```\n", - "\n", - "Answer:\"\"\",\n", - " input_variables=[\"context\", \"question\", \"chat_history\"],\n", - ")\n", - "condense_question_prompt_passthrough = PromptTemplate(\n", - " template=\"\"\"Repeat the following question:\n", - "{question}\n", - "\"\"\",\n", - " input_variables=[\"question\"],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rsGe-bW5m0H1" - }, - "source": [ - "Now let's use our vector store as a retreiver. Retreiver's in Langchain allow us to literally \"retrieve\" documents." - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": { - "id": "1nI0xkJamvXt" - }, - "outputs": [], - "source": [ - "# Initialize retriever, llm and memory for the chain\n", - "retriever = vector_store.as_retriever(\n", - " search_type=\"mmr\", search_kwargs={\"k\": 5, \"lambda_mult\": 0.8}\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3maZ8SLlneYJ" - }, - "source": [ - "Now let's initialize our LLM, in this case we are using Vertex AI's \"gemini-pro\"." - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": { - "id": "VBWhg-ihnnxF" - }, - "outputs": [], - "source": [ - "llm = VertexAI(model_name=\"gemini-pro\", project=project_id)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hN8mpXdtnocg" - }, - "source": [ - "We clear our chat history, so that our application starts without any prior context to other conversations we have had with the application." - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": { - "id": "1UkPcEpJno5Y" - }, - "outputs": [], - "source": [ - "chat_history.clear()\n", - "\n", - "memory = ConversationSummaryBufferMemory(\n", - " llm=llm,\n", - " chat_memory=chat_history,\n", - " output_key=\"answer\",\n", - " memory_key=\"chat_history\",\n", - " return_messages=True,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BDAT2koSn8Mz" - }, - "source": [ - "Now let's create a conversational retrieval chain. This will allow the LLM to use chat history in it's responses, meaning we can ask it follow up questions to our questions instead of having to start from scratch after each inquiry." - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "7Fu8fKdEn8h8", - "outputId": "abadf0d2-abcd-47a4-d598-45140205593f" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Question: What movie was Brad Pitt in?\n", - "Answer: Inglourious Basterds, By the Sea, Killing Them Softly, Babel\n", - "\n", - "Question: How about Jonny Depp?\n", - "Answer: Charlie and the Chocolate Factory, The Rum Diary, The Imaginarium of Doctor Parnassus, What's Eating Gilbert Grape (Vector search result)\n", - "\n", - "Question: Are there movies about animals?\n", - "Answer: Yes, there are movies about animals. For example, \"Animals on the Loose: A You vs. Wild Movie\" is an interactive special where you and Bear Grylls must pursue escaped wild animals and secure their protective habitat. (Vector search result)\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "[HumanMessage(content='What movie was Brad Pitt in?'),\n", - " AIMessage(content='Inglourious Basterds, By the Sea, Killing Them Softly, Babel'),\n", - " HumanMessage(content='How about Jonny Depp?'),\n", - " AIMessage(content=\"Charlie and the Chocolate Factory, The Rum Diary, The Imaginarium of Doctor Parnassus, What's Eating Gilbert Grape (Vector search result)\"),\n", - " HumanMessage(content='Are there movies about animals?'),\n", - " AIMessage(content='Yes, there are movies about animals. For example, \"Animals on the Loose: A You vs. Wild Movie\" is an interactive special where you and Bear Grylls must pursue escaped wild animals and secure their protective habitat.'),\n", - " HumanMessage(content='What movie was Brad Pitt in?'),\n", - " AIMessage(content='Inglourious Basterds, By the Sea, Killing Them Softly, Babel'),\n", - " HumanMessage(content='How about Jonny Depp?'),\n", - " AIMessage(content=\"Charlie and the Chocolate Factory, The Rum Diary, The Imaginarium of Doctor Parnassus, What's Eating Gilbert Grape (Vector search result)\"),\n", - " HumanMessage(content='Are there movies about animals?'),\n", - " AIMessage(content='Yes, there are movies about animals. For example, \"Animals on the Loose: A You vs. Wild Movie\" is an interactive special where you and Bear Grylls must pursue escaped wild animals and secure their protective habitat. (Vector search result)')]" - ] - }, - "execution_count": 32, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# create the ConversationalRetrievalChain\n", - "rag_chain = ConversationalRetrievalChain.from_llm(\n", - " llm=llm,\n", - " retriever=retriever,\n", - " verbose=False,\n", - " memory=memory,\n", - " condense_question_prompt=condense_question_prompt_passthrough,\n", - " combine_docs_chain_kwargs={\"prompt\": prompt},\n", - ")\n", - "\n", - "# ask some questions\n", - "q = \"What movie was Brad Pitt in?\"\n", - "ans = rag_chain({\"question\": q, \"chat_history\": chat_history})[\"answer\"]\n", - "print(f\"Question: {q}\\nAnswer: {ans}\\n\")\n", - "\n", - "q = \"How about Jonny Depp?\"\n", - "ans = rag_chain({\"question\": q, \"chat_history\": chat_history})[\"answer\"]\n", - "print(f\"Question: {q}\\nAnswer: {ans}\\n\")\n", - "\n", - "q = \"Are there movies about animals?\"\n", - "ans = rag_chain({\"question\": q, \"chat_history\": chat_history})[\"answer\"]\n", - "print(f\"Question: {q}\\nAnswer: {ans}\\n\")\n", - "\n", - "# browser the chat history\n", - "chat_history.messages" - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6zvUr-Qev6lL" + }, + "outputs": [], + "source": [ + "# Copyright 2024 Google LLC\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ob11AkrStrRI" + }, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/langchain-google-cloud-sql-pg-python/blob/main/samples/langchain_quick_start.ipynb)\n", + "\n", + "---\n", + "# **Introduction**\n", + "\n", + "In this codelab, you'll learn how to create a powerful interactive generative AI application using Retrieval Augmented Generation powered by [Cloud SQL for PostgreSQL](https://cloud.google.com/sql/docs/postgres) and [LangChain](https://www.langchain.com/). We will be creating an application grounded in a [Netflix Movie dataset](https://www.kaggle.com/datasets/shivamb/netflix-shows), allowing you to interact with movie data in exciting new ways." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ma6pEng3ypbA" + }, + "source": [ + "## Prerequisites\n", + "\n", + "* A basic understanding of the Google Cloud Console\n", + "* Basic skills in command line interface and Google Cloud shell\n", + "* Basic python knowledge" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DzDgqJHgysy1" + }, + "source": [ + "## What you'll learn\n", + "\n", + "* How to deploy a Cloud SQL for PostgreSQL instance\n", + "* How to use Cloud SQL for PostgreSQL as a DocumentLoader\n", + "* How to use Cloud SQL for PostgreSQL as a VectorStore\n", + "* How to use Cloud SQL for PostgreSQL for ChatMessageHistory storage" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FbcZUjT1yvTq" + }, + "source": [ + "## What you'll need\n", + "* A Google Cloud Account and Google Cloud Project\n", + "* A web browser such as [Chrome](https://www.google.com/chrome/)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vHdR4fF3vLWA" + }, + "source": [ + "# **Setup and Requirements**\n", + "\n", + "In the following instructions you will learn to:\n", + "1. Install required dependencies for our application\n", + "2. Set up authentication for our project\n", + "3. Set up a Cloud SQL for PostgreSQL Instance\n", + "4. Import the data used by our application" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uy9KqgPQ4GBi" + }, + "source": [ + "## Install dependencies\n", + "First you will need to install the dependencies needed to run this demo app." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "M_ppDxYf4Gqs" + }, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet langchain-google-cloud-sql-pg langchain-google-vertexai langchain" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Colab only:** Uncomment the following cell to restart the kernel or use the button to restart the kernel. For Vertex AI Workbench you can restart the terminal using the button on top." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Automatically restart kernel after installs so that your environment can access the new packages\n", + "import IPython\n", + "\n", + "app = IPython.Application.instance()\n", + "app.kernel.do_shutdown(True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DeUbHclxw7_l" + }, + "source": [ + "## Authenticate to Google Cloud within Colab\n", + "In order to access your Google Cloud Project from this notebook, you will need to Authenticate as an IAM user." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "a168rJE1xDHO" + }, + "outputs": [], + "source": [ + "from google.colab import auth\n", + "\n", + "auth.authenticate_user()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UCiNGP1Qxd6x" + }, + "source": [ + "## Connect Your Google Cloud Project\n", + "Time to connect your Google Cloud Project to this notebook so that you can leverage Google Cloud from within Colab." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "qjFuhRhVxlWP" + }, + "outputs": [], + "source": [ + "# @markdown Please fill in the value below with your GCP project ID and then run the cell.\n", + "\n", + "# Please fill in these values.\n", + "project_id = \"\" # @param {type:\"string\"}\n", + "\n", + "# Quick input validations.\n", + "assert project_id, \"⚠️ Please provide a Google Cloud project ID\"\n", + "\n", + "# Configure gcloud.\n", + "!gcloud config set project {project_id}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O-oqMC5Ox-ZM" + }, + "source": [ + "## Configure Your Google Cloud Project\n", + "\n", + "Configure the following in your Google Cloud Project." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1. IAM principal (user, service account, etc.) with the [Cloud SQL Client][client-role] role. The user logged into this notebook will be used as the IAM principal and will be granted the Cloud SQL Client role.\n", + "\n", + "[client-role]: https://cloud.google.com/sql/docs/mysql/roles-and-permissions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "current_user = !gcloud auth list --filter=status:ACTIVE --format=\"value(account)\"\n", + "!gcloud projects add-iam-policy-binding {project_id} \\\n", + " --member=user:{current_user[0]} \\\n", + " --role=\"roles/cloudsql.client\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "2. Enable the APIs for Cloud SQL and Vertex AI within your project." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CKWrwyfzyTwH" + }, + "outputs": [], + "source": [ + "# Enable GCP services\n", + "!gcloud services enable sqladmin.googleapis.com aiplatform.googleapis.com" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Gn8g7-wCyZU6" + }, + "source": [ + "## Set up Cloud SQL\n", + "You will need a **Postgres** Cloud SQL instance for the following stages of this notebook." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "T616pEOUygYQ" + }, + "source": [ + "### Create a Postgres Instance\n", + "Running the below cell will verify the existence of the Cloud SQL instance and or create a new instance and database if one does not exist.\n", + "\n", + "> ⏳ - Creating a Cloud SQL instance may take a few minutes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "XXI1uUu3y8gc" + }, + "outputs": [], + "source": [ + "#@markdown Please fill in the both the Google Cloud region and name of your Cloud SQL instance. Once filled in, run the cell.\n", + "\n", + "# Please fill in these values.\n", + "region = \"us-central1\" #@param {type:\"string\"}\n", + "instance_name = \"langchain-quickstart-instance\" #@param {type:\"string\"}\n", + "database_name = \"langchain-quickstart-db\" #@param {type:\"string\"}\n", + "password = input(\"Please provide a password to be used for 'postgres' database user: \")\n", + "\n", + "# Quick input validations.\n", + "assert region, \"⚠️ Please provide a Google Cloud region\"\n", + "assert instance_name, \"⚠️ Please provide the name of your instance\"\n", + "assert database_name, \"⚠️ Please provide the name of your database_name\"\n", + "\n", + "# check if Cloud SQL instance exists in the provided region\n", + "database_version = !gcloud sql instances describe {instance_name} --format=\"value(databaseVersion)\"\n", + "if database_version[0].startswith(\"POSTGRES\"):\n", + " print(\"Found existing Postgres Cloud SQL Instance!\")\n", + "else:\n", + " print(\"Creating new Cloud SQL instance...\")\n", + " !gcloud sql instances create {instance_name} --database-version=POSTGRES_15 \\\n", + " --region={region} --cpu=1 --memory=4GB --root-password={password} \\\n", + " --database-flags=cloudsql.iam_authentication=On\n", + " !gcloud sql databases create {database_name} --instance={instance_name}\n", + "\n", + "\n", + "databases = !gcloud sql databases list --instance={instance_name} --format=\"value(name)\"\n", + "if database_name not in databases:\n", + " !gcloud sql databases create {database_name} --instance={instance_name}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HdolCWyatZmG" + }, + "source": [ + "## Import data to your database\n", + "\n", + "Now that you have your database, you will need to import data! We will be using a [Netflix Dataset from Kaggle](https://www.kaggle.com/datasets/shivamb/netflix-shows). Here is what the data looks like:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "36-FBKzJ-tLa" + }, + "source": [ + "| show_id | type | title | director | cast | country | date_added | release_year | rating | duration | listed_in | description |\n", + "|---------|---------|----------------------|------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------|---------------|-------------------|--------------|--------|-----------|----------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| s1 | Movie | Dick Johnson Is Dead | Kirsten Johnson | | United States | September 25, 2021 | 2020 | PG-13 | 90 min | Documentaries | As her father nears the end of his life, filmmaker Kirsten Johnson stages his death in inventive and comical ways to help them both face the inevitable. |\n", + "| s2 | TV Show | Blood & Water | | Ama Qamata, Khosi Ngema, Gail Mabalane, Thabang Molaba, Dillon Windvogel, Natasha Thahane, Arno Greeff, Xolile Tshabalala, Getmore Sithole, Cindy Mahlangu, Ryle De Morny, Greteli Fincham, Sello Maake Ka-Ncube, Odwa Gwanya, Mekaila Mathys, Sandi Schultz, Duane Williams, Shamilla Miller, Patrick Mofokeng | South Africa | September 24, 2021 | 2021 | TV-MA | 2 Seasons | International TV Shows, TV Dramas, TV Mysteries | After crossing paths at a party, a Cape Town teen sets out to prove whether a private-school swimming star is her sister who was abducted at birth. |\n", + "| s3 | TV Show | Ganglands | Julien Leclercq | Sami Bouajila, Tracy Gotoas, Samuel Jouy, Nabiha Akkari, Sofia Lesaffre, Salim Kechiouche, Noureddine Farihi, Geert Van Rampelberg, Bakary Diombera | | September 24, 2021 | 2021 | TV-MA | 1 Season | Crime TV Shows, International TV Shows, TV Action & Adventure | To protect his family from a powerful drug lord, skilled thief Mehdi and his expert team of robbers are pulled into a violent and deadly turf war. |\n", + "| s4 | TV Show | Jailbirds New Orleans | | | | September 24, 2021 | 2021 | TV-MA | 1 Season | Docuseries, Reality TV | Feuds, flirtations and toilet talk go down among the incarcerated women at the Orleans Justice Center in New Orleans on this gritty reality series. |\n", + "| s5 | TV Show | Kota Factory | | Mayur More, Jitendra Kumar, Ranjan Raj, Alam Khan, Ahsaas Channa, Revathi Pillai, Urvi Singh, Arun Kumar | India | September 24, 2021 | 2021 | TV-MA | 2 Seasons | International TV Shows, Romantic TV Shows, TV Comedies | In a city of coaching centers known to train India’s finest collegiate minds, an earnest but unexceptional student and his friends navigate campus life. |\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kQ2KWsYI_Msa" + }, + "source": [ + "You won't need to directly load the csv data into your database. Instead we prepared a table, \"netflix_titles\", in the format of a `.sql` file for Postgres. You can easily import the table into your database with one `gcloud` command." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qbLhv9jgD8nm" + }, + "outputs": [], + "source": [ + "# Import the Netflix titles table using gcloud command\n", + "import_command_output = !gcloud sql import sql {instance_name} gs://cloud-samples-data/langchain/cloud-sql/postgres/first_five_netflix_titles.sql --database={database_name} --quiet\n", + "\n", + "if \"Imported data\" in str(import_command_output):\n", + " print(import_command_output)\n", + "elif \"already exists\" in str(import_command_output):\n", + " print(\"Did not import because the table already existed.\")\n", + "else:\n", + " raise Exception(f\"The import seems failed:\\n{import_command_output}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SsGS80H04bDN" + }, + "source": [ + "# **Use case 1: Cloud SQL for PostgreSQL as a document loader**\n", + "\n", + "Now that you have data in your database, you are ready to use Cloud SQL for PostgreSQL as a [document loader](https://python.langchain.com/docs/modules/data_connection/document_loaders/). This means we will pull data from the database and load it into memory as documents. We can then feed these documents into the vector store." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-CQgPON8dwSK" + }, + "source": [ + "Next let's connect to our CloudSQL PostgreSQL instance using the PostgresEngine class.\n", + "\n", + "To connect to your Cloud SQL instance from this notebook, you will need to enable public IP on your instance. Alternatively, you can follow [these instructions](https://cloud.google.com/sql/docs/postgres/connect-to-instance-from-outside-vpc) to connect to an Cloud SQL for PostgreSQL instance with Private IP from outside your VPC. Learn more about [specifying IP types](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector?tab=readme-ov-file#specifying-ip-address-type)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zrwTsWHMkQ_v" + }, + "outputs": [], + "source": [ + "from langchain_google_cloud_sql_pg import PostgresLoader, PostgresEngine, Column\n", + "\n", + "pg_engine = PostgresEngine.from_instance(\n", + " project_id=project_id,\n", + " instance=instance_name,\n", + " region=region,\n", + " database=database_name,\n", + " user=\"postgres\",\n", + " password=password,\n", + " ip_type=\"public\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8s-C0P-Oee69" + }, + "source": [ + "Once we initialize a PostgresEngine object, we can pass it into the PostgresLoader to connect to a specific database. As you can see we also pass in a query, table_name and a list of columns. The query tells the loader what query to use to pull data. The \"content_columns\" argument refers to the columns that will be used as \"content\" in the document object we will later construct. The rest of the columns in that table will become the \"metadata\" associated with the documents." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "2SdFJT6Vece1" + }, + "outputs": [], + "source": [ + "table_name = \"netflix_titles\"\n", + "content_columns = [\"title\", \"director\", \"cast\", \"description\"]\n", + "loader = await PostgresLoader.create(\n", + " engine=pg_engine,\n", + " query=f\"SELECT * FROM {table_name};\",\n", + " content_columns=content_columns,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dsL-KFrmfuS1" + }, + "source": [ + "Then let's run the function to pull our documents from out database using our document loader. You can see the first 5 documents from the database here. Nice, you just used CloudSQL for PostgreSQL as a document loader!" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file + "id": "t4zTx-HLfwmW", + "outputId": "da7239e4-710d-43ce-c004-520a0af9c79f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded 8802 from the database. 5 Examples:\n", + "page_content='Midnight Mass Mike Flanagan Kate Siegel, Zach Gilford, Hamish Linklater, Henry Thomas, Kristin Lehman, Samantha Sloyan, Igby Rigney, Rahul Kohli, Annarah Cymone, Annabeth Gish, Alex Essoe, Rahul Abburi, Matt Biedel, Michael Trucco, Crystal Balint, Louis Oliver The arrival of a charismatic young priest brings glorious miracles, ominous mysteries and renewed religious fervor to a dying town desperate to believe.' metadata={'show_id': 's6', 'type': 'TV Show', 'country': None, 'date_added': 'September 24, 2021', 'release_year': 2021, 'rating': 'TV-MA', 'duration': '1 Season', 'listed_in': 'TV Dramas, TV Horror, TV Mysteries'}\n", + "page_content=\"My Little Pony: A New Generation Robert Cullen, José Luis Ucha Vanessa Hudgens, Kimiko Glenn, James Marsden, Sofia Carson, Liza Koshy, Ken Jeong, Elizabeth Perkins, Jane Krakowski, Michael McKean, Phil LaMarr Equestria's divided. But a bright-eyed hero believes Earth Ponies, Pegasi and Unicorns should be pals — and, hoof to heart, she’s determined to prove it.\" metadata={'show_id': 's7', 'type': 'Movie', 'country': None, 'date_added': 'September 24, 2021', 'release_year': 2021, 'rating': 'PG', 'duration': '91 min', 'listed_in': 'Children & Family Movies'}\n", + "page_content='Sankofa Haile Gerima Kofi Ghanaba, Oyafunmike Ogunlano, Alexandra Duah, Nick Medley, Mutabaruka, Afemo Omilami, Reggie Carter, Mzuri On a photo shoot in Ghana, an American model slips back in time, becomes enslaved on a plantation and bears witness to the agony of her ancestral past.' metadata={'show_id': 's8', 'type': 'Movie', 'country': 'United States, Ghana, Burkina Faso, United Kingdom, Germany, Ethiopia', 'date_added': 'September 24, 2021', 'release_year': 1993, 'rating': 'TV-MA', 'duration': '125 min', 'listed_in': 'Dramas, Independent Movies, International Movies'}\n", + "page_content=\"The Great British Baking Show Andy Devonshire Mel Giedroyc, Sue Perkins, Mary Berry, Paul Hollywood A talented batch of amateur bakers face off in a 10-week competition, whipping up their best dishes in the hopes of being named the U.K.'s best.\" metadata={'show_id': 's9', 'type': 'TV Show', 'country': 'United Kingdom', 'date_added': 'September 24, 2021', 'release_year': 2021, 'rating': 'TV-14', 'duration': '9 Seasons', 'listed_in': 'British TV Shows, Reality TV'}\n", + "page_content=\"The Starling Theodore Melfi Melissa McCarthy, Chris O'Dowd, Kevin Kline, Timothy Olyphant, Daveed Diggs, Skyler Gisondo, Laura Harrier, Rosalind Chao, Kimberly Quinn, Loretta Devine, Ravi Kapoor A woman adjusting to life after a loss contends with a feisty bird that's taken over her garden — and a husband who's struggling to find a way forward.\" metadata={'show_id': 's10', 'type': 'Movie', 'country': 'United States', 'date_added': 'September 24, 2021', 'release_year': 2021, 'rating': 'PG-13', 'duration': '104 min', 'listed_in': 'Comedies, Dramas'}\n" + ] + } + ], + "source": [ + "documents = await loader.aload()\n", + "print(f\"Loaded {len(documents)} from the database. 5 Examples:\")\n", + "for doc in documents[:5]:\n", + " print(doc)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z9uLV3bs4noo" + }, + "source": [ + "# **Use case 2: Cloud SQL for PostgreSQL as Vector Store**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "duVsSeMcgEWl" + }, + "source": [ + "Now, let's learn how to put all of the documents we just loaded into a [vector store](https://python.langchain.com/docs/modules/data_connection/vectorstores/) so that we can use vector search to answer our user's questions!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jfH8oQJ945Ko" + }, + "source": [ + "### Create Your Vector Store table\n", + "\n", + "Based on the documents that we loaded before, we want to create a table with a vector column as our vector store. We will start it by intializing a vector table by calling the `init_vectorstore_table` function from our `engine`. As you can see we list all of the columns for our metadata. We also specify a vector size, 768, that corresponds with the length of the vectors computed by the model our embeddings service uses, Vertex AI's textembedding-gecko.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "id": "e_rmjywG47pv" + }, + "outputs": [], + "source": [ + "from langchain_google_cloud_sql_pg import PostgresEngine, Column\n", + "\n", + "sample_vector_table_name = \"movie_vector_table_samples\"\n", + "\n", + "pg_engine = PostgresEngine.from_instance(\n", + " project_id=project_id,\n", + " instance=instance_name,\n", + " region=region,\n", + " database=database_name,\n", + " user=\"postgres\",\n", + " password=password,\n", + ")\n", + "\n", + "pg_engine.init_vectorstore_table(\n", + " sample_vector_table_name,\n", + " vector_size=768,\n", + " metadata_columns=[\n", + " Column(\"show_id\", \"VARCHAR\", nullable=True),\n", + " Column(\"type\", \"VARCHAR\", nullable=True),\n", + " Column(\"country\", \"VARCHAR\", nullable=True),\n", + " Column(\"date_added\", \"VARCHAR\", nullable=True),\n", + " Column(\"release_year\", \"INTEGER\", nullable=True),\n", + " Column(\"rating\", \"VARCHAR\", nullable=True),\n", + " Column(\"duration\", \"VARCHAR\", nullable=True),\n", + " Column(\"listed_in\", \"VARCHAR\", nullable=True),\n", + " ],\n", + " overwrite_existing=True, # Enabling this will recreate the table if exists.\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KG6rwEuJLNIo" + }, + "source": [ + "### Try inserting the documents into the vector table\n", + "\n", + "Now we will create a vector_store object backed by our vector table in the Cloud SQL database. Let's load the data from the documents to the vector table. Note that for each row, the embedding service will be called to compute the embeddings to store in the vector table. Pricing details can be found [here](https://cloud.google.com/vertex-ai/pricing)." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "id": "Wo4-7EYCIFF9" + }, + "outputs": [], + "source": [ + "from langchain_google_vertexai import VertexAIEmbeddings\n", + "from langchain_google_cloud_sql_pg import PostgresVectorStore\n", + "\n", + "# Initialize the embedding service. In this case we are using version 003 of Vertex AI's textembedding-gecko model. In general, it is good practice to specify the model version used.\n", + "embeddings_service = VertexAIEmbeddings(\n", + " model_name=\"textembedding-gecko@003\", project=project_id\n", + ")\n", + "\n", + "vector_store = PostgresVectorStore.create_sync(\n", + " engine=pg_engine,\n", + " embedding_service=embeddings_service,\n", + " table_name=sample_vector_table_name,\n", + " metadata_columns=[\n", + " \"show_id\",\n", + " \"type\",\n", + " \"country\",\n", + " \"date_added\",\n", + " \"release_year\",\n", + " \"duration\",\n", + " \"listed_in\",\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fr1rP6KQ-8ag" + }, + "source": [ + "Now let's try to put the documents data into the vector table. Here is a code example to load the first 5 documents in the list." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CTks8Cy--93B" + }, + "outputs": [], + "source": [ + "import uuid\n", + "\n", + "docs_to_load = documents[:5]\n", + "\n", + "# ! Uncomment the following line to load all 8,800+ documents to the database vector table with calling the embedding service.\n", + "# docs_to_load = documents\n", + "\n", + "ids = [str(uuid.uuid4()) for i in range(len(docs_to_load))]\n", + "vector_store.add_documents(docs_to_load, ids)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "29iztdvfL2BN" + }, + "source": [ + "### Import the rest of your data into your vector table\n", + "\n", + "You don't have to call the embedding service 8,800 times to load all the documents for the demo. Instead, we have prepared a table with the all 8,800+ rows with pre-computed embeddings in a `.sql` file. Again, let's import to our DB using `gcloud` command.\n", + "\n", + "It will restore the `.sql` file to a table with vectors called `movie_vector_table`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FEe9El7QMjHi" + }, + "outputs": [], + "source": [ + "# Import the netflix titles with vector table using gcloud command\n", + "import_command_output = !gcloud sql import sql {instance_name} gs://cloud-samples-data/langchain/cloud-sql/postgres/netflix_titles_vector_table.sql --database={database_name} --quiet\n", + "\n", + "if \"Imported data\" in str(import_command_output):\n", + " print(import_command_output)\n", + "elif \"already exists\" in str(import_command_output):\n", + " print(\"Did not import because the table already existed.\")\n", + "else:\n", + " raise Exception(f\"The import seems failed:\\n{import_command_output}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZM_OFzZrQEPs" + }, + "source": [ + "# **Use case 3: Cloud SQL for PostgreSQL as Chat Memory**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dxqIPQtjDquk" + }, + "source": [ + "Next we will add chat history (called [“memory” in the context of LangChain](https://python.langchain.com/docs/modules/memory/)) to our application so the LLM can retain context and information across multiple interactions, leading to more coherent and sophisticated conversations or text generation. We can use Cloud SQL for PostgreSQL as “memory” storage in our application so that the LLM can use context from prior conversations to better answer the user’s prompts. First let's initialize Cloud SQL for PostgreSQL as memory storage." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "id": "vyYQILyoEAqg" + }, + "outputs": [], + "source": [ + "from langchain_google_cloud_sql_pg import PostgresChatMessageHistory, PostgresEngine\n", + "\n", + "message_table_name = \"message_store\"\n", + "\n", + "pg_engine.init_chat_history_table(table_name=message_table_name)\n", + "\n", + "chat_history = PostgresChatMessageHistory.create_sync(\n", + " pg_engine,\n", + " session_id=\"my-test-session\",\n", + " table_name=message_table_name,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2yuXYLTCl2K1" + }, + "source": [ + "Here is an example of how you would add a user message and how you would add an ai message." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "qDVoTWZal0ZF", + "outputId": "aeb8c338-9f0d-4143-c09d-9c49478940e0" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[HumanMessage(content='What movie was Brad Pitt in?'),\n", + " AIMessage(content='Brad Pitt was in Inglourious Basterds, By the Sea, Killing Them Softly and Babel according to the given data.'),\n", + " HumanMessage(content='How about Jonny Depp?'),\n", + " AIMessage(content=\"Jonny Depp was in Charlie and the Chocolate Factory, The Rum Diary, The Imaginarium of Doctor Parnassus, and What's Eating Gilbert Grape according to the given data.\"),\n", + " HumanMessage(content='Are there movies about animals?'),\n", + " AIMessage(content='Yes, Rango features animals.'),\n", + " HumanMessage(content='Hi!'),\n", + " AIMessage(content=\"Hello there. I'm a model and am happy to help!\")]" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_history.add_user_message(\"Hi!\")\n", + "chat_history.add_ai_message(\"Hello there. I'm a model and am happy to help!\")\n", + "\n", + "chat_history.messages" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "k0O9mta8RQ0v" + }, + "source": [ + "# **Conversational RAG Chain backed by Cloud SQL Postgres**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "j2OxF3JoNA7J" + }, + "source": [ + "So far we've tested with using Cloud SQL for PostgreSQL as document loader, Vector Store and Chat Memory. Now let's use it in the `ConversationalRetrievalChain`.\n", + "\n", + "We will build a chat bot that can answer movie related questions based on the vector search results.\n", + "\n", + "First let's initialize all of our PostgresSQLEngine object to use as a connection in our vector store and chat_history." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "id": "9ukjOO-sNQ8_" + }, + "outputs": [], + "source": [ + "from langchain_google_vertexai import VertexAIEmbeddings, VertexAI\n", + "from langchain.chains import ConversationalRetrievalChain\n", + "from langchain.memory import ConversationSummaryBufferMemory\n", + "from langchain_core.prompts import PromptTemplate\n", + "from langchain_google_cloud_sql_pg import (\n", + " PostgresEngine,\n", + " PostgresVectorStore,\n", + " PostgresChatMessageHistory,\n", + ")\n", + "\n", + "# Initialize the embedding service\n", + "embeddings_service = VertexAIEmbeddings(\n", + " model_name=\"textembedding-gecko@latest\", project=project_id\n", + ")\n", + "\n", + "# Initialize the engine\n", + "pg_engine = PostgresEngine.from_instance(\n", + " project_id=project_id,\n", + " instance=instance_name,\n", + " region=region,\n", + " database=database_name,\n", + " user=\"postgres\",\n", + " password=password,\n", + ")\n", + "\n", + "# Initialize the Vector Store\n", + "vector_table_name = \"movie_vector_table\"\n", + "vector_store = PostgresVectorStore.create_sync(\n", + " engine=pg_engine,\n", + " embedding_service=embeddings_service,\n", + " table_name=vector_table_name,\n", + " metadata_columns=[\n", + " \"show_id\",\n", + " \"type\",\n", + " \"country\",\n", + " \"date_added\",\n", + " \"release_year\",\n", + " \"duration\",\n", + " \"listed_in\",\n", + " ],\n", + ")\n", + "\n", + "# Initialize the PostgresChatMessageHistory\n", + "chat_history = PostgresChatMessageHistory.create_sync(\n", + " pg_engine,\n", + " session_id=\"my-test-session\",\n", + " table_name=\"message_store\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ytlz9D3LmcU7" + }, + "source": [ + "Let's create a prompt for the LLM. Here we can add instructions specific to our application, such as \"Don't make things up\"." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "id": "LoAHNdrWmW9W" + }, + "outputs": [], + "source": [ + "# Prepare some prompt templates for the ConversationalRetrievalChain\n", + "prompt = PromptTemplate(\n", + " template=\"\"\"Use all the information from the context and the conversation history to answer new question. If you see the answer in previous conversation history or the context. \\\n", + "Answer it with clarifying the source information. If you don't see it in the context or the chat history, just say you \\\n", + "didn't find the answer in the given data. Don't make things up.\n", + "\n", + "Previous conversation history from the questioner. \"Human\" was the user who's asking the new question. \"Assistant\" was you as the assistant:\n", + "```{chat_history}\n", + "```\n", + "\n", + "Vector search result of the new question:\n", + "```{context}\n", + "```\n", + "\n", + "New Question:\n", + "```{question}```\n", + "\n", + "Answer:\"\"\",\n", + " input_variables=[\"context\", \"question\", \"chat_history\"],\n", + ")\n", + "condense_question_prompt_passthrough = PromptTemplate(\n", + " template=\"\"\"Repeat the following question:\n", + "{question}\n", + "\"\"\",\n", + " input_variables=[\"question\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rsGe-bW5m0H1" + }, + "source": [ + "Now let's use our vector store as a retreiver. Retreiver's in Langchain allow us to literally \"retrieve\" documents." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "id": "1nI0xkJamvXt" + }, + "outputs": [], + "source": [ + "# Initialize retriever, llm and memory for the chain\n", + "retriever = vector_store.as_retriever(\n", + " search_type=\"mmr\", search_kwargs={\"k\": 5, \"lambda_mult\": 0.8}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3maZ8SLlneYJ" + }, + "source": [ + "Now let's initialize our LLM, in this case we are using Vertex AI's \"gemini-pro\"." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "id": "VBWhg-ihnnxF" + }, + "outputs": [], + "source": [ + "llm = VertexAI(model_name=\"gemini-pro\", project=project_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hN8mpXdtnocg" + }, + "source": [ + "We clear our chat history, so that our application starts without any prior context to other conversations we have had with the application." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "id": "1UkPcEpJno5Y" + }, + "outputs": [], + "source": [ + "chat_history.clear()\n", + "\n", + "memory = ConversationSummaryBufferMemory(\n", + " llm=llm,\n", + " chat_memory=chat_history,\n", + " output_key=\"answer\",\n", + " memory_key=\"chat_history\",\n", + " return_messages=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BDAT2koSn8Mz" + }, + "source": [ + "Now let's create a conversational retrieval chain. This will allow the LLM to use chat history in it's responses, meaning we can ask it follow up questions to our questions instead of having to start from scratch after each inquiry." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7Fu8fKdEn8h8", + "outputId": "abadf0d2-abcd-47a4-d598-45140205593f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Question: What movie was Brad Pitt in?\n", + "Answer: Inglourious Basterds, By the Sea, Killing Them Softly, Babel\n", + "\n", + "Question: How about Jonny Depp?\n", + "Answer: Charlie and the Chocolate Factory, The Rum Diary, The Imaginarium of Doctor Parnassus, What's Eating Gilbert Grape (Vector search result)\n", + "\n", + "Question: Are there movies about animals?\n", + "Answer: Yes, there are movies about animals. For example, \"Animals on the Loose: A You vs. Wild Movie\" is an interactive special where you and Bear Grylls must pursue escaped wild animals and secure their protective habitat. (Vector search result)\n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "[HumanMessage(content='What movie was Brad Pitt in?'),\n", + " AIMessage(content='Inglourious Basterds, By the Sea, Killing Them Softly, Babel'),\n", + " HumanMessage(content='How about Jonny Depp?'),\n", + " AIMessage(content=\"Charlie and the Chocolate Factory, The Rum Diary, The Imaginarium of Doctor Parnassus, What's Eating Gilbert Grape (Vector search result)\"),\n", + " HumanMessage(content='Are there movies about animals?'),\n", + " AIMessage(content='Yes, there are movies about animals. For example, \"Animals on the Loose: A You vs. Wild Movie\" is an interactive special where you and Bear Grylls must pursue escaped wild animals and secure their protective habitat.'),\n", + " HumanMessage(content='What movie was Brad Pitt in?'),\n", + " AIMessage(content='Inglourious Basterds, By the Sea, Killing Them Softly, Babel'),\n", + " HumanMessage(content='How about Jonny Depp?'),\n", + " AIMessage(content=\"Charlie and the Chocolate Factory, The Rum Diary, The Imaginarium of Doctor Parnassus, What's Eating Gilbert Grape (Vector search result)\"),\n", + " HumanMessage(content='Are there movies about animals?'),\n", + " AIMessage(content='Yes, there are movies about animals. For example, \"Animals on the Loose: A You vs. Wild Movie\" is an interactive special where you and Bear Grylls must pursue escaped wild animals and secure their protective habitat. (Vector search result)')]" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# create the ConversationalRetrievalChain\n", + "rag_chain = ConversationalRetrievalChain.from_llm(\n", + " llm=llm,\n", + " retriever=retriever,\n", + " verbose=False,\n", + " memory=memory,\n", + " condense_question_prompt=condense_question_prompt_passthrough,\n", + " combine_docs_chain_kwargs={\"prompt\": prompt},\n", + ")\n", + "\n", + "# ask some questions\n", + "q = \"What movie was Brad Pitt in?\"\n", + "ans = rag_chain({\"question\": q, \"chat_history\": chat_history})[\"answer\"]\n", + "print(f\"Question: {q}\\nAnswer: {ans}\\n\")\n", + "\n", + "q = \"How about Jonny Depp?\"\n", + "ans = rag_chain({\"question\": q, \"chat_history\": chat_history})[\"answer\"]\n", + "print(f\"Question: {q}\\nAnswer: {ans}\\n\")\n", + "\n", + "q = \"Are there movies about animals?\"\n", + "ans = rag_chain({\"question\": q, \"chat_history\": chat_history})[\"answer\"]\n", + "print(f\"Question: {q}\\nAnswer: {ans}\\n\")\n", + "\n", + "# browser the chat history\n", + "chat_history.messages" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/samples/requirements.txt b/samples/requirements.txt index 090f6c2e..153755af 100644 --- a/samples/requirements.txt +++ b/samples/requirements.txt @@ -1,5 +1,5 @@ -google-cloud-aiplatform[reasoningengine,langchain]==1.79.0 -google-cloud-resource-manager==1.14.0 -langchain-community==0.3.16 -langchain-google-cloud-sql-pg==0.11.1 -langchain-google-vertexai==2.0.12 +google-cloud-aiplatform[reasoningengine,langchain]==1.81.0 +google-cloud-resource-manager==1.14.1 +langchain-community==0.3.18 +langchain-google-cloud-sql-pg==0.12.1 +langchain-google-vertexai==2.0.14 diff --git a/src/langchain_google_cloud_sql_pg/__init__.py b/src/langchain_google_cloud_sql_pg/__init__.py index f0d2eab0..ca8ab9ef 100644 --- a/src/langchain_google_cloud_sql_pg/__init__.py +++ b/src/langchain_google_cloud_sql_pg/__init__.py @@ -14,6 +14,7 @@ from . import indexes from .chat_message_history import PostgresChatMessageHistory +from .checkpoint import PostgresSaver from .engine import Column, PostgresEngine from .loader import PostgresDocumentSaver, PostgresLoader from .vectorstore import PostgresVectorStore @@ -27,5 +28,6 @@ "PostgresEngine", "PostgresLoader", "PostgresDocumentSaver", + "PostgresSaver", "__version__", ] diff --git a/src/langchain_google_cloud_sql_pg/async_checkpoint.py b/src/langchain_google_cloud_sql_pg/async_checkpoint.py new file mode 100644 index 00000000..560182f7 --- /dev/null +++ b/src/langchain_google_cloud_sql_pg/async_checkpoint.py @@ -0,0 +1,592 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any, AsyncIterator, Iterator, Optional, Sequence, Tuple + +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + WRITES_IDX_MAP, + BaseCheckpointSaver, + ChannelVersions, + Checkpoint, + CheckpointMetadata, + CheckpointTuple, + get_checkpoint_id, +) +from langgraph.checkpoint.serde.base import SerializerProtocol +from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer +from langgraph.checkpoint.serde.types import TASKS +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import CHECKPOINTS_TABLE, PostgresEngine + +MetadataInput = Optional[dict[str, Any]] + +checkpoints_columns = [ + "thread_id", + "checkpoint_ns", + "checkpoint_id", + "parent_checkpoint_id", + "type", + "checkpoint", + "metadata", +] + +writes_columns = [ + "thread_id", + "checkpoint_ns", + "checkpoint_id", + "task_id", + "idx", + "channel", + "type", + "blob", +] + + +class AsyncPostgresSaver(BaseCheckpointSaver[str]): + """Checkpoint stored in PgSQL""" + + __create_key = object() + + jsonplus_serde = JsonPlusSerializer() + + def __init__( + self, + key: object, + pool: AsyncEngine, + table_name: str = CHECKPOINTS_TABLE, + schema_name: str = "public", + serde: Optional[SerializerProtocol] = None, + ) -> None: + super().__init__(serde=serde) + if key != AsyncPostgresSaver.__create_key: + raise Exception( + "only create class through 'create' or 'create_sync' methods" + ) + self.pool = pool + self.table_name = table_name + self.table_name_writes = f"{table_name}_writes" + self.schema_name = schema_name + + @classmethod + async def create( + cls, + engine: PostgresEngine, + table_name: str = CHECKPOINTS_TABLE, + schema_name: str = "public", + serde: Optional[SerializerProtocol] = None, + ) -> "AsyncPostgresSaver": + """Create a new AsyncPostgresSaver instance. + + Args: + engine (PostgresEngine): PostgresEngine engine to use. + schema_name (str): The schema name where the table is located (default: "public"). + serde (SerializerProtocol): Serializer for encoding/decoding checkpoints (default: None). + table_name (str): Custom table name to use (default: CHECKPOINTS_TABLE). + + Raises: + IndexError: If the table provided does not contain required schema. + + Returns: + AsyncPostgresSaver: A newly created instance of AsyncPostgresSaver. + """ + + checkpoints_table_schema = await engine._aload_table_schema( + table_name, schema_name + ) + checkpoints_column_names = checkpoints_table_schema.columns.keys() + + if not (all(x in checkpoints_column_names for x in checkpoints_columns)): + raise IndexError( + f"Table checkpoints.'{schema_name}' has incorrect schema. Got " + f"column names '{checkpoints_column_names}' but required column names " + f"'{checkpoints_columns}'.\nPlease create table with following schema:" + f"\nCREATE TABLE {schema_name}.checkpoints (" + "\n thread_id TEXT NOT NULL," + "\n checkpoint_ns TEXT NOT NULL," + "\n checkpoint_id TEXT NOT NULL," + "\n parent_checkpoint_id TEXT," + "\n type TEXT," + "\n checkpoint JSONB NOT NULL," + "\n metadata JSONB NOT NULL" + "\n);" + ) + + checkpoint_writes_table_schema = await engine._aload_table_schema( + f"{table_name}_writes", schema_name + ) + checkpoint_writes_column_names = checkpoint_writes_table_schema.columns.keys() + + if not (all(x in checkpoint_writes_column_names for x in writes_columns)): + raise IndexError( + f"Table checkpoint_writes.'{schema_name}' has incorrect schema. Got " + f"column names '{checkpoint_writes_column_names}' but required column names " + f"'{writes_columns}'.\nPlease create table with following schema:" + f"\nCREATE TABLE {schema_name}.checkpoint_writes (" + "\n thread_id TEXT NOT NULL," + "\n checkpoint_ns TEXT NOT NULL," + "\n checkpoint_id TEXT NOT NULL," + "\n task_id TEXT NOT NULL," + "\n idx INT NOT NULL," + "\n channel TEXT NOT NULL," + "\n type TEXT," + "\n blob JSONB NOT NULL" + "\n);" + ) + return cls(cls.__create_key, engine._pool, table_name, schema_name, serde) + + def _dump_writes( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + task_id: str, + task_path: str, + writes: Sequence[tuple[str, Any]], + ) -> list[dict[str, Any]]: + return [ + { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint_id, + "task_id": task_id, + "task_path": task_path, + "idx": WRITES_IDX_MAP.get(channel, idx), + "channel": channel, + "type": self.serde.dumps_typed(value)[0], + "blob": self.serde.dumps_typed(value)[1], + } + for idx, (channel, value) in enumerate(writes) + ] + + def _load_writes( + self, writes: list[tuple[bytes, bytes, bytes, bytes]] + ) -> list[tuple[str, str, Any]]: + return ( + [ + ( + tid.decode(), + channel.decode(), + self.serde.loads_typed((t.decode(), v)), + ) + for tid, channel, t, v in writes + ] + if writes + else [] + ) + + def _search_where( + self, + config: Optional[RunnableConfig], + filter: MetadataInput, + before: Optional[RunnableConfig] = None, + ) -> tuple[str, dict[str, Any]]: + """Return WHERE clause predicates for alist() given config, filter, before. + + This method returns a tuple of a string and a tuple of values. The string + is the parameterized WHERE clause predicate (including the WHERE keyword): + "WHERE column1 = $1 AND column2 IS $2". The list of values contains the + values for each of the corresponding parameters. + """ + wheres = [] + param_values = {} + + # construct predicate for config filter + if config: + wheres.append("thread_id = :thread_id") + param_values.update({"thread_id": config["configurable"]["thread_id"]}) + checkpoint_ns = config["configurable"].get("checkpoint_ns") + if checkpoint_ns is not None: + wheres.append("checkpoint_ns = :checkpoint_ns") + param_values.update({"checkpoint_ns": checkpoint_ns}) + + if checkpoint_id := get_checkpoint_id(config): + wheres.append("checkpoint_id = :checkpoint_id") + param_values.update({"checkpoint_id": checkpoint_id}) + + # construct predicate for metadata filter + if filter: + wheres.append("encode(metadata,'escape')::jsonb @> :metadata ") + param_values.update({"metadata": f"{json.dumps(filter)}"}) + + # construct predicate for `before` + if before is not None: + wheres.append("checkpoint_id < :checkpoint_id") + param_values.update({"checkpoint_id": get_checkpoint_id(before)}) + + return ( + "WHERE " + " AND ".join(wheres) if wheres else "", + param_values, + ) + + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Asynchronously store a checkpoint with its configuration and metadata. + + Args: + config (RunnableConfig): Configuration for the checkpoint. + checkpoint (Checkpoint): The checkpoint to store. + metadata (CheckpointMetadata): Additional metadata for the checkpoint. + new_versions (ChannelVersions): New channel versions as of this write. + + Returns: + RunnableConfig: Updated configuration after storing the checkpoint. + """ + configurable = config["configurable"] + thread_id = configurable.get("thread_id") + checkpoint_ns = configurable.get("checkpoint_ns") + checkpoint_id = configurable.get( + "checkpoint_id", configurable.get("thread_ts", None) + ) + + next_config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + } + + query = f"""INSERT INTO "{self.schema_name}"."{self.table_name}" (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata) + VALUES (:thread_id, :checkpoint_ns, :checkpoint_id, :parent_checkpoint_id, :type, :checkpoint, :metadata) + ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id) + DO UPDATE SET + checkpoint = EXCLUDED.checkpoint, + metadata = EXCLUDED.metadata; + """ + + async with self.pool.connect() as conn: + type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint) + serialized_metadata = self.jsonplus_serde.dumps(metadata) + await conn.execute( + text(query), + { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + "parent_checkpoint_id": checkpoint_id, + "type": type_, + "checkpoint": serialized_checkpoint, + "metadata": serialized_metadata, + }, + ) + await conn.commit() + + return next_config + + async def aput_writes( + self, + config: RunnableConfig, + writes: Sequence[Tuple[str, Any]], + task_id: str, + task_path: str = "", + ) -> None: + """Asynchronously store intermediate writes linked to a checkpoint. + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (List[Tuple[str, Any]]): List of writes to store. + task_id (str): Identifier for the task creating the writes. + task_path (str): Path of the task creating the writes. + + Returns: + None + """ + upsert = f"""INSERT INTO "{self.schema_name}"."{self.table_name_writes}"(thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) + VALUES (:thread_id, :checkpoint_ns, :checkpoint_id, :task_id, :idx, :channel, :type, :blob) + ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO UPDATE SET + channel = EXCLUDED.channel, + type = EXCLUDED.type, + blob = EXCLUDED.blob; + """ + insert = f"""INSERT INTO "{self.schema_name}"."{self.table_name_writes}"(thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) + VALUES (:thread_id, :checkpoint_ns, :checkpoint_id, :task_id, :idx, :channel, :type, :blob) + ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO NOTHING + """ + query = upsert if all(w[0] in WRITES_IDX_MAP for w in writes) else insert + + params = self._dump_writes( + config["configurable"]["thread_id"], + config["configurable"]["checkpoint_ns"], + config["configurable"]["checkpoint_id"], + task_id, + task_path, + writes, + ) + + async with self.pool.connect() as conn: + await conn.execute( + text(query), + params, + ) + await conn.commit() + + async def alist( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> AsyncIterator[CheckpointTuple]: + """Asynchronously list checkpoints that match the given criteria. + + Args: + config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. + filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. + before (Optional[RunnableConfig]): List checkpoints created before this configuration. + limit (Optional[int]): Maximum number of checkpoints to return. + + Returns: + AsyncIterator[CheckpointTuple]: Async iterator of matching checkpoint tuples. + """ + SELECT = f""" + SELECT + thread_id, + checkpoint, + checkpoint_ns, + checkpoint_id, + parent_checkpoint_id, + metadata, + type, + ( + SELECT array_agg(array[cw.task_id::text::bytea, cw.channel::bytea, cw.type::bytea, cw.blob] order by cw.task_id, cw.idx) + FROM "{self.schema_name}"."{self.table_name_writes}" cw + where cw.thread_id = c.thread_id + AND cw.checkpoint_ns = c.checkpoint_ns + AND cw.checkpoint_id = c.checkpoint_id + ) AS pending_writes, + ( + SELECT array_agg(array[cw.type::bytea, cw.blob] order by cw.task_path, cw.task_id, cw.idx) + FROM "{self.schema_name}"."{self.table_name_writes}" cw + WHERE cw.thread_id = c.thread_id + AND cw.checkpoint_ns = c.checkpoint_ns + AND cw.checkpoint_id = c.parent_checkpoint_id + AND cw.channel = '{TASKS}' + ) AS pending_sends + FROM "{self.schema_name}"."{self.table_name}" c + """ + + where, args = self._search_where(config, filter, before) + query = SELECT + where + " ORDER BY checkpoint_id DESC" + if limit: + query += f" LIMIT {limit}" + + async with self.pool.connect() as conn: + result = await conn.execute(text(query), args) + while True: + row = result.fetchone() + if not row: + break + value = row._mapping + yield CheckpointTuple( + config={ + "configurable": { + "thread_id": value["thread_id"], + "checkpoint_ns": value["checkpoint_ns"], + "checkpoint_id": value["checkpoint_id"], + } + }, + checkpoint=self.serde.loads_typed( + (value["type"], value["checkpoint"]) + ), + metadata=( + self.jsonplus_serde.loads(value["metadata"]) # type: ignore + if value["metadata"] is not None + else {} + ), + parent_config=( + { + "configurable": { + "thread_id": value["thread_id"], + "checkpoint_ns": value["checkpoint_ns"], + "checkpoint_id": value["parent_checkpoint_id"], + } + } + if value["parent_checkpoint_id"] + else None + ), + pending_writes=self._load_writes(value["pending_writes"]), + ) + + async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Asynchronously fetch a checkpoint tuple using the given configuration. + + Args: + config (RunnableConfig): Configuration specifying which checkpoint to retrieve. + + Returns: + Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found. + """ + + SELECT = f""" + SELECT + thread_id, + checkpoint, + checkpoint_ns, + checkpoint_id, + parent_checkpoint_id, + metadata, + type, + ( + SELECT array_agg(array[cw.task_id::text::bytea, cw.channel::bytea, cw.type::bytea, cw.blob] order by cw.task_id, cw.idx) + FROM "{self.schema_name}"."{self.table_name_writes}" cw + where cw.thread_id = c.thread_id + AND cw.checkpoint_ns = c.checkpoint_ns + AND cw.checkpoint_id = c.checkpoint_id + ) AS pending_writes, + ( + SELECT array_agg(array[cw.type::bytea, cw.blob] order by cw.task_path, cw.task_id, cw.idx) + FROM "{self.schema_name}"."{self.table_name_writes}" cw + WHERE cw.thread_id = c.thread_id + AND cw.checkpoint_ns = c.checkpoint_ns + AND cw.checkpoint_id = c.parent_checkpoint_id + AND cw.channel = '{TASKS}' + ) AS pending_sends + FROM "{self.schema_name}"."{self.table_name}" c + """ + + thread_id = config["configurable"]["thread_id"] + checkpoint_id = get_checkpoint_id(config) + checkpoint_ns = config["configurable"].get("checkpoint_ns", "") + if checkpoint_id: + args = { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint_id, + } + where = "WHERE thread_id = :thread_id AND checkpoint_ns = :checkpoint_ns AND checkpoint_id = :checkpoint_id" + else: + args = {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns} + where = "WHERE thread_id = :thread_id AND checkpoint_ns = :checkpoint_ns ORDER BY checkpoint_id DESC LIMIT 1" + + async with self.pool.connect() as conn: + result = await conn.execute(text(SELECT + where), args) + row = result.fetchone() + if not row: + return None + value = row._mapping + return CheckpointTuple( + config={ + "configurable": { + "thread_id": value["thread_id"], + "checkpoint_ns": value["checkpoint_ns"], + "checkpoint_id": value["checkpoint_id"], + } + }, + checkpoint=self.serde.loads_typed((value["type"], value["checkpoint"])), + metadata=( + self.jsonplus_serde.loads(value["metadata"]) # type: ignore + if value["metadata"] is not None + else {} + ), + parent_config=( + { + "configurable": { + "thread_id": value["thread_id"], + "checkpoint_ns": value["checkpoint_ns"], + "checkpoint_id": value["parent_checkpoint_id"], + } + } + if value["parent_checkpoint_id"] + else None + ), + pending_writes=self._load_writes(value["pending_writes"]), + ) + + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Asynchronously store a checkpoint with its configuration and metadata. + + Args: + config (RunnableConfig): Configuration for the checkpoint. + checkpoint (Checkpoint): The checkpoint to store. + metadata (CheckpointMetadata): Additional metadata for the checkpoint. + new_versions (ChannelVersions): New channel versions as of this write. + + Returns: + RunnableConfig: Updated configuration after storing the checkpoint. + """ + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresSaver. Use PostgresSaver interface instead." + ) + + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[Tuple[str, Any]], + task_id: str, + task_path: str = "", + ) -> None: + """Asynchronously store intermediate writes linked to a checkpoint. + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (List[Tuple[str, Any]]): List of writes to store. + task_id (str): Identifier for the task creating the writes. + task_path (str): Path of the task creating the writes. + + Returns: + None + """ + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresSaver. Use PostgresSaver interface instead." + ) + + def list( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> Iterator[CheckpointTuple]: + """Asynchronously list checkpoints that match the given criteria. + + Args: + config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. + filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. + before (Optional[RunnableConfig]): List checkpoints created before this configuration. + limit (Optional[int]): Maximum number of checkpoints to return. + + Returns: + AsyncIterator[CheckpointTuple]: Async iterator of matching checkpoint tuples. + """ + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresSaver. Use PostgresSaver interface instead." + ) + + def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Asynchronously fetch a checkpoint tuple using the given configuration. + + Args: + config (RunnableConfig): Configuration specifying which checkpoint to retrieve. + + Returns: + Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found. + """ + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresSaver. Use PostgresSaver interface instead." + ) diff --git a/src/langchain_google_cloud_sql_pg/async_vectorstore.py b/src/langchain_google_cloud_sql_pg/async_vectorstore.py index 9d08ae86..b8884438 100644 --- a/src/langchain_google_cloud_sql_pg/async_vectorstore.py +++ b/src/langchain_google_cloud_sql_pg/async_vectorstore.py @@ -246,7 +246,11 @@ async def __aadd_embeddings( else "" ) insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{metadata_col_names}' - values = {"id": id, "content": content, "embedding": str(embedding)} + values = { + "id": id, + "content": content, + "embedding": str([float(dimension) for dimension in embedding]), + } values_stmt = "VALUES (:id, :content, :embedding" # Add metadata @@ -496,9 +500,9 @@ async def __query_collection( columns.append(self.metadata_json_column) column_names = ", ".join(f'"{col}"' for col in columns) - filter = f"WHERE {filter}" if filter else "" - stmt = f"SELECT {column_names}, {search_function}({self.embedding_column}, '{embedding}') as distance FROM \"{self.schema_name}\".\"{self.table_name}\" {filter} ORDER BY {self.embedding_column} {operator} '{embedding}' LIMIT {k};" + embedding_string = f"'{[float(dimension) for dimension in embedding]}'" + stmt = f'SELECT {column_names}, {search_function}({self.embedding_column}, {embedding_string}) as distance FROM "{self.schema_name}"."{self.table_name}" {filter} ORDER BY {self.embedding_column} {operator} {embedding_string} LIMIT {k};' if self.index_query_options: async with self.pool.connect() as conn: await conn.execute( diff --git a/src/langchain_google_cloud_sql_pg/checkpoint.py b/src/langchain_google_cloud_sql_pg/checkpoint.py new file mode 100644 index 00000000..661261bc --- /dev/null +++ b/src/langchain_google_cloud_sql_pg/checkpoint.py @@ -0,0 +1,248 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, AsyncIterator, Iterator, Optional, Sequence, Tuple + +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + BaseCheckpointSaver, + ChannelVersions, + Checkpoint, + CheckpointMetadata, + CheckpointTuple, +) +from langgraph.checkpoint.serde.base import SerializerProtocol + +from .async_checkpoint import AsyncPostgresSaver +from .engine import CHECKPOINTS_TABLE, PostgresEngine + + +class PostgresSaver(BaseCheckpointSaver[str]): + """Checkpoint stored in PgSQL""" + + __create_key = object() + + def __init__( + self, + key: object, + engine: PostgresEngine, + checkpoint: AsyncPostgresSaver, + table_name: str = CHECKPOINTS_TABLE, + schema_name: str = "public", + serde: Optional[SerializerProtocol] = None, + ) -> None: + super().__init__(serde=serde) + if key != PostgresSaver.__create_key: + raise Exception( + "only create class through 'create' or 'create_sync' methods" + ) + self._engine = engine + self.__checkpoint = checkpoint + + @classmethod + async def create( + cls, + engine: PostgresEngine, + table_name: str = CHECKPOINTS_TABLE, + schema_name: str = "public", + serde: Optional[SerializerProtocol] = None, + ) -> "PostgresSaver": + """Create a new PostgresSaver instance. + Args: + engine (PostgresEngine): PgSQL engine to use. + table_name (str): Table name that stores the checkpoints (default: "checkpoints"). + schema_name (str): The schema name where the table is located (default: "public"). + serde (SerializerProtocol): Serializer for encoding/decoding checkpoints (default: None). + Raises: + IndexError: If the table provided does not contain required schema. + Returns: + PostgresSaver: A newly created instance of PostgresSaver. + """ + coro = AsyncPostgresSaver.create(engine, table_name, schema_name, serde) + checkpoint = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, checkpoint) + + @classmethod + def create_sync( + cls, + engine: PostgresEngine, + table_name: str = CHECKPOINTS_TABLE, + schema_name: str = "public", + serde: Optional[SerializerProtocol] = None, + ) -> "PostgresSaver": + """Create a new PostgresSaver instance. + Args: + engine (PostgresEngine): PgSQL engine to use. + table_name (str): Table name that stores the checkpoints (default: "checkpoints"). + schema_name (str): The schema name where the table is located (default: "public"). + serde (SerializerProtocol): Serializer for encoding/decoding checkpoints (default: None). + Raises: + IndexError: If the table provided does not contain required schema. + Returns: + PostgresSaver: A newly created instance of PostgresSaver. + """ + coro = AsyncPostgresSaver.create(engine, table_name, schema_name, serde) + checkpoint = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, checkpoint) + + async def alist( + self, + config: Optional[RunnableConfig], + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> AsyncIterator[CheckpointTuple]: + """Asynchronously list checkpoints that match the given criteria + Args: + config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. + filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. + before (Optional[RunnableConfig]): List checkpoints created before this configuration. + limit (Optional[int]): Maximum number of checkpoints to return. + Returns: + AsyncIterator[CheckpointTuple]: Async iterator of matching checkpoint tuples. + """ + iterator = self.__checkpoint.alist( + config=config, filter=filter, before=before, limit=limit + ) + while True: + try: + result = await self._engine._run_as_async(iterator.__anext__()) + yield result + except StopAsyncIteration: + break + + def list( + self, + config: Optional[RunnableConfig], + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> Iterator[CheckpointTuple]: + """List checkpoints from PgSQL + Args: + config (RunnableConfig): The config to use for listing the checkpoints. + filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. Defaults to None. + before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None. + limit (Optional[int]): The maximum number of checkpoints to return. Defaults to None. + Yields: + Iterator[CheckpointTuple]: An iterator of checkpoint tuples. + """ + + iterator: AsyncIterator[CheckpointTuple] = self.__checkpoint.alist( + config=config, filter=filter, before=before, limit=limit + ) + while True: + try: + result = self._engine._run_as_sync(iterator.__anext__()) + yield result + except StopAsyncIteration: + break + + async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Asynchronously fetch a checkpoint tuple using the given configuration. + Args: + config (RunnableConfig): The config to use for retrieving the checkpoint. + Returns: + Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. + """ + return await self._engine._run_as_async(self.__checkpoint.aget_tuple(config)) + + def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Get a checkpoint tuple from PgSQL. + Args: + config (RunnableConfig): The config to use for retrieving the checkpoint. + Returns: + Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. + """ + return self._engine._run_as_sync(self.__checkpoint.aget_tuple(config)) + + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Asynchronously store a checkpoint with its configuration and metadata. + Args: + config (RunnableConfig): The config to associate with the checkpoint. + checkpoint (Checkpoint): The checkpoint to save. + metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. + new_versions (ChannelVersions): New channel versions as of this write. + Returns: + RunnableConfig: Updated configuration after storing the checkpoint. + """ + return await self._engine._run_as_async( + self.__checkpoint.aput(config, checkpoint, metadata, new_versions) + ) + + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Save a checkpoint to the database. + Args: + config (RunnableConfig): The config to associate with the checkpoint. + checkpoint (Checkpoint): The checkpoint to save. + metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. + new_versions (ChannelVersions): New channel versions as of this write. + Returns: + RunnableConfig: Updated configuration after storing the checkpoint. + """ + return self._engine._run_as_sync( + self.__checkpoint.aput(config, checkpoint, metadata, new_versions) + ) + + async def aput_writes( + self, + config: RunnableConfig, + writes: Sequence[Tuple[str, Any]], + task_id: str, + task_path: str = "", + ) -> None: + """Asynchronously store intermediate writes linked to a checkpoint. + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (List[Tuple[str, Any]]): List of writes to store. + task_id (str): Identifier for the task creating the writes. + task_path (str): Path of the task creating the writes. + Returns: + None + """ + await self._engine._run_as_async( + self.__checkpoint.aput_writes(config, writes, task_id, task_path) + ) + + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str, + task_path: str = "", + ) -> None: + """Store intermediate writes linked to a checkpoint. + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (List[Tuple[str, Any]]): List of writes to store. + task_id (str): Identifier for the task creating the writes. + task_path (str): Path of the task creating the writes. + Returns: + None + """ + self._engine._run_as_sync( + self.__checkpoint.aput_writes(config, writes, task_id, task_path) + ) diff --git a/src/langchain_google_cloud_sql_pg/engine.py b/src/langchain_google_cloud_sql_pg/engine.py index 1fc30815..c40462b5 100644 --- a/src/langchain_google_cloud_sql_pg/engine.py +++ b/src/langchain_google_cloud_sql_pg/engine.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -39,6 +39,8 @@ USER_AGENT = "langchain-google-cloud-sql-pg-python/" + __version__ +CHECKPOINTS_TABLE = "checkpoints" + async def _get_iam_principal_email( credentials: google.auth.credentials.Credentials, @@ -416,7 +418,7 @@ def _run_as_sync(self, coro: Awaitable[T]) -> T: async def close(self) -> None: """Dispose of connection pool""" - await self._pool.dispose() + await self._run_as_async(self._pool.dispose()) async def _ainit_vectorstore_table( self, @@ -747,6 +749,87 @@ def init_document_table( ) ) + async def _ainit_checkpoint_table( + self, table_name: str = CHECKPOINTS_TABLE, schema_name: str = "public" + ) -> None: + """ + Create PgSQL tables to save checkpoints. + + Args: + schema_name (str): The schema name to store the checkpoint tables. + Default: "public". + table_name (str): The PgSQL database table name. + Default: "checkpoints". + + Returns: + None + """ + create_checkpoints_table = f"""CREATE TABLE "{schema_name}"."{table_name}"( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + parent_checkpoint_id TEXT, + type TEXT, + checkpoint BYTEA, + metadata BYTEA, + PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id) + );""" + + create_checkpoint_writes_table = f"""CREATE TABLE "{schema_name}"."{table_name + "_writes"}"( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + task_id TEXT NOT NULL, + idx INTEGER NOT NULL, + channel TEXT NOT NULL, + type TEXT, + blob BYTEA NOT NULL, + task_path TEXT NOT NULL DEFAULT '', + PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) + );""" + + async with self._pool.connect() as conn: + await conn.execute(text(create_checkpoints_table)) + await conn.execute(text(create_checkpoint_writes_table)) + await conn.commit() + + async def ainit_checkpoint_table( + self, table_name: str = CHECKPOINTS_TABLE, schema_name: str = "public" + ) -> None: + """Create an PgSQL table to save checkpoint messages. + + Args: + schema_name (str): The schema name to store checkpoint tables. + Default: "public". + table_name (str): The PgSQL database table name. + Default: "checkpoints". + + Returns: + None + """ + await self._run_as_async( + self._ainit_checkpoint_table( + table_name, + schema_name, + ) + ) + + def init_checkpoint_table( + self, table_name: str = CHECKPOINTS_TABLE, schema_name: str = "public" + ) -> None: + """Create Cloud SQL tables to store checkpoints. + + Args: + schema_name (str): The schema name to store checkpoint tables. + Default: "public". + table_name (str): The PgSQL database table name. + Default: "checkpoints". + + Returns: + None + """ + self._run_as_sync(self._ainit_checkpoint_table(table_name, schema_name)) + async def _aload_table_schema( self, table_name: str, @@ -765,7 +848,7 @@ async def _aload_table_schema( ) except InvalidRequestError as e: raise ValueError( - f"Table, '{schema_name}'.'{table_name}', does not exist: " + str(e) + f'Table, "{schema_name}"."{table_name}", does not exist: ' + str(e) ) table = Table(table_name, metadata, schema=schema_name) diff --git a/src/langchain_google_cloud_sql_pg/version.py b/src/langchain_google_cloud_sql_pg/version.py index c832cc68..0af59b38 100644 --- a/src/langchain_google_cloud_sql_pg/version.py +++ b/src/langchain_google_cloud_sql_pg/version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "0.12.1" +__version__ = "0.13.0" diff --git a/tests/test_async_checkpoint.py b/tests/test_async_checkpoint.py new file mode 100644 index 00000000..1cad46a5 --- /dev/null +++ b/tests/test_async_checkpoint.py @@ -0,0 +1,430 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import uuid +from typing import Any, List, Literal, Optional, Sequence, Tuple, Union + +import pytest +import pytest_asyncio +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import BaseChatModel, LanguageModelInput +from langchain_core.messages import ( + AIMessage, + AnyMessage, + BaseMessage, + HumanMessage, + SystemMessage, + ToolCall, + ToolMessage, +) +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + Checkpoint, + CheckpointMetadata, + create_checkpoint, + empty_checkpoint, +) +from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer +from langgraph.prebuilt import ( + ToolNode, + ValidationNode, + create_react_agent, + tools_condition, +) +from sqlalchemy import text +from sqlalchemy.engine.row import RowMapping + +from langchain_google_cloud_sql_pg.async_checkpoint import AsyncPostgresSaver +from langchain_google_cloud_sql_pg.engine import PostgresEngine + +write_config: RunnableConfig = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} +read_config: RunnableConfig = {"configurable": {"thread_id": "1"}} +thread_agent_config: RunnableConfig = {"configurable": {"thread_id": "123"}} + +project_id = os.environ["PROJECT_ID"] +region = os.environ["REGION"] +instance_id = os.environ["INSTANCE_ID"] +db_name = os.environ["DATABASE_ID"] +table_name = "checkpoint" + str(uuid.uuid4()) +table_name_writes = table_name + "_writes" + +checkpoint: Checkpoint = { + "v": 1, + "ts": "2024-07-31T20:14:19.804150+00:00", + "id": "1ef4f797-8335-6428-8001-8a1503f9b875", + "channel_values": {"my_key": "meow", "node": "node"}, + "channel_versions": { + "__start__": 2, + "my_key": 3, + "start:node": 3, + "node": 3, + }, + "versions_seen": { + "__input__": {}, + "__start__": {"__start__": 1}, + "node": {"start:node": 2}, + }, + "pending_sends": [], +} + + +class AnyStr(str): + def __init__(self, prefix: Union[str, re.Pattern] = "") -> None: + super().__init__() + self.prefix = prefix + + def __eq__(self, other: object) -> bool: + return isinstance(other, str) and ( + ( + other.startswith(self.prefix) + if isinstance(self.prefix, str) + else bool(self.prefix.match(other)) + ) + ) + + def __hash__(self) -> int: + return hash((str(self), self.prefix)) + + +def _AnyIdToolMessage(**kwargs: Any) -> ToolMessage: + """Create a tool message with an any id field.""" + message = ToolMessage(**kwargs) + message.id = AnyStr() + return message + + +async def aexecute(engine: PostgresEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + +@pytest_asyncio.fixture +async def async_engine(): + async_engine = await PostgresEngine.afrom_instance( + project_id=project_id, + region=region, + instance=instance_id, + database=db_name, + ) + + yield async_engine + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name_writes}"') + await async_engine.close() + await async_engine._connector.close_async() + + +@pytest_asyncio.fixture +async def checkpointer(async_engine): + await async_engine._ainit_checkpoint_table(table_name=table_name) + checkpointer = await AsyncPostgresSaver.create( + async_engine, + table_name, # serde=JsonPlusSerializer + ) + yield checkpointer + + +@pytest.mark.asyncio +async def test_checkpoint_async( + async_engine: PostgresEngine, + checkpointer: AsyncPostgresSaver, +) -> None: + test_config = { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875", + } + } + # Verify if updated configuration after storing the checkpoint is correct + next_config = await checkpointer.aput(write_config, checkpoint, {}, {}) + assert dict(next_config) == test_config + + # Verify if the checkpoint is stored correctly in the database + results = await afetch(async_engine, f'SELECT * FROM "{table_name}"') + assert len(results) == 1 + for row in results: + assert isinstance(row["thread_id"], str) + await aexecute(async_engine, f'TRUNCATE TABLE "{table_name}"') + + +@pytest.fixture +def test_data(): + """Fixture providing test data for checkpoint tests.""" + config_0: RunnableConfig = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} + config_1: RunnableConfig = { + "configurable": { + "thread_id": "thread-1", + # for backwards compatibility testing + "thread_ts": "1", + "checkpoint_ns": "", + } + } + config_2: RunnableConfig = { + "configurable": { + "thread_id": "thread-2", + "checkpoint_id": "2", + "checkpoint_ns": "", + } + } + config_3: RunnableConfig = { + "configurable": { + "thread_id": "thread-2", + "checkpoint_id": "2-inner", + "checkpoint_ns": "inner", + } + } + chkpnt_0: Checkpoint = { + "v": 1, + "ts": "2024-07-31T20:14:19.804150+00:00", + "id": "1ef4f797-8335-6428-8001-8a1503f9b875", + "channel_values": {"my_key": "meow", "node": "node"}, + "channel_versions": { + "__start__": 2, + "my_key": 3, + "start:node": 3, + "node": 3, + }, + "versions_seen": { + "__input__": {}, + "__start__": {"__start__": 1}, + "node": {"start:node": 2}, + }, + "pending_sends": [], + } + chkpnt_1: Checkpoint = empty_checkpoint() + chkpnt_2: Checkpoint = create_checkpoint(chkpnt_1, {}, 1) + chkpnt_3: Checkpoint = empty_checkpoint() + + metadata_1: CheckpointMetadata = { + "source": "input", + "step": 2, + "writes": {}, + "parents": 1, + } + metadata_2: CheckpointMetadata = { + "source": "loop", + "step": 1, + "writes": {"foo": "bar"}, + "parents": None, + } + metadata_3: CheckpointMetadata = {} + + return { + "configs": [config_0, config_1, config_2, config_3], + "checkpoints": [chkpnt_0, chkpnt_1, chkpnt_2, chkpnt_3], + "metadata": [metadata_1, metadata_2, metadata_3], + } + + +@pytest.mark.asyncio +async def test_checkpoint_aput_writes( + async_engine: PostgresEngine, + checkpointer: AsyncPostgresSaver, +) -> None: + config: RunnableConfig = { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875", + } + } + + # Verify if the checkpoint writes are stored correctly in the database + writes: Sequence[Tuple[str, Any]] = [ + ("test_channel1", {}), + ("test_channel2", {}), + ] + await checkpointer.aput_writes(config, writes, task_id="1") + + results = await afetch(async_engine, f'SELECT * FROM "{table_name_writes}"') + assert len(results) == 2 + for row in results: + assert isinstance(row["task_id"], str) + await aexecute(async_engine, f'TRUNCATE TABLE "{table_name_writes}"') + + +@pytest.mark.asyncio +async def test_checkpoint_alist( + async_engine: PostgresEngine, + checkpointer: AsyncPostgresSaver, + test_data: dict[str, Any], +) -> None: + configs = test_data["configs"] + checkpoints = test_data["checkpoints"] + metadata = test_data["metadata"] + + await checkpointer.aput(configs[1], checkpoints[1], metadata[0], {}) + await checkpointer.aput(configs[2], checkpoints[2], metadata[1], {}) + await checkpointer.aput(configs[3], checkpoints[3], metadata[2], {}) + + # call method / assertions + query_1 = {"source": "input"} # search by 1 key + query_2 = { + "step": 1, + "writes": {"foo": "bar"}, + } # search by multiple keys + query_3: dict[str, Any] = {} # search by no keys, return all checkpoints + query_4 = {"source": "update", "step": 1} # no match + + search_results_1 = [c async for c in checkpointer.alist(None, filter=query_1)] + assert len(search_results_1) == 1 + print(metadata[0]) + print(search_results_1[0].metadata) + assert search_results_1[0].metadata == metadata[0] + + search_results_2 = [c async for c in checkpointer.alist(None, filter=query_2)] + assert len(search_results_2) == 1 + assert search_results_2[0].metadata == metadata[1] + + search_results_3 = [c async for c in checkpointer.alist(None, filter=query_3)] + assert len(search_results_3) == 3 + + search_results_4 = [c async for c in checkpointer.alist(None, filter=query_4)] + assert len(search_results_4) == 0 + + # search by config (defaults to checkpoints across all namespaces) + search_results_5 = [ + c async for c in checkpointer.alist({"configurable": {"thread_id": "thread-2"}}) + ] + assert len(search_results_5) == 2 + assert { + search_results_5[0].config["configurable"]["checkpoint_ns"], + search_results_5[1].config["configurable"]["checkpoint_ns"], + } == {"", "inner"} + + +class FakeToolCallingModel(BaseChatModel): + tool_calls: Optional[list[list[ToolCall]]] = None + index: int = 0 + tool_style: Literal["openai", "anthropic"] = "openai" + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + messages_string = "-".join( + [str(m.content) for m in messages if isinstance(m.content, str)] + ) + tool_calls = ( + self.tool_calls[self.index % len(self.tool_calls)] + if self.tool_calls + else [] + ) + message = AIMessage( + content=messages_string, + id=str(self.index), + tool_calls=tool_calls.copy(), + ) + self.index += 1 + return ChatResult(generations=[ChatGeneration(message=message)]) + + @property + def _llm_type(self) -> str: + return "fake-tool-call-model" + + +@pytest.mark.asyncio +async def test_checkpoint_with_agent( + checkpointer: AsyncPostgresSaver, +) -> None: + # from the tests in https://github.com/langchain-ai/langgraph/blob/909190cede6a80bb94a2d4cfe7dedc49ef0d4127/libs/langgraph/tests/test_prebuilt.py + model = FakeToolCallingModel() + + agent = create_react_agent(model, [], checkpointer=checkpointer) + inputs = [HumanMessage("hi?")] + response = await agent.ainvoke( + {"messages": inputs}, config=thread_agent_config, debug=True + ) + expected_response = {"messages": inputs + [AIMessage(content="hi?", id="0")]} + assert response == expected_response + + def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage: + """Create a human message with an any id field.""" + message = HumanMessage(**kwargs) + message.id = AnyStr() + return message + + saved = await checkpointer.aget_tuple(thread_agent_config) + assert saved is not None + assert saved.checkpoint["channel_values"] == { + "messages": [ + _AnyIdHumanMessage(content="hi?"), + AIMessage(content="hi?", id="0"), + ], + "agent": "agent", + } + assert saved.metadata == { + "parents": {}, + "source": "loop", + "writes": {"agent": {"messages": [AIMessage(content="hi?", id="0")]}}, + "step": 1, + "thread_id": "123", + } + assert saved.pending_writes == [] + + +@pytest.mark.asyncio +async def test_checkpoint_aget_tuple( + checkpointer: AsyncPostgresSaver, + test_data: dict[str, Any], +) -> None: + configs = test_data["configs"] + checkpoints = test_data["checkpoints"] + metadata = test_data["metadata"] + + new_config = await checkpointer.aput(configs[1], checkpoints[1], metadata[0], {}) + + # Matching checkpoint + search_results_1 = await checkpointer.aget_tuple(new_config) + assert search_results_1.metadata == metadata[0] # type: ignore + + # No matching checkpoint + assert await checkpointer.aget_tuple(configs[0]) is None + + +@pytest.mark.asyncio +async def test_metadata( + checkpointer: AsyncPostgresSaver, + test_data: dict[str, Any], +) -> None: + config = await checkpointer.aput( + test_data["configs"][0], + test_data["checkpoints"][0], + {"my_key": "abc"}, # type: ignore + {}, + ) + assert (await checkpointer.aget_tuple(config)).metadata["my_key"] == "abc" # type: ignore + assert [c async for c in checkpointer.alist(None, filter={"my_key": "abc"})][ + 0 + ].metadata[ + "my_key" # type: ignore + ] == "abc" # type: ignore diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py new file mode 100644 index 00000000..3a9c5bd3 --- /dev/null +++ b/tests/test_checkpoint.py @@ -0,0 +1,355 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +from typing import Any, Sequence, Tuple + +import pytest +import pytest_asyncio +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + Checkpoint, + CheckpointMetadata, + create_checkpoint, + empty_checkpoint, +) +from sqlalchemy import text +from sqlalchemy.engine.row import RowMapping + +from langchain_google_cloud_sql_pg.checkpoint import PostgresSaver +from langchain_google_cloud_sql_pg.engine import PostgresEngine + +write_config: RunnableConfig = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} +read_config: RunnableConfig = {"configurable": {"thread_id": "1"}} + +project_id = os.environ["PROJECT_ID"] +region = os.environ["REGION"] +instance_id = os.environ["INSTANCE_ID"] +db_name = os.environ["DATABASE_ID"] +table_name = "checkpoint" + str(uuid.uuid4()) +table_name_writes = table_name + "_writes" +table_name_async = "checkpoint" + str(uuid.uuid4()) +table_name_writes_async = table_name_async + "_writes" + +checkpoint: Checkpoint = { + "v": 1, + "ts": "2024-07-31T20:14:19.804150+00:00", + "id": "1ef4f797-8335-6428-8001-8a1503f9b875", + "channel_values": {"my_key": "meow", "node": "node"}, + "channel_versions": { + "__start__": 2, + "my_key": 3, + "start:node": 3, + "node": 3, + }, + "versions_seen": { + "__input__": {}, + "__start__": {"__start__": 1}, + "node": {"start:node": 2}, + }, + "pending_sends": [], +} + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +async def aexecute(engine: PostgresEngine, query: str) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async def run(engine, query): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + return await engine._run_as_async(run(engine, query)) + + +@pytest_asyncio.fixture +async def engine(): + engine = PostgresEngine.from_instance( + project_id=project_id, + region=region, + instance=instance_id, + database=db_name, + ) + yield engine + # use default table + await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name_writes}"') + await engine.close() + await engine._connector.close_async() + + +@pytest_asyncio.fixture +async def async_engine(): + async_engine = await PostgresEngine.afrom_instance( + project_id=project_id, + region=region, + instance=instance_id, + database=db_name, + ) + yield async_engine + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name_async}"') + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name_writes_async}"') + await async_engine.close() + await async_engine._connector.close_async() + + +@pytest_asyncio.fixture +def checkpointer(engine): + engine.init_checkpoint_table(table_name=table_name) + checkpointer = PostgresSaver.create_sync(engine, table_name) + yield checkpointer + + +@pytest_asyncio.fixture +async def async_checkpointer(async_engine): + await async_engine.ainit_checkpoint_table(table_name=table_name_async) + async_checkpointer = await PostgresSaver.create(async_engine, table_name_async) + yield async_checkpointer + + +@pytest.mark.asyncio +async def test_checkpoint_async( + async_engine: PostgresEngine, + async_checkpointer: PostgresSaver, +) -> None: + test_config = { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875", + } + } + # Verify if updated configuration after storing the checkpoint is correct + next_config = await async_checkpointer.aput(write_config, checkpoint, {}, {}) + assert dict(next_config) == test_config + + # Verify if the checkpoint is stored correctly in the database + results = await afetch(async_engine, f'SELECT * FROM "{table_name_async}"') + assert len(results) == 1 + for row in results: + assert isinstance(row["thread_id"], str) + await aexecute(async_engine, f'TRUNCATE TABLE "{table_name_async}"') + + +# Test put method for checkpoint +@pytest.mark.asyncio +async def test_checkpoint_sync( + engine: PostgresEngine, + checkpointer: PostgresSaver, +) -> None: + test_config = { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875", + } + } + # Verify if updated configuration after storing the checkpoint is correct + next_config = checkpointer.put(write_config, checkpoint, {}, {}) + assert dict(next_config) == test_config + + # Verify if the checkpoint is stored correctly in the database + results = await afetch(engine, f'SELECT * FROM "{table_name}"') + assert len(results) == 1 + for row in results: + assert isinstance(row["thread_id"], str) + await aexecute(engine, f'TRUNCATE TABLE "{table_name}"') + + +@pytest.mark.asyncio +async def test_chat_table_async(async_engine): + with pytest.raises(ValueError): + await PostgresSaver.create(engine=async_engine, table_name="doesnotexist") + + +def test_checkpoint_table(engine: Any) -> None: + with pytest.raises(ValueError): + PostgresSaver.create_sync(engine=engine, table_name="doesnotexist") + + +@pytest.fixture +def test_data(): + """Fixture providing test data for checkpoint tests.""" + config_0: RunnableConfig = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} + config_1: RunnableConfig = { + "configurable": { + "thread_id": "thread-1", + # for backwards compatibility testing + "thread_ts": "1", + "checkpoint_ns": "", + } + } + config_2: RunnableConfig = { + "configurable": { + "thread_id": "thread-2", + "checkpoint_id": "2", + "checkpoint_ns": "", + } + } + config_3: RunnableConfig = { + "configurable": { + "thread_id": "thread-2", + "checkpoint_id": "2-inner", + "checkpoint_ns": "inner", + } + } + chkpnt_0: Checkpoint = { + "v": 1, + "ts": "2024-07-31T20:14:19.804150+00:00", + "id": "1ef4f797-8335-6428-8001-8a1503f9b875", + "channel_values": {"my_key": "meow", "node": "node"}, + "channel_versions": { + "__start__": 2, + "my_key": 3, + "start:node": 3, + "node": 3, + }, + "versions_seen": { + "__input__": {}, + "__start__": {"__start__": 1}, + "node": {"start:node": 2}, + }, + "pending_sends": [], + } + chkpnt_1: Checkpoint = empty_checkpoint() + chkpnt_2: Checkpoint = create_checkpoint(chkpnt_1, {}, 1) + chkpnt_3: Checkpoint = empty_checkpoint() + + metadata_1: CheckpointMetadata = { + "source": "input", + "step": 2, + "writes": {}, + "parents": 1, + } + metadata_2: CheckpointMetadata = { + "source": "loop", + "step": 1, + "writes": {"foo": "bar"}, + "parents": None, + } + metadata_3: CheckpointMetadata = {} + + return { + "configs": [config_0, config_1, config_2, config_3], + "checkpoints": [chkpnt_0, chkpnt_1, chkpnt_2, chkpnt_3], + "metadata": [metadata_1, metadata_2, metadata_3], + } + + +@pytest.mark.asyncio +async def test_checkpoint_put_writes( + engine: PostgresEngine, + checkpointer: PostgresSaver, +) -> None: + config: RunnableConfig = { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875", + } + } + + # Verify if the checkpoint writes are stored correctly in the database + writes: Sequence[Tuple[str, Any]] = [ + ("test_channel1", {}), + ("test_channel2", {}), + ] + checkpointer.put_writes(config, writes, task_id="1") + + results = await afetch(engine, f'SELECT * FROM "{table_name_writes}"') + assert len(results) == 2 + for row in results: + assert isinstance(row["task_id"], str) + await aexecute(engine, f'TRUNCATE TABLE "{table_name_writes}"') + + +def test_checkpoint_list( + checkpointer: PostgresSaver, + test_data: dict[str, Any], +) -> None: + configs = test_data["configs"] + checkpoints = test_data["checkpoints"] + metadata = test_data["metadata"] + + checkpointer.put(configs[1], checkpoints[1], metadata[0], {}) + checkpointer.put(configs[2], checkpoints[2], metadata[1], {}) + checkpointer.put(configs[3], checkpoints[3], metadata[2], {}) + + # call method / assertions + query_1 = {"source": "input"} # search by 1 key + query_2 = { + "step": 1, + "writes": {"foo": "bar"}, + } # search by multiple keys + query_3: dict[str, Any] = {} # search by no keys, return all checkpoints + query_4 = {"source": "update", "step": 1} # no match + + search_results_1 = list(checkpointer.list(None, filter=query_1)) + assert len(search_results_1) == 1 + assert search_results_1[0].metadata == metadata[0] + search_results_2 = list(checkpointer.list(None, filter=query_2)) + assert len(search_results_2) == 1 + assert search_results_2[0].metadata == metadata[1] + + search_results_3 = list(checkpointer.list(None, filter=query_3)) + assert len(search_results_3) == 3 + + search_results_4 = list(checkpointer.list(None, filter=query_4)) + assert len(search_results_4) == 0 + + # search by config (defaults to checkpoints across all namespaces) + search_results_5 = list( + checkpointer.list({"configurable": {"thread_id": "thread-2"}}) + ) + assert len(search_results_5) == 2 + assert { + search_results_5[0].config["configurable"]["checkpoint_ns"], + search_results_5[1].config["configurable"]["checkpoint_ns"], + } == {"", "inner"} + + +def test_checkpoint_get_tuple( + checkpointer: PostgresSaver, + test_data: dict[str, Any], +) -> None: + configs = test_data["configs"] + checkpoints = test_data["checkpoints"] + metadata = test_data["metadata"] + + new_config = checkpointer.put(configs[1], checkpoints[1], metadata[0], {}) + + # Matching checkpoint + search_results_1 = checkpointer.get_tuple(new_config) + assert search_results_1.metadata == metadata[0] # type: ignore + + # No matching checkpoint + assert checkpointer.get_tuple(configs[0]) is None diff --git a/tests/test_engine.py b/tests/test_engine.py index 1c2653bf..7883cf4b 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -130,7 +130,9 @@ async def test_init_table(self, engine): id = str(uuid.uuid4()) content = "coffee" embedding = await embeddings_service.aembed_query(content) - stmt = f"INSERT INTO {DEFAULT_TABLE} (langchain_id, content, embedding) VALUES ('{id}', '{content}','{embedding}');" + # Note: DeterministicFakeEmbedding generates a numpy array, converting to list a list of float values + embedding_string = [float(dimension) for dimension in embedding] + stmt = f"INSERT INTO {DEFAULT_TABLE} (langchain_id, content, embedding) VALUES ('{id}', '{content}','{embedding_string}');" await aexecute(engine, stmt) async def test_init_table_custom(self, engine): @@ -200,6 +202,7 @@ async def test_password( assert engine await aexecute(engine, "SELECT 1") PostgresEngine._connector = None + await engine.close() async def test_from_engine( self, @@ -300,6 +303,41 @@ async def test_iam_account_override( await aexecute(engine, "SELECT 1") await engine.close() + async def test_ainit_checkpoint_writes_table(self, engine): + table_name = f"checkpoint{uuid.uuid4()}" + table_name_writes = f"{table_name}_writes" + await engine.ainit_checkpoint_table(table_name=table_name) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name_writes}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "thread_id", "data_type": "text"}, + {"column_name": "checkpoint_ns", "data_type": "text"}, + {"column_name": "checkpoint_id", "data_type": "text"}, + {"column_name": "task_id", "data_type": "text"}, + {"column_name": "idx", "data_type": "integer"}, + {"column_name": "channel", "data_type": "text"}, + {"column_name": "type", "data_type": "text"}, + {"column_name": "blob", "data_type": "bytea"}, + {"column_name": "task_path", "data_type": "text"}, + ] + for row in results: + assert row in expected + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "thread_id", "data_type": "text"}, + {"column_name": "checkpoint_ns", "data_type": "text"}, + {"column_name": "checkpoint_id", "data_type": "text"}, + {"column_name": "parent_checkpoint_id", "data_type": "text"}, + {"column_name": "checkpoint", "data_type": "bytea"}, + {"column_name": "metadata", "data_type": "bytea"}, + {"column_name": "type", "data_type": "text"}, + ] + for row in results: + assert row in expected + await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name_writes}"') + @pytest.mark.asyncio(scope="module") class TestEngineSync: @@ -350,7 +388,9 @@ async def test_init_table(self, engine): id = str(uuid.uuid4()) content = "coffee" embedding = await embeddings_service.aembed_query(content) - stmt = f"INSERT INTO {DEFAULT_TABLE_SYNC} (langchain_id, content, embedding) VALUES ('{id}', '{content}','{embedding}');" + # Note: DeterministicFakeEmbedding generates a numpy array, converting to list a list of float values + embedding_string = [float(dimension) for dimension in embedding] + stmt = f"INSERT INTO {DEFAULT_TABLE_SYNC} (langchain_id, content, embedding) VALUES ('{id}', '{content}','{embedding_string}');" await aexecute(engine, stmt) async def test_init_table_custom(self, engine): @@ -421,6 +461,7 @@ async def test_password( assert engine await aexecute(engine, "SELECT 1") PostgresEngine._connector = None + await engine.close() async def test_engine_constructor_key( self, @@ -449,3 +490,38 @@ async def test_iam_account_override( assert engine await aexecute(engine, "SELECT 1") await engine.close() + + async def test_init_checkpoints_table(self, engine): + table_name = f"checkpoint{uuid.uuid4()}" + table_name_writes = f"{table_name}_writes" + engine.init_checkpoint_table(table_name=table_name) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "thread_id", "data_type": "text"}, + {"column_name": "checkpoint_ns", "data_type": "text"}, + {"column_name": "checkpoint_id", "data_type": "text"}, + {"column_name": "parent_checkpoint_id", "data_type": "text"}, + {"column_name": "type", "data_type": "text"}, + {"column_name": "checkpoint", "data_type": "bytea"}, + {"column_name": "metadata", "data_type": "bytea"}, + ] + for row in results: + assert row in expected + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name_writes}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "thread_id", "data_type": "text"}, + {"column_name": "checkpoint_ns", "data_type": "text"}, + {"column_name": "checkpoint_id", "data_type": "text"}, + {"column_name": "task_id", "data_type": "text"}, + {"column_name": "idx", "data_type": "integer"}, + {"column_name": "channel", "data_type": "text"}, + {"column_name": "type", "data_type": "text"}, + {"column_name": "blob", "data_type": "bytea"}, + {"column_name": "task_path", "data_type": "text"}, + ] + for row in results: + assert row in expected + await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name_writes}"')